Skip to content

vllm.model_executor.models.step3p5

Inference-only Jurassic model.

logger module-attribute

logger = init_logger(__name__)

FP32ReplicatedLinear

Bases: ReplicatedLinear

Use FP32 for higher precision.

Source code in vllm/model_executor/models/step3p5.py
class FP32ReplicatedLinear(ReplicatedLinear):
    """
    Use FP32 for higher precision.
    """

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
        assert self.params_dtype == torch.float32
        return super().forward(x.to(torch.float32))

forward

forward(
    x: Tensor,
) -> Tensor | tuple[Tensor, Parameter | None]
Source code in vllm/model_executor/models/step3p5.py
def forward(
    self,
    x: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
    assert self.params_dtype == torch.float32
    return super().forward(x.to(torch.float32))

FusedMoEBlock

Bases: Module

Source code in vllm/model_executor/models/step3p5.py
class FusedMoEBlock(nn.Module):
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()

        self.tp_size = get_tensor_model_parallel_world_size()
        self.layer_idx = extract_layer_index(prefix)

        self.ep_size = get_ep_group().device_group.size()
        self.ep_rank = get_ep_group().device_group.rank()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.hidden_size = config.hidden_size
        self.enable_eplb = parallel_config.enable_eplb
        self.n_routed_experts = config.moe_num_experts
        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        if self.tp_size > config.moe_num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.moe_num_experts}."
            )

        self.gate = FP32ReplicatedLinear(
            config.hidden_size,
            config.moe_num_experts,
            bias=False,
            quant_config=None,
            params_dtype=torch.float32,  # Use FP32 for higher precision.
            prefix=f"{prefix}.gate",
        )
        self.use_moe_router_bias = config.use_moe_router_bias
        assert self.use_moe_router_bias, "Only support use_moe_router_bias is true."
        self.routed_scaling_factor = config.moe_router_scaling_factor
        self.router_bias = nn.Parameter(
            torch.zeros(config.moe_num_experts, dtype=torch.float32),
            requires_grad=False,
        )
        self.need_fp32_gate = config.need_fp32_gate
        assert self.need_fp32_gate, (
            "Router logits must use FP32 precision for numerical stability."
        )

        activation = "silu"
        swiglu_limits = config.swiglu_limits or []
        swiglu_limit = (
            swiglu_limits[self.layer_idx]
            if self.layer_idx < len(swiglu_limits)
            else None
        )
        if swiglu_limit not in (None, 0):
            swiglu_limit = float(swiglu_limit)
            assert swiglu_limit == 7.0, (
                "Swiglu limit in fused moe block only suport 7.0 now."
            )
            activation = "swiglustep"
            logger.debug(
                "step3p5 layer_idx: %s, activation: %s, limit: %s",
                self.layer_idx,
                activation,
                swiglu_limit,
            )

        self.share_expert = Step3p5MLP(
            config=config,
            hidden_size=self.hidden_size,
            intermediate_size=config.share_expert_dim,
            hidden_act="silu",
            reduce_results=False,
            quant_config=quant_config,
            prefix=f"{prefix}.share_expert",
        )
        self.experts = SharedFusedMoE(
            shared_experts=self.share_expert,
            gate=self.gate,
            num_experts=config.moe_num_experts,
            top_k=config.moe_top_k,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_expert_weight,
            quant_config=quant_config,
            activation=activation,
            prefix=f"{prefix}.experts",
            scoring_func=getattr(config, "moe_router_activation", "sigmoid"),
            e_score_correction_bias=self.router_bias,
            routed_scaling_factor=config.moe_router_scaling_factor,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)

        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )

        shared_output, final_hidden_states = fused_moe_out
        if self.share_expert is None:
            assert shared_output is None

        if self.share_expert is not None:
            assert shared_output is not None
            final_hidden_states += shared_output

        if self.tp_size > 1:
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states
            )

        return final_hidden_states.view(num_tokens, hidden_dim)

enable_eplb instance-attribute

enable_eplb = enable_eplb

ep_rank instance-attribute

ep_rank = rank()

ep_size instance-attribute

ep_size = size()

experts instance-attribute

experts = SharedFusedMoE(
    shared_experts=share_expert,
    gate=gate,
    num_experts=moe_num_experts,
    top_k=moe_top_k,
    hidden_size=hidden_size,
    intermediate_size=moe_intermediate_size,
    reduce_results=False,
    renormalize=norm_expert_weight,
    quant_config=quant_config,
    activation=activation,
    prefix=f"{prefix}.experts",
    scoring_func=getattr(
        config, "moe_router_activation", "sigmoid"
    ),
    e_score_correction_bias=router_bias,
    routed_scaling_factor=moe_router_scaling_factor,
    enable_eplb=enable_eplb,
    num_redundant_experts=n_redundant_experts,
)

gate instance-attribute

gate = FP32ReplicatedLinear(
    hidden_size,
    moe_num_experts,
    bias=False,
    quant_config=None,
    params_dtype=float32,
    prefix=f"{prefix}.gate",
)

hidden_size instance-attribute

hidden_size = hidden_size

layer_idx instance-attribute

layer_idx = extract_layer_index(prefix)

n_local_physical_experts instance-attribute

n_local_physical_experts = n_physical_experts // ep_size

n_logical_experts instance-attribute

n_logical_experts = n_routed_experts

n_physical_experts instance-attribute

n_physical_experts = n_logical_experts + n_redundant_experts

n_redundant_experts instance-attribute

n_redundant_experts = num_redundant_experts

n_routed_experts instance-attribute

n_routed_experts = moe_num_experts

need_fp32_gate instance-attribute

need_fp32_gate = need_fp32_gate

physical_expert_end instance-attribute

physical_expert_end = (
    physical_expert_start + n_local_physical_experts
)

physical_expert_start instance-attribute

physical_expert_start = ep_rank * n_local_physical_experts

routed_scaling_factor instance-attribute

routed_scaling_factor = moe_router_scaling_factor

router_bias instance-attribute

router_bias = Parameter(
    zeros(moe_num_experts, dtype=float32),
    requires_grad=False,
)

share_expert instance-attribute

share_expert = Step3p5MLP(
    config=config,
    hidden_size=hidden_size,
    intermediate_size=share_expert_dim,
    hidden_act="silu",
    reduce_results=False,
    quant_config=quant_config,
    prefix=f"{prefix}.share_expert",
)

