2
0

Cloning from Github

This commit is contained in:
2025-10-20 18:58:25 +02:00
commit 9643e6494d
50 changed files with 9257 additions and 0 deletions

View File

View File

@@ -0,0 +1,248 @@
""" VibeVoice_AcousticTokenizer model configuration"""
from typing import Dict, List, Optional, Tuple
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
logger = logging.get_logger(__name__)
class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
model_type = "vibevoice_acoustic_tokenizer"
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0.5,
std_dist_type: str = 'gaussian',
# common
mixer_layer: str = 'depthwise_conv',
conv_norm: str = 'none',
pad_mode: str = 'constant',
disable_last_norm: bool = True,
layernorm: str = 'RMSNorm',
layernorm_eps: float = 1e-5,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-6,
weight_init_value: float = 1e-2,
# encoder specific
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
encoder_depths: str = "3-3-3-3-3-3-8",
# decoder specific
decoder_n_filters: int = 32,
decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
decoder_depths: Optional[str] = None,
**kwargs
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
# common parameters
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
# encoder specific parameters
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
# decoder specific parameters
self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
self.decoder_n_filters = decoder_n_filters
self.decoder_depths = decoder_depths
class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
model_type = "vibevoice_semantic_tokenizer"
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0,
std_dist_type: str = 'none',
# common
mixer_layer: str = 'depthwise_conv',
conv_norm: str = 'none',
pad_mode: str = 'constant',
disable_last_norm: bool = True,
layernorm: str = 'RMSNorm',
layernorm_eps: float = 1e-5,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-6,
weight_init_value: float = 1e-2,
# encoder specific
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
encoder_depths: str = "3-3-3-3-3-3-8",
**kwargs
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
# common parameters
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
# encoder specific parameters
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
model_type = "vibevoice_diffusion_head"
def __init__(
self,
hidden_size=768,
head_layers=4,
head_ffn_ratio=3.0,
rms_norm_eps=1e-5,
latent_size=64,
speech_vae_dim=None,
prediction_type="v_prediction",
diffusion_type="ddpm",
ddpm_num_steps=1000,
ddpm_num_inference_steps=20,
ddpm_beta_schedule="cosine",
ddpm_batch_mul=4,
**kwargs
):
self.hidden_size = hidden_size
self.head_layers = head_layers
self.head_ffn_ratio = head_ffn_ratio
self.rms_norm_eps = rms_norm_eps
self.latent_size = latent_size
self.speech_vae_dim = speech_vae_dim
self.prediction_type = prediction_type
self.diffusion_type = diffusion_type
self.ddpm_num_steps = ddpm_num_steps
self.ddpm_num_inference_steps = ddpm_num_inference_steps
self.ddpm_beta_schedule = ddpm_beta_schedule
self.ddpm_batch_mul = ddpm_batch_mul
super().__init__(**kwargs)
class VibeVoiceConfig(PretrainedConfig):
model_type = "vibevoice"
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
"decoder_config": Qwen2Config,
"diffusion_head_config": VibeVoiceDiffusionHeadConfig,
}
# keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
semantic_tokenizer_config=None,
decoder_config=None,
diffusion_head_config=None,
**kwargs
):
# kwargs["_attn_implementation"] = "flash_attention_2"
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
# If an instance of the config class is provided
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if semantic_tokenizer_config is None:
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
elif isinstance(semantic_tokenizer_config, dict):
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
# If an instance of the config class is provided
self.semantic_tokenizer_config = semantic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
# If a dictionary is provided, instantiate the config class with it
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
if decoder_config.get("model_type", '') == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
elif isinstance(decoder_config, (Qwen2Config,)):
# If an instance of the config class is provided
self.decoder_config = decoder_config
if diffusion_head_config is None:
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
elif isinstance(diffusion_head_config, dict):
diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
# If an instance of the config class is provided
self.diffusion_head_config = diffusion_head_config
# other parameters
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
super().__init__(**kwargs)
__all__ = [
"VibeVoiceAcousticTokenizerConfig",
"VibeVoiceSemanticTokenizerConfig",
"VibeVoiceDiffusionHeadConfig",
"VibeVoiceConfig"
]

View File

