|
return splitk_sparse_gemv(hidden_states, weights, threshold, sparsity_bin) if hidden_states.shape[1] == 1 else torch.matmul(hidden_states, weights.T) |
Hi, I notice that the SparseGEMV kernel only manage the case when batch_size=1 & seqlen=1. Beyond that case, the kernel outputs wrong answer.
Is it expected that this kernel only work for decoding stage? Then where is the implementation about Appendix A4?