Skip to content

vllm.model_executor.models.funaudiochat

Inference-only FunAudioChat model compatible with HuggingFace weights.

FunAudioChat is a Qwen3 text model augmented with
  • a continuous audio encoder (Whisper-mel frontend + transformer)
  • a discrete audio encoder (speech tokenizer + projector)

In the HF implementation, audio features are scattered into <|AUDIO|> token positions via inputs_embeds, while position_ids (RoPE) remains standard 1D.

FunAudioChatAudioAttention

Bases: Module

Multi-headed attention used inside the continuous audio tower.

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatAudioAttention(nn.Module):
    """Multi-headed attention used inside the continuous audio tower."""

    def __init__(self, config: Any):
        super().__init__()
        self.embed_dim = int(config.d_model)
        self.total_num_heads = int(config.encoder_attention_heads)
        self.dropout = float(getattr(config, "attention_dropout", 0.0))
        self.head_dim = self.embed_dim // self.total_num_heads
        self.num_key_value_groups = 1  # needed for eager attention
        self.config = config

        if self.head_dim * self.total_num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got embed_dim={self.embed_dim}, "
                f"num_heads={self.total_num_heads})."
            )
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = 0.0
        self.is_decoder = False
        self.is_causal = False

        self.qkv_proj = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.total_num_heads,
            bias=True,
        )
        self.num_heads = self.qkv_proj.num_heads
        self.num_kv_heads = self.qkv_proj.num_kv_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.attn = MMEncoderAttention(
            num_heads=self.num_heads,
            head_size=self.head_dim,
            scale=self.scaling,
            num_kv_heads=self.num_kv_heads,
            prefix="funaudiochat_audio_tower.attn",
        )
        self.out_proj = RowParallelLinear(
            self.embed_dim,
            self.embed_dim,
            bias=True,
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters())
        with torch.no_grad():
            if self.qkv_proj.bias is not None:
                # HF FunAudioChat uses bias=False for k_proj. Ensure the missing
                # shard starts as zeros, while allowing q/v shards to load.
                self.qkv_proj.bias.zero_()

        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)

            loaded_params.add(name)

        return loaded_params

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor:
        del kwargs
        del attention_mask
        seq_length, _ = hidden_states.size()

        qkv, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv.split(
            [self.q_size, self.kv_size, self.kv_size], dim=-1
        )

        max_seqlen: torch.Tensor | None = None
        if cu_seqlens is not None:
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

        attn_output = self.attn(
            query_states.reshape(1, seq_length, self.q_size),
            key_states.reshape(1, seq_length, self.kv_size),
            value_states.reshape(1, seq_length, self.kv_size),
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        ).reshape(seq_length, -1)

        output, _ = self.out_proj(attn_output)
        return output

attention_dropout instance-attribute

attention_dropout = 0.0

attn instance-attribute

attn = MMEncoderAttention(
    num_heads=num_heads,
    head_size=head_dim,
    scale=scaling,
    num_kv_heads=num_kv_heads,
    prefix="funaudiochat_audio_tower.attn",
)

config instance-attribute

config = config

dropout instance-attribute

dropout = float(getattr(config, 'attention_dropout', 0.0))

embed_dim instance-attribute

embed_dim = int(d_model)

head_dim instance-attribute

head_dim = embed_dim // total_num_heads

is_causal instance-attribute

is_causal = False

is_decoder instance-attribute

is_decoder = False

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

num_heads instance-attribute

num_heads = num_heads

num_key_value_groups instance-attribute

num_key_value_groups = 1

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

out_proj instance-attribute

out_proj = RowParallelLinear(
    embed_dim, embed_dim, bias=True
)

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    embed_dim, head_dim, total_num_heads, bias=True
)

scaling instance-attribute

scaling = head_dim ** -0.5

total_num_heads instance-attribute

total_num_heads = int(encoder_attention_heads)

__init__

__init__(config: Any)
Source code in vllm/model_executor/models/funaudiochat.py
def __init__(self, config: Any):
    super().__init__()
    self.embed_dim = int(config.d_model)
    self.total_num_heads = int(config.encoder_attention_heads)
    self.dropout = float(getattr(config, "attention_dropout", 0.0))
    self.head_dim = self.embed_dim // self.total_num_heads
    self.num_key_value_groups = 1  # needed for eager attention
    self.config = config

    if self.head_dim * self.total_num_heads != self.embed_dim:
        raise ValueError(
            "embed_dim must be divisible by num_heads "
            f"(got embed_dim={self.embed_dim}, "
            f"num_heads={self.total_num_heads})."
        )
    self.scaling = self.head_dim**-0.5
    self.attention_dropout = 0.0
    self.is_decoder = False
    self.is_causal = False

    self.qkv_proj = QKVParallelLinear(
        self.embed_dim,
        self.head_dim,
        self.total_num_heads,
        bias=True,
    )
    self.num_heads = self.qkv_proj.num_heads
    self.num_kv_heads = self.qkv_proj.num_kv_heads
    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim

    self.attn = MMEncoderAttention(
        num_heads=self.num_heads,
        head_size=self.head_dim,
        scale=self.scaling,
        num_kv_heads=self.num_kv_heads,
        prefix="funaudiochat_audio_tower.attn",
    )
    self.out_proj = RowParallelLinear(
        self.embed_dim,
        self.embed_dim,
        bias=True,
    )

forward

forward(
    hidden_states: Tensor,
    cu_seqlens: Tensor | None = None,
    attention_mask: Tensor | None = None,
    **kwargs: object,
) -> Tensor
Source code in vllm/model_executor/models/funaudiochat.py
def forward(
    self,
    hidden_states: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    attention_mask: torch.Tensor | None = None,
    **kwargs: object,
) -> torch.Tensor:
    del kwargs
    del attention_mask
    seq_length, _ = hidden_states.size()

    qkv, _ = self.qkv_proj(hidden_states)
    query_states, key_states, value_states = qkv.split(
        [self.q_size, self.kv_size, self.kv_size], dim=-1
    )

    max_seqlen: torch.Tensor | None = None
    if cu_seqlens is not None:
        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

    attn_output = self.attn(
        query_states.reshape(1, seq_length, self.q_size),
        key_states.reshape(1, seq_length, self.kv_size),
        value_states.reshape(1, seq_length, self.kv_size),
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
    ).reshape(seq_length, -1)

    output, _ = self.out_proj(attn_output)
    return output

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/funaudiochat.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
    ]

    params_dict = dict(self.named_parameters())
    with torch.no_grad():
        if self.qkv_proj.bias is not None:
            # HF FunAudioChat uses bias=False for k_proj. Ensure the missing
            # shard starts as zeros, while allowing q/v shards to load.
            self.qkv_proj.bias.zero_()

    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        for param_name, shard_name, shard_id in stacked_params_mapping:
            if shard_name not in name:
                continue
            name = name.replace(shard_name, param_name)
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)

        loaded_params.add(name)

    return loaded_params

FunAudioChatAudioEncoder

Bases: Module

