Skip to content

vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method

logger module-attribute

logger = init_logger(__name__)

UnquantizedFusedMoEMethod

Bases: FusedMoEMethodBase, CustomOp

MoE method without quantization.

Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
    """MoE method without quantization."""

    def __init__(self, moe: FusedMoEConfig):
        super().__init__(moe)
        self.unquantized_backend = select_unquantized_moe_backend(
            use_ep=self.moe.moe_parallel_config.use_ep,
            use_dp=self.moe.moe_parallel_config.dp_size > 1,
        )

        # AITER only supports gated activations (silu/gelu), so disable it
        # for non-gated MoE (is_act_and_mul=False)
        self.rocm_aiter_moe_enabled = (
            rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
        )
        self.kernel: mk.FusedMoEModularKernel | None = None
        self._is_monolithic = current_platform.is_cpu() or current_platform.is_xpu()

    @property
    def is_monolithic(self) -> bool:
        return self._is_monolithic

    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> FusedMoEPrepareAndFinalize | None:
        if self.unquantized_backend == UnquantizedMoeBackend.AITER:
            return None
        else:
            return super().maybe_make_prepare_finalize(routing_tables)

    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> FusedMoEPermuteExpertsUnpermute:
        assert self.moe_quant_config is not None
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            logger.debug("BatchedTritonExperts %s", self.moe)
            return BatchedTritonExperts(
                moe_config=self.moe,
                quant_config=self.moe_quant_config,
                max_num_tokens=self.moe.max_num_tokens,
                num_dispatchers=prepare_finalize.num_dispatchers(),
            )
        else:
            logger.debug("TritonExperts %s", self.moe)
            return TritonExperts(
                moe_config=self.moe,
                quant_config=self.moe_quant_config,
            )

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition
        # Fused gate_up_proj (column parallel)
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                w13_up_dim,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)
        if self.moe.has_bias:
            w13_bias = torch.nn.Parameter(
                torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, extra_weight_attrs)
        # down_proj (row parallel)
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)
        if self.moe.has_bias:
            w2_bias = torch.nn.Parameter(
                torch.zeros(num_experts, hidden_size, dtype=params_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, extra_weight_attrs)

    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (
            envs.VLLM_ROCM_MOE_PADDING
            and current_platform.is_rocm()
            and weight.stride(-1) == 1
            and (weight.stride(-2) * weight.element_size()) % 512 == 0
        ):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()

        return weight

    def _setup_kernel(
        self,
        layer: Module,
        w13: torch.Tensor,
        w2: torch.Tensor,
    ) -> None:
        # Shuffle weights to runtime format.
        w13, w2 = convert_to_unquantized_kernel_format(
            self.unquantized_backend,
            layer=layer,
            w13_weight=w13,
            w2_weight=w2,
        )
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)

        # Setup Modular Kernel for TP Case
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        assert self.moe_quant_config is not None

        self.kernel, self.use_inplace = make_unquantized_moe_kernel(
            backend=self.unquantized_backend,
            quant_config=self.moe_quant_config,
            moe_config=self.moe,
        )

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        super().process_weights_after_loading(layer)

        # Padding the weight for better performance on ROCm
        layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
        layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)

        if self.unquantized_backend == UnquantizedMoeBackend.XPU:
            import intel_extension_for_pytorch as ipex

            ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
            self.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
                layer.w13_weight,
                layer.w2_weight,
                use_prepack=True,
                experts_start_id=ep_rank_start,
            )
        elif self.unquantized_backend == UnquantizedMoeBackend.CPU:
            from vllm.model_executor.layers.fused_moe import cpu_fused_moe

            if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
                from vllm.model_executor.layers.utils import check_cpu_sgl_kernel

                dtype_w13 = layer.w13_weight.dtype
                _, n_w13, k_w13 = layer.w13_weight.size()
                dtype_w2 = layer.w2_weight.dtype
                _, n_w2, k_w2 = layer.w2_weight.size()
                if (
                    envs.VLLM_CPU_SGL_KERNEL
                    and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13)
                    and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)
                ):
                    packed_w13_weight = torch.ops._C.convert_weight_packed(
                        layer.w13_weight
                    )
                    assert packed_w13_weight.size() == layer.w13_weight.size()
                    layer.w13_weight.copy_(packed_w13_weight)
                    del packed_w13_weight
                    packed_w2_weight = torch.ops._C.convert_weight_packed(
                        layer.w2_weight
                    )
                    assert packed_w2_weight.size() == layer.w2_weight.size()
                    layer.w2_weight.copy_(packed_w2_weight)
                    self.cpu_fused_moe: Callable = cpu_fused_moe.SGLFusedMOE(layer)
                else:
                    self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
            else:
                self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
        elif current_platform.is_cuda_alike():
            self._setup_kernel(
                layer=layer,
                w13=layer.w13_weight,
                w2=layer.w2_weight,
            )

    def apply(
        self,
        layer: "FusedMoE",  # type: ignore[name-defined] # noqa: F821
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        return self.forward(
            layer=layer,
            x=x,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
        )

    def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
        if self.moe.has_bias:
            return biased_moe_quant_config(
                layer.w13_bias,
                layer.w2_bias,
            )
        else:
            return FUSED_MOE_UNQUANTIZED_CONFIG

    def forward_cuda(
        self,
        layer: "FusedMoE",  # type: ignore[name-defined] # noqa: F821
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.kernel is not None
        return self.kernel(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
        )

    def forward_monolithic_cpu(
        self,
        layer: "FusedMoE",  # type: ignore[name-defined] # noqa: F821
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        return self.cpu_fused_moe(
            layer,
            x,
            layer.use_grouped_topk,
            layer.top_k,
            router_logits,
            layer.renormalize,
            layer.topk_group,
            layer.num_expert_group,
            layer.global_num_experts,
            layer.expert_map,
            layer.custom_routing_function,
            layer.scoring_func,
            layer.routed_scaling_factor,
            layer.e_score_correction_bias,
            layer.apply_router_weight_on_input,
            layer.activation,
        )

    def forward_monolithic_xpu(
        self,
        layer: "FusedMoE",  # type: ignore[name-defined] # noqa: F821
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        return self.ipex_fusion(
            x,
            layer.use_grouped_topk,
            layer.top_k,
            router_logits,
            layer.renormalize,
            layer.topk_group,
            layer.num_expert_group,
            custom_routing_function=layer.custom_routing_function,
        )

    if current_platform.is_cpu():
        forward_native: Callable = forward_monolithic_cpu
        apply_monolithic = forward_monolithic_cpu
    elif current_platform.is_xpu():
        forward_native = forward_monolithic_xpu
        apply_monolithic = forward_monolithic_xpu
    else:
        forward_native = forward_cuda

_is_monolithic instance-attribute

_is_monolithic = is_cpu() or is_xpu()

allow_inplace property

allow_inplace: bool

apply_monolithic class-attribute instance-attribute

apply_monolithic = forward_monolithic_cpu

forward_native class-attribute instance-attribute

forward_native: Callable = forward_monolithic_cpu

is_monolithic property

is_monolithic: bool

kernel instance-attribute

kernel: FusedMoEModularKernel | None = None

rocm_aiter_moe_enabled instance-attribute

rocm_aiter_moe_enabled = (
    is_fused_moe_enabled() and is_act_and_mul
)

supports_eplb property

supports_eplb: bool

unquantized_backend instance-attribute

unquantized_backend = select_unquantized_moe_backend(
    use_ep=use_ep, use_dp=dp_size > 1
)

__init__

__init__(moe: FusedMoEConfig)
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def __init__(self, moe: FusedMoEConfig):
    super().__init__(moe)
    self.unquantized_backend = select_unquantized_moe_backend(
        use_ep=self.moe.moe_parallel_config.use_ep,
        use_dp=self.moe.moe_parallel_config.dp_size > 1,
    )

    # AITER only supports gated activations (silu/gelu), so disable it
    # for non-gated MoE (is_act_and_mul=False)
    self.rocm_aiter_moe_enabled = (
        rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
    )
    self.kernel: mk.FusedMoEModularKernel | None = None
    self._is_monolithic = current_platform.is_cpu() or current_platform.is_xpu()

_maybe_pad_weight

_maybe_pad_weight(weight: Tensor) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
    # Pad the weight tensor. This is an optimization on ROCm platform, which
    # can benefit from tensors located far enough from one another in memory
    if (
        envs.VLLM_ROCM_MOE_PADDING
        and current_platform.is_rocm()
        and weight.stride(-1) == 1
        and (weight.stride(-2) * weight.element_size()) % 512 == 0
    ):
        num_pad = 256 // weight.element_size()
        weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
        torch.cuda.empty_cache()

    return weight

_setup_kernel

_setup_kernel(
    layer: Module, w13: Tensor, w2: Tensor
) -> None
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def _setup_kernel(
    self,
    layer: Module,
    w13: torch.Tensor,
    w2: torch.Tensor,
) -> None:
    # Shuffle weights to runtime format.
    w13, w2 = convert_to_unquantized_kernel_format(
        self.unquantized_backend,
        layer=layer,
        w13_weight=w13,
        w2_weight=w2,
    )
    replace_parameter(layer, "w13_weight", w13)
    replace_parameter(layer, "w2_weight", w2)

    # Setup Modular Kernel for TP Case
    self.moe_quant_config = self.get_fused_moe_quant_config(layer)
    assert self.moe_quant_config is not None

    self.kernel, self.use_inplace = make_unquantized_moe_kernel(
        backend=self.unquantized_backend,
        quant_config=self.moe_quant_config,
        moe_config=self.moe,
    )

apply

apply(
    layer: FusedMoE,
    x: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def apply(
    self,
    layer: "FusedMoE",  # type: ignore[name-defined] # noqa: F821
    x: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    return self.forward(
        layer=layer,
        x=x,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
    )

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def create_weights(
    self,
    layer: torch.nn.Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    if self.moe.is_act_and_mul:
        w13_up_dim = 2 * intermediate_size_per_partition
    else:
        w13_up_dim = intermediate_size_per_partition
    # Fused gate_up_proj (column parallel)
    w13_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            w13_up_dim,
            hidden_size,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)
    if self.moe.has_bias:
        w13_bias = torch.nn.Parameter(
            torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)
    # down_proj (row parallel)
    w2_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)
    if self.moe.has_bias:
        w2_bias = torch.nn.Parameter(
            torch.zeros(num_experts, hidden_size, dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

forward_cuda

forward_cuda(
    layer: FusedMoE,
    x: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def forward_cuda(
    self,
    layer: "FusedMoE",  # type: ignore[name-defined] # noqa: F821
    x: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    assert self.kernel is not None
    return self.kernel(
        hidden_states=x,
        w1=layer.w13_weight,
        w2=layer.w2_weight,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=self.use_inplace,
        activation=layer.activation,
        apply_router_weight_on_input=layer.apply_router_weight_on_input,
        global_num_experts=layer.global_num_experts,
        expert_map=layer.expert_map,
    )

forward_monolithic_cpu

forward_monolithic_cpu(
    layer: FusedMoE, x: Tensor, router_logits: Tensor
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def forward_monolithic_cpu(
    self,
    layer: "FusedMoE",  # type: ignore[name-defined] # noqa: F821
    x: torch.Tensor,
    router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    return self.cpu_fused_moe(
        layer,
        x,
        layer.use_grouped_topk,
        layer.top_k,
        router_logits,
        layer.renormalize,
        layer.topk_group,
        layer.num_expert_group,
        layer.global_num_experts,
        layer.expert_map,
        layer.custom_routing_function,
        layer.scoring_func,
        layer.routed_scaling_factor,
        layer.e_score_correction_bias,
        layer.apply_router_weight_on_input,
        layer.activation,
    )

forward_monolithic_xpu

forward_monolithic_xpu(
    layer: FusedMoE, x: Tensor, router_logits: Tensor
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def forward_monolithic_xpu(
    self,
    layer: "FusedMoE",  # type: ignore[name-defined] # noqa: F821
    x: torch.Tensor,
    router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    return self.ipex_fusion(
        x,
        layer.use_grouped_topk,
        layer.top_k,
        router_logits,
        layer.renormalize,
        layer.topk_group,
        layer.num_expert_group,
        custom_routing_function=layer.custom_routing_function,
    )

get_fused_moe_quant_config

get_fused_moe_quant_config(
    layer: Module,
) -> FusedMoEQuantConfig
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
    if self.moe.has_bias:
        return biased_moe_quant_config(
            layer.w13_bias,
            layer.w2_bias,
        )
    else:
        return FUSED_MOE_UNQUANTIZED_CONFIG

maybe_make_prepare_finalize

maybe_make_prepare_finalize(
    routing_tables: tuple[Tensor, Tensor, Tensor]
    | None = None,
) -> FusedMoEPrepareAndFinalize | None
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def maybe_make_prepare_finalize(
    self,
    routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
    if self.unquantized_backend == UnquantizedMoeBackend.AITER:
        return None
    else:
        return super().maybe_make_prepare_finalize(routing_tables)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    super().process_weights_after_loading(layer)

    # Padding the weight for better performance on ROCm
    layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
    layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)

    if self.unquantized_backend == UnquantizedMoeBackend.XPU:
        import intel_extension_for_pytorch as ipex

        ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
        self.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
            layer.w13_weight,
            layer.w2_weight,
            use_prepack=True,
            experts_start_id=ep_rank_start,
        )
    elif self.unquantized_backend == UnquantizedMoeBackend.CPU:
        from vllm.model_executor.layers.fused_moe import cpu_fused_moe

        if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
            from vllm.model_executor.layers.utils import check_cpu_sgl_kernel

            dtype_w13 = layer.w13_weight.dtype
            _, n_w13, k_w13 = layer.w13_weight.size()
            dtype_w2 = layer.w2_weight.dtype
            _, n_w2, k_w2 = layer.w2_weight.size()
            if (
                envs.VLLM_CPU_SGL_KERNEL
                and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13)
                and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)
            ):
                packed_w13_weight = torch.ops._C.convert_weight_packed(
                    layer.w13_weight
                )
                assert packed_w13_weight.size() == layer.w13_weight.size()
                layer.w13_weight.copy_(packed_w13_weight)
                del packed_w13_weight
                packed_w2_weight = torch.ops._C.convert_weight_packed(
                    layer.w2_weight
                )
                assert packed_w2_weight.size() == layer.w2_weight.size()
                layer.w2_weight.copy_(packed_w2_weight)
                self.cpu_fused_moe: Callable = cpu_fused_moe.SGLFusedMOE(layer)
            else:
                self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
        else:
            self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
    elif current_platform.is_cuda_alike():
        self._setup_kernel(
            layer=layer,
            w13=layer.w13_weight,
            w2=layer.w2_weight,
        )

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    layer: Module,
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
def select_gemm_impl(
    self,
    prepare_finalize: FusedMoEPrepareAndFinalize,
    layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
    assert self.moe_quant_config is not None
    if (
        prepare_finalize.activation_format
        == FusedMoEActivationFormat.BatchedExperts
    ):
        logger.debug("BatchedTritonExperts %s", self.moe)
        return BatchedTritonExperts(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
            max_num_tokens=self.moe.max_num_tokens,
            num_dispatchers=prepare_finalize.num_dispatchers(),
        )
    else:
        logger.debug("TritonExperts %s", self.moe)
        return TritonExperts(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
        )