Skip to content

vllm.model_executor.layers.attention

Modules:

Name Description
attention
chunked_local_attention
cross_attention
encoder_only_attention
kv_transfer_utils
mla_attention

MLA Common Components

mm_encoder_attention
static_sink_attention

__all__ module-attribute

__all__ = [
    "Attention",
    "ChunkedLocalAttention",
    "CrossAttention",
    "EncoderOnlyAttention",
    "MLAAttention",
    "MMEncoderAttention",
    "StaticSinkAttention",
]

Attention

Bases: Module, AttentionLayerBase

Attention layer.

This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. The class does the following:

  1. Store the input key and value tensors in the KV cache.
  2. Perform (multi-head/multi-query/grouped-query) attention.
  3. Return the output tensor.
Source code in vllm/model_executor/layers/attention/attention.py
class Attention(nn.Module, AttentionLayerBase):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        use_alibi_sqrt: bool | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
        head_size_v: int | None = None,
        **extra_impl_args,
    ) -> None:
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
        super().__init__()
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

        vllm_config = get_current_vllm_config()
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False

        # llm-compressor mdls need to set cache_dtype to "fp8" manually.
        if getattr(quant_config, "kv_cache_scheme", None) is not None:
            kv_cache_dtype = "fp8"
            calculate_kv_scales = False
            if cache_config is not None:
                cache_config.cache_dtype = "fp8"
                cache_config.calculate_kv_scales = False

        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
        self.kv_cache_dtype = kv_cache_dtype
        self.calculate_kv_scales = calculate_kv_scales
        if num_kv_heads is None:
            num_kv_heads = num_heads
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
        self.quant_config = quant_config
        self.layer_name = prefix

        self.num_heads = num_heads
        self.head_size = head_size
        self.head_size_v = self.head_size if head_size_v is None else head_size_v
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
        self.has_sink = extra_impl_args.get("sinks") is not None

        # NOTE: model_config may be None during certain tests
        model_config = vllm_config.model_config
        self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
        if attn_backend is None:
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
                use_mla=False,
                has_sink=self.has_sink,
                use_mm_prefix=self.use_mm_prefix,
                attn_type=attn_type,
            )
        else:
            self.attn_backend = attn_backend
        backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
        use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
        if use_alibi_sqrt and not backend_supports_alibi_sqrt:
            raise ValueError(
                f"use_alibi_sqrt is not supported by backend "
                f"{self.attn_backend.get_name()}."
            )
        self.use_alibi_sqrt = bool(use_alibi_sqrt)
        if backend_supports_alibi_sqrt:
            extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
        # prefix caching + batch invariance is currently not supported for
        # FLASHINFER and TRITON_MLA.
        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "FLASHINFER"
                or self.attn_backend.get_name() == "TRITON_MLA"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for FLASHINFER/TRITON_MLA "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

        impl_cls = self.attn_backend.get_impl_cls()
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
        self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
        self.dtype = dtype

        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
        self.use_direct_call = not current_platform.opaque_attention_op()

        self.use_output = self.attn_backend.accept_output_buffer
        compilation_config = vllm_config.compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.attn_type = attn_type

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([])
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
        ]

        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(self, quant_config, prefix)

        # for attn backends supporting query quantization
        self.query_quant = None
        if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
            "fp8"
        ):
            is_per_head = (
                hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
            )
            block_size = self.head_size * self.num_heads // self.num_kv_heads
            self.query_quant = QuantFP8(
                static=True,
                group_shape=GroupShape(-1, block_size)
                if is_per_head
                else GroupShape.PER_TENSOR,
            )

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}

            # check if query quantization is supported
            if self.impl.supports_quant_query_input:
                query, _ = self.query_quant(query, self._q_scale)

        if self.use_output:
            if output_shape is None:
                # Handle both 2D [num_tokens, hidden] and
                # 3D [num_tokens, heads, head_dim] query
                num_tokens = query.shape[0]
                output_shape = torch.Size(
                    (num_tokens, self.num_heads * self.head_size_v)
                )
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
            hidden_size = output_shape[-1]
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size_v)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size_v)
            if self.use_direct_call:
                kv_cache_dummy_dep = None
                if not self.attn_backend.forward_includes_kv_cache_update:
                    kv_cache_dummy_dep = unified_kv_cache_update(
                        key, value, self.layer_name
                    )
                unified_attention_with_output(
                    query,
                    key,
                    value,
                    output,
                    self.layer_name,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
            else:
                kv_cache_dummy_dep = None
                if not self.attn_backend.forward_includes_kv_cache_update and (
                    # torch can only dispatch custom op if a tensor is passed
                    key is not None or value is not None
                ):
                    kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
                        key, value, self.layer_name
                    )
                torch.ops.vllm.unified_attention_with_output(
                    query,
                    key,
                    value,
                    output,
                    self.layer_name,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
            return output.view(-1, hidden_size)
        else:
            assert self.attn_backend.forward_includes_kv_cache_update, (
                "Split KV cache update not supported when output tensor not provided."
            )
            if self.use_direct_call:
                return unified_attention(query, key, value, self.layer_name)
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name
                )

    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
        s += f", backend={self.impl.__class__.__name__}"
        return s

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        self.impl.process_weights_after_loading(act_dtype)

        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                head_size_v=self.head_size_v,
                dtype=self.kv_cache_torch_dtype,
            )

attn_backend instance-attribute

attn_backend = get_attn_backend(
    head_size,
    dtype,
    kv_cache_dtype,
    block_size,
    use_mla=False,
    has_sink=has_sink,
    use_mm_prefix=use_mm_prefix,
    attn_type=attn_type,
)

attn_type instance-attribute

attn_type = attn_type

backend instance-attribute

backend = AttentionBackendEnum[get_name()]

calculate_kv_scales instance-attribute

calculate_kv_scales = calculate_kv_scales

dtype instance-attribute

dtype = dtype

has_sink instance-attribute

has_sink = get('sinks') is not None

head_size instance-attribute

head_size = head_size

head_size_v instance-attribute

head_size_v = (
    head_size if head_size_v is None else head_size_v
)

impl instance-attribute

impl = impl_cls(
    num_heads,
    head_size,
    scale,
    num_kv_heads,
    alibi_slopes,
    sliding_window,
    kv_cache_dtype,
    logits_soft_cap,
    attn_type,
    kv_sharing_target_layer_name,
    **extra_impl_args,
)

kv_cache instance-attribute

kv_cache = [
    (tensor([])) for _ in (range(pipeline_parallel_size))
]

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_cache_torch_dtype instance-attribute

kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
    kv_cache_dtype, model_config
)

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

layer_name instance-attribute

layer_name = prefix

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

quant_config instance-attribute

quant_config = quant_config

query_quant instance-attribute

query_quant = None

sliding_window instance-attribute

sliding_window = sliding_window

use_alibi_sqrt instance-attribute

use_alibi_sqrt = bool(use_alibi_sqrt)

use_direct_call instance-attribute

use_direct_call = not opaque_attention_op()

use_mm_prefix instance-attribute

use_mm_prefix = model_config is not None and is_mm_prefix_lm

use_output instance-attribute

use_output = accept_output_buffer

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    use_alibi_sqrt: bool | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    logits_soft_cap: float | None = None,
    per_layer_sliding_window: int | None = None,
    prefix: str = "",
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: str | None = None,
    attn_backend: type[AttentionBackend] | None = None,
    head_size_v: int | None = None,
    **extra_impl_args,
) -> None

The KV cache is stored inside this class and is accessed via self.kv_cache.

