vllm.v1.attention.backends.mla.indexer ¶
kv_spans_from_batches ¶
kv_spans_from_batches(
start_seq_loc: Tensor,
seq_len_per_batch: Tensor,
device: device,
) -> tuple[Tensor, Tensor]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_seq_loc | Tensor | 1D long tensor [B+1], cumulative counts of selected tokens per batch. Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. | required |
seq_len_per_batch | Tensor | 1D long tensor [B], full sequence length (KV length) of each batch. Example: [5, 9, 4]. | required |
Returns:
| Name | Type | Description |
|---|---|---|
start_tensor | Tensor | 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. |
end_location | Tensor | 1D long tensor [N], exclusive end = start + token's local position. (So the attended KV slice is kv[start:end].) |
Assumes each batch contributes its full seq_len_per_batch[i] keys to the KV cache, andthe selected tokens within a batch are the last counts[i] positions of that sequence.