Continuous audio tower.

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatAudioEncoder(nn.Module):
    """Continuous audio tower."""

    def __init__(self, config: Any):
        super().__init__()
        self.config = config

        embed_dim = int(config.d_model)
        self.num_mel_bins = int(config.num_mel_bins)
        self.max_source_positions = int(config.max_source_positions)
        self.embed_scale = (embed_dim**0.5) if bool(config.scale_embedding) else 1.0
        self.n_window = int(config.n_window)

        self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
        self.layers = nn.ModuleList(
            [
                FunAudioChatAudioEncoderLayer(config)
                for _ in range(int(config.encoder_layers))
            ]
        )
        self.ln_post = nn.LayerNorm(embed_dim)
        self.avg_pooler = nn.AvgPool1d(2, stride=2)
        self.proj = nn.Linear(embed_dim, int(config.output_dim))
        self.positional_embedding = _SinusoidsPositionEmbedding(
            self.max_source_positions, embed_dim
        )

        # Present in HF weights even if unused during S2T.
        self.audio_bos_eos_token = nn.Embedding(2, int(config.output_dim))

    @property
    def dtype(self) -> torch.dtype:
        return self.conv1.weight.dtype

    def _prepare_attention_mask(
        self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor
    ) -> torch.Tensor | None:
        if getattr(self.config, "_attn_implementation", "eager") == "flash_attention_2":
            return None

        seq_length = inputs_tensor.shape[0]
        attention_mask = torch.full(
            (1, 1, seq_length, seq_length),
            torch.finfo(inputs_tensor.dtype).min,
            device=inputs_tensor.device,
            dtype=inputs_tensor.dtype,
        )
        for i in range(1, len(cu_seqlens)):
            start = int(cu_seqlens[i - 1].item())
            end = int(cu_seqlens[i].item())
            attention_mask[..., start:end, start:end] = 0
        return attention_mask

    def forward(
        self,
        input_features: torch.Tensor,
        feature_lens: torch.Tensor,
        aftercnn_lens: torch.Tensor,
        speech_maxlen: int,
        **kwargs: object,
    ) -> BaseModelOutput:
        # For max-length audio (300s => ~7500 speech frames at 25Hz), the
        # Torch SDPA path can be prohibitively memory hungry (~O(n^2) inside the
        # longest chunks). Require FlashAttention for such inputs to avoid OOM
        # and performance cliffs.
        if int(speech_maxlen) >= 7500:
            if not _has_module("flash_attn"):
                raise RuntimeError(
                    "FunAudioChat long audio (~300s) requires FlashAttention-2 "
                    "for the continuous audio tower, but `flash_attn` is not "
                    "installed in the runtime environment."
                )
            if not getattr(
                self.layers[0].self_attn.attn, "is_flash_attn_backend", False
            ):
                raise RuntimeError(
                    "FunAudioChat long audio (~300s) requires FlashAttention for the "
                    "continuous audio tower, but the selected MM encoder attention "
                    "backend is not FlashAttention."
                )

        # Handle empty / invalid items (feature_lens == 0) without crashing.
        original_batch_size = int(feature_lens.size(0))
        device = input_features.device

        valid_mask = feature_lens > 0
        valid_indices = torch.where(valid_mask)[0]

        if valid_indices.numel() == 0:
            output_dim = int(self.proj.out_features)
            return BaseModelOutput(
                last_hidden_state=torch.zeros(
                    (original_batch_size, speech_maxlen, output_dim),
                    device=device,
                    dtype=self.proj.weight.dtype,
                )
            )

        input_features_list = input_features.split(feature_lens.tolist(), dim=1)
        valid_input_features_list = [input_features_list[int(i)] for i in valid_indices]
        valid_input_features = torch.cat(valid_input_features_list, dim=1)

        valid_feature_lens = feature_lens[valid_mask]
        valid_aftercnn_lens = aftercnn_lens[valid_mask]

        chunk_num = torch.ceil(valid_feature_lens / (self.n_window * 2)).long()

        chunk_lengths_list: list[int] = []
        full_chunk_len = self.n_window * 2
        for i, length in enumerate(valid_feature_lens):
            num_chunks_for_sample = int(chunk_num[i].item())
            if num_chunks_for_sample == 0:
                continue
            chunk_lengths_list.extend([full_chunk_len] * (num_chunks_for_sample - 1))
            last_chunk_len = int(length.item()) % full_chunk_len
            if last_chunk_len == 0:
                last_chunk_len = full_chunk_len
            chunk_lengths_list.append(last_chunk_len)

        chunk_lengths = torch.tensor(
            chunk_lengths_list, dtype=torch.long, device=device
        )

        chunk_list = valid_input_features.split(chunk_lengths.tolist(), dim=1)
        padded_feature, padded_mask, padded_mask_after_cnn = (
            self.padded_and_mask_function(
                chunk_list, chunk_lengths, padding_value=0, padding_side="right"
            )
        )

        padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask
        padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2)

        padded_embed = padded_embed + self.positional_embedding.positional_embedding[
            : padded_embed.shape[1], :
        ].unsqueeze(0).to(padded_embed.dtype)

        hidden_states = padded_embed[padded_mask_after_cnn]
        cu_seqlens = torch.cat(
            (
                torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32),
                padded_mask_after_cnn.sum(1).cumsum(0),
            )
        ).to(torch.int32)

        for encoder_layer in self.layers:
            (hidden_states,) = encoder_layer(
                hidden_states,
                cu_seqlens=cu_seqlens,
                **kwargs,
            )

        hidden_states_list = hidden_states.split(valid_aftercnn_lens.tolist(), dim=0)

        pooled_list: list[torch.Tensor] = []
        pooled_lengths: list[int] = []
        for each_audio_states in hidden_states_list:
            seq_len = int(each_audio_states.shape[0])
            if seq_len >= 2:
                pooled = nn.functional.avg_pool1d(
                    each_audio_states.transpose(0, 1), kernel_size=2, stride=2
                ).transpose(0, 1)
            else:
                pooled = each_audio_states
            pooled_list.append(pooled)
            pooled_lengths.append(int(pooled.shape[0]))

        pooled_concat = torch.cat(pooled_list, dim=0)
        processed_concat = self.proj(self.ln_post(pooled_concat))
        processed_audio_list = list(processed_concat.split(pooled_lengths, dim=0))

        output_dim = (
            int(processed_audio_list[0].shape[-1])
            if processed_audio_list
            else int(self.proj.out_features)
        )
        output_hidden_states = torch.zeros(
            (original_batch_size, speech_maxlen, output_dim),
            dtype=processed_audio_list[0].dtype
            if processed_audio_list
            else self.proj.weight.dtype,
            device=device,
        )

        for valid_idx, processed in zip(valid_indices, processed_audio_list):
            seq_len = min(int(processed.shape[0]), int(speech_maxlen))
            output_hidden_states[int(valid_idx), :seq_len] = processed[:seq_len]

        return BaseModelOutput(last_hidden_state=output_hidden_states)

    def padded_and_mask_function(
        self,
        tensor_list: Sequence[torch.Tensor],
        tensor_len: torch.Tensor,
        padding_value: float = 0.0,
        padding_side: str = "right",
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        max_len = int(tensor_len.max().item())
        dim = int(tensor_list[0].shape[0])
        padded_tensor = torch.full(
            size=(len(tensor_list), dim, max_len),
            fill_value=padding_value,
            dtype=self.dtype,
            device=tensor_list[0].device,
        )

        batch_mask = torch.zeros(
            (len(tensor_len), max_len), dtype=torch.long, device=padded_tensor.device
        )
        for i, length in enumerate(tensor_len):
            length_val = int(length.item())
            batch_mask[i, :length_val] = 1
            padded_tensor[i, :, :length_val] = tensor_list[i]

        feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
        max_len_after_cnn = int(feature_lens_after_cnn.max().item())
        batch_mask_after_cnn = torch.zeros(
            (len(tensor_len), max_len_after_cnn),
            dtype=torch.long,
            device=padded_tensor.device,
        )
        for i, length in enumerate(feature_lens_after_cnn):
            batch_mask_after_cnn[i, : int(length.item())] = 1

        if padding_side != "right":
            raise NotImplementedError("Only right padding is supported.")

        return (
            padded_tensor,
            batch_mask.unsqueeze(1).to(padded_tensor.dtype),
            batch_mask_after_cnn.bool(),
        )

    # From the HF FunAudioChat implementation.
    def _get_feat_extract_output_lengths(
        self, input_lengths: torch.LongTensor
    ) -> tuple[torch.LongTensor, torch.LongTensor]:
        input_lengths = (input_lengths - 1) // 2 + 1
        output_lengths = (input_lengths - 2) // 2 + 1
        return input_lengths, output_lengths

audio_bos_eos_token instance-attribute

audio_bos_eos_token = Embedding(2, int(output_dim))

avg_pooler instance-attribute

avg_pooler = AvgPool1d(2, stride=2)

config instance-attribute

config = config

conv1 instance-attribute

conv1 = Conv1d(
    num_mel_bins, embed_dim, kernel_size=3, padding=1
)

conv2 instance-attribute

conv2 = Conv1d(
    embed_dim, embed_dim, kernel_size=3, stride=2, padding=1
)

dtype property

dtype: dtype

embed_scale instance-attribute

embed_scale = (
    embed_dim**0.5 if bool(scale_embedding) else 1.0
)

layers instance-attribute

layers = ModuleList(
    [
        (FunAudioChatAudioEncoderLayer(config))
        for _ in (range(int(encoder_layers)))
    ]
)

ln_post instance-attribute

ln_post = LayerNorm(embed_dim)

max_source_positions instance-attribute

max_source_positions = int(max_source_positions)

n_window instance-attribute

n_window = int(n_window)

num_mel_bins instance-attribute

num_mel_bins = int(num_mel_bins)

positional_embedding instance-attribute

positional_embedding = _SinusoidsPositionEmbedding(
    max_source_positions, embed_dim
)

proj instance-attribute

proj = Linear(embed_dim, int(output_dim))

__init__

__init__(config: Any)
Source code in vllm/model_executor/models/funaudiochat.py
def __init__(self, config: Any):
    super().__init__()
    self.config = config

    embed_dim = int(config.d_model)
    self.num_mel_bins = int(config.num_mel_bins)
    self.max_source_positions = int(config.max_source_positions)
    self.embed_scale = (embed_dim**0.5) if bool(config.scale_embedding) else 1.0
    self.n_window = int(config.n_window)

    self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
    self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
    self.layers = nn.ModuleList(
        [
            FunAudioChatAudioEncoderLayer(config)
            for _ in range(int(config.encoder_layers))
        ]
    )
    self.ln_post = nn.LayerNorm(embed_dim)
    self.avg_pooler = nn.AvgPool1d(2, stride=2)
    self.proj = nn.Linear(embed_dim, int(config.output_dim))
    self.positional_embedding = _SinusoidsPositionEmbedding(
        self.max_source_positions, embed_dim
    )

    # Present in HF weights even if unused during S2T.
    self.audio_bos_eos_token = nn.Embedding(2, int(config.output_dim))

_get_feat_extract_output_lengths

_get_feat_extract_output_lengths(
    input_lengths: LongTensor,
) -> tuple[LongTensor, LongTensor]
Source code in vllm/model_executor/models/funaudiochat.py
def _get_feat_extract_output_lengths(
    self, input_lengths: torch.LongTensor
) -> tuple[torch.LongTensor, torch.LongTensor]:
    input_lengths = (input_lengths - 1) // 2 + 1
    output_lengths = (input_lengths - 2) // 2 + 1
    return input_lengths, output_lengths

_prepare_attention_mask

_prepare_attention_mask(
    inputs_tensor: Tensor, cu_seqlens: Tensor
) -> Tensor | None
Source code in vllm/model_executor/models/funaudiochat.py
def _prepare_attention_mask(
    self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor
) -> torch.Tensor | None:
    if getattr(self.config, "_attn_implementation", "eager") == "flash_attention_2":
        return None

    seq_length = inputs_tensor.shape[0]
    attention_mask = torch.full(
        (1, 1, seq_length, seq_length),
        torch.finfo(inputs_tensor.dtype).min,
        device=inputs_tensor.device,
        dtype=inputs_tensor.dtype,
    )
    for i in range(1, len(cu_seqlens)):
        start = int(cu_seqlens[i - 1].item())
        end = int(cu_seqlens[i].item())
        attention_mask[..., start:end, start:end] = 0
    return attention_mask

forward

forward(
    input_features: Tensor,
    feature_lens: Tensor,
    aftercnn_lens: Tensor,
    speech_maxlen: int,
    **kwargs: object,
) -> BaseModelOutput
Source code in vllm/model_executor/models/funaudiochat.py
def forward(
    self,
    input_features: torch.Tensor,
    feature_lens: torch.Tensor,
    aftercnn_lens: torch.Tensor,
    speech_maxlen: int,
    **kwargs: object,
) -> BaseModelOutput:
    # For max-length audio (300s => ~7500 speech frames at 25Hz), the
    # Torch SDPA path can be prohibitively memory hungry (~O(n^2) inside the
    # longest chunks). Require FlashAttention for such inputs to avoid OOM
    # and performance cliffs.
    if int(speech_maxlen) >= 7500:
        if not _has_module("flash_attn"):
            raise RuntimeError(
                "FunAudioChat long audio (~300s) requires FlashAttention-2 "
                "for the continuous audio tower, but `flash_attn` is not "
                "installed in the runtime environment."
            )
        if not getattr(
            self.layers[0].self_attn.attn, "is_flash_attn_backend", False
        ):
            raise RuntimeError(
                "FunAudioChat long audio (~300s) requires FlashAttention for the "
                "continuous audio tower, but the selected MM encoder attention "
                "backend is not FlashAttention."
            )

    # Handle empty / invalid items (feature_lens == 0) without crashing.
    original_batch_size = int(feature_lens.size(0))
    device = input_features.device

    valid_mask = feature_lens > 0
    valid_indices = torch.where(valid_mask)[0]

    if valid_indices.numel() == 0:
        output_dim = int(self.proj.out_features)
        return BaseModelOutput(
            last_hidden_state=torch.zeros(
                (original_batch_size, speech_maxlen, output_dim),
                device=device,
                dtype=self.proj.weight.dtype,
            )
        )

    input_features_list = input_features.split(feature_lens.tolist(), dim=1)
    valid_input_features_list = [input_features_list[int(i)] for i in valid_indices]
    valid_input_features = torch.cat(valid_input_features_list, dim=1)

    valid_feature_lens = feature_lens[valid_mask]
    valid_aftercnn_lens = aftercnn_lens[valid_mask]

    chunk_num = torch.ceil(valid_feature_lens / (self.n_window * 2)).long()

    chunk_lengths_list: list[int] = []
    full_chunk_len = self.n_window * 2
    for i, length in enumerate(valid_feature_lens):
        num_chunks_for_sample = int(chunk_num[i].item())
        if num_chunks_for_sample == 0:
            continue
        chunk_lengths_list.extend([full_chunk_len] * (num_chunks_for_sample - 1))
        last_chunk_len = int(length.item()) % full_chunk_len
        if last_chunk_len == 0:
            last_chunk_len = full_chunk_len
        chunk_lengths_list.append(last_chunk_len)

    chunk_lengths = torch.tensor(
        chunk_lengths_list, dtype=torch.long, device=device
    )

    chunk_list = valid_input_features.split(chunk_lengths.tolist(), dim=1)
    padded_feature, padded_mask, padded_mask_after_cnn = (
        self.padded_and_mask_function(
            chunk_list, chunk_lengths, padding_value=0, padding_side="right"
        )
    )

    padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask
    padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2)

    padded_embed = padded_embed + self.positional_embedding.positional_embedding[
        : padded_embed.shape[1], :
    ].unsqueeze(0).to(padded_embed.dtype)

    hidden_states = padded_embed[padded_mask_after_cnn]
    cu_seqlens = torch.cat(
        (
            torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32),
            padded_mask_after_cnn.sum(1).cumsum(0),
        )
    ).to(torch.int32)

    for encoder_layer in self.layers:
        (hidden_states,) = encoder_layer(
            hidden_states,
            cu_seqlens=cu_seqlens,
            **kwargs,
        )

    hidden_states_list = hidden_states.split(valid_aftercnn_lens.tolist(), dim=0)

    pooled_list: list[torch.Tensor] = []
    pooled_lengths: list[int] = []
    for each_audio_states in hidden_states_list:
        seq_len = int(each_audio_states.shape[0])
        if seq_len >= 2:
            pooled = nn.functional.avg_pool1d(
                each_audio_states.transpose(0, 1), kernel_size=2, stride=2
            ).transpose(0, 1)
        else:
            pooled = each_audio_states
        pooled_list.append(pooled)
        pooled_lengths.append(int(pooled.shape[0]))

    pooled_concat = torch.cat(pooled_list, dim=0)
    processed_concat = self.proj(self.ln_post(pooled_concat))
    processed_audio_list = list(processed_concat.split(pooled_lengths, dim=0))

    output_dim = (
        int(processed_audio_list[0].shape[-1])
        if processed_audio_list
        else int(self.proj.out_features)
    )
    output_hidden_states = torch.zeros(
        (original_batch_size, speech_maxlen, output_dim),
        dtype=processed_audio_list[0].dtype
        if processed_audio_list
        else self.proj.weight.dtype,
        device=device,
    )

    for valid_idx, processed in zip(valid_indices, processed_audio_list):
        seq_len = min(int(processed.shape[0]), int(speech_maxlen))
        output_hidden_states[int(valid_idx), :seq_len] = processed[:seq_len]

    return BaseModelOutput(last_hidden_state=output_hidden_states)