Source code in vllm/model_executor/layers/attention/attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    use_alibi_sqrt: bool | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    logits_soft_cap: float | None = None,
    per_layer_sliding_window: int | None = None,
    prefix: str = "",
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: str | None = None,
    attn_backend: type[AttentionBackend] | None = None,
    head_size_v: int | None = None,
    **extra_impl_args,
) -> None:
    """
    The KV cache is stored inside this class and is accessed via
    `self.kv_cache`.
    """
    super().__init__()
    if per_layer_sliding_window is not None:
        # per-layer sliding window
        sliding_window = per_layer_sliding_window
    elif cache_config is not None:
        # model-level sliding window
        sliding_window = cache_config.sliding_window
    else:
        sliding_window = None

    vllm_config = get_current_vllm_config()
    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
        calculate_kv_scales = cache_config.calculate_kv_scales
    else:
        kv_cache_dtype = "auto"
        block_size = 16
        calculate_kv_scales = False

    # llm-compressor mdls need to set cache_dtype to "fp8" manually.
    if getattr(quant_config, "kv_cache_scheme", None) is not None:
        kv_cache_dtype = "fp8"
        calculate_kv_scales = False
        if cache_config is not None:
            cache_config.cache_dtype = "fp8"
            cache_config.calculate_kv_scales = False

    self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
        kv_cache_dtype, vllm_config.model_config
    )
    self.kv_cache_dtype = kv_cache_dtype
    self.calculate_kv_scales = calculate_kv_scales
    if num_kv_heads is None:
        num_kv_heads = num_heads
    assert num_heads % num_kv_heads == 0, (
        f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
    )
    self.quant_config = quant_config
    self.layer_name = prefix

    self.num_heads = num_heads
    self.head_size = head_size
    self.head_size_v = self.head_size if head_size_v is None else head_size_v
    self.num_kv_heads = num_kv_heads
    self.sliding_window = sliding_window
    self.has_sink = extra_impl_args.get("sinks") is not None

    # NOTE: model_config may be None during certain tests
    model_config = vllm_config.model_config
    self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

    # During model initialization, the default dtype is set as the model
    # weight and activation dtype.
    dtype = torch.get_default_dtype()
    if attn_backend is None:
        self.attn_backend = get_attn_backend(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=False,
            has_sink=self.has_sink,
            use_mm_prefix=self.use_mm_prefix,
            attn_type=attn_type,
        )
    else:
        self.attn_backend = attn_backend
    backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
    use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
    if use_alibi_sqrt and not backend_supports_alibi_sqrt:
        raise ValueError(
            f"use_alibi_sqrt is not supported by backend "
            f"{self.attn_backend.get_name()}."
        )
    self.use_alibi_sqrt = bool(use_alibi_sqrt)
    if backend_supports_alibi_sqrt:
        extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
    # prefix caching + batch invariance is currently not supported for
    # FLASHINFER and TRITON_MLA.
    if (
        cache_config is not None
        and cache_config.enable_prefix_caching
        and vllm_is_batch_invariant()
        and (
            self.attn_backend.get_name() == "FLASHINFER"
            or self.attn_backend.get_name() == "TRITON_MLA"
        )
    ):
        logger.warning_once(
            "Disabling prefix caching for FLASHINFER/TRITON_MLA "
            "with batch invariance, as it is not yet supported.",
            scope="local",
        )
        cache_config.enable_prefix_caching = False

    impl_cls = self.attn_backend.get_impl_cls()
    self.impl = impl_cls(
        num_heads,
        head_size,
        scale,
        num_kv_heads,
        alibi_slopes,
        sliding_window,
        kv_cache_dtype,
        logits_soft_cap,
        attn_type,
        kv_sharing_target_layer_name,
        **extra_impl_args,
    )
    self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
    self.dtype = dtype

    # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
    # torch.compile works by registering the attention as one giant
    # opaque custom op. For other platforms, we directly call them
    # and let torch.compile handle them.
    self.use_direct_call = not current_platform.opaque_attention_op()

    self.use_output = self.attn_backend.accept_output_buffer
    compilation_config = vllm_config.compilation_config
    if prefix in compilation_config.static_forward_context:
        raise ValueError(f"Duplicate layer name: {prefix}")
    compilation_config.static_forward_context[prefix] = self
    self.attn_type = attn_type

    if kv_sharing_target_layer_name is not None:
        validate_kv_sharing_target(
            prefix,
            kv_sharing_target_layer_name,
            compilation_config.static_forward_context,
        )
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

    # use a placeholder kv cache tensor during init, which will be replaced
    # by bind_kv_cache
    # this variable will not be accessed if use_direct_call is True
    self.kv_cache = [
        torch.tensor([])
        for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
    ]

    # Initialize KV cache quantization attributes
    _init_kv_cache_quant(self, quant_config, prefix)

    # for attn backends supporting query quantization
    self.query_quant = None
    if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
        "fp8"
    ):
        is_per_head = (
            hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
        )
        block_size = self.head_size * self.num_heads // self.num_kv_heads
        self.query_quant = QuantFP8(
            static=True,
            group_shape=GroupShape(-1, block_size)
            if is_per_head
            else GroupShape.PER_TENSOR,
        )

calc_kv_scales

calc_kv_scales(query, key, value)
Source code in vllm/model_executor/layers/attention/attention.py
def calc_kv_scales(self, query, key, value):
    self._q_scale.copy_(torch.abs(query).max() / self.q_range)
    self._k_scale.copy_(torch.abs(key).max() / self.k_range)
    self._v_scale.copy_(torch.abs(value).max() / self.v_range)
    self._q_scale_float = self._q_scale.item()
    self._k_scale_float = self._k_scale.item()
    self._v_scale_float = self._v_scale.item()
    # We only calculate the scales once
    self.calculate_kv_scales = False

extra_repr

extra_repr() -> str
Source code in vllm/model_executor/layers/attention/attention.py
def extra_repr(self) -> str:
    s = f"head_size={self.impl.head_size}"  # type: ignore
    s += f", num_heads={self.impl.num_heads}"  # type: ignore
    s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
    s += f", scale={self.impl.scale}"  # type: ignore
    s += f", backend={self.impl.__class__.__name__}"
    return s

forward

forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    output_shape: Size | None = None,
) -> Tensor

The KV cache is stored inside this class and is accessed via self.kv_cache.

Attention metadata (attn_metadata) is set using a context manager in the model runner's execute_model method. It is accessed via forward context using vllm.forward_context.get_forward_context().attn_metadata.

Source code in vllm/model_executor/layers/attention/attention.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    # For some alternate attention backends like MLA the attention output
    # shape does not match the query shape, so we optionally let the model
    # definition specify the output tensor shape.
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    """
    The KV cache is stored inside this class and is accessed via
    `self.kv_cache`.

    Attention metadata (`attn_metadata`) is set using a context manager in
    the model runner's `execute_model` method. It is accessed via forward
    context using
    `vllm.forward_context.get_forward_context().attn_metadata`.
    """
    if self.calculate_kv_scales:
        torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
    output_dtype = query.dtype
    if self.query_quant is not None:
        # quantizing with a simple torch operation enables
        # torch.compile to fuse this into previous ops
        # which reduces overheads during decoding.
        # Otherwise queries are quantized using custom ops
        # which causes decoding overheads
        assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}

        # check if query quantization is supported
        if self.impl.supports_quant_query_input:
            query, _ = self.query_quant(query, self._q_scale)

    if self.use_output:
        if output_shape is None:
            # Handle both 2D [num_tokens, hidden] and
            # 3D [num_tokens, heads, head_dim] query
            num_tokens = query.shape[0]
            output_shape = torch.Size(
                (num_tokens, self.num_heads * self.head_size_v)
            )
        output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
        hidden_size = output_shape[-1]
        # Reshape the query, key, and value tensors.
        # NOTE(woosuk): We do this outside the custom op to minimize the
        # CPU overheads from the non-CUDA-graph regions.
        query = query.view(-1, self.num_heads, self.head_size)
        output = output.view(-1, self.num_heads, self.head_size_v)
        if key is not None:
            key = key.view(-1, self.num_kv_heads, self.head_size)
        if value is not None:
            value = value.view(-1, self.num_kv_heads, self.head_size_v)
        if self.use_direct_call:
            kv_cache_dummy_dep = None
            if not self.attn_backend.forward_includes_kv_cache_update:
                kv_cache_dummy_dep = unified_kv_cache_update(
                    key, value, self.layer_name
                )
            unified_attention_with_output(
                query,
                key,
                value,
                output,
                self.layer_name,
                kv_cache_dummy_dep=kv_cache_dummy_dep,
            )
        else:
            kv_cache_dummy_dep = None
            if not self.attn_backend.forward_includes_kv_cache_update and (
                # torch can only dispatch custom op if a tensor is passed
                key is not None or value is not None
            ):
                kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
                    key, value, self.layer_name
                )
            torch.ops.vllm.unified_attention_with_output(
                query,
                key,
                value,
                output,
                self.layer_name,
                kv_cache_dummy_dep=kv_cache_dummy_dep,
            )
        return output.view(-1, hidden_size)
    else:
        assert self.attn_backend.forward_includes_kv_cache_update, (
            "Split KV cache update not supported when output tensor not provided."
        )
        if self.use_direct_call:
            return unified_attention(query, key, value, self.layer_name)
        else:
            return torch.ops.vllm.unified_attention(
                query, key, value, self.layer_name
            )

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/model_executor/layers/attention/attention.py
def get_attn_backend(self) -> type[AttentionBackend]:
    return self.attn_backend

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    # Block size may get updated after model loading, refresh it
    block_size = vllm_config.cache_config.block_size
    # Should not be called for enc-dec or encoder-only attention.
    assert self.attn_type == AttentionType.DECODER
    if self.sliding_window is not None:
        assert not vllm_config.model_config.use_mla, (
            "MLA is not supported for slidingwindow"
        )
        return SlidingWindowSpec(
            block_size=block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            dtype=self.kv_cache_torch_dtype,
            sliding_window=self.sliding_window,
        )
    else:
        return FullAttentionSpec(
            block_size=block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            head_size_v=self.head_size_v,
            dtype=self.kv_cache_torch_dtype,
        )

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm/model_executor/layers/attention/attention.py
def process_weights_after_loading(self, act_dtype: torch.dtype):
    self.impl.process_weights_after_loading(act_dtype)

    # If we should not load quant weights, we initialize the scales to 1.0
    # as the default value. See [Note: Register q/k/v/prob scales in state dict]
    # for more details.
    quant_method = (
        self.quant_config.get_quant_method(self, prefix=self.layer_name)
        if self.quant_config
        else None
    )
    if not should_load_quant_weights(quant_method):
        set_default_quant_scales(self, register_buffer=False)

