Skip to content

vllm.model_executor.layers.attention.mm_encoder_attention

logger module-attribute

logger = init_logger(__name__)

MMEncoderAttention

Bases: CustomOp

Multi-headed attention without any cache, used for multimodal encoder.

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
    """Multi-headed attention without any cache, used for multimodal encoder."""

    # --8<-- [end:mm_encoder_attn]

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float | None = None,
        num_kv_heads: int | None = None,
        prefix: str = "",
    ) -> None:
        """
        Args:
            num_heads: number of attention heads per partition.
            head_size: hidden_size per attention head.
            scale: scale factor.
            num_kv_heads: number of kv heads.
            prefix: This has no effect, it is only here to make it easier to
                    swap between Attention and MultiHeadAttention
        """
        super().__init__()

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        self.layer_name = prefix

        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
            f"divisible by num_kv_heads ({self.num_kv_heads})"
        )
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()

        # Get device-specific vision attention backend.
        self.attn_backend = get_vit_attn_backend(
            head_size=head_size,
            dtype=dtype,
        )

        self.is_flash_attn_backend = self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }

        self._fa_version = (
            get_flash_attn_version() if self.is_flash_attn_backend else None
        )

        logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

    @classmethod
    def enabled(cls) -> bool:
        return True

    def maybe_reshape_qkv_to_4d(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        bsz: int,
        q_len: int,
        kv_len: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Reshape query, key, value to 4D tensors:
        (batch_size, seq_len, num_heads, head_size)
        """
        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

        return query, key, value

    def _forward_sdpa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.maybe_reshape_qkv_to_4d(
            query, key, value, bsz, q_len, kv_len
        )

        output = vit_torch_sdpa_wrapper(
            q=query,
            k=key,
            v=value,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def _forward_fa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        assert (cu_seqlens is not None and max_seqlen is not None) or (
            cu_seqlens is None and max_seqlen is None
        ), "cu_seqlens and max_seqlen should be both set or both None."

        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.maybe_reshape_qkv_to_4d(
            query, key, value, bsz, q_len, kv_len
        )

        output = vit_flash_attn_wrapper(
            q=query,
            k=key,
            v=value,
            batch_size=bsz,
            is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
            fa_version=self._fa_version,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def forward_native(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_cuda(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        if self.is_flash_attn_backend:
            return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
            return self._forward_sdpa(query, key, value, cu_seqlens)
        else:
            raise ValueError(
                f"Unsupported multi-modal encoder attention backend for CUDA: "
                f"{self.attn_backend}."
            )

    def forward_cpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_xpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        assert self.is_flash_attn_backend, (
            "XPU only supports FLASH_ATTN for vision attention."
        )
        return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

_fa_version instance-attribute

_fa_version = (
    get_flash_attn_version()
    if is_flash_attn_backend
    else None
)

attn_backend instance-attribute

attn_backend = get_vit_attn_backend(
    head_size=head_size, dtype=dtype
)

head_size instance-attribute

head_size = head_size

is_flash_attn_backend instance-attribute

is_flash_attn_backend = attn_backend in {
    FLASH_ATTN,
    ROCM_AITER_FA,
}

layer_name instance-attribute

layer_name = prefix

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = (
    num_heads if num_kv_heads is None else num_kv_heads
)

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = scale

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
) -> None

Parameters:

Name Type Description Default
num_heads int

number of attention heads per partition.

required
head_size int

hidden_size per attention head.

required
scale float | None

scale factor.

None
num_kv_heads int | None

number of kv heads.

None
prefix str

This has no effect, it is only here to make it easier to swap between Attention and MultiHeadAttention

''
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
) -> None:
    """
    Args:
        num_heads: number of attention heads per partition.
        head_size: hidden_size per attention head.
        scale: scale factor.
        num_kv_heads: number of kv heads.
        prefix: This has no effect, it is only here to make it easier to
                swap between Attention and MultiHeadAttention
    """
    super().__init__()

    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = scale
    self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
    self.layer_name = prefix

    assert self.num_heads % self.num_kv_heads == 0, (
        f"num_heads ({self.num_heads}) is not "
        f"divisible by num_kv_heads ({self.num_kv_heads})"
    )
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    # During model initialization, the default dtype is set as the model
    # weight and activation dtype.
    dtype = torch.get_default_dtype()

    # Get device-specific vision attention backend.
    self.attn_backend = get_vit_attn_backend(
        head_size=head_size,
        dtype=dtype,
    )

    self.is_flash_attn_backend = self.attn_backend in {
        AttentionBackendEnum.FLASH_ATTN,
        AttentionBackendEnum.ROCM_AITER_FA,
    }

    self._fa_version = (
        get_flash_attn_version() if self.is_flash_attn_backend else None
    )

    logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

_forward_fa

_forward_fa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_fa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    assert (cu_seqlens is not None and max_seqlen is not None) or (
        cu_seqlens is None and max_seqlen is None
    ), "cu_seqlens and max_seqlen should be both set or both None."

    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.maybe_reshape_qkv_to_4d(
        query, key, value, bsz, q_len, kv_len
    )

    output = vit_flash_attn_wrapper(
        q=query,
        k=key,
        v=value,
        batch_size=bsz,
        is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
        fa_version=self._fa_version,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

_forward_sdpa

_forward_sdpa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_sdpa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.maybe_reshape_qkv_to_4d(
        query, key, value, bsz, q_len, kv_len
    )

    output = vit_torch_sdpa_wrapper(
        q=query,
        k=key,
        v=value,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

enabled classmethod

enabled() -> bool
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
@classmethod
def enabled(cls) -> bool:
    return True

forward_cpu

forward_cpu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_cpu(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    return self._forward_sdpa(query, key, value, cu_seqlens)

forward_cuda

forward_cuda(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_cuda(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    if self.is_flash_attn_backend:
        return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
    elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
        return self._forward_sdpa(query, key, value, cu_seqlens)
    else:
        raise ValueError(
            f"Unsupported multi-modal encoder attention backend for CUDA: "
            f"{self.attn_backend}."
        )

forward_native

forward_native(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_native(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    return self._forward_sdpa(query, key, value, cu_seqlens)

forward_xpu

forward_xpu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_xpu(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    assert self.is_flash_attn_backend, (
        "XPU only supports FLASH_ATTN for vision attention."
    )
    return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

maybe_reshape_qkv_to_4d

maybe_reshape_qkv_to_4d(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[Tensor, Tensor, Tensor]

Reshape query, key, value to 4D tensors: (batch_size, seq_len, num_heads, head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def maybe_reshape_qkv_to_4d(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Reshape query, key, value to 4D tensors:
    (batch_size, seq_len, num_heads, head_size)
    """
    query = query.view(bsz, q_len, self.num_heads, self.head_size)
    key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
    value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

    if (num_repeat := self.num_queries_per_kv) > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_repeat, dim=2)
        value = torch.repeat_interleave(value, num_repeat, dim=2)

    return query, key, value