Skip to content

vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse

logger module-attribute

logger = init_logger(__name__)

ROCMAiterMLASparseBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
class ROCMAiterMLASparseBackend(AttentionBackend):
    accept_output_buffer: bool = True

    @staticmethod
    def get_name() -> str:
        return "ROCM_AITER_MLA_SPARSE"

    @staticmethod
    def get_metadata_cls() -> type["ROCMAiterMLASparseMetadata"]:
        return ROCMAiterMLASparseMetadata

    @staticmethod
    def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
        return ROCMAiterMLASparseMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
        return ROCMAiterMLASparseImpl

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        return (num_blocks, block_size, head_size)

    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [576]

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[
    ROCMAiterMLASparseMetadataBuilder
]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
    return ROCMAiterMLASparseMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[ROCMAiterMLASparseImpl]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
    return ROCMAiterMLASparseImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,  # assumed to be 1 for MLA
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    return (num_blocks, block_size, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[ROCMAiterMLASparseMetadata]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_metadata_cls() -> type["ROCMAiterMLASparseMetadata"]:
    return ROCMAiterMLASparseMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@staticmethod
def get_name() -> str:
    return "ROCM_AITER_MLA_SPARSE"

get_supported_dtypes classmethod

get_supported_dtypes() -> list[dtype]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16]

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    return [576]

ROCMAiterMLASparseImpl

Bases: MLACommonBaseImpl[ROCMAiterMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: str,
        kv_sharing_target_layer_name: str | None,
        # MLA Specific Arguments
        topk_indice_buffer: torch.Tensor | None = None,
        indexer: Optional["Indexer"] = None,
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )
        self.softmax_scale = scale
        assert indexer is not None
        self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
        self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()

    def _forward_bf16_kv(
        self,
        q: torch.Tensor,  # [sq, heads, d_qk]
        kv_c_and_k_pe_cache: torch.Tensor,  # [blocks, heads, d_qk]
        topk_indices: torch.Tensor,  # [sq, topk]
        attn_metadata: ROCMAiterMLASparseMetadata,
    ) -> torch.Tensor:
        num_tokens = q.shape[0]
        output = torch.empty(
            [num_tokens, self.num_heads, self.kv_lora_rank],
            dtype=q.dtype,
            device=q.device,
        )
        seq_len = (topk_indices != -1).sum(dim=-1)
        torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:])
        attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1])
        fetch_id_to_ragged_triton(
            topk_indices,
            attn_metadata.paged_kv_indptr,
            attn_metadata.paged_kv_indices,
            attn_metadata.topk_tokens,
        )

        rocm_aiter_ops.mla_decode_fwd(
            q,
            kv_c_and_k_pe_cache,
            output,
            self.scale,
            attn_metadata.qo_indptr,
            1,
            attn_metadata.paged_kv_indptr,
            attn_metadata.paged_kv_indices,
            attn_metadata.paged_kv_last_page_len,
        )

        return output[:, : self.num_heads, :]

    def forward(
        self,
        layer: AttentionLayer,
        q: torch.Tensor,
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
        attn_metadata: ROCMAiterMLASparseMetadata,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
        # MQA 576/512 approach for both prefill and decode

        assert output is not None, "Output tensor must be provided."

        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported for ROCMAiterMLASparse"
            )

        if attn_metadata is None:
            # The zero fill is required when used with DP + EP
            # to ensure all ranks within a DP group compute the
            # same expert outputs.
            return output.fill_(0)

        num_actual_toks = attn_metadata.num_actual_tokens

        # Inputs and outputs may be padded for CUDA graphs

        q = q[:num_actual_toks, ...]
        k_c_normed = k_c_normed[:num_actual_toks, ...]
        k_pe = k_pe[:num_actual_toks, ...]

        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # Convert from (B, N, P) to (N, B, P)
        q_nope = q_nope.transpose(0, 1)
        if self.is_fp8bmm_enabled:
            # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
            ql_nope = rocm_aiter_ops.triton_fp8_bmm(
                q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
            )
        else:
            # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
            ql_nope = torch.bmm(q_nope, self.W_UK_T)
            # Convert from (N, B, L) to (B, N, L)
            ql_nope = ql_nope.transpose(0, 1)

        assert self.topk_indices_buffer is not None
        topk_indices = self.topk_indices_buffer[:num_actual_toks]

        topk_indices_global = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
        )

        q = torch.cat([ql_nope, q_pe], dim=-1)

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                k_c_normed,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=self.kv_cache_dtype,
                scale=layer._k_scale,
            )

        attn_out = self._forward_bf16_kv(
            q, kv_cache, topk_indices_global, attn_metadata
        )

        self._v_up_proj(attn_out, out=output[:num_actual_toks])
        return output