ChunkedLocalAttention

Bases: Attention

Source code in vllm/model_executor/layers/attention/chunked_local_attention.py
class ChunkedLocalAttention(Attention):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        attention_chunk_size: int,
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        kv_sharing_target_layer_name: str | None = None,
        prefix: str = "",
    ):
        self.attention_chunk_size = attention_chunk_size
        dtype = torch.get_default_dtype()
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        underlying_attn_backend = get_attn_backend(
            head_size, dtype, kv_cache_dtype, block_size
        )
        attn_backend = create_chunked_local_attention_backend(
            underlying_attn_backend, attention_chunk_size, block_size
        )

        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
            alibi_slopes=alibi_slopes,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
            kv_sharing_target_layer_name=kv_sharing_target_layer_name,
            attn_backend=attn_backend,
        )

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        assert self.attention_chunk_size
        return ChunkedLocalAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            dtype=self.kv_cache_torch_dtype,
            attention_chunk_size=self.attention_chunk_size,
        )

attention_chunk_size instance-attribute

attention_chunk_size = attention_chunk_size

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    attention_chunk_size: int,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    kv_sharing_target_layer_name: str | None = None,
    prefix: str = "",
)
Source code in vllm/model_executor/layers/attention/chunked_local_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    attention_chunk_size: int,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    kv_sharing_target_layer_name: str | None = None,
    prefix: str = "",
):
    self.attention_chunk_size = attention_chunk_size
    dtype = torch.get_default_dtype()
    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    underlying_attn_backend = get_attn_backend(
        head_size, dtype, kv_cache_dtype, block_size
    )
    attn_backend = create_chunked_local_attention_backend(
        underlying_attn_backend, attention_chunk_size, block_size
    )

    super().__init__(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        num_kv_heads=num_kv_heads,
        alibi_slopes=alibi_slopes,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=prefix,
        kv_sharing_target_layer_name=kv_sharing_target_layer_name,
        attn_backend=attn_backend,
    )

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/chunked_local_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    assert self.attention_chunk_size
    return ChunkedLocalAttentionSpec(
        block_size=vllm_config.cache_config.block_size,
        num_kv_heads=self.num_kv_heads,
        head_size=self.head_size,
        dtype=self.kv_cache_torch_dtype,
        attention_chunk_size=self.attention_chunk_size,
    )

CrossAttention

Bases: Attention

Cross-attention for encoder-decoder models. Handles attention between decoder queries and encoder keys/values.

Source code in vllm/model_executor/layers/attention/cross_attention.py
class CrossAttention(Attention):
    """
    Cross-attention for encoder-decoder models.
    Handles attention between decoder queries and encoder keys/values.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        cache_config: CacheConfig | None = None,
        attn_type: str | None = None,
        **kwargs,
    ):
        dtype = torch.get_default_dtype()

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        if attn_type is not None:
            assert attn_type == AttentionType.ENCODER_DECODER, (
                "CrossAttention only supports AttentionType.ENCODER_DECODER"
            )

        underlying_attn_backend = get_attn_backend(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            attn_type=AttentionType.ENCODER_DECODER,
        )
        attn_backend = create_cross_attention_backend(underlying_attn_backend)

        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            cache_config=cache_config,
            attn_backend=attn_backend,
            attn_type=AttentionType.ENCODER_DECODER,
            **kwargs,
        )

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        return CrossAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            dtype=self.kv_cache_torch_dtype,
        )

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: CacheConfig | None = None,
    attn_type: str | None = None,
    **kwargs,
)
Source code in vllm/model_executor/layers/attention/cross_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: CacheConfig | None = None,
    attn_type: str | None = None,
    **kwargs,
):
    dtype = torch.get_default_dtype()

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    if attn_type is not None:
        assert attn_type == AttentionType.ENCODER_DECODER, (
            "CrossAttention only supports AttentionType.ENCODER_DECODER"
        )

    underlying_attn_backend = get_attn_backend(
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        attn_type=AttentionType.ENCODER_DECODER,
    )
    attn_backend = create_cross_attention_backend(underlying_attn_backend)

    super().__init__(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        cache_config=cache_config,
        attn_backend=attn_backend,
        attn_type=AttentionType.ENCODER_DECODER,
        **kwargs,
    )

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/cross_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    return CrossAttentionSpec(
        block_size=vllm_config.cache_config.block_size,
        num_kv_heads=self.num_kv_heads,
        head_size=self.head_size,
        dtype=self.kv_cache_torch_dtype,
    )

EncoderOnlyAttention

Bases: Attention

Encoder attention is a special case that doesn't need a KV Cache.

Source code in vllm/model_executor/layers/attention/encoder_only_attention.py
class EncoderOnlyAttention(Attention):
    """
    Encoder attention is a special case that doesn't need a KV Cache.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        cache_config: CacheConfig | None = None,
        attn_type: str | None = None,
        **kwargs,
    ):
        dtype = torch.get_default_dtype()

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        underlying_attn_backend = get_attn_backend(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            attn_type=AttentionType.ENCODER_ONLY,
        )

        attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)

        if attn_type is not None:
            assert attn_type == AttentionType.ENCODER_ONLY, (
                "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
            )

        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            cache_config=cache_config,
            attn_backend=attn_backend,
            attn_type=AttentionType.ENCODER_ONLY,
            **kwargs,
        )

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Does not need KV cache
        return None

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: CacheConfig | None = None,
    attn_type: str | None = None,
    **kwargs,
)
Source code in vllm/model_executor/layers/attention/encoder_only_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: CacheConfig | None = None,
    attn_type: str | None = None,
    **kwargs,
):
    dtype = torch.get_default_dtype()

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    underlying_attn_backend = get_attn_backend(
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        attn_type=AttentionType.ENCODER_ONLY,
    )

    attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)

    if attn_type is not None:
        assert attn_type == AttentionType.ENCODER_ONLY, (
            "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
        )

    super().__init__(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        cache_config=cache_config,
        attn_backend=attn_backend,
        attn_type=AttentionType.ENCODER_ONLY,
        **kwargs,
    )

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/encoder_only_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    # Does not need KV cache
    return None

MLAAttention

Bases: Module, AttentionLayerBase

Multi-Head Latent Attention layer.

NOTE: Please read the comment at the top of the file before trying to understand this class

This class takes query, and compressed key/value tensors as input. The class does the following:

  1. Store the input key and value tensors in the KV cache.
  2. Perform (multi-head/multi-query/grouped-query) attention.
  3. Return the output tensor.