padded_and_mask_function

padded_and_mask_function(
    tensor_list: Sequence[Tensor],
    tensor_len: Tensor,
    padding_value: float = 0.0,
    padding_side: str = "right",
) -> tuple[Tensor, Tensor, Tensor]
Source code in vllm/model_executor/models/funaudiochat.py
def padded_and_mask_function(
    self,
    tensor_list: Sequence[torch.Tensor],
    tensor_len: torch.Tensor,
    padding_value: float = 0.0,
    padding_side: str = "right",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    max_len = int(tensor_len.max().item())
    dim = int(tensor_list[0].shape[0])
    padded_tensor = torch.full(
        size=(len(tensor_list), dim, max_len),
        fill_value=padding_value,
        dtype=self.dtype,
        device=tensor_list[0].device,
    )

    batch_mask = torch.zeros(
        (len(tensor_len), max_len), dtype=torch.long, device=padded_tensor.device
    )
    for i, length in enumerate(tensor_len):
        length_val = int(length.item())
        batch_mask[i, :length_val] = 1
        padded_tensor[i, :, :length_val] = tensor_list[i]

    feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
    max_len_after_cnn = int(feature_lens_after_cnn.max().item())
    batch_mask_after_cnn = torch.zeros(
        (len(tensor_len), max_len_after_cnn),
        dtype=torch.long,
        device=padded_tensor.device,
    )
    for i, length in enumerate(feature_lens_after_cnn):
        batch_mask_after_cnn[i, : int(length.item())] = 1

    if padding_side != "right":
        raise NotImplementedError("Only right padding is supported.")

    return (
        padded_tensor,
        batch_mask.unsqueeze(1).to(padded_tensor.dtype),
        batch_mask_after_cnn.bool(),
    )

FunAudioChatAudioEncoderLayer

Bases: Module

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatAudioEncoderLayer(nn.Module):
    def __init__(self, config: Any):
        super().__init__()
        self.embed_dim = int(config.d_model)
        self.self_attn = FunAudioChatAudioAttention(config)
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = float(config.dropout)
        self.activation_fn = get_activation(str(config.activation_function))
        self.activation_dropout = float(config.activation_dropout)
        self.fc1 = nn.Linear(self.embed_dim, int(config.encoder_ffn_dim))
        self.fc2 = nn.Linear(int(config.encoder_ffn_dim), self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        **kwargs: object,
    ) -> tuple[torch.Tensor]:
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            cu_seqlens=cu_seqlens,
            attention_mask=attention_mask,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.activation_dropout, training=self.training
        )
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
        hidden_states = residual + hidden_states

        return (hidden_states,)