tp_size instance-attribute

use_moe_router_bias instance-attribute

use_moe_router_bias = use_moe_router_bias

__init__

__init__(vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/step3p5.py
def __init__(
    self,
    vllm_config: VllmConfig,
    prefix: str = "",
):
    super().__init__()

    self.tp_size = get_tensor_model_parallel_world_size()
    self.layer_idx = extract_layer_index(prefix)

    self.ep_size = get_ep_group().device_group.size()
    self.ep_rank = get_ep_group().device_group.rank()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    parallel_config = vllm_config.parallel_config

    self.hidden_size = config.hidden_size
    self.enable_eplb = parallel_config.enable_eplb
    self.n_routed_experts = config.moe_num_experts
    self.n_logical_experts = self.n_routed_experts
    self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts
    self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
    self.n_local_physical_experts = self.n_physical_experts // self.ep_size

    self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
    self.physical_expert_end = (
        self.physical_expert_start + self.n_local_physical_experts
    )

    if self.tp_size > config.moe_num_experts:
        raise ValueError(
            f"Tensor parallel size {self.tp_size} is greater than "
            f"the number of experts {config.moe_num_experts}."
        )

    self.gate = FP32ReplicatedLinear(
        config.hidden_size,
        config.moe_num_experts,
        bias=False,
        quant_config=None,
        params_dtype=torch.float32,  # Use FP32 for higher precision.
        prefix=f"{prefix}.gate",
    )
    self.use_moe_router_bias = config.use_moe_router_bias
    assert self.use_moe_router_bias, "Only support use_moe_router_bias is true."
    self.routed_scaling_factor = config.moe_router_scaling_factor
    self.router_bias = nn.Parameter(
        torch.zeros(config.moe_num_experts, dtype=torch.float32),
        requires_grad=False,
    )
    self.need_fp32_gate = config.need_fp32_gate
    assert self.need_fp32_gate, (
        "Router logits must use FP32 precision for numerical stability."
    )

    activation = "silu"
    swiglu_limits = config.swiglu_limits or []
    swiglu_limit = (
        swiglu_limits[self.layer_idx]
        if self.layer_idx < len(swiglu_limits)
        else None
    )
    if swiglu_limit not in (None, 0):
        swiglu_limit = float(swiglu_limit)
        assert swiglu_limit == 7.0, (
            "Swiglu limit in fused moe block only suport 7.0 now."
        )
        activation = "swiglustep"
        logger.debug(
            "step3p5 layer_idx: %s, activation: %s, limit: %s",
            self.layer_idx,
            activation,
            swiglu_limit,
        )

    self.share_expert = Step3p5MLP(
        config=config,
        hidden_size=self.hidden_size,
        intermediate_size=config.share_expert_dim,
        hidden_act="silu",
        reduce_results=False,
        quant_config=quant_config,
        prefix=f"{prefix}.share_expert",
    )
    self.experts = SharedFusedMoE(
        shared_experts=self.share_expert,
        gate=self.gate,
        num_experts=config.moe_num_experts,
        top_k=config.moe_top_k,
        hidden_size=config.hidden_size,
        intermediate_size=config.moe_intermediate_size,
        reduce_results=False,
        renormalize=config.norm_expert_weight,
        quant_config=quant_config,
        activation=activation,
        prefix=f"{prefix}.experts",
        scoring_func=getattr(config, "moe_router_activation", "sigmoid"),
        e_score_correction_bias=self.router_bias,
        routed_scaling_factor=config.moe_router_scaling_factor,
        enable_eplb=self.enable_eplb,
        num_redundant_experts=self.n_redundant_experts,
    )

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/step3p5.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    num_tokens, hidden_dim = hidden_states.shape
    hidden_states = hidden_states.view(-1, hidden_dim)

    if self.experts.is_internal_router:
        # In this case, the gate/router runs inside the FusedMoE class
        fused_moe_out = self.experts(
            hidden_states=hidden_states, router_logits=hidden_states
        )
    else:
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        fused_moe_out = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )

    shared_output, final_hidden_states = fused_moe_out
    if self.share_expert is None:
        assert shared_output is None

    if self.share_expert is not None:
        assert shared_output is not None
        final_hidden_states += shared_output

    if self.tp_size > 1:
        final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
            final_hidden_states
        )

    return final_hidden_states.view(num_tokens, hidden_dim)

Step3p5Attention

Bases: Module

Source code in vllm/model_executor/models/step3p5.py
class Step3p5Attention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        head_dim: int | None = None,
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
        rope_theta: float | list[float] | None = 10000,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        rope_scaling: dict[str, Any] | None = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
        # Step3p5 specific args
        sliding_window: int | None = None,
        use_head_wise_attn_gate: bool = False,
        layer_types: list = None,
        use_rope_layers: list = None,
        yarn_only_types: list = None,
        swa_num_attention_heads: int | None = None,
        partial_rotary_factor: float = 1.0,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.total_num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.layer_idx = extract_layer_index(prefix)
        if layer_types:
            enable_sliding_window = layer_types[self.layer_idx] == "sliding_attention"
        else:
            enable_sliding_window = self.layer_idx % 2 == 0
        if yarn_only_types and layer_types[self.layer_idx] not in yarn_only_types:
            rope_scaling = None

        if sliding_window is not None and enable_sliding_window:
            sliding_window = sliding_window
            if swa_num_attention_heads is not None:
                num_heads = swa_num_attention_heads
                self.total_num_heads = swa_num_attention_heads
        else:
            sliding_window = None

        if isinstance(rope_theta, list):
            rope_theta = rope_theta[self.layer_idx]

        self.rank = get_tensor_model_parallel_rank()
        self.partial_rotary_factor = partial_rotary_factor
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        if rope_scaling is not None and not isinstance(rope_scaling, dict):
            raise ValueError("rope_scaling must be a dict for Step3p5Attention.")

        rope_parameters: dict[str, Any] = (
            dict(rope_scaling) if rope_scaling is not None else {}
        )
        rope_parameters.setdefault("rope_type", "default")
        rope_parameters["rope_theta"] = self.rope_theta
        rope_parameters["partial_rotary_factor"] = partial_rotary_factor

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=max_position,
            rope_parameters=rope_parameters,
        )

        self.q_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
        self.k_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
        self.use_head_wise_attn_gate = use_head_wise_attn_gate
        if use_head_wise_attn_gate:
            self.g_proj = ColumnParallelLinear(
                hidden_size,
                self.total_num_heads,
                bias=False,
                prefix=f"{prefix}.g_proj",
            )

        self.use_rope = True
        if use_rope_layers:
            self.use_rope = use_rope_layers[self.layer_idx]

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            per_layer_sliding_window=sliding_window,
            attn_type=attn_type,
        )

        self.max_position_embeddings = max_position
        assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5
        self.rotary_dim = (
            self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        # Add qk-norm inline similar to Qwen3 MOE attention
        q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
        q_by_head = self.q_norm(q_by_head.contiguous())
        q = q_by_head.view(q.shape)

        k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
        k_by_head = self.k_norm(k_by_head.contiguous())
        k = k_by_head.view(k.shape)
        if self.use_rope:
            q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        if self.use_head_wise_attn_gate:
            extra_dims, _ = self.g_proj(hidden_states)
            output = (
                attn_output.view(*attn_output.shape[:-1], self.num_heads, self.head_dim)
                * extra_dims.unsqueeze(-1).sigmoid()
            )
            attn_output = output.view(*attn_output.shape)
        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads,
    head_dim,
    scaling,
    num_kv_heads=num_kv_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
    per_layer_sliding_window=sliding_window,
    attn_type=attn_type,
)