is_fp8bmm_enabled instance-attribute

is_fp8bmm_enabled = is_fp8bmm_enabled()

softmax_scale instance-attribute

softmax_scale = scale

topk_indices_buffer instance-attribute

topk_indices_buffer: Tensor | None = topk_indices_buffer

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    topk_indice_buffer: Tensor | None = None,
    indexer: Optional[Indexer] = None,
    **mla_args,
) -> None
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    # MLA Specific Arguments
    topk_indice_buffer: torch.Tensor | None = None,
    indexer: Optional["Indexer"] = None,
    **mla_args,
) -> None:
    super().__init__(
        num_heads,
        head_size,
        scale,
        num_kv_heads,
        alibi_slopes,
        sliding_window,
        kv_cache_dtype,
        logits_soft_cap,
        attn_type,
        kv_sharing_target_layer_name,
        **mla_args,
    )
    self.softmax_scale = scale
    assert indexer is not None
    self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
    self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()

_forward_bf16_kv

_forward_bf16_kv(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
    attn_metadata: ROCMAiterMLASparseMetadata,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def _forward_bf16_kv(
    self,
    q: torch.Tensor,  # [sq, heads, d_qk]
    kv_c_and_k_pe_cache: torch.Tensor,  # [blocks, heads, d_qk]
    topk_indices: torch.Tensor,  # [sq, topk]
    attn_metadata: ROCMAiterMLASparseMetadata,
) -> torch.Tensor:
    num_tokens = q.shape[0]
    output = torch.empty(
        [num_tokens, self.num_heads, self.kv_lora_rank],
        dtype=q.dtype,
        device=q.device,
    )
    seq_len = (topk_indices != -1).sum(dim=-1)
    torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:])
    attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1])
    fetch_id_to_ragged_triton(
        topk_indices,
        attn_metadata.paged_kv_indptr,
        attn_metadata.paged_kv_indices,
        attn_metadata.topk_tokens,
    )

    rocm_aiter_ops.mla_decode_fwd(
        q,
        kv_c_and_k_pe_cache,
        output,
        self.scale,
        attn_metadata.qo_indptr,
        1,
        attn_metadata.paged_kv_indptr,
        attn_metadata.paged_kv_indices,
        attn_metadata.paged_kv_last_page_len,
    )

    return output[:, : self.num_heads, :]

forward

