Skip to content

vllm.model_executor.layers.fla.ops.kda

chunk_kda_scaled_dot_kkt_fwd

chunk_kda_scaled_dot_kkt_fwd(
    q: Tensor,
    k: Tensor,
    gk: Tensor | None = None,
    beta: Tensor | None = None,
    scale: float | None = None,
    cu_seqlens: LongTensor | None = None,
    chunk_size: int = 64,
    output_dtype: dtype = float32,
) -> tuple[Tensor, Tensor]

Compute beta * K * K^T.

Parameters:

Name Type Description Default
k Tensor

The key tensor of shape [B, T, H, K].

required
beta Tensor

The beta tensor of shape [B, T, H].

None
gk Tensor

The cumulative sum of the gate tensor of shape [B, T, H, K] applied to the key tensor. Default: None.

None
cu_seqlens LongTensor

The cumulative sequence lengths of the input tensor. Default: None

None
chunk_size int

The chunk size. Default: 64.

64
output_dtype dtype

The dtype of the output tensor. Default: torch.float32

float32

Returns:

Type Description
tuple[Tensor, Tensor]

beta * K * K^T of shape [B, T, H, BT] where BT is the chunk size.

Source code in vllm/model_executor/layers/fla/ops/kda.py
def chunk_kda_scaled_dot_kkt_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    gk: torch.Tensor | None = None,
    beta: torch.Tensor | None = None,
    scale: float | None = None,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_size: int = 64,
    output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
    r"""
    Compute beta * K * K^T.

    Args:
        k (torch.Tensor):
            The key tensor of shape `[B, T, H, K]`.
        beta (torch.Tensor):
            The beta tensor of shape `[B, T, H]`.
        gk (torch.Tensor):
            The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
        cu_seqlens (torch.LongTensor):
            The cumulative sequence lengths of the input tensor.
            Default: None
        chunk_size (int):
            The chunk size. Default: 64.
        output_dtype (torch.dtype):
            The dtype of the output tensor. Default: `torch.float32`

    Returns:
        beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
    """
    B, T, H, K = k.shape
    assert K <= 256
    BT = chunk_size
    chunk_indices = (
        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
    )
    NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

    BC = min(16, BT)
    NC = cdiv(BT, BC)
    BK = max(next_power_of_2(K), 16)
    A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
    Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
    grid = (NT, NC * NC, B * H)
    chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
        q=q,
        k=k,
        g=gk,
        beta=beta,
        A=A,
        Aqk=Aqk,
        scale=scale,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        T=T,
        H=H,
        K=K,
        BT=BT,
        BC=BC,
        NC=NC,
    )

    grid = (NT, NC, B * H)
    chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
        q=q,
        k=k,
        g=gk,
        beta=beta,
        A=A,
        Aqk=Aqk,
        scale=scale,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        T=T,
        H=H,
        K=K,
        BT=BT,
        BC=BC,
        BK=BK,
    )
    return A, Aqk

fused_kda_gate

fused_kda_gate(
    g: Tensor,
    A: Tensor,
    head_k_dim: int,
    g_bias: Tensor | None = None,
    beta: float = 1.0,
    threshold: float = 20.0,
) -> Tensor
Forward pass for KDA gate

input g: [..., H*D] param A: [H] or [1, 1, H, 1] beta: softplus beta parameter threshold: softplus threshold parameter return : [..., H, D]

Source code in vllm/model_executor/layers/fla/ops/kda.py
def fused_kda_gate(
    g: torch.Tensor,
    A: torch.Tensor,
    head_k_dim: int,
    g_bias: torch.Tensor | None = None,
    beta: float = 1.0,
    threshold: float = 20.0,
) -> torch.Tensor:
    """
    Forward pass for KDA gate:
      input g: [..., H*D]
      param A: [H] or [1, 1, H, 1]
      beta: softplus beta parameter
      threshold: softplus threshold parameter
      return  : [..., H, D]
    """
    orig_shape = g.shape[:-1]

    g = g.view(-1, g.shape[-1])
    T = g.shape[0]
    HD = g.shape[1]
    H = A.numel()
    assert H * head_k_dim == HD

    y = torch.empty_like(g, dtype=torch.float32)

    def grid(meta):
        return (cdiv(T, meta["BT"]), H)

    kda_gate_fwd_kernel[grid](
        g,
        A,
        y,
        g_bias,
        beta,
        threshold,
        T,
        H,
        head_k_dim,
        BD=next_power_of_2(head_k_dim),
        HAS_BIAS=g_bias is not None,
    )

    y = y.view(*orig_shape, H, head_k_dim)
    return y