g_proj instance-attribute

g_proj = ColumnParallelLinear(
    hidden_size,
    total_num_heads,
    bias=False,
    prefix=f"{prefix}.g_proj",
)

head_dim instance-attribute

head_dim = head_dim or hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

k_norm instance-attribute

k_norm = GemmaRMSNorm(head_dim, rms_norm_eps)

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

layer_idx instance-attribute

layer_idx = extract_layer_index(prefix)

max_position_embeddings instance-attribute

max_position_embeddings = max_position

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

partial_rotary_factor instance-attribute

partial_rotary_factor = partial_rotary_factor

q_norm instance-attribute

q_norm = GemmaRMSNorm(head_dim, rms_norm_eps)

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=qkv_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

rank instance-attribute

rope_theta instance-attribute

rope_theta = rope_theta

rotary_dim instance-attribute

rotary_dim = (
    head_dim
    if partial_rotary_factor == 1
    else head_dim // 2
)

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_size=head_dim,
    max_position=max_position,
    rope_parameters=rope_parameters,
)

scaling instance-attribute

scaling = head_dim ** -0.5

total_num_heads instance-attribute

total_num_heads = num_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_kv_heads

use_head_wise_attn_gate instance-attribute

use_head_wise_attn_gate = use_head_wise_attn_gate

use_rope instance-attribute

use_rope = True

__init__

