Skip to content

vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm

ROCmFP8ScaledMMLinearKernel

Bases: FP8ScaledMMLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
class ROCmFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_rocm():
            return False, "requires ROCm."

        from vllm.platforms.rocm import on_mi3xx

        if not on_mi3xx():
            return False, "requires MI3xx."

        if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
            return False, "requires VLLM_ROCM_USE_SKINNY_GEMM to be enabled."

        return True, None

    @classmethod
    def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        per_tensor_activation_scales = (
            c.activation_quant_key.scale.group_shape.is_per_tensor()
        )
        per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()

        if not (per_tensor_activation_scales and per_tensor_weight_scales):
            return False, "requires per tensor activation and weight scales."

        return True, None

    def apply_scaled_mm(
        self,
        *,
        A: torch.Tensor,
        B: torch.Tensor,
        out_dtype: torch.dtype,
        As: torch.Tensor,
        Bs: torch.Tensor,
        bias: torch.Tensor | None,
        output_shape: list,
    ) -> torch.Tensor:
        output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
            A, B, out_dtype, As, Bs, bias
        )
        return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)

apply_scaled_mm

apply_scaled_mm(
    *,
    A: Tensor,
    B: Tensor,
    out_dtype: dtype,
    As: Tensor,
    Bs: Tensor,
    bias: Tensor | None,
    output_shape: list,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
def apply_scaled_mm(
    self,
    *,
    A: torch.Tensor,
    B: torch.Tensor,
    out_dtype: torch.dtype,
    As: torch.Tensor,
    Bs: torch.Tensor,
    bias: torch.Tensor | None,
    output_shape: list,
) -> torch.Tensor:
    output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
        A, B, out_dtype, As, Bs, bias
    )
    return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)

can_implement classmethod

can_implement(
    c: FP8ScaledMMLinearLayerConfig,
) -> tuple[bool, str | None]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
    per_tensor_activation_scales = (
        c.activation_quant_key.scale.group_shape.is_per_tensor()
    )
    per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()

    if not (per_tensor_activation_scales and per_tensor_weight_scales):
        return False, "requires per tensor activation and weight scales."

    return True, None

is_supported classmethod

is_supported(
    compute_capability: int | None = None,
) -> tuple[bool, str | None]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
@classmethod
def is_supported(
    cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
    if not current_platform.is_rocm():
        return False, "requires ROCm."

    from vllm.platforms.rocm import on_mi3xx

    if not on_mi3xx():
        return False, "requires MI3xx."

    if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
        return False, "requires VLLM_ROCM_USE_SKINNY_GEMM to be enabled."

    return True, None

rocm_per_tensor_float_w8a8_scaled_mm_fake

rocm_per_tensor_float_w8a8_scaled_mm_fake(
    A: Tensor,
    B: Tensor,
    out_dtype: dtype,
    As: Tensor,
    Bs: Tensor,
    bias: Tensor,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    out_dtype: torch.dtype,
    As: torch.Tensor,
    Bs: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)

rocm_per_tensor_float_w8a8_scaled_mm_impl

rocm_per_tensor_float_w8a8_scaled_mm_impl(
    A: Tensor,
    B: Tensor,
    out_dtype: dtype,
    As: Tensor,
    Bs: Tensor,
    bias: Tensor,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
    A: torch.Tensor,
    B: torch.Tensor,
    out_dtype: torch.dtype,
    As: torch.Tensor,
    Bs: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    if (
        A.shape[0] == 1
        and B.shape[1] % 16 == 0
        and ((bias is None) or (bias.dtype == out_dtype))
    ):
        output = ops.wvSplitKQ(
            B.t(),
            A,
            out_dtype,
            As,
            Bs,
            get_cu_count(),
            bias,
        )
    # Fallback
    else:
        output = torch._scaled_mm(
            A,
            B,
            out_dtype=out_dtype,
            scale_a=As,
            scale_b=Bs,
            bias=bias,
        )
    return output