Cloning from Github
This commit is contained in:
0
vibevoice/modular/__init__.py
Normal file
0
vibevoice/modular/__init__.py
Normal file
248
vibevoice/modular/configuration_vibevoice.py
Normal file
248
vibevoice/modular/configuration_vibevoice.py
Normal file
@@ -0,0 +1,248 @@
|
||||
""" VibeVoice_AcousticTokenizer model configuration"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
|
||||
model_type = "vibevoice_acoustic_tokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 1,
|
||||
corpus_normalize: float = 0.0,
|
||||
causal: bool = True,
|
||||
vae_dim: int = 64,
|
||||
fix_std: float = 0.5,
|
||||
std_dist_type: str = 'gaussian',
|
||||
# common
|
||||
mixer_layer: str = 'depthwise_conv',
|
||||
conv_norm: str = 'none',
|
||||
pad_mode: str = 'constant',
|
||||
disable_last_norm: bool = True,
|
||||
layernorm: str = 'RMSNorm',
|
||||
layernorm_eps: float = 1e-5,
|
||||
layernorm_elementwise_affine: bool = True,
|
||||
conv_bias: bool = True,
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
weight_init_value: float = 1e-2,
|
||||
# encoder specific
|
||||
encoder_n_filters: int = 32,
|
||||
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
|
||||
encoder_depths: str = "3-3-3-3-3-3-8",
|
||||
# decoder specific
|
||||
decoder_n_filters: int = 32,
|
||||
decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
|
||||
decoder_depths: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.channels = channels
|
||||
self.corpus_normalize = corpus_normalize
|
||||
self.causal = causal
|
||||
self.vae_dim = vae_dim
|
||||
self.fix_std = fix_std
|
||||
self.std_dist_type = std_dist_type
|
||||
|
||||
# common parameters
|
||||
self.conv_norm = conv_norm
|
||||
self.pad_mode = pad_mode
|
||||
self.layernorm_eps = layernorm_eps
|
||||
self.disable_last_norm = disable_last_norm
|
||||
self.layernorm = layernorm
|
||||
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
||||
self.conv_bias = conv_bias
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.weight_init_value = weight_init_value
|
||||
self.mixer_layer = mixer_layer
|
||||
|
||||
# encoder specific parameters
|
||||
self.encoder_n_filters = encoder_n_filters
|
||||
self.encoder_ratios = encoder_ratios
|
||||
self.encoder_depths = encoder_depths
|
||||
|
||||
# decoder specific parameters
|
||||
self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
|
||||
self.decoder_n_filters = decoder_n_filters
|
||||
self.decoder_depths = decoder_depths
|
||||
|
||||
|
||||
class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
|
||||
model_type = "vibevoice_semantic_tokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 1,
|
||||
corpus_normalize: float = 0.0,
|
||||
causal: bool = True,
|
||||
vae_dim: int = 64,
|
||||
fix_std: float = 0,
|
||||
std_dist_type: str = 'none',
|
||||
# common
|
||||
mixer_layer: str = 'depthwise_conv',
|
||||
conv_norm: str = 'none',
|
||||
pad_mode: str = 'constant',
|
||||
disable_last_norm: bool = True,
|
||||
layernorm: str = 'RMSNorm',
|
||||
layernorm_eps: float = 1e-5,
|
||||
layernorm_elementwise_affine: bool = True,
|
||||
conv_bias: bool = True,
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
weight_init_value: float = 1e-2,
|
||||
# encoder specific
|
||||
encoder_n_filters: int = 32,
|
||||
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
|
||||
encoder_depths: str = "3-3-3-3-3-3-8",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.channels = channels
|
||||
self.corpus_normalize = corpus_normalize
|
||||
self.causal = causal
|
||||
self.vae_dim = vae_dim
|
||||
self.fix_std = fix_std
|
||||
self.std_dist_type = std_dist_type
|
||||
|
||||
# common parameters
|
||||
self.conv_norm = conv_norm
|
||||
self.pad_mode = pad_mode
|
||||
self.layernorm_eps = layernorm_eps
|
||||
self.disable_last_norm = disable_last_norm
|
||||
self.layernorm = layernorm
|
||||
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
||||
self.conv_bias = conv_bias
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.weight_init_value = weight_init_value
|
||||
self.mixer_layer = mixer_layer
|
||||
|
||||
# encoder specific parameters
|
||||
self.encoder_n_filters = encoder_n_filters
|
||||
self.encoder_ratios = encoder_ratios
|
||||
self.encoder_depths = encoder_depths
|
||||
|
||||
|
||||
class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
|
||||
model_type = "vibevoice_diffusion_head"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
head_layers=4,
|
||||
head_ffn_ratio=3.0,
|
||||
rms_norm_eps=1e-5,
|
||||
latent_size=64,
|
||||
speech_vae_dim=None,
|
||||
prediction_type="v_prediction",
|
||||
diffusion_type="ddpm",
|
||||
ddpm_num_steps=1000,
|
||||
ddpm_num_inference_steps=20,
|
||||
ddpm_beta_schedule="cosine",
|
||||
ddpm_batch_mul=4,
|
||||
**kwargs
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_layers = head_layers
|
||||
self.head_ffn_ratio = head_ffn_ratio
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.latent_size = latent_size
|
||||
self.speech_vae_dim = speech_vae_dim
|
||||
self.prediction_type = prediction_type
|
||||
self.diffusion_type = diffusion_type
|
||||
self.ddpm_num_steps = ddpm_num_steps
|
||||
self.ddpm_num_inference_steps = ddpm_num_inference_steps
|
||||
self.ddpm_beta_schedule = ddpm_beta_schedule
|
||||
self.ddpm_batch_mul = ddpm_batch_mul
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
class VibeVoiceConfig(PretrainedConfig):
|
||||
model_type = "vibevoice"
|
||||
is_composition = True
|
||||
sub_configs = {
|
||||
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
|
||||
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
|
||||
"decoder_config": Qwen2Config,
|
||||
"diffusion_head_config": VibeVoiceDiffusionHeadConfig,
|
||||
}
|
||||
# keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen2`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
acoustic_tokenizer_config=None,
|
||||
semantic_tokenizer_config=None,
|
||||
decoder_config=None,
|
||||
diffusion_head_config=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# kwargs["_attn_implementation"] = "flash_attention_2"
|
||||
kwargs["_attn_implementation_autoset"] = False
|
||||
|
||||
if acoustic_tokenizer_config is None:
|
||||
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
|
||||
elif isinstance(acoustic_tokenizer_config, dict):
|
||||
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
|
||||
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
|
||||
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
|
||||
# If an instance of the config class is provided
|
||||
self.acoustic_tokenizer_config = acoustic_tokenizer_config
|
||||
|
||||
if semantic_tokenizer_config is None:
|
||||
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
|
||||
elif isinstance(semantic_tokenizer_config, dict):
|
||||
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
|
||||
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
|
||||
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
|
||||
# If an instance of the config class is provided
|
||||
self.semantic_tokenizer_config = semantic_tokenizer_config
|
||||
|
||||
if decoder_config is None:
|
||||
self.decoder_config = self.sub_configs["decoder_config"]()
|
||||
elif isinstance(decoder_config, dict):
|
||||
# If a dictionary is provided, instantiate the config class with it
|
||||
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
|
||||
if decoder_config.get("model_type", '') == "qwen2":
|
||||
self.decoder_config = Qwen2Config(**decoder_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
|
||||
elif isinstance(decoder_config, (Qwen2Config,)):
|
||||
# If an instance of the config class is provided
|
||||
self.decoder_config = decoder_config
|
||||
|
||||
if diffusion_head_config is None:
|
||||
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
|
||||
elif isinstance(diffusion_head_config, dict):
|
||||
diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
|
||||
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
|
||||
elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
|
||||
# If an instance of the config class is provided
|
||||
self.diffusion_head_config = diffusion_head_config
|
||||
|
||||
# other parameters
|
||||
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
|
||||
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceAcousticTokenizerConfig",
|
||||
"VibeVoiceSemanticTokenizerConfig",
|
||||
"VibeVoiceDiffusionHeadConfig",
|
||||
"VibeVoiceConfig"
|
||||
]
|
||||
488
vibevoice/modular/modeling_vibevoice.py
Normal file
488
vibevoice/modular/modeling_vibevoice.py
Normal file
@@ -0,0 +1,488 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union, Callable
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from transformers import modeling_utils
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
||||
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
||||
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
||||
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceCausalLMOutputWithPast(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
diffusion_loss: Optional[torch.FloatTensor] = None
|
||||
speech_token_num: Optional[int] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceGenerationOutput(ModelOutput):
|
||||
"""
|
||||
Output type for VibeVoice generation.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences.
|
||||
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
||||
List of generated speech waveforms or latents for each speech segment.
|
||||
"""
|
||||
sequences: torch.LongTensor = None
|
||||
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class SpeechConnector(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_dim, output_dim)
|
||||
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
|
||||
self.fc2 = nn.Linear(output_dim, output_dim)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = self.fc1(features)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
# @auto_docstring
|
||||
class VibeVoicePreTrainedModel(PreTrainedModel):
|
||||
config_class = VibeVoiceConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, VibeVoiceDiffusionHead):
|
||||
module.initialize_weights()
|
||||
return
|
||||
|
||||
# Use the language model's initializer_range if available
|
||||
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
|
||||
std = self.config.language_model_config.initializer_range
|
||||
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
|
||||
std = self.config.decoder_config.initializer_range
|
||||
else:
|
||||
std = 0.02 # Default value
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# @auto_docstring
|
||||
class VibeVoiceModel(VibeVoicePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
||||
if isinstance(config.torch_dtype, str):
|
||||
dtype = getattr(torch, config.torch_dtype)
|
||||
else:
|
||||
dtype = config.torch_dtype
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
# Initialize Qwen2 model for language modeling
|
||||
lm_config = config.decoder_config
|
||||
self.language_model = AutoModel.from_config(lm_config)
|
||||
|
||||
# Initialize speech components if needed
|
||||
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
|
||||
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
|
||||
|
||||
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||
|
||||
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
|
||||
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
|
||||
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
|
||||
|
||||
# Initialize prediction head for speech generation
|
||||
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
|
||||
|
||||
# Initialize noise scheduler
|
||||
self.noise_scheduler = DPMSolverMultistepScheduler(
|
||||
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
|
||||
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
|
||||
prediction_type=config.diffusion_head_config.prediction_type
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
if hasattr(self.language_model, 'embed_tokens'):
|
||||
# If the language model has an embed_tokens attribute, return it
|
||||
return self.language_model.embed_tokens
|
||||
|
||||
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
|
||||
if attr.orig_name == 'embed_tokens.weight':
|
||||
return getattr(self.language_model, name)
|
||||
assert False, 'should not arrive here'
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.embed_tokens = value
|
||||
|
||||
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
||||
"""Set the speech tokenizers used for encoding and decoding speech."""
|
||||
self.acoustic_tokenizer = acoustic_tokenizer
|
||||
self.semantic_tokenizer = semantic_tokenizer
|
||||
|
||||
# Reset the encoder to evaluation mode
|
||||
if self.acoustic_tokenizer is not None:
|
||||
self.acoustic_tokenizer.eval()
|
||||
|
||||
if self.semantic_tokenizer is not None:
|
||||
self.semantic_tokenizer.eval()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Forward through language model
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return outputs
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = VibeVoiceModel(config)
|
||||
self.vocab_size = config.decoder_config.vocab_size
|
||||
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.language_model
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings.
|
||||
"""
|
||||
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
|
||||
# The standard PreTrainedModel method will handle the tying.
|
||||
# It typically does a simple parameter object assignment, which is
|
||||
# CORRECT to do BEFORE FSDP wraps the model.
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
input_embeddings = self.get_input_embeddings()
|
||||
if hasattr(input_embeddings, 'weight'):
|
||||
output_embeddings.weight = input_embeddings.weight
|
||||
else:
|
||||
# maybe returned input_embeddings a tensor directly
|
||||
output_embeddings.weight = input_embeddings
|
||||
|
||||
if getattr(output_embeddings, "bias", None) is not None:
|
||||
output_embeddings.bias.data = nn.functional.pad(
|
||||
output_embeddings.bias.data,
|
||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
print("✅ Tied input and output embeddings using standard assignment.")
|
||||
else:
|
||||
print("ℹ️ tie_word_embeddings is False, not tying weights.")
|
||||
|
||||
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
|
||||
# The key is to avoid calling it after accelerator.prepare().
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
# Your current implementation using data.copy_ is good practice,
|
||||
# but the best way is to not call this after prepare().
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def forward_speech_features(
|
||||
self,
|
||||
speech_tensors=None,
|
||||
speech_masks=None,
|
||||
speech_type="audio",
|
||||
return_unmask=False
|
||||
):
|
||||
if speech_tensors is None:
|
||||
# Use config to get vae_dim instead of non-existent self.args
|
||||
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
||||
audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
|
||||
connect_features = self.model.acoustic_connector(audio_features)
|
||||
return audio_features, connect_features
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if speech_type == "audio":
|
||||
with torch.no_grad():
|
||||
frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
|
||||
audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||
|
||||
elif speech_type == "vae":
|
||||
# Use config to get vae_dim instead of non-existent self.args
|
||||
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
||||
speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
|
||||
|
||||
# gaussian sample from the speech_mode
|
||||
batch_size = speech_mode.size(0)
|
||||
value = self.model.acoustic_tokenizer.fix_std / 0.8
|
||||
std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
|
||||
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
|
||||
audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
|
||||
else:
|
||||
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
||||
|
||||
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
|
||||
scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
|
||||
bias_factor = -audio_tokens[speech_masks].flatten().mean()
|
||||
|
||||
# Only use distributed operations if the process group is initialized
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
|
||||
world_size = dist.get_world_size()
|
||||
self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
|
||||
self.model.speech_bias_factor.copy_(bias_factor / world_size)
|
||||
print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
||||
else:
|
||||
# Single process case
|
||||
self.model.speech_scaling_factor.copy_(scaling_factor)
|
||||
self.model.speech_bias_factor.copy_(bias_factor)
|
||||
print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
||||
|
||||
audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
|
||||
|
||||
connect_features = self.model.acoustic_connector(audio_features)
|
||||
if return_unmask:
|
||||
return audio_features, connect_features
|
||||
return audio_features[speech_masks], connect_features[speech_masks]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
# New arguments for speech processing and loss calculation
|
||||
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||
speech_masks: Optional[torch.BoolTensor] = None,
|
||||
speeches_loss_input: Optional[torch.FloatTensor] = None,
|
||||
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
||||
acoustic_input_mask: Optional[torch.BoolTensor] = None,
|
||||
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
|
||||
ddpm_batch_mul: int = 1,
|
||||
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
|
||||
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
x = self.get_input_embeddings()(input_ids)
|
||||
|
||||
semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
|
||||
if speeches_loss_input is not None:
|
||||
# only part audio need diffuse
|
||||
speech_all_features, speech_all_connect_features = self.forward_speech_features(
|
||||
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
||||
speech_masks=speech_masks,
|
||||
speech_type=kwargs.get("speech_type", "audio"),
|
||||
return_unmask=True
|
||||
)
|
||||
if speech_tensors is not None:
|
||||
if semantic_speech_all_connect_features is not None:
|
||||
x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks]
|
||||
else:
|
||||
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
|
||||
speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse
|
||||
speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks]
|
||||
else:
|
||||
speech_features, speech_connect_features = self.forward_speech_features(
|
||||
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
||||
speech_masks=speech_masks,
|
||||
speech_type=kwargs.get("speech_type", "audio"),
|
||||
)
|
||||
if speech_tensors is not None:
|
||||
x[acoustic_input_mask] = speech_connect_features
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=x,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=False,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
logits = self.lm_head(hidden_states)
|
||||
# logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# The custom CE loss with masking is calculated in the training script.
|
||||
# We leave the standard loss calculation here as None.
|
||||
pass
|
||||
|
||||
# --- Diffusion Loss Calculation ---
|
||||
diffusion_loss = None
|
||||
# This block is executed only if we are in a context that involves speech.
|
||||
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
|
||||
condition_features = hidden_states[acoustic_loss_mask]
|
||||
|
||||
speech_len, latent_size = speech_features.shape
|
||||
|
||||
noise = torch.randn(
|
||||
(speech_len * ddpm_batch_mul, latent_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
timesteps = torch.multinomial(
|
||||
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
|
||||
speech_len * ddpm_batch_mul,
|
||||
replacement=True,
|
||||
).to(hidden_states.device)
|
||||
|
||||
speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
||||
condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
||||
|
||||
noisy_speech_features = self.model.noise_scheduler.add_noise(
|
||||
speech_features_repeated, noise, timesteps
|
||||
)
|
||||
|
||||
model_output = self.model.prediction_head(
|
||||
noisy_speech_features,
|
||||
timesteps.type_as(x),
|
||||
condition_features_repeated
|
||||
)
|
||||
|
||||
prediction_type = self.config.diffusion_head_config.prediction_type
|
||||
if prediction_type == "epsilon":
|
||||
target_for_loss = noise
|
||||
elif prediction_type == "v_prediction":
|
||||
target_for_loss = self.model.noise_scheduler.get_velocity(
|
||||
speech_features_repeated, noise, timesteps
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
|
||||
|
||||
diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
|
||||
if latent_size > 0 and ddpm_batch_mul > 0:
|
||||
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
|
||||
else:
|
||||
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
|
||||
|
||||
else:
|
||||
# Dummy loss for DDP to work when there are no speech samples in a batch,
|
||||
# but we are in a speech context.
|
||||
diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
|
||||
diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
|
||||
diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
|
||||
# --- End Diffusion Loss Calculation ---
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, speech_len) + outputs.to_tuple()[1:]
|
||||
return (loss, diffusion_loss) + output
|
||||
|
||||
return VibeVoiceCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
diffusion_loss=diffusion_loss,
|
||||
speech_token_num=speech_len if speech_tensors is not None else 0,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
|
||||
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceModel",
|
||||
"VibeVoicePreTrainedModel",
|
||||
"VibeVoiceForConditionalGeneration",
|
||||
"VibeVoiceCausalLMOutputWithPast",
|
||||
"VibeVoiceGenerationOutput",
|
||||
]
|
||||
715
vibevoice/modular/modeling_vibevoice_inference.py
Normal file
715
vibevoice/modular/modeling_vibevoice_inference.py
Normal file
@@ -0,0 +1,715 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union, Callable
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||
|
||||
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from transformers import modeling_utils
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
# from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
||||
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput
|
||||
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
||||
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceConfig
|
||||
|
||||
from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
|
||||
|
||||
from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel
|
||||
from .streamer import AudioStreamer, AsyncAudioStreamer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
||||
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceGenerationOutput(ModelOutput):
|
||||
"""
|
||||
Output type for VibeVoice generation.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences.
|
||||
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
||||
List of generated speech waveforms or latents for each speech segment.
|
||||
"""
|
||||
sequences: torch.LongTensor = None
|
||||
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
||||
reach_max_step_sample: Optional[torch.BoolTensor] = None
|
||||
|
||||
class VibeVoiceTokenConstraintProcessor(LogitsProcessor):
|
||||
"""Constrains token generation to only valid tokens during speech generation."""
|
||||
|
||||
def __init__(self, valid_token_ids: List[int], device: torch.device = None):
|
||||
self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Create a mask for valid tokens
|
||||
mask = torch.full_like(scores, float('-inf'))
|
||||
mask[:, self.valid_token_ids] = 0
|
||||
|
||||
# Apply mask to scores
|
||||
scores = scores + mask
|
||||
return scores
|
||||
|
||||
class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
# Initialize the base model
|
||||
self.model = VibeVoiceModel(config)
|
||||
|
||||
# LM head for text generation
|
||||
self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False)
|
||||
|
||||
# inference configuration
|
||||
self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def noise_scheduler(self):
|
||||
return self.model.noise_scheduler
|
||||
|
||||
@property
|
||||
def prediction_head(self):
|
||||
return self.model.prediction_head
|
||||
|
||||
@property
|
||||
def speech_scaling_factor(self):
|
||||
return self.model.speech_scaling_factor
|
||||
|
||||
@property
|
||||
def speech_bias_factor(self):
|
||||
return self.model.speech_bias_factor
|
||||
|
||||
@property
|
||||
def acoustic_tokenizer(self):
|
||||
return self.model.acoustic_tokenizer
|
||||
|
||||
@property
|
||||
def semantic_tokenizer(self):
|
||||
return self.model.semantic_tokenizer
|
||||
|
||||
@property
|
||||
def acoustic_connector(self):
|
||||
return self.model.acoustic_connector
|
||||
|
||||
@property
|
||||
def semantic_connector(self):
|
||||
return self.model.semantic_connector
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings.
|
||||
"""
|
||||
# Tie lm_head.weight to language_model.embed_tokens.weight
|
||||
if not getattr(self.config, 'tie_word_embeddings', False):
|
||||
return
|
||||
|
||||
if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
|
||||
self.lm_head.weight = self.model.language_model.embed_tokens.weight
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
||||
"""Set the speech tokenizers used for encoding and decoding speech."""
|
||||
self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
|
||||
|
||||
def set_ddpm_inference_steps(self, num_steps=None):
|
||||
self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
|
||||
|
||||
def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
|
||||
"""Process speech inputs through tokenizers and connectors."""
|
||||
with torch.no_grad():
|
||||
if speech_type == "audio":
|
||||
# Encode audio to acoustic latents
|
||||
encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
|
||||
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||
|
||||
# Apply scaling and bias
|
||||
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
|
||||
|
||||
# Connect to language model space
|
||||
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
|
||||
|
||||
return acoustic_features, acoustic_connected
|
||||
elif speech_type == "pt":
|
||||
encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std)
|
||||
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||
|
||||
# Apply scaling and bias
|
||||
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
|
||||
|
||||
# Connect to language model space
|
||||
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
|
||||
|
||||
return acoustic_features, acoustic_connected
|
||||
else:
|
||||
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
||||
|
||||
# @can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||
speech_masks: Optional[torch.BoolTensor] = None,
|
||||
speech_input_mask: Optional[torch.BoolTensor] = None,
|
||||
logits_to_keep: Union[int, slice] = 0,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
||||
"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
speech_tensors (`torch.FloatTensor`, *optional*):
|
||||
Input speech waveforms for voice cloning or speech understanding.
|
||||
speech_masks (`torch.BoolTensor`, *optional*):
|
||||
Masks indicating valid speech frames.
|
||||
speech_input_mask (`torch.BoolTensor`, *optional*):
|
||||
Positions in the input sequence where speech embeddings should be inserted.
|
||||
|
||||
Returns:
|
||||
`VibeVoiceCausalLMOutputWithPast` or tuple
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Get embeddings
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||
|
||||
# Process speech inputs if provided
|
||||
if speech_tensors is not None and speech_masks is not None:
|
||||
acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors.to(self.dtype), speech_masks)
|
||||
if speech_input_mask is not None:
|
||||
inputs_embeds[speech_input_mask] = speech_embeds
|
||||
|
||||
outputs = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
raise NotImplementedError("Loss computation is not implemented in this version.")
|
||||
|
||||
return VibeVoiceCausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
last_hidden_state=hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
|
||||
if generation_config is None:
|
||||
generation_config = GenerationConfig(
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
)
|
||||
else:
|
||||
generation_config = GenerationConfig(
|
||||
**generation_config,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
generation_config, model_kwargs = self._prepare_generation_config(
|
||||
generation_config,
|
||||
True,
|
||||
speech_start_id=tokenizer.speech_start_id,
|
||||
speech_end_id=tokenizer.speech_end_id,
|
||||
speech_diffusion_id=tokenizer.speech_diffusion_id,
|
||||
**kwargs
|
||||
)
|
||||
generation_config.speech_start_id = tokenizer.speech_start_id
|
||||
generation_config.speech_end_id = tokenizer.speech_end_id
|
||||
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
|
||||
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
device = self.device
|
||||
|
||||
self._prepare_special_tokens(generation_config, True, device=device)
|
||||
generation_config.use_cache = True
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
input_ids = inputs_tensor.to(self.device)
|
||||
|
||||
input_ids_length = input_ids.shape[1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
has_default_min_length=has_default_min_length,
|
||||
model_input_name=model_input_name,
|
||||
inputs_tensor=inputs_tensor,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
max_cache_length = generation_config.max_length - 1
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
|
||||
for k, v in model_kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
model_kwargs[k] = v.to(device=device)
|
||||
|
||||
if return_processors:
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=LogitsProcessorList(),
|
||||
device=inputs_tensor.device,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
|
||||
|
||||
return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
|
||||
else:
|
||||
return generation_config, model_kwargs, input_ids
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||
speech_masks: Optional[torch.BoolTensor] = None,
|
||||
speech_input_mask: Optional[torch.BoolTensor] = None,
|
||||
return_speech: bool = True,
|
||||
cfg_scale: float = 1.0,
|
||||
stop_check_fn: Optional[Callable[[], bool]] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
|
||||
"""
|
||||
Generates sequences of token ids and optionally speech outputs.
|
||||
|
||||
Args:
|
||||
All standard generation arguments from GenerationMixin
|
||||
negative_prompt_ids: Negative prompt for CFG in speech generation
|
||||
negative_prompt_attention_mask: Attention mask for negative prompt
|
||||
speech_tensors: Input speech for voice cloning
|
||||
speech_masks: Masks for speech tensors
|
||||
speech_input_mask: Positions to insert speech embeddings
|
||||
return_speech: Whether to decode and return speech outputs
|
||||
cfg_scale: CFG scale for speech generation
|
||||
stop_check_fn: Optional callable that returns True if generation should stop
|
||||
|
||||
Returns:
|
||||
Generated token sequences and optionally speech outputs
|
||||
"""
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||
parsed_scripts = kwargs.pop("parsed_scripts", None)
|
||||
all_speakers_list = kwargs.pop("all_speakers_list", None)
|
||||
max_length_times = kwargs.pop("max_length_times", 2)
|
||||
|
||||
if kwargs.get('max_new_tokens', None) is None:
|
||||
kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]
|
||||
|
||||
generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
|
||||
generation_config, inputs, tokenizer, return_processors=True, **kwargs
|
||||
)
|
||||
|
||||
negative_kwargs = {
|
||||
'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
|
||||
'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
|
||||
'max_new_tokens': kwargs.get('max_new_tokens', 100)
|
||||
}
|
||||
negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
|
||||
None, None, tokenizer, return_processors=False, **negative_kwargs
|
||||
)
|
||||
|
||||
acoustic_cache = VibeVoiceTokenizerStreamingCache()
|
||||
semantic_cache = VibeVoiceTokenizerStreamingCache()
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
device = input_ids.device
|
||||
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
||||
correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
is_prefill = True
|
||||
inputs_embeds = None
|
||||
verbose = kwargs.get("verbose", False)
|
||||
|
||||
# Initialize audio chunks storage for each sample
|
||||
audio_chunks = [[] for _ in range(batch_size)]
|
||||
|
||||
initial_length = input_ids.shape[-1]
|
||||
initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)
|
||||
|
||||
# Define all valid tokens that can be generated
|
||||
valid_tokens = [
|
||||
generation_config.speech_start_id,
|
||||
generation_config.speech_end_id,
|
||||
generation_config.speech_diffusion_id,
|
||||
generation_config.eos_token_id
|
||||
]
|
||||
# Add bos_token_id if it exists
|
||||
if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
|
||||
valid_tokens.append(generation_config.bos_token_id)
|
||||
|
||||
# Add custom processor to constrain token generation
|
||||
token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(token_constraint_processor)
|
||||
|
||||
max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
|
||||
max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long())
|
||||
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create progress iterator if verbose
|
||||
if kwargs.get("show_progress_bar", True):
|
||||
progress_bar = tqdm(range(max_steps), desc="Generating", leave=False)
|
||||
else:
|
||||
progress_bar = range(max_steps)
|
||||
|
||||
for step in progress_bar:
|
||||
# Check for external stop signal
|
||||
if stop_check_fn is not None and stop_check_fn():
|
||||
if verbose:
|
||||
print(f"Generation stopped externally at step {step + 1}")
|
||||
# End the audio streamer if it exists
|
||||
if audio_streamer is not None:
|
||||
audio_streamer.end()
|
||||
break
|
||||
|
||||
# Check if audio_streamer has been ended (stopped externally)
|
||||
if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
|
||||
if any(audio_streamer.finished_flags):
|
||||
if verbose:
|
||||
print(f"Audio generation stopped externally at step {step + 1}")
|
||||
break
|
||||
|
||||
if finished_tags.all():
|
||||
if hasattr(progress_bar, 'set_description'):
|
||||
progress_bar.set_description("Generation complete")
|
||||
break
|
||||
|
||||
if input_ids.shape[-1] >= generation_config.max_length:
|
||||
print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
|
||||
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
|
||||
if reached_samples.numel() > 0:
|
||||
reach_max_step_sample[reached_samples] = True
|
||||
break
|
||||
|
||||
# Update progress bar description with active samples
|
||||
if hasattr(progress_bar, 'set_description'):
|
||||
active_samples = (~finished_tags).sum().item()
|
||||
progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})")
|
||||
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
if is_prefill:
|
||||
# we process the speech inputs only during the first generation step
|
||||
prefill_inputs = {
|
||||
"speech_tensors": speech_tensors.to(device=device),
|
||||
"speech_masks": speech_masks.to(device),
|
||||
"speech_input_mask": speech_input_mask.to(device),
|
||||
}
|
||||
is_prefill = False
|
||||
else:
|
||||
_ = model_inputs.pop('inputs_embeds', None)
|
||||
prefill_inputs = {'inputs_embeds': inputs_embeds}
|
||||
|
||||
# Forward pass through the model
|
||||
outputs = self(
|
||||
**model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False,
|
||||
)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=False,
|
||||
)
|
||||
|
||||
# Get logits and apply logits processor
|
||||
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
||||
# next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
|
||||
# token selection
|
||||
if generation_config.do_sample:
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||
|
||||
next_tokens[finished_tags] = generation_config.eos_token_id
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
||||
if not kwargs.get('refresh_negative', True):
|
||||
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
|
||||
# Forward negative pass through the model
|
||||
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
|
||||
negative_model_inputs['inputs_embeds'] = inputs_embeds
|
||||
negative_model_inputs['input_ids'] = None
|
||||
|
||||
negative_outputs = self(
|
||||
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
|
||||
)
|
||||
negative_model_kwargs = self._update_model_kwargs_for_generation(
|
||||
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
|
||||
)
|
||||
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
||||
# reached end of generation
|
||||
if (next_tokens == generation_config.eos_token_id).any():
|
||||
eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
|
||||
# Only print for samples that are newly finished (not already marked as finished)
|
||||
new_eos_indices = eos_indices[~finished_tags[eos_indices]]
|
||||
if new_eos_indices.numel() > 0:
|
||||
finished_tags[new_eos_indices] = True
|
||||
if verbose:
|
||||
print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True)
|
||||
if audio_streamer is not None:
|
||||
audio_streamer.end(new_eos_indices)
|
||||
|
||||
# Check if any sample reached its maximum generation length
|
||||
max_length_reached = step >= max_step_per_sample
|
||||
new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
|
||||
if new_max_length_indices.numel() > 0:
|
||||
finished_tags[new_max_length_indices] = True
|
||||
reach_max_step_sample[new_max_length_indices] = True
|
||||
if verbose:
|
||||
print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True)
|
||||
if audio_streamer is not None:
|
||||
audio_streamer.end(new_max_length_indices)
|
||||
|
||||
# speech_end
|
||||
diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
|
||||
if diffusion_end_indices.numel() > 0:
|
||||
# Clear tokenizer caches for samples that reached speech end
|
||||
acoustic_cache.set_to_zero(diffusion_end_indices)
|
||||
semantic_cache.set_to_zero(diffusion_end_indices)
|
||||
|
||||
# speech_begin
|
||||
diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)]
|
||||
if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True):
|
||||
# update attention mask
|
||||
for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
|
||||
negative_model_kwargs['attention_mask'][sample_idx, :] = 0
|
||||
negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
|
||||
# update past key values
|
||||
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
||||
negative_model_kwargs['past_key_values'].value_cache)):
|
||||
# Process each non-diffusion sample
|
||||
for sample_idx in diffusion_start_indices.tolist():
|
||||
# Shift cache for this sample
|
||||
k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
|
||||
v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
|
||||
# update negative_input_ids
|
||||
for sample_idx in diffusion_start_indices.tolist():
|
||||
negative_input_ids[sample_idx, -1] = generation_config.speech_start_id
|
||||
|
||||
# Prepare inputs_embeds for next iteration
|
||||
# Initialize with default embeddings for all tokens
|
||||
next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size]
|
||||
|
||||
# forward diffusion
|
||||
# Diffusion indices are those that are not finished and not special tokens
|
||||
diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)]
|
||||
|
||||
if diffusion_indices.numel() > 0:
|
||||
if kwargs.get('refresh_negative', True):
|
||||
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
|
||||
# Forward negative pass through the model
|
||||
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
|
||||
negative_model_inputs['inputs_embeds'] = inputs_embeds
|
||||
negative_model_inputs['input_ids'] = None
|
||||
|
||||
negative_outputs = self(
|
||||
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
|
||||
)
|
||||
negative_model_kwargs = self._update_model_kwargs_for_generation(
|
||||
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
|
||||
)
|
||||
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
|
||||
# correct the non-diffusion indices
|
||||
# we forward all samples' negative outputs even if
|
||||
# they are not in diffusion mode to keep the cache consistent
|
||||
# So we need to correct the kv cache of non-diffusion samples
|
||||
non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id)
|
||||
if non_diffusion_mask.any():
|
||||
non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask]
|
||||
start_indices = correct_cnt[non_diffusion_indices]
|
||||
|
||||
# 1. Update attention_mask - need to handle each sample separately
|
||||
seq_len = negative_model_kwargs['attention_mask'].shape[1]
|
||||
for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())):
|
||||
# Shift the attention mask for this sample
|
||||
if start_idx + 1 < seq_len - 1:
|
||||
negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \
|
||||
negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone()
|
||||
negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
|
||||
|
||||
# 2. Update past_key_values
|
||||
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
||||
negative_model_kwargs['past_key_values'].value_cache)):
|
||||
# Process each non-diffusion sample
|
||||
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
||||
if start_idx + 1 < k_cache.shape[2] - 1:
|
||||
# Shift cache for this sample
|
||||
k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone()
|
||||
v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone()
|
||||
|
||||
# 3. Update negative_input_ids
|
||||
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
||||
if start_idx + 1 < negative_input_ids.shape[1] - 1:
|
||||
negative_input_ids[sample_idx, start_idx+1:] = \
|
||||
negative_input_ids[sample_idx, start_idx:-1].clone()
|
||||
|
||||
correct_cnt[non_diffusion_indices] += 1
|
||||
|
||||
positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
|
||||
negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :]
|
||||
|
||||
speech_latent = self.sample_speech_tokens(
|
||||
positive_condition,
|
||||
negative_condition,
|
||||
cfg_scale=cfg_scale,
|
||||
).unsqueeze(1)
|
||||
|
||||
# Decode acoustic latent to audio using acoustic streaming cache
|
||||
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
|
||||
audio_chunk = self.model.acoustic_tokenizer.decode(
|
||||
scaled_latent.to(self.model.acoustic_tokenizer.device),
|
||||
cache=acoustic_cache, # Use acoustic-specific cache
|
||||
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
|
||||
use_cache=True,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Store audio chunks for each sample
|
||||
for i, sample_idx in enumerate(diffusion_indices):
|
||||
idx = sample_idx.item()
|
||||
# Only append audio chunk if the sample is not finished
|
||||
if not finished_tags[idx]:
|
||||
audio_chunks[idx].append(audio_chunk[i])
|
||||
|
||||
# Add streaming support here
|
||||
if audio_streamer is not None:
|
||||
# Stream the audio chunks immediately
|
||||
audio_streamer.put(audio_chunk, diffusion_indices)
|
||||
|
||||
# Encode audio to semantic features using semantic streaming cache
|
||||
semantic_features = self.model.semantic_tokenizer.encode(
|
||||
audio_chunk,
|
||||
cache=semantic_cache, # Use semantic-specific cache
|
||||
sample_indices=diffusion_indices,
|
||||
use_cache=True,
|
||||
debug=False
|
||||
).mean # semantic tokenizer has no VAE.
|
||||
|
||||
# Combine acoustic and semantic features for next input
|
||||
acoustic_embed = self.model.acoustic_connector(speech_latent)
|
||||
semantic_embed = self.model.semantic_connector(semantic_features)
|
||||
diffusion_embeds = acoustic_embed + semantic_embed
|
||||
|
||||
# Update embeddings for diffusion indices
|
||||
next_inputs_embeds[diffusion_indices] = diffusion_embeds
|
||||
|
||||
# Set inputs_embeds for next iteration
|
||||
inputs_embeds = next_inputs_embeds
|
||||
|
||||
if audio_streamer is not None:
|
||||
audio_streamer.end()
|
||||
|
||||
# Concatenate audio chunks for each sample
|
||||
final_audio_outputs = []
|
||||
for sample_chunks in audio_chunks:
|
||||
if sample_chunks:
|
||||
# Concatenate all chunks along the time dimension (assumed to be the last dimension)
|
||||
concatenated_audio = torch.cat(sample_chunks, dim=-1)
|
||||
final_audio_outputs.append(concatenated_audio)
|
||||
else:
|
||||
# If no audio was generated for this sample, append None
|
||||
final_audio_outputs.append(None)
|
||||
|
||||
return VibeVoiceGenerationOutput(
|
||||
sequences=input_ids,
|
||||
speech_outputs=final_audio_outputs if return_speech else None,
|
||||
reach_max_step_sample=reach_max_step_sample,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
|
||||
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
|
||||
condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
|
||||
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
|
||||
for t in self.model.noise_scheduler.timesteps:
|
||||
half = speech[: len(speech) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
|
||||
return speech[: len(speech) // 2]
|
||||
|
||||
|
||||
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceForConditionalGenerationInference",
|
||||
]
|
||||
287
vibevoice/modular/modular_vibevoice_diffusion_head.py
Normal file
287
vibevoice/modular/modular_vibevoice_diffusion_head.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.models.auto import AutoModel
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
# from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.utils import logging
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if self.weight is not None:
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
"""Apply modulation to input tensor."""
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`): Size of the output embedding
|
||||
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
|
||||
# nn.SiLU(),
|
||||
ACT2FN['silu'],
|
||||
nn.Linear(hidden_size, hidden_size, bias=False),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
dim (`int`): The dimension of the output.
|
||||
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding.to(t.dtype)
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FeedForwardNetwork(nn.Module):
|
||||
"""
|
||||
Standard feed-forward network with SwiGLU activation.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): Input dimension
|
||||
ffn_dim (`int`): Hidden dimension
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
ffn_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
||||
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
|
||||
|
||||
def forward(self, x):
|
||||
gate = self.gate_proj(x)
|
||||
up = self.up_proj(x)
|
||||
|
||||
# SwiGLU activation
|
||||
# gate = F.silu(gate)
|
||||
gate = self.act_fn(gate)
|
||||
return self.down_proj(gate * up)
|
||||
|
||||
|
||||
class HeadLayer(nn.Module):
|
||||
"""
|
||||
A layer in the diffusion head.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): Input dimension
|
||||
ffn_dim (`int`): Hidden dimension
|
||||
cond_dim (`int`): Condition embedding dimension
|
||||
norm_eps (`float`, optional): Epsilon for normalization
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
ffn_dim,
|
||||
cond_dim,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.cond_dim = cond_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.ffn = FeedForwardNetwork(
|
||||
self.embed_dim,
|
||||
self.ffn_dim,
|
||||
)
|
||||
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
# nn.SiLU(),
|
||||
ACT2FN['silu'],
|
||||
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
|
||||
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
Final layer in the diffusion head.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`): Input dimension
|
||||
output_size (`int`): Output dimension
|
||||
cond_size (`int`): Condition embedding dimension
|
||||
norm_eps (`float`, optional): Epsilon for normalization
|
||||
"""
|
||||
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
|
||||
super().__init__()
|
||||
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
|
||||
self.linear = nn.Linear(hidden_size, output_size, bias=False)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
# nn.SiLU(),
|
||||
ACT2FN['silu'],
|
||||
nn.Linear(cond_size, 2 * hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class VibeVoiceDiffusionHead(PreTrainedModel):
|
||||
"""
|
||||
Diffusion head model for vibevoice.
|
||||
|
||||
Args:
|
||||
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
|
||||
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
|
||||
"""
|
||||
config_class = VibeVoiceDiffusionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.cond_dim = config.hidden_size
|
||||
latent_size = config.latent_size
|
||||
|
||||
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
|
||||
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
|
||||
self.t_embedder = TimestepEmbedder(self.cond_dim)
|
||||
|
||||
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
|
||||
|
||||
# Create the intermediate layers
|
||||
self.layers = nn.ModuleList([
|
||||
HeadLayer(
|
||||
embed_dim=config.hidden_size,
|
||||
ffn_dim=ffn_dim,
|
||||
cond_dim=self.cond_dim,
|
||||
norm_eps=config.rms_norm_eps
|
||||
)
|
||||
for _ in range(config.head_layers)
|
||||
])
|
||||
|
||||
# Final layer for output
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size=config.hidden_size,
|
||||
output_size=latent_size,
|
||||
cond_size=self.cond_dim,
|
||||
norm_eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
"""Initialize the weights of the model."""
|
||||
# Initialize timestep embedder
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out adaLN modulation layers
|
||||
for layer in self.layers:
|
||||
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
|
||||
|
||||
# Zero-out output layers
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
noisy_images,
|
||||
timesteps,
|
||||
condition,
|
||||
):
|
||||
"""
|
||||
Forward pass of the prediction head.
|
||||
|
||||
Args:
|
||||
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
|
||||
timesteps (`torch.Tensor`): Timesteps for diffusion
|
||||
condition (`torch.Tensor`): Conditioning information
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The predicted noise/velocity
|
||||
"""
|
||||
x = self.noisy_images_proj(noisy_images)
|
||||
t = self.t_embedder(timesteps)
|
||||
condition = self.cond_proj(condition)
|
||||
c = condition + t
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, c)
|
||||
|
||||
x = self.final_layer(x, c)
|
||||
return x
|
||||
|
||||
|
||||
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceDiffusionHead",
|
||||
]
|
||||
214
vibevoice/modular/modular_vibevoice_text_tokenizer.py
Normal file
214
vibevoice/modular/modular_vibevoice_text_tokenizer.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Tokenization classes for vibevoice."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers.utils import logging
|
||||
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
|
||||
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VibeVoiceTextTokenizer(Qwen2Tokenizer):
|
||||
"""
|
||||
Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`):
|
||||
Path to the merges file.
|
||||
errors (`str`, *optional*, defaults to `"replace"`):
|
||||
Paradigm to follow when decoding bytes to UTF-8.
|
||||
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The unknown token.
|
||||
bos_token (`str`, *optional*):
|
||||
The beginning of sequence token. Not used for vibevoice.
|
||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The token used for padding.
|
||||
add_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to add special tokens when encoding.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
errors="replace",
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token=None,
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token="<|endoftext|>",
|
||||
add_prefix_space=False,
|
||||
add_special_tokens=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
errors=errors,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
add_special_tokens=add_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add VibeVoice-specific special tokens
|
||||
self._add_vibevoice_special_tokens()
|
||||
|
||||
def _add_vibevoice_special_tokens(self):
|
||||
"""Add VibeVoice-specific special tokens."""
|
||||
special_tokens = {
|
||||
"additional_special_tokens": [
|
||||
"<|vision_start|>", # Speech start (reusing vision tokens)
|
||||
"<|vision_end|>", # Speech end
|
||||
"<|vision_pad|>", # Speech diffusion pad
|
||||
]
|
||||
}
|
||||
num_added = self.add_special_tokens(special_tokens)
|
||||
|
||||
# Cache special token IDs
|
||||
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
|
||||
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
|
||||
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
|
||||
|
||||
self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
|
||||
|
||||
return num_added
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
"""Id of the end of sequence token."""
|
||||
return self._eos_id
|
||||
|
||||
@property
|
||||
def speech_start_id(self) -> int:
|
||||
"""Id of the speech start token."""
|
||||
return self._speech_start_id
|
||||
|
||||
@property
|
||||
def speech_end_id(self) -> int:
|
||||
"""Id of the speech end token."""
|
||||
return self._speech_end_id
|
||||
|
||||
@property
|
||||
def speech_diffusion_id(self) -> int:
|
||||
"""Id of the speech diffusion token."""
|
||||
return self._speech_diffusion_id
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
"""Id used for padding (returns -100 for loss masking)."""
|
||||
return -100
|
||||
|
||||
|
||||
class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
|
||||
"""
|
||||
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
|
||||
Based on the Qwen2 tokenizer with additional special tokens for speech.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`, *optional*):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`, *optional*):
|
||||
Path to the merges file.
|
||||
tokenizer_file (`str`, *optional*):
|
||||
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
|
||||
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The unknown token.
|
||||
bos_token (`str`, *optional*):
|
||||
The beginning of sequence token. Not used for vibevoice.
|
||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The token used for padding.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token=None,
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token="<|endoftext|>",
|
||||
add_prefix_space=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add VibeVoice-specific special tokens
|
||||
self._add_vibevoice_special_tokens()
|
||||
|
||||
def _add_vibevoice_special_tokens(self):
|
||||
"""Add VibeVoice-specific special tokens."""
|
||||
special_tokens = {
|
||||
"additional_special_tokens": [
|
||||
"<|vision_start|>", # Speech start (reusing vision tokens)
|
||||
"<|vision_end|>", # Speech end
|
||||
"<|vision_pad|>", # Speech diffusion pad
|
||||
]
|
||||
}
|
||||
num_added = self.add_special_tokens(special_tokens)
|
||||
|
||||
# Cache special token IDs
|
||||
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
|
||||
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
|
||||
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
|
||||
|
||||
# self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
|
||||
self._eos_id = self.eos_token_id # qwen2 / qwen3
|
||||
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
|
||||
|
||||
return num_added
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
"""Id of the end of sequence token."""
|
||||
return self._eos_id
|
||||
|
||||
@property
|
||||
def speech_start_id(self) -> int:
|
||||
"""Id of the speech start token."""
|
||||
return self._speech_start_id
|
||||
|
||||
@property
|
||||
def speech_end_id(self) -> int:
|
||||
"""Id of the speech end token."""
|
||||
return self._speech_end_id
|
||||
|
||||
@property
|
||||
def speech_diffusion_id(self) -> int:
|
||||
"""Id of the speech diffusion token."""
|
||||
return self._speech_diffusion_id
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
"""Id used for padding (returns -100 for loss masking)."""
|
||||
return self._pad_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceTextTokenizer",
|
||||
"VibeVoiceTextTokenizerFast",
|
||||
]
|
||||
1195
vibevoice/modular/modular_vibevoice_tokenizer.py
Normal file
1195
vibevoice/modular/modular_vibevoice_tokenizer.py
Normal file
File diff suppressed because it is too large
Load Diff
264
vibevoice/modular/streamer.py
Normal file
264
vibevoice/modular/streamer.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import asyncio
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
|
||||
from transformers.generation import BaseStreamer
|
||||
|
||||
|
||||
class AudioStreamer(BaseStreamer):
|
||||
"""
|
||||
Audio streamer that stores audio chunks in queues for each sample in the batch.
|
||||
This allows streaming audio generation for multiple samples simultaneously.
|
||||
|
||||
Parameters:
|
||||
batch_size (`int`):
|
||||
The batch size for generation
|
||||
stop_signal (`any`, *optional*):
|
||||
The signal to put in the queue when generation ends. Defaults to None.
|
||||
timeout (`float`, *optional*):
|
||||
The timeout for the audio queue. If `None`, the queue will block indefinitely.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
stop_signal: Optional[any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.stop_signal = stop_signal
|
||||
self.timeout = timeout
|
||||
|
||||
# Create a queue for each sample in the batch
|
||||
self.audio_queues = [Queue() for _ in range(batch_size)]
|
||||
self.finished_flags = [False for _ in range(batch_size)]
|
||||
self.sample_indices_map = {} # Maps from sample index to queue index
|
||||
|
||||
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
||||
"""
|
||||
Receives audio chunks and puts them in the appropriate queues.
|
||||
|
||||
Args:
|
||||
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
|
||||
sample_indices: Tensor indicating which samples these chunks belong to
|
||||
"""
|
||||
for i, sample_idx in enumerate(sample_indices):
|
||||
idx = sample_idx.item()
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
# Convert to numpy or keep as tensor based on preference
|
||||
audio_chunk = audio_chunks[i].detach().cpu()
|
||||
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
|
||||
|
||||
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Signals the end of generation for specified samples or all samples.
|
||||
|
||||
Args:
|
||||
sample_indices: Optional tensor of sample indices to end. If None, ends all.
|
||||
"""
|
||||
if sample_indices is None:
|
||||
# End all samples
|
||||
for idx in range(self.batch_size):
|
||||
if not self.finished_flags[idx]:
|
||||
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
||||
self.finished_flags[idx] = True
|
||||
else:
|
||||
# End specific samples
|
||||
for sample_idx in sample_indices:
|
||||
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
||||
self.finished_flags[idx] = True
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator over the batch of audio streams."""
|
||||
return AudioBatchIterator(self)
|
||||
|
||||
def get_stream(self, sample_idx: int):
|
||||
"""Get the audio stream for a specific sample."""
|
||||
if sample_idx >= self.batch_size:
|
||||
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
||||
return AudioSampleIterator(self, sample_idx)
|
||||
|
||||
|
||||
class AudioSampleIterator:
|
||||
"""Iterator for a single audio stream from the batch."""
|
||||
|
||||
def __init__(self, streamer: AudioStreamer, sample_idx: int):
|
||||
self.streamer = streamer
|
||||
self.sample_idx = sample_idx
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
|
||||
if value == self.streamer.stop_signal:
|
||||
raise StopIteration()
|
||||
return value
|
||||
|
||||
|
||||
class AudioBatchIterator:
|
||||
"""Iterator that yields audio chunks for all samples in the batch."""
|
||||
|
||||
def __init__(self, streamer: AudioStreamer):
|
||||
self.streamer = streamer
|
||||
self.active_samples = set(range(streamer.batch_size))
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not self.active_samples:
|
||||
raise StopIteration()
|
||||
|
||||
batch_chunks = {}
|
||||
samples_to_remove = set()
|
||||
|
||||
# Try to get chunks from all active samples
|
||||
for idx in self.active_samples:
|
||||
try:
|
||||
value = self.streamer.audio_queues[idx].get(block=False)
|
||||
if value == self.streamer.stop_signal:
|
||||
samples_to_remove.add(idx)
|
||||
else:
|
||||
batch_chunks[idx] = value
|
||||
except:
|
||||
# Queue is empty for this sample, skip it this iteration
|
||||
pass
|
||||
|
||||
# Remove finished samples
|
||||
self.active_samples -= samples_to_remove
|
||||
|
||||
if batch_chunks:
|
||||
return batch_chunks
|
||||
elif self.active_samples:
|
||||
# If no chunks were ready but we still have active samples,
|
||||
# wait a bit and try again
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
return self.__next__()
|
||||
else:
|
||||
raise StopIteration()
|
||||
|
||||
|
||||
class AsyncAudioStreamer(AudioStreamer):
|
||||
"""
|
||||
Async version of AudioStreamer for use in async contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
stop_signal: Optional[any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
super().__init__(batch_size, stop_signal, timeout)
|
||||
# Replace regular queues with async queues
|
||||
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
|
||||
self.loop = asyncio.get_running_loop()
|
||||
|
||||
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
||||
"""Put audio chunks in the appropriate async queues."""
|
||||
for i, sample_idx in enumerate(sample_indices):
|
||||
idx = sample_idx.item()
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
audio_chunk = audio_chunks[i].detach().cpu()
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.audio_queues[idx].put_nowait, audio_chunk
|
||||
)
|
||||
|
||||
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
||||
"""Signal the end of generation for specified samples."""
|
||||
if sample_indices is None:
|
||||
indices_to_end = range(self.batch_size)
|
||||
else:
|
||||
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
||||
|
||||
for idx in indices_to_end:
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.audio_queues[idx].put_nowait, self.stop_signal
|
||||
)
|
||||
self.finished_flags[idx] = True
|
||||
|
||||
async def get_stream(self, sample_idx: int):
|
||||
"""Get async iterator for a specific sample's audio stream."""
|
||||
if sample_idx >= self.batch_size:
|
||||
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
||||
|
||||
while True:
|
||||
value = await self.audio_queues[sample_idx].get()
|
||||
if value == self.stop_signal:
|
||||
break
|
||||
yield value
|
||||
|
||||
def __aiter__(self):
|
||||
"""Returns an async iterator over all audio streams."""
|
||||
return AsyncAudioBatchIterator(self)
|
||||
|
||||
|
||||
class AsyncAudioBatchIterator:
|
||||
"""Async iterator for batch audio streaming."""
|
||||
|
||||
def __init__(self, streamer: AsyncAudioStreamer):
|
||||
self.streamer = streamer
|
||||
self.active_samples = set(range(streamer.batch_size))
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self.active_samples:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
batch_chunks = {}
|
||||
samples_to_remove = set()
|
||||
|
||||
# Create tasks for all active samples
|
||||
tasks = {
|
||||
idx: asyncio.create_task(self._get_chunk(idx))
|
||||
for idx in self.active_samples
|
||||
}
|
||||
|
||||
# Wait for at least one chunk to be ready
|
||||
done, pending = await asyncio.wait(
|
||||
tasks.values(),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
timeout=self.streamer.timeout
|
||||
)
|
||||
|
||||
# Cancel pending tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Process completed tasks
|
||||
for idx, task in tasks.items():
|
||||
if task in done:
|
||||
try:
|
||||
value = await task
|
||||
if value == self.streamer.stop_signal:
|
||||
samples_to_remove.add(idx)
|
||||
else:
|
||||
batch_chunks[idx] = value
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.active_samples -= samples_to_remove
|
||||
|
||||
if batch_chunks:
|
||||
return batch_chunks
|
||||
elif self.active_samples:
|
||||
# Try again if we still have active samples
|
||||
return await self.__anext__()
|
||||
else:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
async def _get_chunk(self, idx):
|
||||
"""Helper to get a chunk from a specific queue."""
|
||||
return await self.streamer.audio_queues[idx].get()
|
||||
Reference in New Issue
Block a user