Source code in vllm/model_executor/layers/attention/mla_attention.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    NOTE: Please read the comment at the top of the file before trying to
    understand this class

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int | None,
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        use_sparse: bool = False,
        indexer: object | None = None,
        q_pad_num_heads: int | None = None,
        **extra_impl_args,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.kv_b_proj = kv_b_proj
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix
        self.indexer = indexer
        self.q_pad_num_heads = q_pad_num_heads

        self.num_kv_heads = 1
        self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
        self.quant_config = quant_config

        # Initialize KV cache quantization attributes
        self.kv_cache_dtype = kv_cache_dtype
        self.calculate_kv_scales = calculate_kv_scales
        _init_kv_cache_quant(self, quant_config, prefix)

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )

        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "TRITON_MLA"
                or self.attn_backend.get_name() == "FLASHINFER"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for TRITON_MLA / FLASHINFER "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
            q_pad_num_heads=q_pad_num_heads,
            **extra_impl_args,
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

        self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()

        # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
        self.is_aiter_triton_fp4_bmm_enabled = (
            rocm_aiter_ops.is_fp4bmm_enabled()
            and self.kv_b_proj.weight.dtype == torch.bfloat16
        )

        # Attributes for forward_impl method
        self.chunked_prefill_workspace_size = (
            MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
                get_current_vllm_config()
            )
        )
        self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
            static=True,
            group_shape=GroupShape.PER_TENSOR,
            compile_native=True,
        )

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            if self.attn_backend.accept_output_buffer:
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
                self.forward_impl(
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.forward_impl(
                    q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def forward_impl(
        self,
        q: torch.Tensor,
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
        attn_metadata: "MLACommonMetadata",
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        assert output is not None, "Output tensor must be provided."

        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported for MLA"
            )

        if attn_metadata is None:
            # During the profile run try to simulate to worse case output size
            # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
            # since this can be large
            _ = torch.empty(
                (
                    self.chunked_prefill_workspace_size,
                    self.num_heads,
                    self.qk_nope_head_dim + self.v_head_dim,
                ),
                device=k_c_normed.device,
                dtype=k_c_normed.dtype,
            )

            # The zero fill is required when used with DP + EP
            # to ensure all ranks within a DP group compute the
            # same expert outputs.
            return output.fill_(0)

        if self.impl.dcp_world_size == -1:
            self.impl.dcp_world_size = get_dcp_group().world_size

        fp8_attention = self.kv_cache_dtype.startswith("fp8")

        num_actual_toks = attn_metadata.num_actual_tokens

        # Inputs and outputs may be padded for CUDA graphs
        output_padded = output
        output = output[:num_actual_toks, ...]
        q = q[:num_actual_toks, ...]
        k_c_normed = k_c_normed[:num_actual_toks, ...]
        k_pe = k_pe[:num_actual_toks, ...]

        assert (
            attn_metadata.num_decodes is not None
            and attn_metadata.num_prefills is not None
            and attn_metadata.num_decode_tokens is not None
        )

        has_decode = attn_metadata.num_decodes > 0
        has_prefill = attn_metadata.num_prefills > 0
        num_decode_tokens = attn_metadata.num_decode_tokens

        decode_q = q[:num_decode_tokens]

        prefill_q = q[num_decode_tokens:]
        prefill_k_pe = k_pe[num_decode_tokens:]
        prefill_k_c_normed = k_c_normed[num_decode_tokens:]

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                k_c_normed,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=self.kv_cache_dtype,
                scale=self._k_scale,
            )

        if fp8_attention:
            kv_cache = kv_cache.view(current_platform.fp8_dtype())

        # Sparse MLA impls only support forward_mqa (decode-style attention)
        is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl)

        if has_prefill and not is_sparse_impl:
            self.impl.forward_mha(
                prefill_q,
                prefill_k_c_normed,
                prefill_k_pe,
                kv_cache,
                attn_metadata,
                self._k_scale,
                output=output[num_decode_tokens:],
            )

        if has_decode or (has_prefill and is_sparse_impl):
            # For sparse impl, we always use forward_mqa for all tokens
            # For non-sparse impl, we only use forward_mqa for decode tokens
            if is_sparse_impl:
                mqa_q = q
                mqa_output_slice = output
            else:
                assert attn_metadata.decode is not None
                mqa_q = decode_q
                mqa_output_slice = output[:num_decode_tokens]

            mqa_q_nope, mqa_q_pe = mqa_q.split(
                [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
            )

            # Convert from (B, N, P) to (N, B, P)
            mqa_q_nope = mqa_q_nope.transpose(0, 1)

            if self.q_pad_num_heads is not None:
                B, N, L = mqa_q_pe.shape
                mqa_pe_padded = mqa_q_pe.new_empty((B, self.q_pad_num_heads, L))
                mqa_pe_padded.resize_((B, N, L))
                mqa_pe_padded.copy_(mqa_q_pe)
                mqa_q_pe = mqa_pe_padded

            if self.is_aiter_triton_fp4_bmm_enabled:
                from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4

                mqa_ql_nope = batched_gemm_a16wfp4(
                    mqa_q_nope,
                    self.W_K,
                    self.W_K_scale,
                    transpose_bm=True,
                    prequant=True,
                    y_scale=self._q_scale if fp8_attention else None,
                )
            elif self.is_aiter_triton_fp8_bmm_enabled:
                # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
                mqa_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
                    mqa_q_nope,
                    self.W_K,
                    self.W_K_scale,
                    group_size=128,
                    transpose_bm=True,
                )
            else:
                # Pads the head_dim if necessary (for the underlying kernel)
                N, B, P = mqa_q_nope.shape
                _, _, L = self.W_UK_T.shape

                if self.q_pad_num_heads is not None:
                    mqa_ql_nope = mqa_q_nope.new_empty((self.q_pad_num_heads, B, L))
                    mqa_ql_nope.resize_((N, B, L))
                else:
                    mqa_ql_nope = mqa_q_nope.new_empty((N, B, L))

                # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
                torch.bmm(mqa_q_nope, self.W_UK_T, out=mqa_ql_nope)

                # Convert from (N, B, L) to (B, N, L)
                mqa_ql_nope = mqa_ql_nope.transpose(0, 1)

            if fp8_attention:
                assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
                assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
                mqa_q = self._decode_concat_quant_fp8_op(
                    mqa_ql_nope, mqa_q_pe, self._q_scale
                )
            else:
                mqa_q = (mqa_ql_nope, mqa_q_pe)
            if self.impl.dcp_world_size > 1:
                assert not fp8_attention, "DCP not support fp8 kvcache now."
                # concatenate mqa_ql_nope and mqa_q_pe -> (B, N, L + P)
                mqa_q = torch.cat(mqa_q, dim=-1)
                # mqa_q do allgather in head dim.
                mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)

            # call decode attn
            attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)

            # correct dcp attn_out with lse.
            if self.impl.dcp_world_size > 1:
                attn_out = cp_lse_ag_out_rs(
                    attn_out,
                    lse,
                    get_dcp_group(),
                    is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
                )

            # v_up projection
            self._v_up_proj(attn_out, out=mqa_output_slice)
        return output_padded

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        # we currently do not have quantized bmm's which are needed for
        # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
        # the bmm's in 16-bit, the extra memory overhead of this is fairly low
        kv_b_proj_weight = get_and_maybe_dequant_weights(
            self.kv_b_proj, out_dtype=act_dtype
        ).T

        assert kv_b_proj_weight.shape == (
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
        ), (
            f"{kv_b_proj_weight.shape=}, "
            f"{self.kv_lora_rank=}, "
            f"{self.num_heads=}, "
            f"{self.qk_nope_head_dim=}, "
            f"{self.v_head_dim=}"
        )
        kv_b_proj_weight = kv_b_proj_weight.view(
            self.kv_lora_rank,
            self.num_heads,
            self.qk_nope_head_dim + self.v_head_dim,
        )

        W_UK, W_UV = kv_b_proj_weight.split(
            [self.qk_nope_head_dim, self.v_head_dim], dim=-1
        )

        # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
        if self.is_aiter_triton_fp4_bmm_enabled:
            from vllm.model_executor.layers.quantization.quark.utils import (
                quark_quantize_weight_to_mxfp4,
            )

            self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK)
            # Convert from (L, N, P) to (N, L, P)
            self.W_K = self.W_K.transpose(0, 1)
            self.W_K_scale = self.W_K_scale.transpose(0, 1)

            self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4(
                W_UV.permute(1, 2, 0)
            )
        elif self.is_aiter_triton_fp8_bmm_enabled:
            W_K = W_UK.transpose(0, 1)  # 16 512 128
            W_V = W_UV.permute(1, 2, 0)  # 16 128 512
            self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
                W_K, dtype=current_platform.fp8_dtype()
            )
            self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
                W_V, dtype=current_platform.fp8_dtype()
            )

            # The kernel operates on non-padded inputs. Hence, pre-compiling
            # triton kernel to avoid runtime compilation for unseen batch sizes
            # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
            # On DS-R1, this step adds roughly 50s to the model loading time.
            max_batch_size = 1024  # [ToDo] Find the optimal upper limit
            pre_compilation_list = list(range(1, max_batch_size + 1))
            if is_global_first_rank():
                pre_compilation_list = tqdm(
                    pre_compilation_list,
                    desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
                    total=max_batch_size,
                )

            for m in pre_compilation_list:
                x = torch.empty(
                    (self.W_K.shape[0], m, self.W_K.shape[2]),
                    dtype=torch.bfloat16,
                    device=self.W_K.device,
                )
                rocm_aiter_ops.triton_fp8_bmm(
                    x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
                )

                x = torch.empty(
                    (self.W_V.shape[0], m, self.W_V.shape[2]),
                    dtype=torch.bfloat16,
                    device=self.W_V.device,
                )
                rocm_aiter_ops.triton_fp8_bmm(
                    x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
                )
        else:
            # Convert from (L, N, V) to (N, L, V)
            self.W_UV = W_UV.transpose(0, 1)
            # Convert from (L, N, P) to (N, P, L)
            self.W_UK_T = W_UK.permute(1, 2, 0)

        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

    def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
        # Convert from (B, N, L) to (N, B, L)
        x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
        out = out.view(-1, self.num_heads, self.v_head_dim)
        if self.is_aiter_triton_fp4_bmm_enabled:
            out = rocm_aiter_ops.batched_gemm_a16wfp4(
                x,
                self.W_V,
                self.W_V_scale,
                out,
                transpose_bm=True,
                prequant=True,
                y_scale=None,
            )
            x = out.view(-1, self.num_heads * self.v_head_dim)
        elif self.is_aiter_triton_fp8_bmm_enabled:
            # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
            x = rocm_aiter_ops.triton_fp8_bmm(
                x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
            )
        else:
            # Convert from (B, N * V) to (N, B, V)
            out = out.transpose(0, 1)

            # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
            torch.bmm(x, self.W_UV, out=out)  # Reuse "out" to make it "hot"

            # Convert from (N, B, V) to (B, N * V)
            out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)

            # Adjust output buffer shape back to the original (B, N * V)
            N, B, V = out.shape
            out.resize_((B, N * V))
            out.copy_(out_new)  # Copy result