activation_dropout instance-attribute

activation_dropout = float(activation_dropout)

activation_fn instance-attribute

activation_fn = get_activation(str(activation_function))

dropout instance-attribute

dropout = float(dropout)

embed_dim instance-attribute

embed_dim = int(d_model)

fc1 instance-attribute

fc1 = Linear(embed_dim, int(encoder_ffn_dim))

fc2 instance-attribute

fc2 = Linear(int(encoder_ffn_dim), embed_dim)

final_layer_norm instance-attribute

final_layer_norm = LayerNorm(embed_dim)

self_attn instance-attribute

self_attn = FunAudioChatAudioAttention(config)

self_attn_layer_norm instance-attribute

self_attn_layer_norm = LayerNorm(embed_dim)

__init__

__init__(config: Any)
Source code in vllm/model_executor/models/funaudiochat.py
def __init__(self, config: Any):
    super().__init__()
    self.embed_dim = int(config.d_model)
    self.self_attn = FunAudioChatAudioAttention(config)
    self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
    self.dropout = float(config.dropout)
    self.activation_fn = get_activation(str(config.activation_function))
    self.activation_dropout = float(config.activation_dropout)
    self.fc1 = nn.Linear(self.embed_dim, int(config.encoder_ffn_dim))
    self.fc2 = nn.Linear(int(config.encoder_ffn_dim), self.embed_dim)
    self.final_layer_norm = nn.LayerNorm(self.embed_dim)

forward

forward(
    hidden_states: Tensor,
    cu_seqlens: Tensor,
    attention_mask: Tensor | None = None,
    **kwargs: object,
) -> tuple[Tensor]
Source code in vllm/model_executor/models/funaudiochat.py
def forward(
    self,
    hidden_states: torch.Tensor,
    cu_seqlens: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    **kwargs: object,
) -> tuple[torch.Tensor]:
    residual = hidden_states
    hidden_states = self.self_attn_layer_norm(hidden_states)
    hidden_states = self.self_attn(
        hidden_states=hidden_states,
        cu_seqlens=cu_seqlens,
        attention_mask=attention_mask,
        **kwargs,
    )
    hidden_states = residual + hidden_states

    residual = hidden_states
    hidden_states = self.final_layer_norm(hidden_states)
    hidden_states = self.activation_fn(self.fc1(hidden_states))
    hidden_states = nn.functional.dropout(
        hidden_states, p=self.activation_dropout, training=self.training
    )
    hidden_states = self.fc2(hidden_states)
    hidden_states = nn.functional.dropout(
        hidden_states, p=self.dropout, training=self.training
    )
    hidden_states = residual + hidden_states

    return (hidden_states,)

FunAudioChatDiscreteEncoder

Bases: Module

Discrete audio encoder (speech tokenizer -> grouped embeddings).

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatDiscreteEncoder(nn.Module):
    """Discrete audio encoder (speech tokenizer -> grouped embeddings)."""

    def __init__(self, config: Any):
        super().__init__()
        self.padding_idx = int(config.pad_token_id)
        self.group_size = int(config.group_size)
        self.hidden_size = int(config.output_dim)
        self.continuous_features_mode = getattr(
            config, "continuous_features_mode", "add"
        )
        self.embed_tokens = nn.Embedding(
            int(config.codebook_size), self.hidden_size, self.padding_idx
        )
        self.output_matching = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.continual_output_matching = nn.Linear(
            self.hidden_size, self.hidden_size, bias=False
        )

    def forward(
        self,
        audio_ids: torch.Tensor,
        continuous_audio_features: torch.Tensor | None = None,
        continuous_audio_output_lengths: torch.Tensor | None = None,
        feature_exist_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        del continuous_audio_output_lengths

        inputs_embeds = self.embed_tokens(audio_ids)
        hidden_states = inputs_embeds.reshape(
            inputs_embeds.shape[0], -1, self.group_size * self.hidden_size
        )
        hidden_states = hidden_states.reshape(
            hidden_states.shape[0], -1, self.group_size, self.hidden_size
        ).mean(dim=2)
        hidden_states = self.output_matching(hidden_states)

        if continuous_audio_features is not None:
            continuous_audio_features = continuous_audio_features.reshape(
                continuous_audio_features.shape[0],
                -1,
                self.group_size,
                self.hidden_size,
            ).mean(dim=2)
            continuous_audio_hidden_states = self.continual_output_matching(
                continuous_audio_features
            )

            if feature_exist_mask is None:
                feature_exist_mask = torch.ones(
                    (hidden_states.shape[0],),
                    dtype=torch.bool,
                    device=hidden_states.device,
                )
            if self.continuous_features_mode == "add":
                hidden_states[feature_exist_mask] += continuous_audio_hidden_states
            else:
                hidden_states[feature_exist_mask] = continuous_audio_hidden_states

        return hidden_states

    def _get_feat_extract_output_lengths(
        self, input_lengths: torch.LongTensor
    ) -> tuple[torch.LongTensor, torch.LongTensor]:
        output_lengths = (input_lengths + self.group_size - 1) // self.group_size
        return input_lengths, output_lengths

continual_output_matching instance-attribute

continual_output_matching = Linear(
    hidden_size, hidden_size, bias=False
)

continuous_features_mode instance-attribute

continuous_features_mode = getattr(
    config, "continuous_features_mode", "add"
)

embed_tokens instance-attribute

embed_tokens = Embedding(
    int(codebook_size), hidden_size, padding_idx
)

group_size instance-attribute

group_size = int(group_size)

hidden_size instance-attribute

hidden_size = int(output_dim)

output_matching instance-attribute

output_matching = Linear(
    hidden_size, hidden_size, bias=False
)

padding_idx instance-attribute

padding_idx = int(pad_token_id)

__init__

__init__(config: Any)
Source code in vllm/model_executor/models/funaudiochat.py
def __init__(self, config: Any):
    super().__init__()
    self.padding_idx = int(config.pad_token_id)
    self.group_size = int(config.group_size)
    self.hidden_size = int(config.output_dim)
    self.continuous_features_mode = getattr(
        config, "continuous_features_mode", "add"
    )
    self.embed_tokens = nn.Embedding(
        int(config.codebook_size), self.hidden_size, self.padding_idx
    )
    self.output_matching = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
    self.continual_output_matching = nn.Linear(
        self.hidden_size, self.hidden_size, bias=False
    )

_get_feat_extract_output_lengths

_get_feat_extract_output_lengths(
    input_lengths: LongTensor,
) -> tuple[LongTensor, LongTensor]
Source code in vllm/model_executor/models/funaudiochat.py
def _get_feat_extract_output_lengths(
    self, input_lengths: torch.LongTensor
) -> tuple[torch.LongTensor, torch.LongTensor]:
    output_lengths = (input_lengths + self.group_size - 1) // self.group_size
    return input_lengths, output_lengths

forward

forward(
    audio_ids: Tensor,
    continuous_audio_features: Tensor | None = None,
    continuous_audio_output_lengths: Tensor | None = None,
    feature_exist_mask: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/models/funaudiochat.py
def forward(
    self,
    audio_ids: torch.Tensor,
    continuous_audio_features: torch.Tensor | None = None,
    continuous_audio_output_lengths: torch.Tensor | None = None,
    feature_exist_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    del continuous_audio_output_lengths

    inputs_embeds = self.embed_tokens(audio_ids)
    hidden_states = inputs_embeds.reshape(
        inputs_embeds.shape[0], -1, self.group_size * self.hidden_size
    )
    hidden_states = hidden_states.reshape(
        hidden_states.shape[0], -1, self.group_size, self.hidden_size
    ).mean(dim=2)
    hidden_states = self.output_matching(hidden_states)

    if continuous_audio_features is not None:
        continuous_audio_features = continuous_audio_features.reshape(
            continuous_audio_features.shape[0],
            -1,
            self.group_size,
            self.hidden_size,
        ).mean(dim=2)
        continuous_audio_hidden_states = self.continual_output_matching(
            continuous_audio_features
        )

        if feature_exist_mask is None:
            feature_exist_mask = torch.ones(
                (hidden_states.shape[0],),
                dtype=torch.bool,
                device=hidden_states.device,
            )
        if self.continuous_features_mode == "add":
            hidden_states[feature_exist_mask] += continuous_audio_hidden_states
        else:
            hidden_states[feature_exist_mask] = continuous_audio_hidden_states

    return hidden_states

FunAudioChatDummyInputsBuilder

Bases: BaseDummyInputsBuilder[FunAudioChatProcessingInfo]

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatDummyInputsBuilder(
    BaseDummyInputsBuilder[FunAudioChatProcessingInfo]
):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)
        return "<|audio_bos|><|AUDIO|><|audio_eos|>" * int(num_audios)

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
    ) -> MultiModalDataDict:
        feature_extractor = self.info.get_feature_extractor()
        sampling_rate = int(feature_extractor.sampling_rate)

        # Dummy inputs are used for profiling; construct the worst-case audio
        # length that maximizes the number of encoder tokens.
        cfg = self.info.get_hf_config()
        audio_cfg = getattr(cfg, "audio_config", None)
        max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500))
        group_size = self.info.get_audio_group_size()
        token_fps = int(getattr(self.info, "token_fps", 25))
        target_num_frames = max(1, max_audio_tokens) * max(1, group_size)
        audio_len = max(
            1,
            (target_num_frames * sampling_rate + token_fps - 1) // token_fps,
        )
        num_audios = int(mm_counts.get("audio", 0))

        audio_overrides = mm_options.get("audio") if mm_options else None
        return {
            "audio": self._get_dummy_audios(
                length=audio_len,
                num_audios=num_audios,
                overrides=audio_overrides,
            )
        }