__init__(
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    max_position: int = 4096 * 32,
    head_dim: int | None = None,
    rms_norm_eps: float = 1e-06,
    qkv_bias: bool = False,
    rope_theta: float | list[float] | None = 10000,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    rope_scaling: dict[str, Any] | None = None,
    prefix: str = "",
    attn_type: str = DECODER,
    sliding_window: int | None = None,
    use_head_wise_attn_gate: bool = False,
    layer_types: list = None,
    use_rope_layers: list = None,
    yarn_only_types: list = None,
    swa_num_attention_heads: int | None = None,
    partial_rotary_factor: float = 1.0,
)
Source code in vllm/model_executor/models/step3p5.py
def __init__(
    self,
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    max_position: int = 4096 * 32,
    head_dim: int | None = None,
    rms_norm_eps: float = 1e-06,
    qkv_bias: bool = False,
    rope_theta: float | list[float] | None = 10000,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    rope_scaling: dict[str, Any] | None = None,
    prefix: str = "",
    attn_type: str = AttentionType.DECODER,
    # Step3p5 specific args
    sliding_window: int | None = None,
    use_head_wise_attn_gate: bool = False,
    layer_types: list = None,
    use_rope_layers: list = None,
    yarn_only_types: list = None,
    swa_num_attention_heads: int | None = None,
    partial_rotary_factor: float = 1.0,
):
    super().__init__()
    self.hidden_size = hidden_size
    self.total_num_heads = num_heads
    tp_size = get_tensor_model_parallel_world_size()
    self.layer_idx = extract_layer_index(prefix)
    if layer_types:
        enable_sliding_window = layer_types[self.layer_idx] == "sliding_attention"
    else:
        enable_sliding_window = self.layer_idx % 2 == 0
    if yarn_only_types and layer_types[self.layer_idx] not in yarn_only_types:
        rope_scaling = None

    if sliding_window is not None and enable_sliding_window:
        sliding_window = sliding_window
        if swa_num_attention_heads is not None:
            num_heads = swa_num_attention_heads
            self.total_num_heads = swa_num_attention_heads
    else:
        sliding_window = None

    if isinstance(rope_theta, list):
        rope_theta = rope_theta[self.layer_idx]

    self.rank = get_tensor_model_parallel_rank()
    self.partial_rotary_factor = partial_rotary_factor
    assert self.total_num_heads % tp_size == 0
    self.num_heads = self.total_num_heads // tp_size
    self.total_num_kv_heads = num_kv_heads
    if self.total_num_kv_heads >= tp_size:
        # Number of KV heads is greater than TP size, so we partition
        # the KV heads across multiple tensor parallel GPUs.
        assert self.total_num_kv_heads % tp_size == 0
    else:
        # Number of KV heads is less than TP size, so we replicate
        # the KV heads across multiple tensor parallel GPUs.
        assert tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
    self.head_dim = head_dim or hidden_size // self.total_num_heads
    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim
    self.scaling = self.head_dim**-0.5
    self.rope_theta = rope_theta
    self.qkv_proj = QKVParallelLinear(
        hidden_size,
        self.head_dim,
        self.total_num_heads,
        self.total_num_kv_heads,
        bias=qkv_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.o_proj = RowParallelLinear(
        self.total_num_heads * self.head_dim,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )

    if rope_scaling is not None and not isinstance(rope_scaling, dict):
        raise ValueError("rope_scaling must be a dict for Step3p5Attention.")

    rope_parameters: dict[str, Any] = (
        dict(rope_scaling) if rope_scaling is not None else {}
    )
    rope_parameters.setdefault("rope_type", "default")
    rope_parameters["rope_theta"] = self.rope_theta
    rope_parameters["partial_rotary_factor"] = partial_rotary_factor

    self.rotary_emb = get_rope(
        head_size=self.head_dim,
        max_position=max_position,
        rope_parameters=rope_parameters,
    )

    self.q_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
    self.k_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
    self.use_head_wise_attn_gate = use_head_wise_attn_gate
    if use_head_wise_attn_gate:
        self.g_proj = ColumnParallelLinear(
            hidden_size,
            self.total_num_heads,
            bias=False,
            prefix=f"{prefix}.g_proj",
        )

    self.use_rope = True
    if use_rope_layers:
        self.use_rope = use_rope_layers[self.layer_idx]

    self.attn = Attention(
        self.num_heads,
        self.head_dim,
        self.scaling,
        num_kv_heads=self.num_kv_heads,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.attn",
        per_layer_sliding_window=sliding_window,
        attn_type=attn_type,
    )

    self.max_position_embeddings = max_position
    assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5
    self.rotary_dim = (
        self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2
    )

forward

forward(positions: Tensor, hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/step3p5.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    # Add qk-norm inline similar to Qwen3 MOE attention
    q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
    q_by_head = self.q_norm(q_by_head.contiguous())
    q = q_by_head.view(q.shape)

    k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
    k_by_head = self.k_norm(k_by_head.contiguous())
    k = k_by_head.view(k.shape)
    if self.use_rope:
        q, k = self.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v)
    if self.use_head_wise_attn_gate:
        extra_dims, _ = self.g_proj(hidden_states)
        output = (
            attn_output.view(*attn_output.shape[:-1], self.num_heads, self.head_dim)
            * extra_dims.unsqueeze(-1).sigmoid()
        )
        attn_output = output.view(*attn_output.shape)
    output, _ = self.o_proj(attn_output)
    return output

Step3p5DecoderLayer

Bases: Module

Source code in vllm/model_executor/models/step3p5.py
class Step3p5DecoderLayer(nn.Module):
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.hidden_size = config.hidden_size
        layer_idx = extract_layer_index(prefix)
        self.layer_idx = layer_idx
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        if cache_config is not None:
            cache_config.sliding_window = None
        if config.att_impl_type == "GQA":
            num_attention_heads = None
            num_attention_groups = None
            head_dim = None
            if (
                getattr(config, "attention_other_setting", None)
                and getattr(config, "layer_types", [])
                and config.layer_types[layer_idx]
                == config.attention_other_setting["attention_type"]
            ):
                num_attention_heads = config.attention_other_setting[
                    "num_attention_heads"
                ]
                num_attention_groups = config.attention_other_setting[
                    "num_attention_groups"
                ]
                head_dim = config.attention_other_setting["head_dim"]
            partial_rotary_factors = getattr(config, "partial_rotary_factors", [])
            self.self_attn = Step3p5Attention(
                hidden_size=self.hidden_size,
                num_heads=num_attention_heads
                if num_attention_heads
                else config.num_attention_heads,
                max_position=config.max_position_embeddings,
                num_kv_heads=num_attention_groups
                if num_attention_groups
                else config.num_attention_groups,
                rope_theta=config.rope_theta,
                rms_norm_eps=config.rms_norm_eps,
                qkv_bias=getattr(config, "attention_bias", False),
                head_dim=head_dim if head_dim else getattr(config, "head_dim", None),
                cache_config=cache_config,
                quant_config=quant_config,
                rope_scaling=getattr(config, "rope_scaling", None),
                sliding_window=getattr(config, "sliding_window", None),
                use_head_wise_attn_gate=getattr(
                    config, "use_head_wise_attn_gate", False
                ),
                layer_types=getattr(config, "layer_types", []),
                use_rope_layers=getattr(config, "use_rope_layers", []),
                yarn_only_types=getattr(config, "yarn_only_types", []),
                partial_rotary_factor=partial_rotary_factors[layer_idx]
                if partial_rotary_factors
                else 1.0,
                prefix=f"{prefix}.self_attn",
            )
        else:
            raise ValueError(
                f"Unsupported attention implementation: {config.att_impl_type}"
            )
        self.use_moe = False
        self.tp_group = get_tp_group()
        self.use_fused_all_reduce = (
            get_tensor_model_parallel_world_size() > 1
            and get_dp_group().world_size == 1
        )
        if self.use_fused_all_reduce:
            logger.warning_once("Enable custom fused all reduce...")
        else:
            logger.warning_once("Disable custom fused all reduce...")

        moe_layers_enum = getattr(config, "moe_layers_enum", None)
        if moe_layers_enum is not None:
            moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
        else:
            moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
        if layer_idx in moe_layers_idx:
            self.moe = FusedMoEBlock(
                vllm_config,
                prefix=f"{prefix}.moe",
            )
            self.use_moe = True
        else:
            self.mlp = Step3p5MLP(
                config=config,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act="silu",
                quant_config=quant_config,
                reduce_results=True,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
        self.post_attention_layernorm = GemmaRMSNorm(
            config.hidden_size, config.rms_norm_eps
        )
        self.prefix = prefix

    def add_and_maybe_inplace_all_reduce(
        self, in1: torch.Tensor, in2: torch.Tensor
    ) -> torch.Tensor:
        if not self.use_fused_all_reduce:
            return in1 + in2
        return self.tp_group.all_reduce(in1 + in2)

    def forward(
        self, positions: torch.Tensor, hidden_states: torch.Tensor
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states += residual
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)

        if self.use_moe:
            ffn_output = self.moe(hidden_states)
        else:
            ffn_output = self.mlp(hidden_states)
        hidden_states = ffn_output + residual
        return hidden_states

hidden_size instance-attribute

hidden_size = hidden_size

input_layernorm instance-attribute

input_layernorm = GemmaRMSNorm(hidden_size, rms_norm_eps)

layer_idx instance-attribute

layer_idx = layer_idx

mlp instance-attribute

mlp = Step3p5MLP(
    config=config,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    hidden_act="silu",
    quant_config=quant_config,
    reduce_results=True,
    prefix=f"{prefix}.mlp",
)

moe instance-attribute

moe = FusedMoEBlock(vllm_config, prefix=f'{prefix}.moe')

post_attention_layernorm instance-attribute

post_attention_layernorm = GemmaRMSNorm(
    hidden_size, rms_norm_eps
)

prefix instance-attribute

prefix = prefix

self_attn instance-attribute

self_attn = Step3p5Attention(
    hidden_size=hidden_size,
    num_heads=num_attention_heads
    if num_attention_heads
    else num_attention_heads,
    max_position=max_position_embeddings,
    num_kv_heads=num_attention_groups
    if num_attention_groups
    else num_attention_groups,
    rope_theta=rope_theta,
    rms_norm_eps=rms_norm_eps,
    qkv_bias=getattr(config, "attention_bias", False),
    head_dim=head_dim
    if head_dim
    else getattr(config, "head_dim", None),
    cache_config=cache_config,
    quant_config=quant_config,
    rope_scaling=getattr(config, "rope_scaling", None),
    sliding_window=getattr(config, "sliding_window", None),
    use_head_wise_attn_gate=getattr(
        config, "use_head_wise_attn_gate", False
    ),
    layer_types=getattr(config, "layer_types", []),
    use_rope_layers=getattr(config, "use_rope_layers", []),
    yarn_only_types=getattr(config, "yarn_only_types", []),
    partial_rotary_factor=partial_rotary_factors[layer_idx]
    if partial_rotary_factors
    else 1.0,
    prefix=f"{prefix}.self_attn",
)

tp_group instance-attribute

tp_group = get_tp_group()

use_fused_all_reduce instance-attribute

use_fused_all_reduce = (
    get_tensor_model_parallel_world_size() > 1
    and world_size == 1
)

use_moe instance-attribute

use_moe = False

__init__

__init__(vllm_config: VllmConfig, prefix: str = '') -> None
Source code in vllm/model_executor/models/step3p5.py
def __init__(
    self,
    vllm_config: VllmConfig,
    prefix: str = "",
) -> None:
    super().__init__()
    config = vllm_config.model_config.hf_config
    self.hidden_size = config.hidden_size
    layer_idx = extract_layer_index(prefix)
    self.layer_idx = layer_idx
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config
    if cache_config is not None:
        cache_config.sliding_window = None
    if config.att_impl_type == "GQA":
        num_attention_heads = None
        num_attention_groups = None
        head_dim = None
        if (
            getattr(config, "attention_other_setting", None)
            and getattr(config, "layer_types", [])
            and config.layer_types[layer_idx]
            == config.attention_other_setting["attention_type"]
        ):
            num_attention_heads = config.attention_other_setting[
                "num_attention_heads"
            ]
            num_attention_groups = config.attention_other_setting[
                "num_attention_groups"
            ]
            head_dim = config.attention_other_setting["head_dim"]
        partial_rotary_factors = getattr(config, "partial_rotary_factors", [])
        self.self_attn = Step3p5Attention(
            hidden_size=self.hidden_size,
            num_heads=num_attention_heads
            if num_attention_heads
            else config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=num_attention_groups
            if num_attention_groups
            else config.num_attention_groups,
            rope_theta=config.rope_theta,
            rms_norm_eps=config.rms_norm_eps,
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=head_dim if head_dim else getattr(config, "head_dim", None),
            cache_config=cache_config,
            quant_config=quant_config,
            rope_scaling=getattr(config, "rope_scaling", None),
            sliding_window=getattr(config, "sliding_window", None),
            use_head_wise_attn_gate=getattr(
                config, "use_head_wise_attn_gate", False
            ),
            layer_types=getattr(config, "layer_types", []),
            use_rope_layers=getattr(config, "use_rope_layers", []),
            yarn_only_types=getattr(config, "yarn_only_types", []),
            partial_rotary_factor=partial_rotary_factors[layer_idx]
            if partial_rotary_factors
            else 1.0,
            prefix=f"{prefix}.self_attn",
        )
    else:
        raise ValueError(
            f"Unsupported attention implementation: {config.att_impl_type}"
        )
    self.use_moe = False
    self.tp_group = get_tp_group()
    self.use_fused_all_reduce = (
        get_tensor_model_parallel_world_size() > 1
        and get_dp_group().world_size == 1
    )
    if self.use_fused_all_reduce:
        logger.warning_once("Enable custom fused all reduce...")
    else:
        logger.warning_once("Disable custom fused all reduce...")

    moe_layers_enum = getattr(config, "moe_layers_enum", None)
    if moe_layers_enum is not None:
        moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
    else:
        moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
    if layer_idx in moe_layers_idx:
        self.moe = FusedMoEBlock(
            vllm_config,
            prefix=f"{prefix}.moe",
        )
        self.use_moe = True
    else:
        self.mlp = Step3p5MLP(
            config=config,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act="silu",
            quant_config=quant_config,
            reduce_results=True,
            prefix=f"{prefix}.mlp",
        )
    self.input_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
    self.post_attention_layernorm = GemmaRMSNorm(
        config.hidden_size, config.rms_norm_eps
    )
    self.prefix = prefix

add_and_maybe_inplace_all_reduce

add_and_maybe_inplace_all_reduce(
    in1: Tensor, in2: Tensor
) -> Tensor
Source code in vllm/model_executor/models/step3p5.py
def add_and_maybe_inplace_all_reduce(
    self, in1: torch.Tensor, in2: torch.Tensor
) -> torch.Tensor:
    if not self.use_fused_all_reduce:
        return in1 + in2
    return self.tp_group.all_reduce(in1 + in2)

forward

forward(positions: Tensor, hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/step3p5.py
def forward(
    self, positions: torch.Tensor, hidden_states: torch.Tensor
) -> torch.Tensor:
    residual = hidden_states
    hidden_states = self.input_layernorm(hidden_states)

    hidden_states = self.self_attn(
        positions=positions,
        hidden_states=hidden_states,
    )
    hidden_states += residual
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)

    if self.use_moe:
        ffn_output = self.moe(hidden_states)
    else:
        ffn_output = self.mlp(hidden_states)
    hidden_states = ffn_output + residual
    return hidden_states

Step3p5ForCausalLM

Bases: Module, SupportsPP, MixtureOfExperts

Source code in vllm/model_executor/models/step3p5.py
class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".share_expert.": ".moe.share_expert."}
    )

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        lora_config = vllm_config.lora_config
        self.config = config
        self.vllm_config = vllm_config

        self.model = Step3p5Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )

        self.moe_layers: list[FusedMoEBlock] = []
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue
            assert isinstance(layer, Step3p5DecoderLayer)
            if hasattr(layer, "moe") and isinstance(layer.moe, FusedMoEBlock):
                self.moe_layers.append(layer.moe)

        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                if not lora_config
                else lora_config.lora_vocab_padding_size,
            )
            self.logits_processor = LogitsProcessor(
                self.unpadded_vocab_size, config.vocab_size
            )
        else:
            self.lm_head = PPMissingLayer()

        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors
        )

        # Set MoE hyperparameters
        self.expert_weights = []
        assert len(self.moe_layers) > 0, "No MoE layers found in the model."
        example_layer = self.moe_layers[0]
        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ):
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.model.norm(hidden_states)
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_tokens(input_ids)

    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ) -> None:
        for layer_idx, layer in enumerate(self.moe_layers):
            experts = layer.experts
            assert isinstance(experts, FusedMoE)
            # Register the expert weights.
            self.expert_weights.append(experts.get_expert_weights())
            experts.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.moe_layers:
            assert isinstance(layer, FusedMoEBlock)
            layer.n_local_physical_experts = num_local_physical_experts
            layer.n_physical_experts = num_physical_experts
            layer.n_redundant_experts = self.num_redundant_experts
            layer.experts.update_expert_map()

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