forward(
    layer: AttentionLayer,
    q: Tensor,
    k_c_normed: Tensor,
    k_pe: Tensor,
    kv_cache: Tensor,
    attn_metadata: ROCMAiterMLASparseMetadata,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def forward(
    self,
    layer: AttentionLayer,
    q: torch.Tensor,
    k_c_normed: torch.Tensor,  # key in unified attn
    k_pe: torch.Tensor,  # value in unified attn
    kv_cache: torch.Tensor,
    attn_metadata: ROCMAiterMLASparseMetadata,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
    # MQA 576/512 approach for both prefill and decode

    assert output is not None, "Output tensor must be provided."

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported for ROCMAiterMLASparse"
        )

    if attn_metadata is None:
        # The zero fill is required when used with DP + EP
        # to ensure all ranks within a DP group compute the
        # same expert outputs.
        return output.fill_(0)

    num_actual_toks = attn_metadata.num_actual_tokens

    # Inputs and outputs may be padded for CUDA graphs

    q = q[:num_actual_toks, ...]
    k_c_normed = k_c_normed[:num_actual_toks, ...]
    k_pe = k_pe[:num_actual_toks, ...]

    q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
    # Convert from (B, N, P) to (N, B, P)
    q_nope = q_nope.transpose(0, 1)
    if self.is_fp8bmm_enabled:
        # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
        ql_nope = rocm_aiter_ops.triton_fp8_bmm(
            q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
        )
    else:
        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
        ql_nope = torch.bmm(q_nope, self.W_UK_T)
        # Convert from (N, B, L) to (B, N, L)
        ql_nope = ql_nope.transpose(0, 1)

    assert self.topk_indices_buffer is not None
    topk_indices = self.topk_indices_buffer[:num_actual_toks]

    topk_indices_global = triton_convert_req_index_to_global_index(
        attn_metadata.req_id_per_token,
        attn_metadata.block_table,
        topk_indices,
        BLOCK_SIZE=attn_metadata.block_size,
        NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
    )

    q = torch.cat([ql_nope, q_pe], dim=-1)

    # write the latent and rope to kv cache
    if kv_cache.numel() > 0:
        ops.concat_and_cache_mla(
            k_c_normed,
            k_pe.squeeze(1),
            kv_cache,
            attn_metadata.slot_mapping.flatten(),
            kv_cache_dtype=self.kv_cache_dtype,
            scale=layer._k_scale,
        )

    attn_out = self._forward_bf16_kv(
        q, kv_cache, topk_indices_global, attn_metadata
    )

    self._v_up_proj(attn_out, out=output[:num_actual_toks])
    return output

ROCMAiterMLASparseMetadata dataclass

Bases: AttentionMetadata

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@dataclass
class ROCMAiterMLASparseMetadata(AttentionMetadata):
    num_reqs: int
    max_query_len: int
    max_seq_len: int

    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

    block_table: torch.Tensor
    req_id_per_token: torch.Tensor

    qo_indptr: torch.Tensor
    paged_kv_last_page_len: torch.Tensor
    paged_kv_indices: torch.Tensor
    paged_kv_indptr: torch.Tensor
    paged_kv_indptr_rest: torch.Tensor

    block_size: int = 1
    topk_tokens: int = 2048

block_size class-attribute instance-attribute

block_size: int = 1

block_table instance-attribute

block_table: Tensor

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_reqs instance-attribute

num_reqs: int

paged_kv_indices instance-attribute

paged_kv_indices: Tensor

paged_kv_indptr instance-attribute

paged_kv_indptr: Tensor

paged_kv_indptr_rest instance-attribute

paged_kv_indptr_rest: Tensor

paged_kv_last_page_len instance-attribute

paged_kv_last_page_len: Tensor

qo_indptr instance-attribute

qo_indptr: Tensor

query_start_loc instance-attribute

query_start_loc: Tensor

req_id_per_token instance-attribute

req_id_per_token: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

topk_tokens class-attribute instance-attribute

topk_tokens: int = 2048

__init__

__init__(
    num_reqs: int,
    max_query_len: int,
    max_seq_len: int,
    num_actual_tokens: int,
    query_start_loc: Tensor,
    slot_mapping: Tensor,
    block_table: Tensor,
    req_id_per_token: Tensor,
    qo_indptr: Tensor,
    paged_kv_last_page_len: Tensor,
    paged_kv_indices: Tensor,
    paged_kv_indptr: Tensor,
    paged_kv_indptr_rest: Tensor,
    block_size: int = 1,
    topk_tokens: int = 2048,
) -> None

ROCMAiterMLASparseMetadataBuilder dataclass

Bases: AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@dataclass
class ROCMAiterMLASparseMetadataBuilder(
    AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
):
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.kv_cache_spec = kv_cache_spec
        self.model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
        self.device = device
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens

        self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
        self.mla_dims = get_mla_dims(self.model_config)
        self.topk_tokens = vllm_config.model_config.hf_config.index_topk
        self.topk_tokens_tensor = torch.tensor(
            [self.topk_tokens], device=device, dtype=torch.int32
        )
        self.max_model_len_tensor = torch.tensor(
            [self.model_config.max_model_len], device=device, dtype=torch.int32
        )
        # this is ignored by `flash_mla_with_kvcache` if indices not None
        self.dummy_block_table = torch.empty(
            (1, 1), dtype=torch.int32, device=self.device
        )

        self.req_id_per_token_buffer = torch.empty(
            (vllm_config.scheduler_config.max_num_batched_tokens,),
            dtype=torch.int32,
            device=device,
        )
        self.qo_indptr = torch.arange(
            0, max_num_batched_tokens + 1, dtype=torch.int32, device=device
        )
        self.paged_kv_last_page_len = torch.ones(
            max_num_batched_tokens, dtype=torch.int32, device=device
        )

        # These two needs to be calculated in runtime,
        # but we still needs to prepare the buffer
        self.paged_kv_indices = torch.zeros(
            [max_num_batched_tokens * self.topk_tokens],
            dtype=torch.int32,
            device=device,
        )
        self.paged_kv_indptr = torch.zeros(
            [max_num_batched_tokens + 1], dtype=torch.int32, device=device
        )

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> ROCMAiterMLASparseMetadata:
        num_tokens = common_attn_metadata.num_actual_tokens
        starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
        seg_lengths = np.diff(starts)
        req_id_per_token = np.repeat(
            np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
        )
        # Zero-fill for cudagraphs
        self.req_id_per_token_buffer.fill_(0)
        self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
            torch.from_numpy(req_id_per_token), non_blocking=True
        )
        self.paged_kv_indices.fill_(0)
        self.paged_kv_indptr.fill_(0)

        req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
        qo_indptr = self.qo_indptr[: num_tokens + 1]
        paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens]
        paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens]
        paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1]
        paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :]

        metadata = ROCMAiterMLASparseMetadata(
            num_reqs=common_attn_metadata.num_reqs,
            max_query_len=common_attn_metadata.max_query_len,
            max_seq_len=common_attn_metadata.max_seq_len,
            num_actual_tokens=common_attn_metadata.num_actual_tokens,
            query_start_loc=common_attn_metadata.query_start_loc,
            slot_mapping=common_attn_metadata.slot_mapping,
            block_table=common_attn_metadata.block_table_tensor,
            req_id_per_token=req_id_per_token,
            block_size=self.kv_cache_spec.block_size,
            topk_tokens=self.topk_tokens,
            qo_indptr=qo_indptr,
            paged_kv_last_page_len=paged_kv_last_page_len,
            paged_kv_indices=paged_kv_indices,
            paged_kv_indptr=paged_kv_indptr,
            paged_kv_indptr_rest=paged_kv_indptr_rest,
        )
        return metadata