get_dummy_mm_data

get_dummy_mm_data(
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Mapping[str, BaseDummyOptions]
    | None = None,
) -> MultiModalDataDict
Source code in vllm/model_executor/models/funaudiochat.py
def get_dummy_mm_data(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
    feature_extractor = self.info.get_feature_extractor()
    sampling_rate = int(feature_extractor.sampling_rate)

    # Dummy inputs are used for profiling; construct the worst-case audio
    # length that maximizes the number of encoder tokens.
    cfg = self.info.get_hf_config()
    audio_cfg = getattr(cfg, "audio_config", None)
    max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500))
    group_size = self.info.get_audio_group_size()
    token_fps = int(getattr(self.info, "token_fps", 25))
    target_num_frames = max(1, max_audio_tokens) * max(1, group_size)
    audio_len = max(
        1,
        (target_num_frames * sampling_rate + token_fps - 1) // token_fps,
    )
    num_audios = int(mm_counts.get("audio", 0))

    audio_overrides = mm_options.get("audio") if mm_options else None
    return {
        "audio": self._get_dummy_audios(
            length=audio_len,
            num_audios=num_audios,
            overrides=audio_overrides,
        )
    }

get_dummy_text

get_dummy_text(mm_counts: Mapping[str, int]) -> str
Source code in vllm/model_executor/models/funaudiochat.py
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
    num_audios = mm_counts.get("audio", 0)
    return "<|audio_bos|><|AUDIO|><|audio_eos|>" * int(num_audios)

FunAudioChatForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP

Source code in vllm/model_executor/models/funaudiochat.py
@MULTIMODAL_REGISTRY.register_processor(
    FunAudioChatMultiModalProcessor,
    info=FunAudioChatProcessingInfo,
    dummy_inputs=FunAudioChatDummyInputsBuilder,
)
class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("audio"):
            return "<|audio_bos|><|AUDIO|><|audio_eos|>"

        raise ValueError("Only audio modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config
        self.quant_config = quant_config

        with self._mark_tower_model(vllm_config, "audio"):
            self.continuous_audio_tower = FunAudioChatAudioEncoder(config.audio_config)
            self.audio_tower = FunAudioChatDiscreteEncoder(config.audio_config)

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen3ForCausalLM"],
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def _get_continuous_audio_features(
        self,
        input_features: torch.Tensor,
        feature_attention_mask: torch.Tensor,
        speech_maxlen: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Align mask and features to avoid indexing errors when padding differs.
        if (
            input_features.dim() == 3
            and feature_attention_mask.shape[1] != input_features.shape[-1]
        ):
            min_len = min(
                int(feature_attention_mask.shape[1]), int(input_features.shape[-1])
            )
            feature_attention_mask = feature_attention_mask[:, :min_len]
            input_features = input_features[:, :, :min_len]

        feature_lens = torch.sum(feature_attention_mask, dim=1)

        flat_features = input_features.permute(0, 2, 1)[
            feature_attention_mask.bool()
        ].permute(1, 0)

        audio_feat_lengths, audio_output_lengths = (
            self.continuous_audio_tower._get_feat_extract_output_lengths(feature_lens)
        )

        audio_outputs = self.continuous_audio_tower(
            flat_features,
            feature_lens=feature_lens,
            aftercnn_lens=audio_feat_lengths,
            speech_maxlen=speech_maxlen,
        )
        return audio_outputs.last_hidden_state, audio_output_lengths

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        speech_ids = kwargs.get("speech_ids")
        speech_attention_mask = kwargs.get("speech_attention_mask")
        input_features = kwargs.get("input_features")
        feature_attention_mask = kwargs.get("feature_attention_mask")
        feature_exist_mask = kwargs.get("feature_exist_mask")

        if speech_ids is None:
            return []

        pad_id = int(getattr(self.audio_tower, "padding_idx", 0))

        if not isinstance(speech_ids, torch.Tensor):
            if (
                isinstance(speech_ids, (list, tuple))
                and len(speech_ids) > 0
                and all(isinstance(t, torch.Tensor) for t in speech_ids)
            ):
                speech_ids_tensors = []
                for t in speech_ids:
                    if t.dim() == 2 and t.shape[0] == 1:
                        t = t.squeeze(0)
                    if t.dim() != 1:
                        raise TypeError(
                            "FunAudioChat speech_ids must be a 1D tensor per item "
                            f"(got shape={tuple(t.shape)})"
                        )
                    speech_ids_tensors.append(t)
                speech_ids = nn.utils.rnn.pad_sequence(
                    speech_ids_tensors,
                    batch_first=True,
                    padding_value=pad_id,
                )
            else:
                raise TypeError(
                    "FunAudioChat speech_ids must be a Tensor or a sequence of Tensors "
                    f"(got {type(speech_ids)})"
                )

        if speech_attention_mask is None:
            speech_attention_mask = speech_ids.ne(pad_id).to(dtype=torch.int64)

        if not isinstance(speech_attention_mask, torch.Tensor):
            if (
                isinstance(speech_attention_mask, (list, tuple))
                and len(speech_attention_mask) > 0
                and all(isinstance(t, torch.Tensor) for t in speech_attention_mask)
            ):
                mask_tensors = []
                for t in speech_attention_mask:
                    if t.dim() == 2 and t.shape[0] == 1:
                        t = t.squeeze(0)
                    if t.dim() != 1:
                        raise TypeError(
                            "FunAudioChat speech_attention_mask must be a 1D tensor "
                            f"per item (got shape={tuple(t.shape)})"
                        )
                    mask_tensors.append(t)
                speech_attention_mask = nn.utils.rnn.pad_sequence(
                    mask_tensors,
                    batch_first=True,
                    padding_value=0,
                )
            else:
                raise TypeError(
                    "FunAudioChat speech_attention_mask must be a Tensor or a "
                    f"sequence of Tensors (got {type(speech_attention_mask)})"
                )

        debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1"
        if debug:
            print(
                f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} "
                f"speech_attention_mask={tuple(speech_attention_mask.shape)}",
                flush=True,
            )
            attn_impl = getattr(
                self.continuous_audio_tower.config, "_attn_implementation", None
            )
            print(
                f"[FunAudioChat] audio_attn_impl={attn_impl}",
                flush=True,
            )
            if hasattr(self.continuous_audio_tower, "conv1"):
                conv1_w = self.continuous_audio_tower.conv1.weight
                print(
                    f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}",
                    flush=True,
                )
            try:
                attn0 = self.continuous_audio_tower.layers[0].self_attn
                q_norm = float(attn0.q_proj.weight.norm().item())
                k_norm = float(attn0.k_proj.weight.norm().item())
                v_norm = float(attn0.v_proj.weight.norm().item())
                o_norm = float(attn0.out_proj.weight.norm().item())
                print(
                    f"[FunAudioChat] attn0_q_norm={q_norm:.6g} "
                    f"k_norm={k_norm:.6g} "
                    f"v_norm={v_norm:.6g} "
                    f"o_norm={o_norm:.6g}",
                    flush=True,
                )
            except Exception:
                pass
            if isinstance(input_features, torch.Tensor):
                print(
                    f"[FunAudioChat] input_features={tuple(input_features.shape)}",
                    flush=True,
                )
            if isinstance(feature_attention_mask, torch.Tensor):
                print(
                    "[FunAudioChat] feature_attention_mask="
                    f"{tuple(feature_attention_mask.shape)}",
                    flush=True,
                )

        group_size = int(self.audio_tower.group_size)
        speech_maxlen = int(speech_ids.shape[-1])

        # Ensure token length is divisible by group_size.
        target_len = ((speech_maxlen + group_size - 1) // group_size) * group_size
        if target_len > speech_maxlen:
            pad_id = int(self.audio_tower.padding_idx)
            pad_len = target_len - speech_maxlen
            speech_ids = nn.functional.pad(speech_ids, (0, pad_len), value=pad_id)
            speech_attention_mask = nn.functional.pad(
                speech_attention_mask, (0, pad_len), value=0
            )
            speech_maxlen = int(speech_ids.shape[-1])

        continuous_audio_features = None
        continuous_audio_output_lengths = None
        if input_features is not None and feature_attention_mask is not None:
            assert isinstance(input_features, torch.Tensor)
            assert isinstance(feature_attention_mask, torch.Tensor)
            continuous_audio_features, continuous_audio_output_lengths = (
                self._get_continuous_audio_features(
                    input_features=input_features,
                    feature_attention_mask=feature_attention_mask,
                    speech_maxlen=speech_maxlen,
                )
            )

        if feature_exist_mask is None:
            feature_exist_mask = torch.ones(
                (speech_ids.shape[0],), dtype=torch.bool, device=speech_ids.device
            )
        assert isinstance(feature_exist_mask, torch.Tensor)

        audio_features = self.audio_tower(
            speech_ids,
            continuous_audio_features=continuous_audio_features,
            continuous_audio_output_lengths=continuous_audio_output_lengths,
            feature_exist_mask=feature_exist_mask,
        )

        _, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
            speech_attention_mask.sum(-1)
        )
        lengths = audio_output_lengths.tolist()

        embeds = tuple(
            audio_features[i, : int(length)] for i, length in enumerate(lengths)
        )
        if debug:
            embed_lens = [int(t.shape[0]) for t in embeds]
            print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True)
            if embeds:
                t0 = embeds[0]
                print(
                    f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} "
                    f"nan={bool(torch.isnan(t0).any())} "
                    f"norm={float(t0.norm().item()):.6g}",
                    flush=True,
                )
            dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "")
            if (
                dump_path
                and speech_ids.shape[0] == 1
                and len(embeds) == 1
                and embed_lens[0] > 10
            ):
                if not os.path.exists(dump_path):
                    np.save(dump_path, embeds[0].detach().float().cpu().numpy())
                    print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True)
                cont_path = dump_path.replace(".npy", "_cont.npy")
                if continuous_audio_features is not None and not os.path.exists(
                    cont_path
                ):
                    np.save(
                        cont_path,
                        continuous_audio_features.detach().float().cpu().numpy(),
                    )
                    print(
                        f"[FunAudioChat] dumped continuous to {cont_path}", flush=True
                    )
        return embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        del kwargs
        if intermediate_tensors is not None:
            inputs_embeds = None

        return self.language_model.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, skip_prefixes=["audio_invert_tower."])
        return loader.load_weights(weights)