_decode_concat_quant_fp8_op instance-attribute

_decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
    static=True, group_shape=PER_TENSOR, compile_native=True
)

attn_backend instance-attribute

attn_backend = get_attn_backend(
    head_size,
    dtype,
    kv_cache_dtype,
    block_size,
    use_mla=True,
    use_sparse=use_sparse,
)

calculate_kv_scales instance-attribute

calculate_kv_scales = calculate_kv_scales

chunked_prefill_workspace_size instance-attribute

chunked_prefill_workspace_size = (
    determine_chunked_prefill_workspace_size(
        get_current_vllm_config()
    )
)

head_size instance-attribute

head_size = kv_lora_rank + qk_rope_head_dim

impl instance-attribute

impl = impl_cls(
    num_heads=num_heads,
    head_size=head_size,
    scale=scale,
    num_kv_heads=1,
    alibi_slopes=None,
    sliding_window=None,
    kv_cache_dtype=kv_cache_dtype,
    logits_soft_cap=None,
    attn_type=DECODER,
    kv_sharing_target_layer_name=None,
    q_lora_rank=q_lora_rank,
    kv_lora_rank=kv_lora_rank,
    qk_nope_head_dim=qk_nope_head_dim,
    qk_rope_head_dim=qk_rope_head_dim,
    qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
    v_head_dim=v_head_dim,
    kv_b_proj=kv_b_proj,
    indexer=indexer,
    q_pad_num_heads=q_pad_num_heads,
    **extra_impl_args,
)

indexer instance-attribute

indexer = indexer

is_aiter_triton_fp4_bmm_enabled instance-attribute

is_aiter_triton_fp4_bmm_enabled = (
    is_fp4bmm_enabled() and dtype == bfloat16
)

is_aiter_triton_fp8_bmm_enabled instance-attribute

is_aiter_triton_fp8_bmm_enabled = is_fp8bmm_enabled()

k_range instance-attribute

k_range = tensor(K_SCALE_CONSTANT, dtype=float32)

kv_b_proj instance-attribute

kv_b_proj = kv_b_proj

kv_cache instance-attribute

kv_cache = [
    (tensor([])) for _ in (range(pipeline_parallel_size))
]

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_lora_rank instance-attribute

kv_lora_rank = kv_lora_rank

layer_name instance-attribute

layer_name = prefix

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = 1

q_lora_rank instance-attribute

q_lora_rank = q_lora_rank

q_pad_num_heads instance-attribute

q_pad_num_heads = q_pad_num_heads

q_range instance-attribute

q_range = tensor(Q_SCALE_CONSTANT, dtype=float32)

qk_head_dim instance-attribute

qk_head_dim = qk_nope_head_dim + qk_rope_head_dim

qk_nope_head_dim instance-attribute

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim instance-attribute

qk_rope_head_dim = qk_rope_head_dim

quant_config instance-attribute

quant_config = quant_config

scale instance-attribute

scale = scale

use_direct_call instance-attribute

use_direct_call = not opaque_attention_op()

use_sparse instance-attribute

use_sparse = use_sparse

v_head_dim instance-attribute

v_head_dim = v_head_dim

v_range instance-attribute

v_range = tensor(V_SCALE_CONSTANT, dtype=float32)

__init__

__init__(
    num_heads: int,
    scale: float,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: int | None,
    kv_lora_rank: int,
    kv_b_proj: ColumnParallelLinear,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
    use_sparse: bool = False,
    indexer: object | None = None,
    q_pad_num_heads: int | None = None,
    **extra_impl_args,
)
Source code in vllm/model_executor/layers/attention/mla_attention.py
def __init__(
    self,
    num_heads: int,
    scale: float,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: int | None,
    kv_lora_rank: int,
    kv_b_proj: ColumnParallelLinear,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
    use_sparse: bool = False,
    indexer: object | None = None,
    q_pad_num_heads: int | None = None,
    **extra_impl_args,
):
    super().__init__()
    self.num_heads = num_heads
    self.scale = scale
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.v_head_dim = v_head_dim
    self.q_lora_rank = q_lora_rank
    self.kv_lora_rank = kv_lora_rank
    self.kv_b_proj = kv_b_proj
    self.head_size = kv_lora_rank + qk_rope_head_dim
    self.layer_name = prefix
    self.indexer = indexer
    self.q_pad_num_heads = q_pad_num_heads

    self.num_kv_heads = 1
    self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
        calculate_kv_scales = cache_config.calculate_kv_scales
    else:
        kv_cache_dtype = "auto"
        block_size = 16
        calculate_kv_scales = False
    self.quant_config = quant_config

    # Initialize KV cache quantization attributes
    self.kv_cache_dtype = kv_cache_dtype
    self.calculate_kv_scales = calculate_kv_scales
    _init_kv_cache_quant(self, quant_config, prefix)

    dtype = torch.get_default_dtype()
    self.attn_backend = get_attn_backend(
        self.head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla=True,
        use_sparse=use_sparse,
    )

    if (
        cache_config is not None
        and cache_config.enable_prefix_caching
        and vllm_is_batch_invariant()
        and (
            self.attn_backend.get_name() == "TRITON_MLA"
            or self.attn_backend.get_name() == "FLASHINFER"
        )
    ):
        logger.warning_once(
            "Disabling prefix caching for TRITON_MLA / FLASHINFER "
            "with batch invariance, as it is not yet supported.",
            scope="local",
        )
        cache_config.enable_prefix_caching = False

    impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
    self.impl = impl_cls(
        num_heads=self.num_heads,
        head_size=self.head_size,
        scale=self.scale,
        num_kv_heads=1,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype=self.kv_cache_dtype,
        logits_soft_cap=None,
        attn_type=AttentionType.DECODER,
        kv_sharing_target_layer_name=None,
        # MLA Args
        q_lora_rank=self.q_lora_rank,
        kv_lora_rank=self.kv_lora_rank,
        qk_nope_head_dim=self.qk_nope_head_dim,
        qk_rope_head_dim=self.qk_rope_head_dim,
        qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
        v_head_dim=self.v_head_dim,
        kv_b_proj=kv_b_proj,
        indexer=indexer,
        q_pad_num_heads=q_pad_num_heads,
        **extra_impl_args,
    )

    self.use_direct_call = not current_platform.opaque_attention_op()

    compilation_config = get_current_vllm_config().compilation_config
    if prefix in compilation_config.static_forward_context:
        raise ValueError(f"Duplicate layer name: {prefix}")
    compilation_config.static_forward_context[prefix] = self

    self.kv_cache = [
        torch.tensor([])
        for _ in range(
            get_current_vllm_config().parallel_config.pipeline_parallel_size
        )
    ]

    self.use_sparse = use_sparse

    # Initialize q/k/v range constants.
    self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
    self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
    self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

    self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()

    # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
    self.is_aiter_triton_fp4_bmm_enabled = (
        rocm_aiter_ops.is_fp4bmm_enabled()
        and self.kv_b_proj.weight.dtype == torch.bfloat16
    )

    # Attributes for forward_impl method
    self.chunked_prefill_workspace_size = (
        MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
            get_current_vllm_config()
        )
    )
    self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
        static=True,
        group_shape=GroupShape.PER_TENSOR,
        compile_native=True,
    )

_v_up_proj