config instance-attribute

config = config

expert_weights instance-attribute

expert_weights = []

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_substr={
        ".share_expert.": ".moe.share_expert."
    }
)

lm_head instance-attribute

lm_head = ParallelLMHead(
    unpadded_vocab_size,
    hidden_size,
    org_num_embeddings=vocab_size,
    padding_size=DEFAULT_VOCAB_PADDING_SIZE
    if not lora_config
    else lora_vocab_padding_size,
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    unpadded_vocab_size, vocab_size
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = Step3p5Model(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

moe_layers instance-attribute

moe_layers: list[FusedMoEBlock] = []

num_expert_groups instance-attribute

num_expert_groups = 1

num_local_physical_experts instance-attribute

num_local_physical_experts = n_local_physical_experts

num_logical_experts instance-attribute

num_logical_experts = n_logical_experts

num_moe_layers instance-attribute

num_moe_layers = len(moe_layers)

num_physical_experts instance-attribute

num_physical_experts = n_physical_experts

num_redundant_experts instance-attribute

num_redundant_experts = n_redundant_experts

num_routed_experts instance-attribute

num_routed_experts = n_routed_experts

num_shared_experts instance-attribute

num_shared_experts = 0

unpadded_vocab_size instance-attribute

unpadded_vocab_size = vocab_size

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/step3p5.py
def __init__(
    self,
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
):
    super().__init__()
    config = vllm_config.model_config.hf_config
    lora_config = vllm_config.lora_config
    self.config = config
    self.vllm_config = vllm_config

    self.model = Step3p5Model(
        vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
    )

    self.moe_layers: list[FusedMoEBlock] = []
    for layer in self.model.layers:
        if isinstance(layer, PPMissingLayer):
            continue
        assert isinstance(layer, Step3p5DecoderLayer)
        if hasattr(layer, "moe") and isinstance(layer.moe, FusedMoEBlock):
            self.moe_layers.append(layer.moe)

    if get_pp_group().is_last_rank:
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            if not lora_config
            else lora_config.lora_vocab_padding_size,
        )
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size
        )
    else:
        self.lm_head = PPMissingLayer()

    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors
    )

    # Set MoE hyperparameters
    self.expert_weights = []
    assert len(self.moe_layers) > 0, "No MoE layers found in the model."
    example_layer = self.moe_layers[0]
    self.num_moe_layers = len(self.moe_layers)
    self.num_expert_groups = 1
    self.num_shared_experts = 0
    self.num_logical_experts = example_layer.n_logical_experts
    self.num_physical_experts = example_layer.n_physical_experts
    self.num_local_physical_experts = example_layer.n_local_physical_experts
    self.num_routed_experts = example_layer.n_routed_experts
    self.num_redundant_experts = example_layer.n_redundant_experts

