Skip to content

vllm.model_executor.models.funaudiochat

Inference-only FunAudioChat model compatible with HuggingFace weights.

FunAudioChat is a Qwen3 text model augmented with
  • a continuous audio encoder (Whisper-mel frontend + transformer)
  • a discrete audio encoder (speech tokenizer + projector)

In the HF implementation, audio features are scattered into <|AUDIO|> token positions via inputs_embeds, while position_ids (RoPE) remains standard 1D.

FunAudioChatAudioAttention

Bases: Module

Multi-headed attention used inside the continuous audio tower.

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatAudioAttention(nn.Module):
    """Multi-headed attention used inside the continuous audio tower."""

    def __init__(self, config: Any):
        super().__init__()
        self.embed_dim = int(config.d_model)
        self.total_num_heads = int(config.encoder_attention_heads)
        self.dropout = float(getattr(config, "attention_dropout", 0.0))
        self.head_dim = self.embed_dim // self.total_num_heads
        self.num_key_value_groups = 1  # needed for eager attention
        self.config = config

        if self.head_dim * self.total_num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got embed_dim={self.embed_dim}, "
                f"num_heads={self.total_num_heads})."
            )
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = 0.0
        self.is_decoder = False
        self.is_causal = False

        self.qkv_proj = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.total_num_heads,
            bias=True,
        )
        self.num_heads = self.qkv_proj.num_heads
        self.num_kv_heads = self.qkv_proj.num_kv_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.attn = MMEncoderAttention(
            num_heads=self.num_heads,
            head_size=self.head_dim,
            scale=self.scaling,
            num_kv_heads=self.num_kv_heads,
            prefix="funaudiochat_audio_tower.attn",
        )
        self.out_proj = RowParallelLinear(
            self.embed_dim,
            self.embed_dim,
            bias=True,
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters())
        with torch.no_grad():
            if self.qkv_proj.bias is not None:
                # HF FunAudioChat uses bias=False for k_proj. Ensure the missing
                # shard starts as zeros, while allowing q/v shards to load.
                self.qkv_proj.bias.zero_()

        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)

            loaded_params.add(name)

        return loaded_params

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor:
        del kwargs
        del attention_mask
        seq_length, _ = hidden_states.size()

        qkv, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv.split(
            [self.q_size, self.kv_size, self.kv_size], dim=-1
        )

        max_seqlen: torch.Tensor | None = None
        if cu_seqlens is not None:
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

        attn_output = self.attn(
            query_states.reshape(1, seq_length, self.q_size),
            key_states.reshape(1, seq_length, self.kv_size),
            value_states.reshape(1, seq_length, self.kv_size),
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        ).reshape(seq_length, -1)

        output, _ = self.out_proj(attn_output)
        return output

FunAudioChatAudioEncoder

Bases: Module