_cudagraph_support class-attribute

_cudagraph_support: AttentionCGSupport = NEVER

device instance-attribute

device = device

dummy_block_table instance-attribute

dummy_block_table = empty(
    (1, 1), dtype=int32, device=device
)

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

max_model_len_tensor instance-attribute

max_model_len_tensor = tensor(
    [max_model_len], device=device, dtype=int32
)

mla_dims instance-attribute

mla_dims = get_mla_dims(model_config)

model_config instance-attribute

model_config = model_config

num_heads instance-attribute

num_heads = get_num_attention_heads(parallel_config)

paged_kv_indices instance-attribute

paged_kv_indices = zeros(
    [max_num_batched_tokens * topk_tokens],
    dtype=int32,
    device=device,
)

paged_kv_indptr instance-attribute

paged_kv_indptr = zeros(
    [max_num_batched_tokens + 1], dtype=int32, device=device
)

paged_kv_last_page_len instance-attribute

paged_kv_last_page_len = ones(
    max_num_batched_tokens, dtype=int32, device=device
)

qo_indptr instance-attribute

qo_indptr = arange(
    0,
    max_num_batched_tokens + 1,
    dtype=int32,
    device=device,
)

req_id_per_token_buffer instance-attribute

req_id_per_token_buffer = empty(
    (max_num_batched_tokens,), dtype=int32, device=device
)