compute_logits

compute_logits(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/step3p5.py
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
    hidden_states = self.model.norm(hidden_states)
    logits = self.logits_processor(self.lm_head, hidden_states)
    return logits

embed_input_ids

embed_input_ids(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/step3p5.py
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.embed_tokens(input_ids)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
)
Source code in vllm/model_executor/models/step3p5.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
):
    hidden_states = self.model(
        input_ids, positions, intermediate_tensors, inputs_embeds
    )
    return hidden_states

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/step3p5.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

set_eplb_state

set_eplb_state(
    expert_load_view: Tensor,
    logical_to_physical_map: Tensor,
    logical_replica_count: Tensor,
) -> None
Source code in vllm/model_executor/models/step3p5.py
def set_eplb_state(
    self,
    expert_load_view: torch.Tensor,
    logical_to_physical_map: torch.Tensor,
    logical_replica_count: torch.Tensor,
) -> None:
    for layer_idx, layer in enumerate(self.moe_layers):
        experts = layer.experts
        assert isinstance(experts, FusedMoE)
        # Register the expert weights.
        self.expert_weights.append(experts.get_expert_weights())
        experts.set_eplb_state(
            moe_layer_idx=layer_idx,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
        )

update_physical_experts_metadata

update_physical_experts_metadata(
    num_physical_experts: int,
    num_local_physical_experts: int,
) -> None
Source code in vllm/model_executor/models/step3p5.py
def update_physical_experts_metadata(
    self,
    num_physical_experts: int,
    num_local_physical_experts: int,
) -> None:
    assert self.num_local_physical_experts == num_local_physical_experts
    self.num_physical_experts = num_physical_experts
    self.num_local_physical_experts = num_local_physical_experts
    self.num_redundant_experts = num_physical_experts - self.num_logical_experts
    for layer in self.moe_layers:
        assert isinstance(layer, FusedMoEBlock)
        layer.n_local_physical_experts = num_local_physical_experts
        layer.n_physical_experts = num_physical_experts
        layer.n_redundant_experts = self.num_redundant_experts
        layer.experts.update_expert_map()