Continuous audio tower.

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatAudioEncoder(nn.Module):
    """Continuous audio tower."""

    def __init__(self, config: Any):
        super().__init__()
        self.config = config

        embed_dim = int(config.d_model)
        self.num_mel_bins = int(config.num_mel_bins)
        self.max_source_positions = int(config.max_source_positions)
        self.embed_scale = (embed_dim**0.5) if bool(config.scale_embedding) else 1.0
        self.n_window = int(config.n_window)

        self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
        self.layers = nn.ModuleList(
            [
                FunAudioChatAudioEncoderLayer(config)
                for _ in range(int(config.encoder_layers))
            ]
        )
        self.ln_post = nn.LayerNorm(embed_dim)
        self.avg_pooler = nn.AvgPool1d(2, stride=2)
        self.proj = nn.Linear(embed_dim, int(config.output_dim))
        self.positional_embedding = _SinusoidsPositionEmbedding(
            self.max_source_positions, embed_dim
        )

        # Present in HF weights even if unused during S2T.
        self.audio_bos_eos_token = nn.Embedding(2, int(config.output_dim))

    @property
    def dtype(self) -> torch.dtype:
        return self.conv1.weight.dtype

    def _prepare_attention_mask(
        self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor
    ) -> torch.Tensor | None:
        if getattr(self.config, "_attn_implementation", "eager") == "flash_attention_2":
            return None

        seq_length = inputs_tensor.shape[0]
        attention_mask = torch.full(
            (1, 1, seq_length, seq_length),
            torch.finfo(inputs_tensor.dtype).min,
            device=inputs_tensor.device,
            dtype=inputs_tensor.dtype,
        )
        for i in range(1, len(cu_seqlens)):
            start = int(cu_seqlens[i - 1].item())
            end = int(cu_seqlens[i].item())
            attention_mask[..., start:end, start:end] = 0
        return attention_mask

    def forward(
        self,
        input_features: torch.Tensor,
        feature_lens: torch.Tensor,
        aftercnn_lens: torch.Tensor,
        speech_maxlen: int,
        **kwargs: object,
    ) -> BaseModelOutput:
        # For max-length audio (300s => ~7500 speech frames at 25Hz), the
        # Torch SDPA path can be prohibitively memory hungry (~O(n^2) inside the
        # longest chunks). Require FlashAttention for such inputs to avoid OOM
        # and performance cliffs.
        if int(speech_maxlen) >= 7500:
            if not _has_module("flash_attn"):
                raise RuntimeError(
                    "FunAudioChat long audio (~300s) requires FlashAttention-2 "
                    "for the continuous audio tower, but `flash_attn` is not "
                    "installed in the runtime environment."
                )
            if not getattr(
                self.layers[0].self_attn.attn, "is_flash_attn_backend", False
            ):
                raise RuntimeError(
                    "FunAudioChat long audio (~300s) requires FlashAttention for the "
                    "continuous audio tower, but the selected MM encoder attention "
                    "backend is not FlashAttention."
                )

        # Handle empty / invalid items (feature_lens == 0) without crashing.
        original_batch_size = int(feature_lens.size(0))
        device = input_features.device

        valid_mask = feature_lens > 0
        valid_indices = torch.where(valid_mask)[0]

        if valid_indices.numel() == 0:
            output_dim = int(self.proj.out_features)
            return BaseModelOutput(
                last_hidden_state=torch.zeros(
                    (original_batch_size, speech_maxlen, output_dim),
                    device=device,
                    dtype=self.proj.weight.dtype,
                )
            )

        input_features_list = input_features.split(feature_lens.tolist(), dim=1)
        valid_input_features_list = [input_features_list[int(i)] for i in valid_indices]
        valid_input_features = torch.cat(valid_input_features_list, dim=1)

        valid_feature_lens = feature_lens[valid_mask]
        valid_aftercnn_lens = aftercnn_lens[valid_mask]

        chunk_num = torch.ceil(valid_feature_lens / (self.n_window * 2)).long()

        chunk_lengths_list: list[int] = []
        full_chunk_len = self.n_window * 2
        for i, length in enumerate(valid_feature_lens):
            num_chunks_for_sample = int(chunk_num[i].item())
            if num_chunks_for_sample == 0:
                continue
            chunk_lengths_list.extend([full_chunk_len] * (num_chunks_for_sample - 1))
            last_chunk_len = int(length.item()) % full_chunk_len
            if last_chunk_len == 0:
                last_chunk_len = full_chunk_len
            chunk_lengths_list.append(last_chunk_len)

        chunk_lengths = torch.tensor(
            chunk_lengths_list, dtype=torch.long, device=device
        )

        chunk_list = valid_input_features.split(chunk_lengths.tolist(), dim=1)
        padded_feature, padded_mask, padded_mask_after_cnn = (
            self.padded_and_mask_function(
                chunk_list, chunk_lengths, padding_value=0, padding_side="right"
            )
        )

        padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask
        padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2)

        padded_embed = padded_embed + self.positional_embedding.positional_embedding[
            : padded_embed.shape[1], :
        ].unsqueeze(0).to(padded_embed.dtype)

        hidden_states = padded_embed[padded_mask_after_cnn]
        cu_seqlens = torch.cat(
            (
                torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32),
                padded_mask_after_cnn.sum(1).cumsum(0),
            )
        ).to(torch.int32)

        for encoder_layer in self.layers:
            (hidden_states,) = encoder_layer(
                hidden_states,
                cu_seqlens=cu_seqlens,
                **kwargs,
            )

        hidden_states_list = hidden_states.split(valid_aftercnn_lens.tolist(), dim=0)

        pooled_list: list[torch.Tensor] = []
        pooled_lengths: list[int] = []
        for each_audio_states in hidden_states_list:
            seq_len = int(each_audio_states.shape[0])
            if seq_len >= 2:
                pooled = nn.functional.avg_pool1d(
                    each_audio_states.transpose(0, 1), kernel_size=2, stride=2
                ).transpose(0, 1)
            else:
                pooled = each_audio_states
            pooled_list.append(pooled)
            pooled_lengths.append(int(pooled.shape[0]))

        pooled_concat = torch.cat(pooled_list, dim=0)
        processed_concat = self.proj(self.ln_post(pooled_concat))
        processed_audio_list = list(processed_concat.split(pooled_lengths, dim=0))

        output_dim = (
            int(processed_audio_list[0].shape[-1])
            if processed_audio_list
            else int(self.proj.out_features)
        )
        output_hidden_states = torch.zeros(
            (original_batch_size, speech_maxlen, output_dim),
            dtype=processed_audio_list[0].dtype
            if processed_audio_list
            else self.proj.weight.dtype,
            device=device,
        )

        for valid_idx, processed in zip(valid_indices, processed_audio_list):
            seq_len = min(int(processed.shape[0]), int(speech_maxlen))
            output_hidden_states[int(valid_idx), :seq_len] = processed[:seq_len]

        return BaseModelOutput(last_hidden_state=output_hidden_states)

    def padded_and_mask_function(
        self,
        tensor_list: Sequence[torch.Tensor],
        tensor_len: torch.Tensor,
        padding_value: float = 0.0,
        padding_side: str = "right",
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        max_len = int(tensor_len.max().item())
        dim = int(tensor_list[0].shape[0])
        padded_tensor = torch.full(
            size=(len(tensor_list), dim, max_len),
            fill_value=padding_value,
            dtype=self.dtype,
            device=tensor_list[0].device,
        )

        batch_mask = torch.zeros(
            (len(tensor_len), max_len), dtype=torch.long, device=padded_tensor.device
        )
        for i, length in enumerate(tensor_len):
            length_val = int(length.item())
            batch_mask[i, :length_val] = 1
            padded_tensor[i, :, :length_val] = tensor_list[i]

        feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
        max_len_after_cnn = int(feature_lens_after_cnn.max().item())
        batch_mask_after_cnn = torch.zeros(
            (len(tensor_len), max_len_after_cnn),
            dtype=torch.long,
            device=padded_tensor.device,
        )
        for i, length in enumerate(feature_lens_after_cnn):
            batch_mask_after_cnn[i, : int(length.item())] = 1

        if padding_side != "right":
            raise NotImplementedError("Only right padding is supported.")

        return (
            padded_tensor,
            batch_mask.unsqueeze(1).to(padded_tensor.dtype),
            batch_mask_after_cnn.bool(),
        )

    # From the HF FunAudioChat implementation.
    def _get_feat_extract_output_lengths(
        self, input_lengths: torch.LongTensor
    ) -> tuple[torch.LongTensor, torch.LongTensor]:
        input_lengths = (input_lengths - 1) // 2 + 1
        output_lengths = (input_lengths - 2) // 2 + 1
        return input_lengths, output_lengths

FunAudioChatDiscreteEncoder

Bases: Module

Discrete audio encoder (speech tokenizer -> grouped embeddings).

Source code in vllm/model_executor/models/funaudiochat.py
class FunAudioChatDiscreteEncoder(nn.Module):
    """Discrete audio encoder (speech tokenizer -> grouped embeddings)."""

    def __init__(self, config: Any):
        super().__init__()
        self.padding_idx = int(config.pad_token_id)
        self.group_size = int(config.group_size)
        self.hidden_size = int(config.output_dim)
        self.continuous_features_mode = getattr(
            config, "continuous_features_mode", "add"
        )
        self.embed_tokens = nn.Embedding(
            int(config.codebook_size), self.hidden_size, self.padding_idx
        )
        self.output_matching = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.continual_output_matching = nn.Linear(
            self.hidden_size, self.hidden_size, bias=False
        )

    def forward(
        self,
        audio_ids: torch.Tensor,
        continuous_audio_features: torch.Tensor | None = None,
        continuous_audio_output_lengths: torch.Tensor | None = None,
        feature_exist_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        del continuous_audio_output_lengths

        inputs_embeds = self.embed_tokens(audio_ids)
        hidden_states = inputs_embeds.reshape(
            inputs_embeds.shape[0], -1, self.group_size * self.hidden_size
        )
        hidden_states = hidden_states.reshape(
            hidden_states.shape[0], -1, self.group_size, self.hidden_size
        ).mean(dim=2)
        hidden_states = self.output_matching(hidden_states)

        if continuous_audio_features is not None:
            continuous_audio_features = continuous_audio_features.reshape(
                continuous_audio_features.shape[0],
                -1,
                self.group_size,
                self.hidden_size,
            ).mean(dim=2)
            continuous_audio_hidden_states = self.continual_output_matching(
                continuous_audio_features
            )

            if feature_exist_mask is None:
                feature_exist_mask = torch.ones(
                    (hidden_states.shape[0],),
                    dtype=torch.bool,
                    device=hidden_states.device,
                )
            if self.continuous_features_mode == "add":
                hidden_states[feature_exist_mask] += continuous_audio_hidden_states
            else:
                hidden_states[feature_exist_mask] = continuous_audio_hidden_states

        return hidden_states

    def _get_feat_extract_output_lengths(
        self, input_lengths: torch.LongTensor
    ) -> tuple[torch.LongTensor, torch.LongTensor]:
        output_lengths = (input_lengths + self.group_size - 1) // self.group_size
        return input_lengths, output_lengths