topk_tokens instance-attribute

topk_tokens = index_topk

topk_tokens_tensor instance-attribute

topk_tokens_tensor = tensor(
    [topk_tokens], device=device, dtype=int32
)

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    self.kv_cache_spec = kv_cache_spec
    self.model_config = vllm_config.model_config
    parallel_config = vllm_config.parallel_config
    self.device = device
    max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens

    self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
    self.mla_dims = get_mla_dims(self.model_config)
    self.topk_tokens = vllm_config.model_config.hf_config.index_topk
    self.topk_tokens_tensor = torch.tensor(
        [self.topk_tokens], device=device, dtype=torch.int32
    )
    self.max_model_len_tensor = torch.tensor(
        [self.model_config.max_model_len], device=device, dtype=torch.int32
    )
    # this is ignored by `flash_mla_with_kvcache` if indices not None
    self.dummy_block_table = torch.empty(
        (1, 1), dtype=torch.int32, device=self.device
    )

    self.req_id_per_token_buffer = torch.empty(
        (vllm_config.scheduler_config.max_num_batched_tokens,),
        dtype=torch.int32,
        device=device,
    )
    self.qo_indptr = torch.arange(
        0, max_num_batched_tokens + 1, dtype=torch.int32, device=device
    )
    self.paged_kv_last_page_len = torch.ones(
        max_num_batched_tokens, dtype=torch.int32, device=device
    )

    # These two needs to be calculated in runtime,
    # but we still needs to prepare the buffer
    self.paged_kv_indices = torch.zeros(
        [max_num_batched_tokens * self.topk_tokens],
        dtype=torch.int32,
        device=device,
    )
    self.paged_kv_indptr = torch.zeros(
        [max_num_batched_tokens + 1], dtype=torch.int32, device=device
    )

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> ROCMAiterMLASparseMetadata
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> ROCMAiterMLASparseMetadata:
    num_tokens = common_attn_metadata.num_actual_tokens
    starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
    seg_lengths = np.diff(starts)
    req_id_per_token = np.repeat(
        np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
    )
    # Zero-fill for cudagraphs
    self.req_id_per_token_buffer.fill_(0)
    self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
        torch.from_numpy(req_id_per_token), non_blocking=True
    )
    self.paged_kv_indices.fill_(0)
    self.paged_kv_indptr.fill_(0)

    req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
    qo_indptr = self.qo_indptr[: num_tokens + 1]
    paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens]
    paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens]
    paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1]
    paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :]

    metadata = ROCMAiterMLASparseMetadata(
        num_reqs=common_attn_metadata.num_reqs,
        max_query_len=common_attn_metadata.max_query_len,
        max_seq_len=common_attn_metadata.max_seq_len,
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        query_start_loc=common_attn_metadata.query_start_loc,
        slot_mapping=common_attn_metadata.slot_mapping,
        block_table=common_attn_metadata.block_table_tensor,
        req_id_per_token=req_id_per_token,
        block_size=self.kv_cache_spec.block_size,
        topk_tokens=self.topk_tokens,
        qo_indptr=qo_indptr,
        paged_kv_last_page_len=paged_kv_last_page_len,
        paged_kv_indices=paged_kv_indices,
        paged_kv_indptr=paged_kv_indptr,
        paged_kv_indptr_rest=paged_kv_indptr_rest,
    )
    return metadata