audio_tower instance-attribute

audio_tower = FunAudioChatDiscreteEncoder(audio_config)

config instance-attribute

config = config

continuous_audio_tower instance-attribute

continuous_audio_tower = FunAudioChatAudioEncoder(
    audio_config
)

language_model instance-attribute

language_model = init_vllm_registered_model(
    vllm_config=vllm_config,
    hf_config=text_config,
    prefix=maybe_prefix(prefix, "language_model"),
    architectures=["Qwen3ForCausalLM"],
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

multimodal_config instance-attribute

multimodal_config = multimodal_config

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/funaudiochat.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    multimodal_config = vllm_config.model_config.multimodal_config
    self.config = config
    self.multimodal_config = multimodal_config
    self.quant_config = quant_config

    with self._mark_tower_model(vllm_config, "audio"):
        self.continuous_audio_tower = FunAudioChatAudioEncoder(config.audio_config)
        self.audio_tower = FunAudioChatDiscreteEncoder(config.audio_config)

    with self._mark_language_model(vllm_config):
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen3ForCausalLM"],
        )

    self.make_empty_intermediate_tensors = (
        self.language_model.make_empty_intermediate_tensors
    )

_get_continuous_audio_features

_get_continuous_audio_features(
    input_features: Tensor,
    feature_attention_mask: Tensor,
    speech_maxlen: int,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/funaudiochat.py
def _get_continuous_audio_features(
    self,
    input_features: torch.Tensor,
    feature_attention_mask: torch.Tensor,
    speech_maxlen: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    # Align mask and features to avoid indexing errors when padding differs.
    if (
        input_features.dim() == 3
        and feature_attention_mask.shape[1] != input_features.shape[-1]
    ):
        min_len = min(
            int(feature_attention_mask.shape[1]), int(input_features.shape[-1])
        )
        feature_attention_mask = feature_attention_mask[:, :min_len]
        input_features = input_features[:, :, :min_len]

    feature_lens = torch.sum(feature_attention_mask, dim=1)

    flat_features = input_features.permute(0, 2, 1)[
        feature_attention_mask.bool()
    ].permute(1, 0)

    audio_feat_lengths, audio_output_lengths = (
        self.continuous_audio_tower._get_feat_extract_output_lengths(feature_lens)
    )

    audio_outputs = self.continuous_audio_tower(
        flat_features,
        feature_lens=feature_lens,
        aftercnn_lens=audio_feat_lengths,
        speech_maxlen=speech_maxlen,
    )
    return audio_outputs.last_hidden_state, audio_output_lengths

compute_logits

compute_logits(hidden_states: Tensor) -> Tensor | None
Source code in vllm/model_executor/models/funaudiochat.py
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
    return self.language_model.compute_logits(hidden_states)

embed_multimodal

embed_multimodal(**kwargs: object) -> MultiModalEmbeddings
Source code in vllm/model_executor/models/funaudiochat.py
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
    speech_ids = kwargs.get("speech_ids")
    speech_attention_mask = kwargs.get("speech_attention_mask")
    input_features = kwargs.get("input_features")
    feature_attention_mask = kwargs.get("feature_attention_mask")
    feature_exist_mask = kwargs.get("feature_exist_mask")

    if speech_ids is None:
        return []

    pad_id = int(getattr(self.audio_tower, "padding_idx", 0))

    if not isinstance(speech_ids, torch.Tensor):
        if (
            isinstance(speech_ids, (list, tuple))
            and len(speech_ids) > 0
            and all(isinstance(t, torch.Tensor) for t in speech_ids)
        ):
            speech_ids_tensors = []
            for t in speech_ids:
                if t.dim() == 2 and t.shape[0] == 1:
                    t = t.squeeze(0)
                if t.dim() != 1:
                    raise TypeError(
                        "FunAudioChat speech_ids must be a 1D tensor per item "
                        f"(got shape={tuple(t.shape)})"
                    )
                speech_ids_tensors.append(t)
            speech_ids = nn.utils.rnn.pad_sequence(
                speech_ids_tensors,
                batch_first=True,
                padding_value=pad_id,
            )
        else:
            raise TypeError(
                "FunAudioChat speech_ids must be a Tensor or a sequence of Tensors "
                f"(got {type(speech_ids)})"
            )

    if speech_attention_mask is None:
        speech_attention_mask = speech_ids.ne(pad_id).to(dtype=torch.int64)

    if not isinstance(speech_attention_mask, torch.Tensor):
        if (
            isinstance(speech_attention_mask, (list, tuple))
            and len(speech_attention_mask) > 0
            and all(isinstance(t, torch.Tensor) for t in speech_attention_mask)
        ):
            mask_tensors = []
            for t in speech_attention_mask:
                if t.dim() == 2 and t.shape[0] == 1:
                    t = t.squeeze(0)
                if t.dim() != 1:
                    raise TypeError(
                        "FunAudioChat speech_attention_mask must be a 1D tensor "
                        f"per item (got shape={tuple(t.shape)})"
                    )
                mask_tensors.append(t)
            speech_attention_mask = nn.utils.rnn.pad_sequence(
                mask_tensors,
                batch_first=True,
                padding_value=0,
            )
        else:
            raise TypeError(
                "FunAudioChat speech_attention_mask must be a Tensor or a "
                f"sequence of Tensors (got {type(speech_attention_mask)})"
            )

    debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1"
    if debug:
        print(
            f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} "
            f"speech_attention_mask={tuple(speech_attention_mask.shape)}",
            flush=True,
        )
        attn_impl = getattr(
            self.continuous_audio_tower.config, "_attn_implementation", None
        )
        print(
            f"[FunAudioChat] audio_attn_impl={attn_impl}",
            flush=True,
        )
        if hasattr(self.continuous_audio_tower, "conv1"):
            conv1_w = self.continuous_audio_tower.conv1.weight
            print(
                f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}",
                flush=True,
            )
        try:
            attn0 = self.continuous_audio_tower.layers[0].self_attn
            q_norm = float(attn0.q_proj.weight.norm().item())
            k_norm = float(attn0.k_proj.weight.norm().item())
            v_norm = float(attn0.v_proj.weight.norm().item())
            o_norm = float(attn0.out_proj.weight.norm().item())
            print(
                f"[FunAudioChat] attn0_q_norm={q_norm:.6g} "
                f"k_norm={k_norm:.6g} "
                f"v_norm={v_norm:.6g} "
                f"o_norm={o_norm:.6g}",
                flush=True,
            )
        except Exception:
            pass
        if isinstance(input_features, torch.Tensor):
            print(
                f"[FunAudioChat] input_features={tuple(input_features.shape)}",
                flush=True,
            )
        if isinstance(feature_attention_mask, torch.Tensor):
            print(
                "[FunAudioChat] feature_attention_mask="
                f"{tuple(feature_attention_mask.shape)}",
                flush=True,
            )

    group_size = int(self.audio_tower.group_size)
    speech_maxlen = int(speech_ids.shape[-1])

    # Ensure token length is divisible by group_size.
    target_len = ((speech_maxlen + group_size - 1) // group_size) * group_size
    if target_len > speech_maxlen:
        pad_id = int(self.audio_tower.padding_idx)
        pad_len = target_len - speech_maxlen
        speech_ids = nn.functional.pad(speech_ids, (0, pad_len), value=pad_id)
        speech_attention_mask = nn.functional.pad(
            speech_attention_mask, (0, pad_len), value=0
        )
        speech_maxlen = int(speech_ids.shape[-1])

    continuous_audio_features = None
    continuous_audio_output_lengths = None
    if input_features is not None and feature_attention_mask is not None:
        assert isinstance(input_features, torch.Tensor)
        assert isinstance(feature_attention_mask, torch.Tensor)
        continuous_audio_features, continuous_audio_output_lengths = (
            self._get_continuous_audio_features(
                input_features=input_features,
                feature_attention_mask=feature_attention_mask,
                speech_maxlen=speech_maxlen,
            )
        )

    if feature_exist_mask is None:
        feature_exist_mask = torch.ones(
            (speech_ids.shape[0],), dtype=torch.bool, device=speech_ids.device
        )
    assert isinstance(feature_exist_mask, torch.Tensor)

    audio_features = self.audio_tower(
        speech_ids,
        continuous_audio_features=continuous_audio_features,
        continuous_audio_output_lengths=continuous_audio_output_lengths,
        feature_exist_mask=feature_exist_mask,
    )

    _, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
        speech_attention_mask.sum(-1)
    )
    lengths = audio_output_lengths.tolist()

    embeds = tuple(
        audio_features[i, : int(length)] for i, length in enumerate(lengths)
    )
    if debug:
        embed_lens = [int(t.shape[0]) for t in embeds]
        print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True)
        if embeds:
            t0 = embeds[0]
            print(
                f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} "
                f"nan={bool(torch.isnan(t0).any())} "
                f"norm={float(t0.norm().item()):.6g}",
                flush=True,
            )
        dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "")
        if (
            dump_path
            and speech_ids.shape[0] == 1
            and len(embeds) == 1
            and embed_lens[0] > 10
        ):
            if not os.path.exists(dump_path):
                np.save(dump_path, embeds[0].detach().float().cpu().numpy())
                print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True)
            cont_path = dump_path.replace(".npy", "_cont.npy")
            if continuous_audio_features is not None and not os.path.exists(
                cont_path
            ):
                np.save(
                    cont_path,
                    continuous_audio_features.detach().float().cpu().numpy(),
                )
                print(
                    f"[FunAudioChat] dumped continuous to {cont_path}", flush=True
                )
    return embeds

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs: object,
) -> Tensor | IntermediateTensors
Source code in vllm/model_executor/models/funaudiochat.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs: object,
) -> torch.Tensor | IntermediateTensors:
    del kwargs
    if intermediate_tensors is not None:
        inputs_embeds = None

    return self.language_model.model(
        input_ids,
        positions,
        intermediate_tensors,
        inputs_embeds=inputs_embeds,
    )