_v_up_proj(x: Tensor, out: Tensor)
Source code in vllm/model_executor/layers/attention/mla_attention.py
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
    # Convert from (B, N, L) to (N, B, L)
    x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
    out = out.view(-1, self.num_heads, self.v_head_dim)
    if self.is_aiter_triton_fp4_bmm_enabled:
        out = rocm_aiter_ops.batched_gemm_a16wfp4(
            x,
            self.W_V,
            self.W_V_scale,
            out,
            transpose_bm=True,
            prequant=True,
            y_scale=None,
        )
        x = out.view(-1, self.num_heads * self.v_head_dim)
    elif self.is_aiter_triton_fp8_bmm_enabled:
        # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
        x = rocm_aiter_ops.triton_fp8_bmm(
            x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
        )
    else:
        # Convert from (B, N * V) to (N, B, V)
        out = out.transpose(0, 1)

        # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
        torch.bmm(x, self.W_UV, out=out)  # Reuse "out" to make it "hot"

        # Convert from (N, B, V) to (B, N * V)
        out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)

        # Adjust output buffer shape back to the original (B, N * V)
        N, B, V = out.shape
        out.resize_((B, N * V))
        out.copy_(out_new)  # Copy result

calc_kv_scales

calc_kv_scales(
    q: Tensor, kv_c_normed: Tensor, k_pe: Tensor
) -> None

Optional scale calculation for MLA inputs.

Mirrors Attention.calc_kv_scales. Not all MLA backends require this

Source code in vllm/model_executor/layers/attention/mla_attention.py
def calc_kv_scales(
    self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
) -> None:
    """Optional scale calculation for MLA inputs.

    Mirrors Attention.calc_kv_scales. Not all MLA backends require this
    """
    # Use safe defaults if ranges are not present
    q_range = getattr(self, "q_range", torch.tensor(1.0))
    k_range = getattr(self, "k_range", torch.tensor(1.0))
    v_range = getattr(self, "v_range", torch.tensor(1.0))

    self._q_scale.copy_(torch.abs(q).max() / q_range)
    # kv_c_normed is the compressed KV representation; use it for k/v
    kv_abs_max = torch.abs(kv_c_normed).max()
    self._k_scale.copy_(kv_abs_max / k_range)
    self._v_scale.copy_(kv_abs_max / v_range)
    self._q_scale_float = self._q_scale.item()
    self._k_scale_float = self._k_scale.item()
    self._v_scale_float = self._v_scale.item()
    self.calculate_kv_scales = False

forward

forward(
    q: Tensor,
    kv_c_normed: Tensor,
    k_pe: Tensor,
    output_shape: Size | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mla_attention.py
def forward(
    self,
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    if self.calculate_kv_scales:
        torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

    if self.use_direct_call:
        forward_context: ForwardContext = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        if isinstance(attn_metadata, dict):
            attn_metadata = attn_metadata[self.layer_name]
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]

        if self.attn_backend.accept_output_buffer:
            output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
            self.forward_impl(
                q,
                kv_c_normed,
                k_pe,
                self_kv_cache,
                attn_metadata,
                output=output,
            )
            return output
        else:
            return self.forward_impl(
                q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
            )
    else:
        if self.attn_backend.accept_output_buffer:
            output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
            torch.ops.vllm.unified_mla_attention_with_output(
                q,
                kv_c_normed,
                k_pe,
                output,
                self.layer_name,
            )
            return output
        else:
            return torch.ops.vllm.unified_mla_attention(
                q,
                kv_c_normed,
                k_pe,
                self.layer_name,
            )

forward_impl

forward_impl(
    q: Tensor,
    k_c_normed: Tensor,
    k_pe: Tensor,
    kv_cache: Tensor,
    attn_metadata: MLACommonMetadata,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mla_attention.py
def forward_impl(
    self,
    q: torch.Tensor,
    k_c_normed: torch.Tensor,  # key in unified attn
    k_pe: torch.Tensor,  # value in unified attn
    kv_cache: torch.Tensor,
    attn_metadata: "MLACommonMetadata",
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    assert output is not None, "Output tensor must be provided."

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported for MLA"
        )

    if attn_metadata is None:
        # During the profile run try to simulate to worse case output size
        # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
        # since this can be large
        _ = torch.empty(
            (
                self.chunked_prefill_workspace_size,
                self.num_heads,
                self.qk_nope_head_dim + self.v_head_dim,
            ),
            device=k_c_normed.device,
            dtype=k_c_normed.dtype,
        )

        # The zero fill is required when used with DP + EP
        # to ensure all ranks within a DP group compute the
        # same expert outputs.
        return output.fill_(0)

    if self.impl.dcp_world_size == -1:
        self.impl.dcp_world_size = get_dcp_group().world_size

    fp8_attention = self.kv_cache_dtype.startswith("fp8")

    num_actual_toks = attn_metadata.num_actual_tokens

    # Inputs and outputs may be padded for CUDA graphs
    output_padded = output
    output = output[:num_actual_toks, ...]
    q = q[:num_actual_toks, ...]
    k_c_normed = k_c_normed[:num_actual_toks, ...]
    k_pe = k_pe[:num_actual_toks, ...]

    assert (
        attn_metadata.num_decodes is not None
        and attn_metadata.num_prefills is not None
        and attn_metadata.num_decode_tokens is not None
    )

    has_decode = attn_metadata.num_decodes > 0
    has_prefill = attn_metadata.num_prefills > 0
    num_decode_tokens = attn_metadata.num_decode_tokens

    decode_q = q[:num_decode_tokens]

    prefill_q = q[num_decode_tokens:]
    prefill_k_pe = k_pe[num_decode_tokens:]
    prefill_k_c_normed = k_c_normed[num_decode_tokens:]

    # write the latent and rope to kv cache
    if kv_cache.numel() > 0:
        ops.concat_and_cache_mla(
            k_c_normed,
            k_pe.squeeze(1),
            kv_cache,
            attn_metadata.slot_mapping.flatten(),
            kv_cache_dtype=self.kv_cache_dtype,
            scale=self._k_scale,
        )

    if fp8_attention:
        kv_cache = kv_cache.view(current_platform.fp8_dtype())

    # Sparse MLA impls only support forward_mqa (decode-style attention)
    is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl)

    if has_prefill and not is_sparse_impl:
        self.impl.forward_mha(
            prefill_q,
            prefill_k_c_normed,
            prefill_k_pe,
            kv_cache,
            attn_metadata,
            self._k_scale,
            output=output[num_decode_tokens:],
        )

    if has_decode or (has_prefill and is_sparse_impl):
        # For sparse impl, we always use forward_mqa for all tokens
        # For non-sparse impl, we only use forward_mqa for decode tokens
        if is_sparse_impl:
            mqa_q = q
            mqa_output_slice = output
        else:
            assert attn_metadata.decode is not None
            mqa_q = decode_q
            mqa_output_slice = output[:num_decode_tokens]

        mqa_q_nope, mqa_q_pe = mqa_q.split(
            [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
        )

        # Convert from (B, N, P) to (N, B, P)
        mqa_q_nope = mqa_q_nope.transpose(0, 1)

        if self.q_pad_num_heads is not None:
            B, N, L = mqa_q_pe.shape
            mqa_pe_padded = mqa_q_pe.new_empty((B, self.q_pad_num_heads, L))
            mqa_pe_padded.resize_((B, N, L))
            mqa_pe_padded.copy_(mqa_q_pe)
            mqa_q_pe = mqa_pe_padded

        if self.is_aiter_triton_fp4_bmm_enabled:
            from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4

            mqa_ql_nope = batched_gemm_a16wfp4(
                mqa_q_nope,
                self.W_K,
                self.W_K_scale,
                transpose_bm=True,
                prequant=True,
                y_scale=self._q_scale if fp8_attention else None,
            )
        elif self.is_aiter_triton_fp8_bmm_enabled:
            # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
            mqa_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
                mqa_q_nope,
                self.W_K,
                self.W_K_scale,
                group_size=128,
                transpose_bm=True,
            )
        else:
            # Pads the head_dim if necessary (for the underlying kernel)
            N, B, P = mqa_q_nope.shape
            _, _, L = self.W_UK_T.shape

            if self.q_pad_num_heads is not None:
                mqa_ql_nope = mqa_q_nope.new_empty((self.q_pad_num_heads, B, L))
                mqa_ql_nope.resize_((N, B, L))
            else:
                mqa_ql_nope = mqa_q_nope.new_empty((N, B, L))

            # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
            torch.bmm(mqa_q_nope, self.W_UK_T, out=mqa_ql_nope)

            # Convert from (N, B, L) to (B, N, L)
            mqa_ql_nope = mqa_ql_nope.transpose(0, 1)

        if fp8_attention:
            assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
            assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
            mqa_q = self._decode_concat_quant_fp8_op(
                mqa_ql_nope, mqa_q_pe, self._q_scale
            )
        else:
            mqa_q = (mqa_ql_nope, mqa_q_pe)
        if self.impl.dcp_world_size > 1:
            assert not fp8_attention, "DCP not support fp8 kvcache now."
            # concatenate mqa_ql_nope and mqa_q_pe -> (B, N, L + P)
            mqa_q = torch.cat(mqa_q, dim=-1)
            # mqa_q do allgather in head dim.
            mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)

        # call decode attn
        attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)

        # correct dcp attn_out with lse.
        if self.impl.dcp_world_size > 1:
            attn_out = cp_lse_ag_out_rs(
                attn_out,
                lse,
                get_dcp_group(),
                is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
            )

        # v_up projection
        self._v_up_proj(attn_out, out=mqa_output_slice)
    return output_padded

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/model_executor/layers/attention/mla_attention.py
def get_attn_backend(self) -> type[AttentionBackend]:
    return self.attn_backend

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/mla_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    kv_cache_dtype = kv_cache_dtype_str_to_dtype(
        self.kv_cache_dtype, vllm_config.model_config
    )
    return MLAAttentionSpec(
        block_size=vllm_config.cache_config.block_size,
        num_kv_heads=1,
        head_size=self.head_size,
        dtype=kv_cache_dtype,
        cache_dtype_str=vllm_config.cache_config.cache_dtype,
    )

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm/model_executor/layers/attention/mla_attention.py
def process_weights_after_loading(self, act_dtype: torch.dtype):
    # we currently do not have quantized bmm's which are needed for
    # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
    # the bmm's in 16-bit, the extra memory overhead of this is fairly low
    kv_b_proj_weight = get_and_maybe_dequant_weights(
        self.kv_b_proj, out_dtype=act_dtype
    ).T

    assert kv_b_proj_weight.shape == (
        self.kv_lora_rank,
        self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
    ), (
        f"{kv_b_proj_weight.shape=}, "
        f"{self.kv_lora_rank=}, "
        f"{self.num_heads=}, "
        f"{self.qk_nope_head_dim=}, "
        f"{self.v_head_dim=}"
    )
    kv_b_proj_weight = kv_b_proj_weight.view(
        self.kv_lora_rank,
        self.num_heads,
        self.qk_nope_head_dim + self.v_head_dim,
    )

    W_UK, W_UV = kv_b_proj_weight.split(
        [self.qk_nope_head_dim, self.v_head_dim], dim=-1
    )

    # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
    if self.is_aiter_triton_fp4_bmm_enabled:
        from vllm.model_executor.layers.quantization.quark.utils import (
            quark_quantize_weight_to_mxfp4,
        )

        self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK)
        # Convert from (L, N, P) to (N, L, P)
        self.W_K = self.W_K.transpose(0, 1)
        self.W_K_scale = self.W_K_scale.transpose(0, 1)

        self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4(
            W_UV.permute(1, 2, 0)
        )
    elif self.is_aiter_triton_fp8_bmm_enabled:
        W_K = W_UK.transpose(0, 1)  # 16 512 128
        W_V = W_UV.permute(1, 2, 0)  # 16 128 512
        self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
            W_K, dtype=current_platform.fp8_dtype()
        )
        self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
            W_V, dtype=current_platform.fp8_dtype()
        )

        # The kernel operates on non-padded inputs. Hence, pre-compiling
        # triton kernel to avoid runtime compilation for unseen batch sizes
        # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
        # On DS-R1, this step adds roughly 50s to the model loading time.
        max_batch_size = 1024  # [ToDo] Find the optimal upper limit
        pre_compilation_list = list(range(1, max_batch_size + 1))
        if is_global_first_rank():
            pre_compilation_list = tqdm(
                pre_compilation_list,
                desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
                total=max_batch_size,
            )

        for m in pre_compilation_list:
            x = torch.empty(
                (self.W_K.shape[0], m, self.W_K.shape[2]),
                dtype=torch.bfloat16,
                device=self.W_K.device,
            )
            rocm_aiter_ops.triton_fp8_bmm(
                x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
            )

            x = torch.empty(
                (self.W_V.shape[0], m, self.W_V.shape[2]),
                dtype=torch.bfloat16,
                device=self.W_V.device,
            )
            rocm_aiter_ops.triton_fp8_bmm(
                x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
            )
    else:
        # Convert from (L, N, V) to (N, L, V)
        self.W_UV = W_UV.transpose(0, 1)
        # Convert from (L, N, P) to (N, P, L)
        self.W_UK_T = W_UK.permute(1, 2, 0)

    # If we should not load quant weights, we initialize the scales to 1.0
    # as the default value. See [Note: Register q/k/v/prob scales in state dict]
    # for more details.
    quant_method = (
        self.quant_config.get_quant_method(self, prefix=self.layer_name)
        if self.quant_config
        else None
    )
    if not should_load_quant_weights(quant_method):
        set_default_quant_scales(self, register_buffer=False)