@@ -0,0 +1,488 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
from .configuration_vibevoice import VibeVoiceConfig
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
@dataclass
class VibeVoiceCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
diffusion_loss: Optional[torch.FloatTensor] = None
speech_token_num: Optional[int] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class VibeVoiceGenerationOutput(ModelOutput):
"""
Output type for VibeVoice generation.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences.
speech_outputs (`List[torch.FloatTensor]`, *optional*):
List of generated speech waveforms or latents for each speech segment.
"""
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
# @auto_docstring
class VibeVoicePreTrainedModel(PreTrainedModel):
config_class = VibeVoiceConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, VibeVoiceDiffusionHead):
module.initialize_weights()
return
# Use the language model's initializer_range if available
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
std = self.config.decoder_config.initializer_range
else:
std = 0.02 # Default value
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
# @auto_docstring
class VibeVoiceModel(VibeVoicePreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
# Initialize Qwen2 model for language modeling
lm_config = config.decoder_config
self.language_model = AutoModel.from_config(lm_config)
# Initialize speech components if needed
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
# Initialize prediction head for speech generation
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
# Initialize noise scheduler
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type
)
def get_input_embeddings(self):
if hasattr(self.language_model, 'embed_tokens'):
# If the language model has an embed_tokens attribute, return it
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
if attr.orig_name == 'embed_tokens.weight':
return getattr(self.language_model, name)
assert False, 'should not arrive here'
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.acoustic_tokenizer = acoustic_tokenizer
self.semantic_tokenizer = semantic_tokenizer
# Reset the encoder to evaluation mode
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.eval()
if self.semantic_tokenizer is not None:
self.semantic_tokenizer.eval()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Forward through language model
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return outputs
return BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = VibeVoiceModel(config)
self.vocab_size = config.decoder_config.vocab_size
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_decoder(self, decoder):
self.model.language_model = decoder
def get_decoder(self):
return self.model.language_model
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
"""
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
# The standard PreTrainedModel method will handle the tying.
# It typically does a simple parameter object assignment, which is
# CORRECT to do BEFORE FSDP wraps the model.
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if hasattr(input_embeddings, 'weight'):
output_embeddings.weight = input_embeddings.weight
else:
# maybe returned input_embeddings a tensor directly
output_embeddings.weight = input_embeddings
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
"constant",
0,
)
print("✅ Tied input and output embeddings using standard assignment.")
else:
print(" tie_word_embeddings is False, not tying weights.")
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
# The key is to avoid calling it after accelerator.prepare().
def set_output_embeddings(self, new_embeddings):
# Your current implementation using data.copy_ is good practice,
# but the best way is to not call this after prepare().
self.lm_head = new_embeddings
def forward_speech_features(
self,
speech_tensors=None,
speech_masks=None,
speech_type="audio",
return_unmask=False
):
if speech_tensors is None:
# Use config to get vae_dim instead of non-existent self.args
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
connect_features = self.model.acoustic_connector(audio_features)
return audio_features, connect_features
else:
with torch.no_grad():
if speech_type == "audio":
with torch.no_grad():
frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
elif speech_type == "vae":
# Use config to get vae_dim instead of non-existent self.args
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
# gaussian sample from the speech_mode
batch_size = speech_mode.size(0)
value = self.model.acoustic_tokenizer.fix_std / 0.8
std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
else:
raise NotImplementedError(f"Speech type {speech_type} not implemented")
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
bias_factor = -audio_tokens[speech_masks].flatten().mean()
# Only use distributed operations if the process group is initialized
if dist.is_available() and dist.is_initialized():
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
world_size = dist.get_world_size()
self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
self.model.speech_bias_factor.copy_(bias_factor / world_size)
print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
else:
# Single process case
self.model.speech_scaling_factor.copy_(scaling_factor)
self.model.speech_bias_factor.copy_(bias_factor)
print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
connect_features = self.model.acoustic_connector(audio_features)
if return_unmask:
return audio_features, connect_features
return audio_features[speech_masks], connect_features[speech_masks]
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
# New arguments for speech processing and loss calculation
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speeches_loss_input: Optional[torch.FloatTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
acoustic_input_mask: Optional[torch.BoolTensor] = None,
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
ddpm_batch_mul: int = 1,
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
x = self.get_input_embeddings()(input_ids)
semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
if speeches_loss_input is not None:
# only part audio need diffuse
speech_all_features, speech_all_connect_features = self.forward_speech_features(
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
return_unmask=True
)
if speech_tensors is not None:
if semantic_speech_all_connect_features is not None:
x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks]
else:
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse
speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks]
else:
speech_features, speech_connect_features = self.forward_speech_features(
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
)
if speech_tensors is not None:
x[acoustic_input_mask] = speech_connect_features
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=x,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=False,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
# logits = logits.float()
loss = None
if labels is not None:
# The custom CE loss with masking is calculated in the training script.
# We leave the standard loss calculation here as None.
pass
# --- Diffusion Loss Calculation ---
diffusion_loss = None
# This block is executed only if we are in a context that involves speech.
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
condition_features = hidden_states[acoustic_loss_mask]
speech_len, latent_size = speech_features.shape
noise = torch.randn(
(speech_len * ddpm_batch_mul, latent_size),
device=hidden_states.device,
dtype=hidden_states.dtype
)
timesteps = torch.multinomial(
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
speech_len * ddpm_batch_mul,
replacement=True,
).to(hidden_states.device)
speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
noisy_speech_features = self.model.noise_scheduler.add_noise(
speech_features_repeated, noise, timesteps
)
model_output = self.model.prediction_head(
noisy_speech_features,
timesteps.type_as(x),
condition_features_repeated
)
prediction_type = self.config.diffusion_head_config.prediction_type
if prediction_type == "epsilon":
target_for_loss = noise
elif prediction_type == "v_prediction":
target_for_loss = self.model.noise_scheduler.get_velocity(
speech_features_repeated, noise, timesteps
)
else:
raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
if latent_size > 0 and ddpm_batch_mul > 0:
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
else:
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
else:
# Dummy loss for DDP to work when there are no speech samples in a batch,
# but we are in a speech context.
diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
# --- End Diffusion Loss Calculation ---
if not return_dict:
output = (logits, speech_len) + outputs.to_tuple()[1:]
return (loss, diffusion_loss) + output
return VibeVoiceCausalLMOutputWithPast(
loss=loss,
diffusion_loss=diffusion_loss,
speech_token_num=speech_len if speech_tensors is not None else 0,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
__all__ = [
"VibeVoiceModel",
"VibeVoicePreTrainedModel",
"VibeVoiceForConditionalGeneration",
"VibeVoiceCausalLMOutputWithPast",
"VibeVoiceGenerationOutput",
]

View File

@@ -0,0 +1,715 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
# from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
from .configuration_vibevoice import VibeVoiceConfig
from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel
from .streamer import AudioStreamer, AsyncAudioStreamer
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
@dataclass
class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
logits: Optional[torch.FloatTensor] = None
@dataclass
class VibeVoiceGenerationOutput(ModelOutput):
"""
Output type for VibeVoice generation.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences.
speech_outputs (`List[torch.FloatTensor]`, *optional*):
List of generated speech waveforms or latents for each speech segment.
"""
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
reach_max_step_sample: Optional[torch.BoolTensor] = None
class VibeVoiceTokenConstraintProcessor(LogitsProcessor):
"""Constrains token generation to only valid tokens during speech generation."""
def __init__(self, valid_token_ids: List[int], device: torch.device = None):
self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Create a mask for valid tokens
mask = torch.full_like(scores, float('-inf'))
mask[:, self.valid_token_ids] = 0
# Apply mask to scores
scores = scores + mask
return scores
class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
# Initialize the base model
self.model = VibeVoiceModel(config)
# LM head for text generation
self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False)
# inference configuration
self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
# Initialize weights and apply final processing
self.post_init()
@property
def noise_scheduler(self):
return self.model.noise_scheduler
@property
def prediction_head(self):
return self.model.prediction_head
@property
def speech_scaling_factor(self):
return self.model.speech_scaling_factor
@property
def speech_bias_factor(self):
return self.model.speech_bias_factor
@property
def acoustic_tokenizer(self):
return self.model.acoustic_tokenizer
@property
def semantic_tokenizer(self):
return self.model.semantic_tokenizer
@property
def acoustic_connector(self):
return self.model.acoustic_connector
@property
def semantic_connector(self):
return self.model.semantic_connector
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
"""
# Tie lm_head.weight to language_model.embed_tokens.weight
if not getattr(self.config, 'tie_word_embeddings', False):
return
if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
self.lm_head.weight = self.model.language_model.embed_tokens.weight
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
def set_ddpm_inference_steps(self, num_steps=None):
self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
"""Process speech inputs through tokenizers and connectors."""
with torch.no_grad():
if speech_type == "audio":
# Encode audio to acoustic latents
encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
# Apply scaling and bias
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
# Connect to language model space
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
return acoustic_features, acoustic_connected
elif speech_type == "pt":
encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std)
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
# Apply scaling and bias
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
# Connect to language model space
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
return acoustic_features, acoustic_connected
else:
raise NotImplementedError(f"Speech type {speech_type} not implemented")
# @can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
logits_to_keep: Union[int, slice] = 0,
**kwargs,
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
speech_tensors (`torch.FloatTensor`, *optional*):
Input speech waveforms for voice cloning or speech understanding.
speech_masks (`torch.BoolTensor`, *optional*):
Masks indicating valid speech frames.
speech_input_mask (`torch.BoolTensor`, *optional*):
Positions in the input sequence where speech embeddings should be inserted.
Returns:
`VibeVoiceCausalLMOutputWithPast` or tuple
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get embeddings
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
# Process speech inputs if provided
if speech_tensors is not None and speech_masks is not None:
acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors.to(self.dtype), speech_masks)
if speech_input_mask is not None:
inputs_embeds[speech_input_mask] = speech_embeds
outputs = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
raise NotImplementedError("Loss computation is not implemented in this version.")
return VibeVoiceCausalLMOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
if generation_config is None:
generation_config = GenerationConfig(
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id
)
else:
generation_config = GenerationConfig(
**generation_config,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id
)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config,
True,
speech_start_id=tokenizer.speech_start_id,
speech_end_id=tokenizer.speech_end_id,
speech_diffusion_id=tokenizer.speech_diffusion_id,
**kwargs
)
generation_config.speech_start_id = tokenizer.speech_start_id
generation_config.speech_end_id = tokenizer.speech_end_id
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
batch_size = inputs_tensor.shape[0]
device = self.device
self._prepare_special_tokens(generation_config, True, device=device)
generation_config.use_cache = True
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = inputs_tensor.to(self.device)
input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
max_cache_length = generation_config.max_length - 1
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
for k, v in model_kwargs.items():
if isinstance(v, torch.Tensor):
model_kwargs[k] = v.to(device=device)
if return_processors:
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
device=inputs_tensor.device,
model_kwargs=model_kwargs,
)
stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
else:
return generation_config, model_kwargs, input_ids
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
return_speech: bool = True,
cfg_scale: float = 1.0,
stop_check_fn: Optional[Callable[[], bool]] = None,
**kwargs,
) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
"""
Generates sequences of token ids and optionally speech outputs.
Args:
All standard generation arguments from GenerationMixin
negative_prompt_ids: Negative prompt for CFG in speech generation
negative_prompt_attention_mask: Attention mask for negative prompt
speech_tensors: Input speech for voice cloning
speech_masks: Masks for speech tensors
speech_input_mask: Positions to insert speech embeddings
return_speech: Whether to decode and return speech outputs
cfg_scale: CFG scale for speech generation
stop_check_fn: Optional callable that returns True if generation should stop
Returns:
Generated token sequences and optionally speech outputs
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
parsed_scripts = kwargs.pop("parsed_scripts", None)
all_speakers_list = kwargs.pop("all_speakers_list", None)
max_length_times = kwargs.pop("max_length_times", 2)
if kwargs.get('max_new_tokens', None) is None:
kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]
generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
generation_config, inputs, tokenizer, return_processors=True, **kwargs
)
negative_kwargs = {
'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
'max_new_tokens': kwargs.get('max_new_tokens', 100)
}
negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **negative_kwargs
)
acoustic_cache = VibeVoiceTokenizerStreamingCache()
semantic_cache = VibeVoiceTokenizerStreamingCache()
batch_size = input_ids.shape[0]
device = input_ids.device
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
is_prefill = True
inputs_embeds = None
verbose = kwargs.get("verbose", False)
# Initialize audio chunks storage for each sample
audio_chunks = [[] for _ in range(batch_size)]
initial_length = input_ids.shape[-1]
initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)
# Define all valid tokens that can be generated
valid_tokens = [
generation_config.speech_start_id,
generation_config.speech_end_id,
generation_config.speech_diffusion_id,
generation_config.eos_token_id
]
# Add bos_token_id if it exists
if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
valid_tokens.append(generation_config.bos_token_id)
# Add custom processor to constrain token generation
token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(token_constraint_processor)
max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long())
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
# Create progress iterator if verbose
if kwargs.get("show_progress_bar", True):
progress_bar = tqdm(range(max_steps), desc="Generating", leave=False)
else:
progress_bar = range(max_steps)
for step in progress_bar:
# Check for external stop signal
if stop_check_fn is not None and stop_check_fn():
if verbose:
print(f"Generation stopped externally at step {step + 1}")
# End the audio streamer if it exists
if audio_streamer is not None:
audio_streamer.end()
break
# Check if audio_streamer has been ended (stopped externally)
if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
if any(audio_streamer.finished_flags):
if verbose:
print(f"Audio generation stopped externally at step {step + 1}")
break
if finished_tags.all():
if hasattr(progress_bar, 'set_description'):
progress_bar.set_description("Generation complete")
break
if input_ids.shape[-1] >= generation_config.max_length:
print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
# Update progress bar description with active samples
if hasattr(progress_bar, 'set_description'):
active_samples = (~finished_tags).sum().item()
progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})")
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
if is_prefill:
# we process the speech inputs only during the first generation step
prefill_inputs = {
"speech_tensors": speech_tensors.to(device=device),
"speech_masks": speech_masks.to(device),
"speech_input_mask": speech_input_mask.to(device),
}
is_prefill = False
else:
_ = model_inputs.pop('inputs_embeds', None)
prefill_inputs = {'inputs_embeds': inputs_embeds}
# Forward pass through the model
outputs = self(
**model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False,
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False,
)
# Get logits and apply logits processor
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
# next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
next_token_scores = logits_processor(input_ids, next_token_logits)
# token selection
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
next_tokens[finished_tags] = generation_config.eos_token_id
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if not kwargs.get('refresh_negative', True):
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
# Forward negative pass through the model
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
negative_model_inputs['inputs_embeds'] = inputs_embeds
negative_model_inputs['input_ids'] = None
negative_outputs = self(
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
)
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
# reached end of generation
if (next_tokens == generation_config.eos_token_id).any():
eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
# Only print for samples that are newly finished (not already marked as finished)
new_eos_indices = eos_indices[~finished_tags[eos_indices]]
if new_eos_indices.numel() > 0:
finished_tags[new_eos_indices] = True
if verbose:
print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True)
if audio_streamer is not None:
audio_streamer.end(new_eos_indices)
# Check if any sample reached its maximum generation length
max_length_reached = step >= max_step_per_sample
new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
if new_max_length_indices.numel() > 0:
finished_tags[new_max_length_indices] = True
reach_max_step_sample[new_max_length_indices] = True
if verbose:
print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True)
if audio_streamer is not None:
audio_streamer.end(new_max_length_indices)
# speech_end
diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
if diffusion_end_indices.numel() > 0:
# Clear tokenizer caches for samples that reached speech end
acoustic_cache.set_to_zero(diffusion_end_indices)
semantic_cache.set_to_zero(diffusion_end_indices)
# speech_begin
diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)]
if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True):
# update attention mask
for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
negative_model_kwargs['attention_mask'][sample_idx, :] = 0
negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
# update past key values
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
negative_model_kwargs['past_key_values'].value_cache)):
# Process each non-diffusion sample
for sample_idx in diffusion_start_indices.tolist():
# Shift cache for this sample
k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
# update negative_input_ids
for sample_idx in diffusion_start_indices.tolist():
negative_input_ids[sample_idx, -1] = generation_config.speech_start_id
# Prepare inputs_embeds for next iteration
# Initialize with default embeddings for all tokens
next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size]
# forward diffusion
# Diffusion indices are those that are not finished and not special tokens
diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)]
if diffusion_indices.numel() > 0:
if kwargs.get('refresh_negative', True):
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
# Forward negative pass through the model
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
negative_model_inputs['inputs_embeds'] = inputs_embeds
negative_model_inputs['input_ids'] = None
negative_outputs = self(
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
)
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
# correct the non-diffusion indices
# we forward all samples' negative outputs even if
# they are not in diffusion mode to keep the cache consistent
# So we need to correct the kv cache of non-diffusion samples
non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id)
if non_diffusion_mask.any():
non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask]
start_indices = correct_cnt[non_diffusion_indices]
# 1. Update attention_mask - need to handle each sample separately
seq_len = negative_model_kwargs['attention_mask'].shape[1]
for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())):
# Shift the attention mask for this sample
if start_idx + 1 < seq_len - 1:
negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \
negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone()
negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
# 2. Update past_key_values
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
negative_model_kwargs['past_key_values'].value_cache)):
# Process each non-diffusion sample
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
if start_idx + 1 < k_cache.shape[2] - 1:
# Shift cache for this sample
k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone()
v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone()
# 3. Update negative_input_ids
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
if start_idx + 1 < negative_input_ids.shape[1] - 1:
negative_input_ids[sample_idx, start_idx+1:] = \
negative_input_ids[sample_idx, start_idx:-1].clone()
correct_cnt[non_diffusion_indices] += 1
positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :]
speech_latent = self.sample_speech_tokens(
positive_condition,
negative_condition,
cfg_scale=cfg_scale,
).unsqueeze(1)
# Decode acoustic latent to audio using acoustic streaming cache
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
audio_chunk = self.model.acoustic_tokenizer.decode(
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache, # Use acoustic-specific cache
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
use_cache=True,
debug=False
)
# Store audio chunks for each sample
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
# Only append audio chunk if the sample is not finished
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
# Add streaming support here
if audio_streamer is not None:
# Stream the audio chunks immediately
audio_streamer.put(audio_chunk, diffusion_indices)
# Encode audio to semantic features using semantic streaming cache
semantic_features = self.model.semantic_tokenizer.encode(
audio_chunk,
cache=semantic_cache, # Use semantic-specific cache
sample_indices=diffusion_indices,
use_cache=True,
debug=False
).mean # semantic tokenizer has no VAE.
# Combine acoustic and semantic features for next input
acoustic_embed = self.model.acoustic_connector(speech_latent)
semantic_embed = self.model.semantic_connector(semantic_features)
diffusion_embeds = acoustic_embed + semantic_embed
# Update embeddings for diffusion indices
next_inputs_embeds[diffusion_indices] = diffusion_embeds
# Set inputs_embeds for next iteration
inputs_embeds = next_inputs_embeds
if audio_streamer is not None:
audio_streamer.end()
# Concatenate audio chunks for each sample
final_audio_outputs = []
for sample_chunks in audio_chunks:
if sample_chunks:
# Concatenate all chunks along the time dimension (assumed to be the last dimension)
concatenated_audio = torch.cat(sample_chunks, dim=-1)
final_audio_outputs.append(concatenated_audio)
else:
# If no audio was generated for this sample, append None
final_audio_outputs.append(None)
return VibeVoiceGenerationOutput(
sequences=input_ids,
speech_outputs=final_audio_outputs if return_speech else None,
reach_max_step_sample=reach_max_step_sample,
)
@torch.no_grad()
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
for t in self.model.noise_scheduler.timesteps:
half = speech[: len(speech) // 2]
combined = torch.cat([half, half], dim=0)
eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
return speech[: len(speech) // 2]
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference)
__all__ = [
"VibeVoiceForConditionalGenerationInference",
]

View File

@@ -0,0 +1,287 @@
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.modeling_utils import PreTrainedModel
# from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.activations import ACT2FN
from transformers.utils import logging
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
logger = logging.get_logger(__name__)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
def extra_repr(self) -> str:
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
def modulate(x, shift, scale):
"""Apply modulation to input tensor."""
return x * (1 + scale) + shift
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
Args:
hidden_size (`int`): Size of the output embedding
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(hidden_size, hidden_size, bias=False),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
These may be fractional.
dim (`int`): The dimension of the output.
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
Returns:
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding.to(t.dtype)
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForwardNetwork(nn.Module):
"""
Standard feed-forward network with SwiGLU activation.
Args:
embed_dim (`int`): Input dimension
ffn_dim (`int`): Hidden dimension
"""
def __init__(
self,
embed_dim,
ffn_dim,
):
super().__init__()
self.embed_dim = embed_dim
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
# SwiGLU activation
# gate = F.silu(gate)
gate = self.act_fn(gate)
return self.down_proj(gate * up)
class HeadLayer(nn.Module):
"""
A layer in the diffusion head.
Args:
embed_dim (`int`): Input dimension
ffn_dim (`int`): Hidden dimension
cond_dim (`int`): Condition embedding dimension
norm_eps (`float`, optional): Epsilon for normalization
"""
def __init__(
self,
embed_dim,
ffn_dim,
cond_dim,
norm_eps=1e-5,
):
super().__init__()
self.embed_dim = embed_dim
self.cond_dim = cond_dim
self.ffn_dim = ffn_dim
self.ffn = FeedForwardNetwork(
self.embed_dim,
self.ffn_dim,
)
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
)
def forward(self, x, c):
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
return x
class FinalLayer(nn.Module):
"""
Final layer in the diffusion head.
Args:
hidden_size (`int`): Input dimension
output_size (`int`): Output dimension
cond_size (`int`): Condition embedding dimension
norm_eps (`float`, optional): Epsilon for normalization
"""
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
super().__init__()
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
self.linear = nn.Linear(hidden_size, output_size, bias=False)
self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(cond_size, 2 * hidden_size, bias=False)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class VibeVoiceDiffusionHead(PreTrainedModel):
"""
Diffusion head model for vibevoice.
Args:
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
"""
config_class = VibeVoiceDiffusionHeadConfig
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
config,
):
super().__init__(config)
self.config = config
self.cond_dim = config.hidden_size
latent_size = config.latent_size
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
self.t_embedder = TimestepEmbedder(self.cond_dim)
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
# Create the intermediate layers
self.layers = nn.ModuleList([
HeadLayer(
embed_dim=config.hidden_size,
ffn_dim=ffn_dim,
cond_dim=self.cond_dim,
norm_eps=config.rms_norm_eps
)
for _ in range(config.head_layers)
])
# Final layer for output
self.final_layer = FinalLayer(
hidden_size=config.hidden_size,
output_size=latent_size,
cond_size=self.cond_dim,
norm_eps=config.rms_norm_eps
)
self.initialize_weights()
def initialize_weights(self):
"""Initialize the weights of the model."""
# Initialize timestep embedder
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers
for layer in self.layers:
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
# Zero-out output layers
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
def forward(
self,
noisy_images,
timesteps,
condition,
):
"""
Forward pass of the prediction head.
Args:
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
timesteps (`torch.Tensor`): Timesteps for diffusion
condition (`torch.Tensor`): Conditioning information
Returns:
`torch.Tensor`: The predicted noise/velocity
"""
x = self.noisy_images_proj(noisy_images)
t = self.t_embedder(timesteps)
condition = self.cond_proj(condition)
c = condition + t
for layer in self.layers:
x = layer(x, c)
x = self.final_layer(x, c)
return x
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
__all__ = [
"VibeVoiceDiffusionHead",
]

View File

@@ -0,0 +1,214 @@
"""Tokenization classes for vibevoice."""
from typing import List, Optional, Union
from transformers.utils import logging
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
logger = logging.get_logger(__name__)
class VibeVoiceTextTokenizer(Qwen2Tokenizer):
"""
Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token.
bos_token (`str`, *optional*):
The beginning of sequence token. Not used for vibevoice.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding.
add_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to add special tokens when encoding.
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
add_special_tokens=True,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
add_special_tokens=add_special_tokens,
**kwargs,
)
# Add VibeVoice-specific special tokens
self._add_vibevoice_special_tokens()
def _add_vibevoice_special_tokens(self):
"""Add VibeVoice-specific special tokens."""
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>", # Speech start (reusing vision tokens)
"<|vision_end|>", # Speech end
"<|vision_pad|>", # Speech diffusion pad
]
}
num_added = self.add_special_tokens(special_tokens)
# Cache special token IDs
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
return num_added
@property
def eos_id(self) -> int:
"""Id of the end of sequence token."""
return self._eos_id
@property
def speech_start_id(self) -> int:
"""Id of the speech start token."""
return self._speech_start_id
@property
def speech_end_id(self) -> int:
"""Id of the speech end token."""
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
"""Id of the speech diffusion token."""
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
"""Id used for padding (returns -100 for loss masking)."""
return -100
class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
"""
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
Based on the Qwen2 tokenizer with additional special tokens for speech.
Args:
vocab_file (`str`, *optional*):
Path to the vocabulary file.
merges_file (`str`, *optional*):
Path to the merges file.
tokenizer_file (`str`, *optional*):
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token.
bos_token (`str`, *optional*):
The beginning of sequence token. Not used for vibevoice.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding.
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
# Add VibeVoice-specific special tokens
self._add_vibevoice_special_tokens()
def _add_vibevoice_special_tokens(self):
"""Add VibeVoice-specific special tokens."""
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>", # Speech start (reusing vision tokens)
"<|vision_end|>", # Speech end
"<|vision_pad|>", # Speech diffusion pad
]
}
num_added = self.add_special_tokens(special_tokens)
# Cache special token IDs
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
# self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
self._eos_id = self.eos_token_id # qwen2 / qwen3
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
return num_added
@property
def eos_id(self) -> int:
"""Id of the end of sequence token."""
return self._eos_id
@property
def speech_start_id(self) -> int:
"""Id of the speech start token."""
return self._speech_start_id
@property
def speech_end_id(self) -> int:
"""Id of the speech end token."""
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
"""Id of the speech diffusion token."""
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
"""Id used for padding (returns -100 for loss masking)."""
return self._pad_id
__all__ = [
"VibeVoiceTextTokenizer",
"VibeVoiceTextTokenizerFast",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,264 @@
from __future__ import annotations
import torch
import asyncio
from queue import Queue
from typing import TYPE_CHECKING, Optional
from transformers.generation import BaseStreamer
class AudioStreamer(BaseStreamer):
"""
Audio streamer that stores audio chunks in queues for each sample in the batch.
This allows streaming audio generation for multiple samples simultaneously.
Parameters:
batch_size (`int`):
The batch size for generation
stop_signal (`any`, *optional*):
The signal to put in the queue when generation ends. Defaults to None.
timeout (`float`, *optional*):
The timeout for the audio queue. If `None`, the queue will block indefinitely.
"""
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
self.batch_size = batch_size
self.stop_signal = stop_signal
self.timeout = timeout
# Create a queue for each sample in the batch
self.audio_queues = [Queue() for _ in range(batch_size)]
self.finished_flags = [False for _ in range(batch_size)]
self.sample_indices_map = {} # Maps from sample index to queue index
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
"""
Receives audio chunks and puts them in the appropriate queues.
Args:
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
sample_indices: Tensor indicating which samples these chunks belong to
"""
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
# Convert to numpy or keep as tensor based on preference
audio_chunk = audio_chunks[i].detach().cpu()
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
def end(self, sample_indices: Optional[torch.Tensor] = None):
"""
Signals the end of generation for specified samples or all samples.
Args:
sample_indices: Optional tensor of sample indices to end. If None, ends all.
"""
if sample_indices is None:
# End all samples
for idx in range(self.batch_size):
if not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
else:
# End specific samples
for sample_idx in sample_indices:
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
if idx < self.batch_size and not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
def __iter__(self):
"""Returns an iterator over the batch of audio streams."""
return AudioBatchIterator(self)
def get_stream(self, sample_idx: int):
"""Get the audio stream for a specific sample."""
if sample_idx >= self.batch_size:
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
return AudioSampleIterator(self, sample_idx)
class AudioSampleIterator:
"""Iterator for a single audio stream from the batch."""
def __init__(self, streamer: AudioStreamer, sample_idx: int):
self.streamer = streamer
self.sample_idx = sample_idx
def __iter__(self):
return self
def __next__(self):
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
if value == self.streamer.stop_signal:
raise StopIteration()
return value
class AudioBatchIterator:
"""Iterator that yields audio chunks for all samples in the batch."""
def __init__(self, streamer: AudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __iter__(self):
return self
def __next__(self):
if not self.active_samples:
raise StopIteration()
batch_chunks = {}
samples_to_remove = set()
# Try to get chunks from all active samples
for idx in self.active_samples:
try:
value = self.streamer.audio_queues[idx].get(block=False)
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except:
# Queue is empty for this sample, skip it this iteration
pass
# Remove finished samples
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
# If no chunks were ready but we still have active samples,
# wait a bit and try again
import time
time.sleep(0.01)
return self.__next__()
else:
raise StopIteration()
class AsyncAudioStreamer(AudioStreamer):
"""
Async version of AudioStreamer for use in async contexts.
"""
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
super().__init__(batch_size, stop_signal, timeout)
# Replace regular queues with async queues
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
"""Put audio chunks in the appropriate async queues."""
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
audio_chunk = audio_chunks[i].detach().cpu()
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, audio_chunk
)
def end(self, sample_indices: Optional[torch.Tensor] = None):
"""Signal the end of generation for specified samples."""
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, self.stop_signal
)
self.finished_flags[idx] = True
async def get_stream(self, sample_idx: int):
"""Get async iterator for a specific sample's audio stream."""
if sample_idx >= self.batch_size:
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
while True:
value = await self.audio_queues[sample_idx].get()
if value == self.stop_signal:
break
yield value
def __aiter__(self):
"""Returns an async iterator over all audio streams."""
return AsyncAudioBatchIterator(self)
class AsyncAudioBatchIterator:
"""Async iterator for batch audio streaming."""
def __init__(self, streamer: AsyncAudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __aiter__(self):
return self
async def __anext__(self):
if not self.active_samples:
raise StopAsyncIteration()
batch_chunks = {}
samples_to_remove = set()
# Create tasks for all active samples
tasks = {
idx: asyncio.create_task(self._get_chunk(idx))
for idx in self.active_samples
}
# Wait for at least one chunk to be ready
done, pending = await asyncio.wait(
tasks.values(),
return_when=asyncio.FIRST_COMPLETED,
timeout=self.streamer.timeout
)
# Cancel pending tasks
for task in pending:
task.cancel()
# Process completed tasks
for idx, task in tasks.items():
if task in done:
try:
value = await task
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except asyncio.CancelledError:
pass
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
# Try again if we still have active samples
return await self.__anext__()
else:
raise StopAsyncIteration()
async def _get_chunk(self, idx):
"""Helper to get a chunk from a specific queue."""
return await self.streamer.audio_queues[idx].get()