fetch_id_to_ragged_kernel

fetch_id_to_ragged_kernel(
    in_tensor_ptr,
    cumsum_ptr,
    out_tensor_ptr,
    in_tensor_ptr_stride,
    TOPK: constexpr,
    TOKEN_NUM: constexpr,
    BLOCK_SIZE: constexpr,
)
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@triton.jit
def fetch_id_to_ragged_kernel(
    in_tensor_ptr,  # [num_seq, topk]
    cumsum_ptr,  # [num_seq + 1]
    out_tensor_ptr,  # [max_num_seq * topk]
    in_tensor_ptr_stride,
    TOPK: tl.constexpr,
    TOKEN_NUM: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    seq_id = tl.program_id(0)
    block_id = tl.program_id(1)
    offset = tl.arange(0, BLOCK_SIZE)
    token_start = tl.load(cumsum_ptr + seq_id)
    token_end = tl.load(cumsum_ptr + seq_id + 1)
    token_num = token_end - token_start
    row_offset = block_id * BLOCK_SIZE
    if row_offset >= token_num:
        return
    in_tensor_offset = seq_id * in_tensor_ptr_stride + row_offset + offset
    in_tensor_mask = (row_offset + offset) < TOPK
    in_tensor_val = tl.load(in_tensor_ptr + in_tensor_offset, mask=in_tensor_mask)
    out_tensor_offset = token_start + row_offset + offset
    out_tensor_mask = (out_tensor_offset < token_end) & in_tensor_mask
    tl.store(out_tensor_ptr + out_tensor_offset, in_tensor_val, mask=out_tensor_mask)

fetch_id_to_ragged_triton

fetch_id_to_ragged_triton(
    in_tensor: Tensor,
    cumsum: Tensor,
    out_tensor: Tensor,
    topk,
)
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def fetch_id_to_ragged_triton(
    in_tensor: torch.Tensor, cumsum: torch.Tensor, out_tensor: torch.Tensor, topk
):
    num_tokens = in_tensor.size(0)
    block_size = 64
    num_block_per_row = triton.cdiv(topk, block_size)
    grid = (
        num_tokens,
        num_block_per_row,
    )
    fetch_id_to_ragged_kernel[grid](
        in_tensor, cumsum, out_tensor, in_tensor.stride(0), topk, num_tokens, block_size
    )

reference_mla_sparse_prefill

reference_mla_sparse_prefill(
    q: Tensor,
    kv: Tensor,
    indices: Tensor,
    sm_scale: float,
    d_v: int,
) -> tuple[Tensor, Tensor]
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
def reference_mla_sparse_prefill(
    q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int
) -> tuple[torch.Tensor, torch.Tensor]:
    import math

    def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
        return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)

    skv = kv.shape[0]
    sq = q.shape[0]
    topk = indices.shape[-1]
    dqk = q.shape[-1]
    indices = indices[:, 0, :]  # [s_q, topk]
    invalid_indices_mask = (indices < 0) | (indices >= skv)
    indices[invalid_indices_mask] = 0
    qs = q  # [s_q, h_q, d_qk]
    kvs = kv[:, 0, :][indices].view(sq, topk, dqk)  # [s_q, topk, d_qk]

    attn_score = (qs @ kvs.transpose(1, 2)).float()  # [s_q, h_q, topk]
    attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
    attn_score *= sm_scale * math.log2(math.e)
    lse = log2sumexp2(attn_score, dim=-1)  # [s_q, h_q]
    attn_score = torch.exp2(attn_score - lse.unsqueeze(-1))  # [s_q, h_q, topk]
    result = attn_score.to(q.dtype) @ kvs[:, :, :d_v]
    return (result, lse)