Skip to content

vllm.model_executor.layers.fused_moe.shared_fused_moe

SharedFusedMoE

Bases: FusedMoE

A FusedMoE operation that also computes the results of shared experts. If an all2all communicator is being used the shared expert computation can be interleaved with the fused all2all dispatch communication step.

Source code in vllm/model_executor/layers/fused_moe/shared_fused_moe.py
class SharedFusedMoE(FusedMoE):
    """
    A FusedMoE operation that also computes the results of shared experts.
    If an all2all communicator is being used the shared expert computation
    can be interleaved with the fused all2all dispatch communication step.
    """

    def __init__(
        self,
        shared_experts: torch.nn.Module | None,
        gate: torch.nn.Module | None = None,
        use_overlapped: bool = True,
        routed_input_transform: torch.nn.Module | None = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self._shared_experts = shared_experts
        self._routed_input_transform = routed_input_transform

        # Disable shared expert overlap if:
        #   - we are using eplb with non-default backend, because of correctness issues
        #   - we are using flashinfer with DP, since there nothing to gain
        #   - we are using marlin kernels
        backend = self.moe_parallel_config.all2all_backend
        self.use_overlapped = (
            use_overlapped
            and not (
                (self.enable_eplb and backend != "allgather_reducescatter")
                or self.moe_parallel_config.use_fi_all2allv_kernels
            )
            and self._shared_experts is not None
        )

        self._gate = gate

    @property
    def shared_experts(self) -> torch.nn.Module | None:
        return self._shared_experts if self.use_overlapped else None

    @property
    def gate(self) -> torch.nn.Module | None:
        return self._gate if self.use_overlapped else None

    @property
    def is_internal_router(self) -> bool:
        return self.gate is not None

    def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Apply transform for routed experts (e.g., latent projection).

        This is called by FusedMoE.forward_native. The original hidden_states
        is saved separately so shared experts get [S, hidden_size] while
        routed experts get the transformed [S, moe_latent_size].

        TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
        moved inside SharedFusedMoE to all-reduce on the smaller latent
        dimension.
        """
        if self._routed_input_transform is not None:
            result = self._routed_input_transform(hidden_states)
            # ReplicatedLinear returns (output, extra_bias) tuple.
            # We only need the output tensor; extra_bias is not used here.
            if isinstance(result, tuple):
                return result[0]
            return result
        return hidden_states

    def forward(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if not self.use_overlapped:
            if self._shared_experts is not None:
                shared_out = self._shared_experts(hidden_states)

                # Reduce shared expert outputs if necessary, since the MLP
                # should have been created with reduce_results=False.
                if (
                    self.reduce_results
                    and get_tensor_model_parallel_world_size() > 1
                    and self.must_reduce_shared_expert_outputs()
                ):
                    shared_out = tensor_model_parallel_all_reduce(shared_out)
            else:
                shared_out = None

            fused_out = super().forward(
                hidden_states=hidden_states,
                router_logits=router_logits,
            )
        else:
            shared_out, fused_out = super().forward(
                hidden_states=hidden_states,
                router_logits=router_logits,
            )
            # ensure early TP reduction of shared expert outputs when required
            if (
                shared_out is not None
                and self.reduce_results
                and get_tensor_model_parallel_world_size() > 1
                and self.must_reduce_shared_expert_outputs()
            ):
                shared_out = tensor_model_parallel_all_reduce(shared_out)
        return shared_out, fused_out

_gate instance-attribute

_gate = gate

_routed_input_transform instance-attribute

_routed_input_transform = routed_input_transform

_shared_experts instance-attribute

_shared_experts = shared_experts

gate property

gate: Module | None

is_internal_router property

is_internal_router: bool

shared_experts property

shared_experts: Module | None

use_overlapped instance-attribute

use_overlapped = (
    use_overlapped
    and not (
        enable_eplb
        and backend != "allgather_reducescatter"
        or use_fi_all2allv_kernels
    )
    and _shared_experts is not None
)

__init__

__init__(
    shared_experts: Module | None,
    gate: Module | None = None,
    use_overlapped: bool = True,
    routed_input_transform: Module | None = None,
    **kwargs,
)
Source code in vllm/model_executor/layers/fused_moe/shared_fused_moe.py
def __init__(
    self,
    shared_experts: torch.nn.Module | None,
    gate: torch.nn.Module | None = None,
    use_overlapped: bool = True,
    routed_input_transform: torch.nn.Module | None = None,
    **kwargs,
):
    super().__init__(**kwargs)
    self._shared_experts = shared_experts
    self._routed_input_transform = routed_input_transform

    # Disable shared expert overlap if:
    #   - we are using eplb with non-default backend, because of correctness issues
    #   - we are using flashinfer with DP, since there nothing to gain
    #   - we are using marlin kernels
    backend = self.moe_parallel_config.all2all_backend
    self.use_overlapped = (
        use_overlapped
        and not (
            (self.enable_eplb and backend != "allgather_reducescatter")
            or self.moe_parallel_config.use_fi_all2allv_kernels
        )
        and self._shared_experts is not None
    )

    self._gate = gate

apply_routed_input_transform

apply_routed_input_transform(
    hidden_states: Tensor,
) -> Tensor

Apply transform for routed experts (e.g., latent projection).

This is called by FusedMoE.forward_native. The original hidden_states is saved separately so shared experts get [S, hidden_size] while routed experts get the transformed [S, moe_latent_size].

TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be moved inside SharedFusedMoE to all-reduce on the smaller latent dimension.

Source code in vllm/model_executor/layers/fused_moe/shared_fused_moe.py
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
    """Apply transform for routed experts (e.g., latent projection).

    This is called by FusedMoE.forward_native. The original hidden_states
    is saved separately so shared experts get [S, hidden_size] while
    routed experts get the transformed [S, moe_latent_size].

    TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
    moved inside SharedFusedMoE to all-reduce on the smaller latent
    dimension.
    """
    if self._routed_input_transform is not None:
        result = self._routed_input_transform(hidden_states)
        # ReplicatedLinear returns (output, extra_bias) tuple.
        # We only need the output tensor; extra_bias is not used here.
        if isinstance(result, tuple):
            return result[0]
        return result
    return hidden_states

forward

forward(
    hidden_states: Tensor, router_logits: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/shared_fused_moe.py
def forward(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    if not self.use_overlapped:
        if self._shared_experts is not None:
            shared_out = self._shared_experts(hidden_states)

            # Reduce shared expert outputs if necessary, since the MLP
            # should have been created with reduce_results=False.
            if (
                self.reduce_results
                and get_tensor_model_parallel_world_size() > 1
                and self.must_reduce_shared_expert_outputs()
            ):
                shared_out = tensor_model_parallel_all_reduce(shared_out)
        else:
            shared_out = None

        fused_out = super().forward(
            hidden_states=hidden_states,
            router_logits=router_logits,
        )
    else:
        shared_out, fused_out = super().forward(
            hidden_states=hidden_states,
            router_logits=router_logits,
        )
        # ensure early TP reduction of shared expert outputs when required
        if (
            shared_out is not None
            and self.reduce_results
            and get_tensor_model_parallel_world_size() > 1
            and self.must_reduce_shared_expert_outputs()
        ):
            shared_out = tensor_model_parallel_all_reduce(shared_out)
    return shared_out, fused_out