MMEncoderAttention

Bases: CustomOp

Multi-headed attention without any cache, used for multimodal encoder.

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
    """Multi-headed attention without any cache, used for multimodal encoder."""

    # --8<-- [end:mm_encoder_attn]

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float | None = None,
        num_kv_heads: int | None = None,
        prefix: str = "",
    ) -> None:
        """
        Args:
            num_heads: number of attention heads per partition.
            head_size: hidden_size per attention head.
            scale: scale factor.
            num_kv_heads: number of kv heads.
            prefix: This has no effect, it is only here to make it easier to
                    swap between Attention and MultiHeadAttention
        """
        super().__init__()

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        self.layer_name = prefix

        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
            f"divisible by num_kv_heads ({self.num_kv_heads})"
        )
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()

        # Get device-specific vision attention backend.
        self.attn_backend = get_vit_attn_backend(
            head_size=head_size,
            dtype=dtype,
        )

        self.is_flash_attn_backend = self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }

        self._fa_version = (
            get_flash_attn_version() if self.is_flash_attn_backend else None
        )

        logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

    @classmethod
    def enabled(cls) -> bool:
        return True

    def maybe_reshape_qkv_to_4d(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        bsz: int,
        q_len: int,
        kv_len: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Reshape query, key, value to 4D tensors:
        (batch_size, seq_len, num_heads, head_size)
        """
        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

        return query, key, value

    def _forward_sdpa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.maybe_reshape_qkv_to_4d(
            query, key, value, bsz, q_len, kv_len
        )

        output = vit_torch_sdpa_wrapper(
            q=query,
            k=key,
            v=value,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def _forward_fa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        assert (cu_seqlens is not None and max_seqlen is not None) or (
            cu_seqlens is None and max_seqlen is None
        ), "cu_seqlens and max_seqlen should be both set or both None."

        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.maybe_reshape_qkv_to_4d(
            query, key, value, bsz, q_len, kv_len
        )

        output = vit_flash_attn_wrapper(
            q=query,
            k=key,
            v=value,
            batch_size=bsz,
            is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
            fa_version=self._fa_version,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def forward_native(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_cuda(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        if self.is_flash_attn_backend:
            return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
            return self._forward_sdpa(query, key, value, cu_seqlens)
        else:
            raise ValueError(
                f"Unsupported multi-modal encoder attention backend for CUDA: "
                f"{self.attn_backend}."
            )

    def forward_cpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_xpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        assert self.is_flash_attn_backend, (
            "XPU only supports FLASH_ATTN for vision attention."
        )
        return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

_fa_version instance-attribute

_fa_version = (
    get_flash_attn_version()
    if is_flash_attn_backend
    else None
)

attn_backend instance-attribute

attn_backend = get_vit_attn_backend(
    head_size=head_size, dtype=dtype
)

head_size instance-attribute

head_size = head_size

is_flash_attn_backend instance-attribute

is_flash_attn_backend = attn_backend in {
    FLASH_ATTN,
    ROCM_AITER_FA,
}

layer_name instance-attribute

layer_name = prefix

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = (
    num_heads if num_kv_heads is None else num_kv_heads
)

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = scale

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
) -> None

Parameters:

Name Type Description Default
num_heads int

number of attention heads per partition.

required
head_size int

hidden_size per attention head.

required
scale float | None

scale factor.

None
num_kv_heads int | None

number of kv heads.

None
prefix str

This has no effect, it is only here to make it easier to swap between Attention and MultiHeadAttention

''
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
) -> None:
    """
    Args:
        num_heads: number of attention heads per partition.
        head_size: hidden_size per attention head.
        scale: scale factor.
        num_kv_heads: number of kv heads.
        prefix: This has no effect, it is only here to make it easier to
                swap between Attention and MultiHeadAttention
    """
    super().__init__()

    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = scale
    self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
    self.layer_name = prefix

    assert self.num_heads % self.num_kv_heads == 0, (
        f"num_heads ({self.num_heads}) is not "
        f"divisible by num_kv_heads ({self.num_kv_heads})"
    )
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    # During model initialization, the default dtype is set as the model
    # weight and activation dtype.
    dtype = torch.get_default_dtype()

    # Get device-specific vision attention backend.
    self.attn_backend = get_vit_attn_backend(
        head_size=head_size,
        dtype=dtype,
    )

    self.is_flash_attn_backend = self.attn_backend in {
        AttentionBackendEnum.FLASH_ATTN,
        AttentionBackendEnum.ROCM_AITER_FA,
    }

    self._fa_version = (
        get_flash_attn_version() if self.is_flash_attn_backend else None
    )

    logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

_forward_fa

_forward_fa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_fa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    assert (cu_seqlens is not None and max_seqlen is not None) or (
        cu_seqlens is None and max_seqlen is None
    ), "cu_seqlens and max_seqlen should be both set or both None."

    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.maybe_reshape_qkv_to_4d(
        query, key, value, bsz, q_len, kv_len
    )

    output = vit_flash_attn_wrapper(
        q=query,
        k=key,
        v=value,
        batch_size=bsz,
        is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
        fa_version=self._fa_version,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

_forward_sdpa

_forward_sdpa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_sdpa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.maybe_reshape_qkv_to_4d(
        query, key, value, bsz, q_len, kv_len
    )

    output = vit_torch_sdpa_wrapper(
        q=query,
        k=key,
        v=value,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

enabled classmethod

enabled() -> bool
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
@classmethod
def enabled(cls) -> bool:
    return True

forward_cpu

forward_cpu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_cpu(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    return self._forward_sdpa(query, key, value, cu_seqlens)

forward_cuda

forward_cuda(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_cuda(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    if self.is_flash_attn_backend:
        return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
    elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
        return self._forward_sdpa(query, key, value, cu_seqlens)
    else:
        raise ValueError(
            f"Unsupported multi-modal encoder attention backend for CUDA: "
            f"{self.attn_backend}."
        )

forward_native

forward_native(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_native(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    return self._forward_sdpa(query, key, value, cu_seqlens)

forward_xpu

forward_xpu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_xpu(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    assert self.is_flash_attn_backend, (
        "XPU only supports FLASH_ATTN for vision attention."
    )
    return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

maybe_reshape_qkv_to_4d

maybe_reshape_qkv_to_4d(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[Tensor, Tensor, Tensor]

Reshape query, key, value to 4D tensors: (batch_size, seq_len, num_heads, head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def maybe_reshape_qkv_to_4d(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Reshape query, key, value to 4D tensors:
    (batch_size, seq_len, num_heads, head_size)
    """
    query = query.view(bsz, q_len, self.num_heads, self.head_size)
    key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
    value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

    if (num_repeat := self.num_queries_per_kv) > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_repeat, dim=2)
        value = torch.repeat_interleave(value, num_repeat, dim=2)

    return query, key, value

