def context_attention_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
b_start_loc: torch.Tensor,
b_seq_len: torch.Tensor,
max_input_len: int,
is_causal: bool = True,
softmax_scale: float | None = None,
sliding_window_q: int | None = None,
sliding_window_k: int | None = None,
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
BLOCK = get_block_size(q.dtype)
Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
sm_scale = 1.0 / (Lq**0.5) if softmax_scale is None else softmax_scale
# rescale with 1/ln(2) for triton exp2
sm_scale *= RCP_LN2
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
sliding_window_q = sliding_window_q if sliding_window_q is not None else 0
sliding_window_k = sliding_window_k if sliding_window_k is not None else 0
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
kv_group_num=kv_group_num,
BLOCK_M=BLOCK,
BLOCK_DMODEL=triton.next_power_of_2(Lk),
BLOCK_N=BLOCK,
IS_CAUSAL=is_causal,
SLIDING_WINDOW_Q=sliding_window_q,
SLIDING_WINDOW_K=sliding_window_k,
num_warps=num_warps,
num_stages=1,
Lk=Lk,
)