get_language_model

get_language_model() -> Module
Source code in vllm/model_executor/models/funaudiochat.py
def get_language_model(self) -> torch.nn.Module:
    return self.language_model

get_placeholder_str classmethod

get_placeholder_str(modality: str, i: int) -> str | None
Source code in vllm/model_executor/models/funaudiochat.py
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
    if modality.startswith("audio"):
        return "<|audio_bos|><|AUDIO|><|audio_eos|>"

    raise ValueError("Only audio modality is supported")

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/funaudiochat.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self, skip_prefixes=["audio_invert_tower."])
    return loader.load_weights(weights)

FunAudioChatMultiModalProcessor

Bases: BaseMultiModalProcessor[FunAudioChatProcessingInfo]

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatMultiModalProcessor(
    BaseMultiModalProcessor[FunAudioChatProcessingInfo]
):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
        input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)])

        audios = mm_data.get("audios", [])
        if not audios:
            return BatchFeature({"input_ids": input_ids})

        feature_extractor = self.info.get_feature_extractor()
        sr = int(feature_extractor.sampling_rate)
        min_samples = int(getattr(feature_extractor, "n_fft", 400) or 400)

        wavs: list[np.ndarray] = []
        speech_strs: list[str] = []

        speech_tokenizer = self.info.get_speech_tokenizer()
        pad_token = speech_tokenizer.pad_token or "<|audio_pad|>"
        for audio in audios:
            if isinstance(audio, torch.Tensor):
                audio = audio.detach().cpu().numpy()
            audio_np = np.asarray(audio, dtype=np.float32)

            if min_samples > 0 and audio_np.shape[0] < min_samples:
                audio_np = np.pad(
                    audio_np, (0, min_samples - audio_np.shape[0]), mode="constant"
                )

            wavs.append(audio_np)
            num_frames = int(
                (float(audio_np.shape[0]) / float(sr)) * float(self.info.token_fps)
            )
            speech_strs.append(pad_token * max(1, int(num_frames)))

        audio_group_size = self.info.get_audio_group_size()
        speech_inputs = speech_tokenizer(
            speech_strs,
            return_attention_mask=True,
            return_token_type_ids=False,
            padding=True,
            pad_to_multiple_of=audio_group_size,
            return_tensors="pt",
        )

        wav_inputs = feature_extractor(
            wavs,
            sampling_rate=sr,
            return_attention_mask=True,
            padding="max_length",
            return_tensors="pt",
        )

        mm_inputs: dict[str, torch.Tensor] = {
            "speech_ids": speech_inputs["input_ids"],
            "speech_attention_mask": speech_inputs["attention_mask"],
            "input_features": wav_inputs["input_features"],
            "feature_attention_mask": wav_inputs["attention_mask"],
            "feature_exist_mask": torch.ones((len(wavs),), dtype=torch.bool),
        }

        return BatchFeature({"input_ids": input_ids, **mm_inputs})

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return {
            "speech_ids": MultiModalFieldConfig.batched("audio"),
            "speech_attention_mask": MultiModalFieldConfig.batched("audio"),
            "input_features": MultiModalFieldConfig.batched("audio"),
            "feature_attention_mask": MultiModalFieldConfig.batched("audio"),
            "feature_exist_mask": MultiModalFieldConfig.batched("audio"),
        }

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        audio_token = "<|AUDIO|>"
        audio_token_id = vocab[audio_token]

        out_mm_data = out_mm_kwargs.get_data()
        speech_attention_mask = out_mm_data.get("speech_attention_mask")
        if speech_attention_mask is None:
            audio_output_lengths: list[int] = []
        else:
            assert isinstance(speech_attention_mask, torch.Tensor)
            speech_lengths = speech_attention_mask.sum(-1)
            group_size = self.info.get_audio_group_size()
            audio_output_lengths = (
                (speech_lengths + group_size - 1) // group_size
            ).tolist()

        def get_replacement_funaudiochat(item_idx: int):
            num_features = (
                int(audio_output_lengths[item_idx]) if audio_output_lengths else 1
            )
            if num_features <= 0:
                audios = mm_items.get_items("audio", AudioProcessorItems)
                audio_len = audios.get_audio_length(item_idx)
                raise ValueError(
                    f"The audio (len={audio_len}) is too short to be "
                    "represented inside the model"
                )

            audio_tokens = [audio_token_id] * num_features
            return PromptUpdateDetails.select_token_id(
                audio_tokens,
                embed_token_id=audio_token_id,
            )

        return [
            PromptReplacement(
                modality="audio",
                target=audio_token,
                replacement=get_replacement_funaudiochat,
            )
        ]

_call_hf_processor