Step3p5MLP

Bases: Module

Source code in vllm/model_executor/models/step3p5.py
class Step3p5MLP(nn.Module):
    def __init__(
        self,
        config: ModelConfig,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: QuantizationConfig | None = None,
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )

        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()
        self.prefix = prefix
        self.hidden_size = hidden_size
        self.limit = None
        layer_idx = extract_layer_index(prefix)
        if (
            config.swiglu_limits_shared
            and config.swiglu_limits_shared[layer_idx] is not None
            and config.swiglu_limits_shared[layer_idx] != 0
        ):
            self.limit = config.swiglu_limits_shared[layer_idx]
            self.act_fn = SwigluStepAndMul(limit=self.limit)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(hidden_states)
        intermediate_act = self.act_fn(gate_up)
        output, _ = self.down_proj(intermediate_act)
        return output

act_fn instance-attribute

act_fn = SiluAndMul()

down_proj instance-attribute

down_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    reduce_results=reduce_results,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    hidden_size,
    [intermediate_size] * 2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

hidden_size instance-attribute

hidden_size = hidden_size

limit instance-attribute

limit = None

prefix instance-attribute

prefix = prefix

__init__

__init__(
    config: ModelConfig,
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: QuantizationConfig | None = None,
    reduce_results: bool = True,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/step3p5.py
def __init__(
    self,
    config: ModelConfig,
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: QuantizationConfig | None = None,
    reduce_results: bool = True,
    prefix: str = "",
) -> None:
    super().__init__()
    self.gate_up_proj = MergedColumnParallelLinear(
        hidden_size,
        [intermediate_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj",
    )
    self.down_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        reduce_results=reduce_results,
        prefix=f"{prefix}.down_proj",
    )

    if hidden_act != "silu":
        raise ValueError(
            f"Unsupported activation: {hidden_act}. Only silu is supported for now."
        )
    self.act_fn = SiluAndMul()
    self.prefix = prefix
    self.hidden_size = hidden_size
    self.limit = None
    layer_idx = extract_layer_index(prefix)
    if (
        config.swiglu_limits_shared
        and config.swiglu_limits_shared[layer_idx] is not None
        and config.swiglu_limits_shared[layer_idx] != 0
    ):
        self.limit = config.swiglu_limits_shared[layer_idx]
        self.act_fn = SwigluStepAndMul(limit=self.limit)

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/step3p5.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    gate_up, _ = self.gate_up_proj(hidden_states)
    intermediate_act = self.act_fn(gate_up)
    output, _ = self.down_proj(intermediate_act)
    return output

Step3p5Model

Bases: Module

Source code in vllm/model_executor/models/step3p5.py
@support_torch_compile
class Step3p5Model(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        self.vllm_config = vllm_config
        config = vllm_config.model_config.hf_config
        self.vocab_size = config.vocab_size
        self.config = config

        self.moe_num_experts = config.moe_num_experts

        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
            )
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Step3p5DecoderLayer(
                vllm_config,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
        if get_pp_group().is_last_rank:
            self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.embed_input_ids(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            hidden_states = layer(positions, hidden_states)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                }
            )

        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        config = self.config
        assert config.num_attention_groups > 1, "Only support GQA"
        qkv_params_mapping = []
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        expert_params_mapping = [
            (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
            (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
            (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"),
        ]

        disable_moe_stacked_params = [data[1] for data in expert_params_mapping]

        for name, loaded_weight in weights:
            if name.startswith("model."):
                local_name = name[len("model.") :]
                full_name = name
            else:
                local_name = name
                full_name = f"model.{name}" if name else "model"

            spec_layer = get_spec_layer_idx_from_weight_name(config, full_name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model

            # Skip any layers beyond the main model's depth (e.g., MTP layers)
            if full_name.startswith("model.layers."):
                parts = full_name.split(".")
                if len(parts) > 2 and parts[2].isdigit():
                    layer_idx = int(parts[2])
                    if layer_idx >= config.num_hidden_layers:
                        continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in local_name:
                    continue
                if any(
                    disable_moe_stacked_param in local_name
                    for disable_moe_stacked_param in disable_moe_stacked_params
                ):
                    continue
                replaced_name = local_name.replace(weight_name, param_name)
                if is_pp_missing_parameter(replaced_name, self):
                    continue
                if replaced_name not in params_dict:
                    continue
                param = params_dict[replaced_name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                loaded_params.add(replaced_name)
                break
            else:
                for param_name, weight_name, shard_id in expert_params_mapping:
                    if weight_name not in local_name:
                        continue
                    replaced_name = local_name.replace(weight_name, param_name)
                    if is_pp_missing_parameter(replaced_name, self):
                        continue
                    if (
                        replaced_name.endswith(".bias")
                        or replaced_name.endswith("_bias")
                    ) and replaced_name not in params_dict:
                        continue
                    if replaced_name not in params_dict:
                        continue
                    param = params_dict[replaced_name]
                    weight_loader = param.weight_loader
                    moe_expert_num = self.moe_num_experts
                    assert loaded_weight.shape[0] == moe_expert_num
                    for expert_id in range(moe_expert_num):
                        loaded_weight_expert = loaded_weight[expert_id]
                        weight_loader(
                            param,
                            loaded_weight_expert,
                            replaced_name,
                            shard_id=shard_id,
                            expert_id=expert_id,
                        )
                    loaded_params.add(replaced_name)
                    break
                else:
                    for (
                        param_name,
                        weight_name,
                        start_idx,
                        end_idx,
                    ) in qkv_params_mapping:
                        if weight_name not in local_name:
                            continue
                        replaced_name = local_name.replace(weight_name, param_name)
                        if is_pp_missing_parameter(replaced_name, self):
                            continue
                        if replaced_name not in params_dict:
                            continue
                        param = params_dict[replaced_name]
                        dim = param.shape[param.output_dim]
                        begin_idx = int(start_idx * dim)
                        end_idx = int(end_idx * dim)
                        param_slice = param.narrow(
                            param.output_dim, begin_idx, end_idx - begin_idx
                        )
                        param_slice.copy_(loaded_weight)
                        loaded_params.add(replaced_name)
                        break
                    else:
                        if is_pp_missing_parameter(local_name, self):
                            continue
                        if "expert_bias" in local_name:
                            logger.warning_once("ignore expert_bias")
                            continue
                        if local_name not in params_dict:
                            continue
                        param = params_dict[local_name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
                        loaded_params.add(local_name)
        return loaded_params

config instance-attribute

config = config

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, hidden_size
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states"], hidden_size
    )
)

moe_num_experts instance-attribute

moe_num_experts = moe_num_experts

norm instance-attribute

norm = GemmaRMSNorm(hidden_size, rms_norm_eps)

vllm_config instance-attribute

vllm_config = vllm_config

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(vllm_config: VllmConfig, prefix: str = '') -> None
Source code in vllm/model_executor/models/step3p5.py
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
    super().__init__()

    self.vllm_config = vllm_config
    config = vllm_config.model_config.hf_config
    self.vocab_size = config.vocab_size
    self.config = config

    self.moe_num_experts = config.moe_num_experts

    if get_pp_group().is_first_rank or (
        config.tie_word_embeddings and get_pp_group().is_last_rank
    ):
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
    else:
        self.embed_tokens = PPMissingLayer()

    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: Step3p5DecoderLayer(
            vllm_config,
            prefix=prefix,
        ),
        prefix=f"{prefix}.layers",
    )
    if get_pp_group().is_last_rank:
        self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
    else:
        self.norm = PPMissingLayer()

    self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
        ["hidden_states"], config.hidden_size
    )

embed_input_ids

embed_input_ids(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/step3p5.py
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/models/step3p5.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.embed_input_ids(input_ids)
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
    for i in range(self.start_layer, self.end_layer):
        layer = self.layers[i]
        hidden_states = layer(positions, hidden_states)

    if not get_pp_group().is_last_rank:
        return IntermediateTensors(
            {
                "hidden_states": hidden_states,
            }
        )

    return hidden_states

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/step3p5.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    config = self.config
    assert config.num_attention_groups > 1, "Only support GQA"
    qkv_params_mapping = []
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
        ("gate_up_proj", "gate_proj", 0),
        ("gate_up_proj", "up_proj", 1),
    ]

    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    expert_params_mapping = [
        (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
        (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
        (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"),
    ]

    disable_moe_stacked_params = [data[1] for data in expert_params_mapping]

    for name, loaded_weight in weights:
        if name.startswith("model."):
            local_name = name[len("model.") :]
            full_name = name
        else:
            local_name = name
            full_name = f"model.{name}" if name else "model"

        spec_layer = get_spec_layer_idx_from_weight_name(config, full_name)
        if spec_layer is not None:
            continue  # skip spec decode layers for main model

        # Skip any layers beyond the main model's depth (e.g., MTP layers)
        if full_name.startswith("model.layers."):
            parts = full_name.split(".")
            if len(parts) > 2 and parts[2].isdigit():
                layer_idx = int(parts[2])
                if layer_idx >= config.num_hidden_layers:
                    continue

        for param_name, weight_name, shard_id in stacked_params_mapping:
            if weight_name not in local_name:
                continue
            if any(
                disable_moe_stacked_param in local_name
                for disable_moe_stacked_param in disable_moe_stacked_params
            ):
                continue
            replaced_name = local_name.replace(weight_name, param_name)
            if is_pp_missing_parameter(replaced_name, self):
                continue
            if replaced_name not in params_dict:
                continue
            param = params_dict[replaced_name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            loaded_params.add(replaced_name)
            break
        else:
            for param_name, weight_name, shard_id in expert_params_mapping:
                if weight_name not in local_name:
                    continue
                replaced_name = local_name.replace(weight_name, param_name)
                if is_pp_missing_parameter(replaced_name, self):
                    continue
                if (
                    replaced_name.endswith(".bias")
                    or replaced_name.endswith("_bias")
                ) and replaced_name not in params_dict:
                    continue
                if replaced_name not in params_dict:
                    continue
                param = params_dict[replaced_name]
                weight_loader = param.weight_loader
                moe_expert_num = self.moe_num_experts
                assert loaded_weight.shape[0] == moe_expert_num
                for expert_id in range(moe_expert_num):
                    loaded_weight_expert = loaded_weight[expert_id]
                    weight_loader(
                        param,
                        loaded_weight_expert,
                        replaced_name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                loaded_params.add(replaced_name)
                break
            else:
                for (
                    param_name,
                    weight_name,
                    start_idx,
                    end_idx,
                ) in qkv_params_mapping:
                    if weight_name not in local_name:
                        continue
                    replaced_name = local_name.replace(weight_name, param_name)
                    if is_pp_missing_parameter(replaced_name, self):
                        continue
                    if replaced_name not in params_dict:
                        continue
                    param = params_dict[replaced_name]
                    dim = param.shape[param.output_dim]
                    begin_idx = int(start_idx * dim)
                    end_idx = int(end_idx * dim)
                    param_slice = param.narrow(
                        param.output_dim, begin_idx, end_idx - begin_idx
                    )
                    param_slice.copy_(loaded_weight)
                    loaded_params.add(replaced_name)
                    break
                else:
                    if is_pp_missing_parameter(local_name, self):
                        continue
                    if "expert_bias" in local_name:
                        logger.warning_once("ignore expert_bias")
                        continue
                    if local_name not in params_dict:
                        continue
                    param = params_dict[local_name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
                    loaded_params.add(local_name)
    return loaded_params

get_spec_layer_idx_from_weight_name

get_spec_layer_idx_from_weight_name(
    config: ModelConfig, weight_name: str
) -> int | None
Source code in vllm/model_executor/models/step3p5.py
def get_spec_layer_idx_from_weight_name(
    config: ModelConfig, weight_name: str
) -> int | None:
    if hasattr(config, "num_nextn_predict_layers") and (
        config.num_nextn_predict_layers > 0
    ):
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
            if weight_name.startswith(
                f"layers.{layer_idx + i}."  # Step3p5Model
            ) or weight_name.startswith(f"model.layers.{layer_idx + i}."):  # Step3p5MTP
                return layer_idx + i
    return None