StaticSinkAttention

Bases: Attention, CustomOp

Attention with static sink tokens

Source code in vllm/model_executor/layers/attention/static_sink_attention.py
@CustomOp.register("static_sink_attention")
class StaticSinkAttention(Attention, CustomOp):
    """
    Attention with static sink tokens
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        sink_len: int,
        attn_backend: type[AttentionBackend] | None = None,
        cache_config: CacheConfig | None = None,
        **kwargs,
    ):
        dtype = torch.get_default_dtype()

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        if attn_backend is not None:
            underlying_attn_backend = attn_backend
        else:
            underlying_attn_backend = get_attn_backend(
                head_size, dtype, kv_cache_dtype, block_size
            )
        attn_backend = create_static_sink_attention_backend(
            underlying_attn_backend,  # type: ignore[arg-type]
            sink_len=sink_len,
        )
        Attention.__init__(
            self=self,
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            cache_config=cache_config,
            attn_backend=attn_backend,
            **kwargs,
        )
        CustomOp.__init__(self)

        self.sink_len = sink_len
        self.block_size = block_size
        self.sink_populated = False
        self.sink_key = None
        self.sink_value = None

    def update_sink_kv(self, sink_key, sink_value) -> None:
        self.sink_key = sink_key
        self.sink_value = sink_value

    def forward_native(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        assert self.sink_key is not None and self.sink_value is not None, (
            "sink_key and sink_value have not been prepared"
        )
        if not self.sink_populated:
            forward_context: ForwardContext = get_forward_context()
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]
            torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)

        return super().forward(query, key, value, output_shape)

    def forward_cuda(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        return self.forward_native(query, key, value, output_shape)

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def populate_sink_kv(self, self_kv_cache):
        sink_kv_slot_mapping = torch.arange(
            self.block_size,
            self.sink_len + self.block_size,
            device=torch.cuda.current_device(),
            dtype=torch.long,
        )
        triton_reshape_and_cache_flash_diffkv(
            self.sink_key,
            self.sink_value,
            self_kv_cache,
            sink_kv_slot_mapping,
            self.kv_cache_dtype,
            self._k_scale,
            self._v_scale,
        )
        # We only populate the sink_key and sink_value once
        self.sink_populated = True

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER

        return SinkFullAttentionSpec(
            block_size=block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            head_size_v=self.head_size_v,
            sink_len=self.sink_len,
            dtype=self.kv_cache_torch_dtype,
        )

block_size instance-attribute

block_size = block_size

sink_key instance-attribute

sink_key = None

sink_len instance-attribute

sink_len = sink_len

sink_populated instance-attribute

sink_populated = False

sink_value instance-attribute

sink_value = None

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    sink_len: int,
    attn_backend: type[AttentionBackend] | None = None,
    cache_config: CacheConfig | None = None,
    **kwargs,
)
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    sink_len: int,
    attn_backend: type[AttentionBackend] | None = None,
    cache_config: CacheConfig | None = None,
    **kwargs,
):
    dtype = torch.get_default_dtype()

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    if attn_backend is not None:
        underlying_attn_backend = attn_backend
    else:
        underlying_attn_backend = get_attn_backend(
            head_size, dtype, kv_cache_dtype, block_size
        )
    attn_backend = create_static_sink_attention_backend(
        underlying_attn_backend,  # type: ignore[arg-type]
        sink_len=sink_len,
    )
    Attention.__init__(
        self=self,
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        cache_config=cache_config,
        attn_backend=attn_backend,
        **kwargs,
    )
    CustomOp.__init__(self)

    self.sink_len = sink_len
    self.block_size = block_size
    self.sink_populated = False
    self.sink_key = None
    self.sink_value = None

forward

forward(*args, **kwargs)
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def forward(self, *args, **kwargs):
    return self._forward_method(*args, **kwargs)

forward_cuda

forward_cuda(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    output_shape: Size | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def forward_cuda(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    return self.forward_native(query, key, value, output_shape)

forward_native

forward_native(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    output_shape: Size | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def forward_native(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    assert self.sink_key is not None and self.sink_value is not None, (
        "sink_key and sink_value have not been prepared"
    )
    if not self.sink_populated:
        forward_context: ForwardContext = get_forward_context()
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]
        torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)

    return super().forward(query, key, value, output_shape)

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    # Block size may get updated after model loading, refresh it
    block_size = vllm_config.cache_config.block_size
    # Should not be called for enc-dec or encoder-only attention.
    assert self.attn_type == AttentionType.DECODER

    return SinkFullAttentionSpec(
        block_size=block_size,
        num_kv_heads=self.num_kv_heads,
        head_size=self.head_size,
        head_size_v=self.head_size_v,
        sink_len=self.sink_len,
        dtype=self.kv_cache_torch_dtype,
    )

populate_sink_kv

populate_sink_kv(self_kv_cache)
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def populate_sink_kv(self, self_kv_cache):
    sink_kv_slot_mapping = torch.arange(
        self.block_size,
        self.sink_len + self.block_size,
        device=torch.cuda.current_device(),
        dtype=torch.long,
    )
    triton_reshape_and_cache_flash_diffkv(
        self.sink_key,
        self.sink_value,
        self_kv_cache,
        sink_kv_slot_mapping,
        self.kv_cache_dtype,
        self._k_scale,
        self._v_scale,
    )
    # We only populate the sink_key and sink_value once
    self.sink_populated = True

update_sink_kv

update_sink_kv(sink_key, sink_value) -> None
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def update_sink_kv(self, sink_key, sink_value) -> None:
    self.sink_key = sink_key
    self.sink_value = sink_value