_call_hf_processor(
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature
Source code in vllm/model_executor/models/funaudiochat.py
def _call_hf_processor(
    self,
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature:
    tokenizer = self.info.get_tokenizer()
    input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)])

    audios = mm_data.get("audios", [])
    if not audios:
        return BatchFeature({"input_ids": input_ids})

    feature_extractor = self.info.get_feature_extractor()
    sr = int(feature_extractor.sampling_rate)
    min_samples = int(getattr(feature_extractor, "n_fft", 400) or 400)

    wavs: list[np.ndarray] = []
    speech_strs: list[str] = []

    speech_tokenizer = self.info.get_speech_tokenizer()
    pad_token = speech_tokenizer.pad_token or "<|audio_pad|>"
    for audio in audios:
        if isinstance(audio, torch.Tensor):
            audio = audio.detach().cpu().numpy()
        audio_np = np.asarray(audio, dtype=np.float32)

        if min_samples > 0 and audio_np.shape[0] < min_samples:
            audio_np = np.pad(
                audio_np, (0, min_samples - audio_np.shape[0]), mode="constant"
            )

        wavs.append(audio_np)
        num_frames = int(
            (float(audio_np.shape[0]) / float(sr)) * float(self.info.token_fps)
        )
        speech_strs.append(pad_token * max(1, int(num_frames)))

    audio_group_size = self.info.get_audio_group_size()
    speech_inputs = speech_tokenizer(
        speech_strs,
        return_attention_mask=True,
        return_token_type_ids=False,
        padding=True,
        pad_to_multiple_of=audio_group_size,
        return_tensors="pt",
    )

    wav_inputs = feature_extractor(
        wavs,
        sampling_rate=sr,
        return_attention_mask=True,
        padding="max_length",
        return_tensors="pt",
    )

    mm_inputs: dict[str, torch.Tensor] = {
        "speech_ids": speech_inputs["input_ids"],
        "speech_attention_mask": speech_inputs["attention_mask"],
        "input_features": wav_inputs["input_features"],
        "feature_attention_mask": wav_inputs["attention_mask"],
        "feature_exist_mask": torch.ones((len(wavs),), dtype=torch.bool),
    }

    return BatchFeature({"input_ids": input_ids, **mm_inputs})

_get_mm_fields_config

_get_mm_fields_config(
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]
Source code in vllm/model_executor/models/funaudiochat.py
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return {
        "speech_ids": MultiModalFieldConfig.batched("audio"),
        "speech_attention_mask": MultiModalFieldConfig.batched("audio"),
        "input_features": MultiModalFieldConfig.batched("audio"),
        "feature_attention_mask": MultiModalFieldConfig.batched("audio"),
        "feature_exist_mask": MultiModalFieldConfig.batched("audio"),
    }

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]
Source code in vllm/model_executor/models/funaudiochat.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
    tokenizer = self.info.get_tokenizer()
    vocab = tokenizer.get_vocab()

    audio_token = "<|AUDIO|>"
    audio_token_id = vocab[audio_token]

    out_mm_data = out_mm_kwargs.get_data()
    speech_attention_mask = out_mm_data.get("speech_attention_mask")
    if speech_attention_mask is None:
        audio_output_lengths: list[int] = []
    else:
        assert isinstance(speech_attention_mask, torch.Tensor)
        speech_lengths = speech_attention_mask.sum(-1)
        group_size = self.info.get_audio_group_size()
        audio_output_lengths = (
            (speech_lengths + group_size - 1) // group_size
        ).tolist()

    def get_replacement_funaudiochat(item_idx: int):
        num_features = (
            int(audio_output_lengths[item_idx]) if audio_output_lengths else 1
        )
        if num_features <= 0:
            audios = mm_items.get_items("audio", AudioProcessorItems)
            audio_len = audios.get_audio_length(item_idx)
            raise ValueError(
                f"The audio (len={audio_len}) is too short to be "
                "represented inside the model"
            )

        audio_tokens = [audio_token_id] * num_features
        return PromptUpdateDetails.select_token_id(
            audio_tokens,
            embed_token_id=audio_token_id,
        )

    return [
        PromptReplacement(
            modality="audio",
            target=audio_token,
            replacement=get_replacement_funaudiochat,
        )
    ]

_hf_processor_applies_updates

_hf_processor_applies_updates(
    prompt_text: str,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object],
) -> bool
Source code in vllm/model_executor/models/funaudiochat.py
def _hf_processor_applies_updates(
    self,
    prompt_text: str,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object],
) -> bool:
    return False

FunAudioChatProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatProcessingInfo(BaseProcessingInfo):
    token_fps: int = 25

    @cached_property
    def feature_extractor(self) -> WhisperFeatureExtractor:
        return WhisperFeatureExtractor.from_pretrained(self.model_id)

    @cached_property
    def speech_tokenizer(self) -> PreTrainedTokenizerFast:
        return PreTrainedTokenizerFast.from_pretrained(
            self.model_id, subfolder="speech_tokenizer"
        )

    def get_feature_extractor(self) -> WhisperFeatureExtractor:
        return self.feature_extractor

    def get_speech_tokenizer(self) -> PreTrainedTokenizerFast:
        return self.speech_tokenizer

    def get_data_parser(self):
        return MultiModalDataParser(
            target_sr=int(self.feature_extractor.sampling_rate),
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"audio": None}

    def get_target_channels(self) -> int:
        return 1

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int] | None:
        # The discrete audio encoder downsamples 25Hz frames with group_size=5,
        # so for a 300s clip the max number of `<|AUDIO|>` placeholders is 1500.
        cfg = self.get_hf_config()
        audio_cfg = getattr(cfg, "audio_config", None)
        max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500))
        return {"audio": max_audio_tokens}

    def get_audio_group_size(self) -> int:
        cfg = self.get_hf_config()
        audio_cfg = getattr(cfg, "audio_config", None)
        return int(getattr(audio_cfg, "group_size", 5))

feature_extractor cached property

feature_extractor: WhisperFeatureExtractor

speech_tokenizer cached property

speech_tokenizer: PreTrainedTokenizerFast

token_fps class-attribute instance-attribute

token_fps: int = 25

get_audio_group_size

get_audio_group_size() -> int
Source code in vllm/model_executor/models/funaudiochat.py
def get_audio_group_size(self) -> int:
    cfg = self.get_hf_config()
    audio_cfg = getattr(cfg, "audio_config", None)
    return int(getattr(audio_cfg, "group_size", 5))

get_data_parser

get_data_parser()
Source code in vllm/model_executor/models/funaudiochat.py
def get_data_parser(self):
    return MultiModalDataParser(
        target_sr=int(self.feature_extractor.sampling_rate),
        target_channels=self.get_target_channels(),
        expected_hidden_size=self._get_expected_hidden_size(),
    )

get_feature_extractor

get_feature_extractor() -> WhisperFeatureExtractor
Source code in vllm/model_executor/models/funaudiochat.py
def get_feature_extractor(self) -> WhisperFeatureExtractor:
    return self.feature_extractor

get_mm_max_tokens_per_item

get_mm_max_tokens_per_item(
    seq_len: int, mm_counts: Mapping[str, int]
) -> Mapping[str, int] | None
Source code in vllm/model_executor/models/funaudiochat.py
def get_mm_max_tokens_per_item(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> Mapping[str, int] | None:
    # The discrete audio encoder downsamples 25Hz frames with group_size=5,
    # so for a 300s clip the max number of `<|AUDIO|>` placeholders is 1500.
    cfg = self.get_hf_config()
    audio_cfg = getattr(cfg, "audio_config", None)
    max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500))
    return {"audio": max_audio_tokens}

get_speech_tokenizer

get_speech_tokenizer() -> PreTrainedTokenizerFast
Source code in vllm/model_executor/models/funaudiochat.py
def get_speech_tokenizer(self) -> PreTrainedTokenizerFast:
    return self.speech_tokenizer

get_supported_mm_limits

get_supported_mm_limits() -> Mapping[str, int | None]
Source code in vllm/model_executor/models/funaudiochat.py
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
    return {"audio": None}

get_target_channels

get_target_channels() -> int
Source code in vllm/model_executor/models/funaudiochat.py
def get_target_channels(self) -> int:
    return 1

_SinusoidsPositionEmbedding

Bases: Module

Source code in vllm/model_executor/models/funaudiochat.py
class _SinusoidsPositionEmbedding(nn.Module):
    def __init__(self, length: int, channels: int, max_timescale: float = 10000.0):
        super().__init__()
        if channels % 2 != 0:
            raise ValueError("SinusoidsPositionEmbedding needs even channels input")

        log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
        inv_timescales = torch.exp(
            -log_timescale_increment * torch.arange(channels // 2).float()
        )
        scaled_time = (
            torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
        )
        self.register_buffer(
            "positional_embedding",
            torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
            persistent=False,
        )

__init__

__init__(
    length: int,
    channels: int,
    max_timescale: float = 10000.0,
)
Source code in vllm/model_executor/models/funaudiochat.py
def __init__(self, length: int, channels: int, max_timescale: float = 10000.0):
    super().__init__()
    if channels % 2 != 0:
        raise ValueError("SinusoidsPositionEmbedding needs even channels input")

    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(
        -log_timescale_increment * torch.arange(channels // 2).float()
    )
    scaled_time = (
        torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    )
    self.register_buffer(
        "positional_embedding",
        torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
        persistent=False,
    )