Cloning from Github
This commit is contained in:
677
vibevoice/processor/vibevoice_processor.py
Normal file
677
vibevoice/processor/vibevoice_processor.py
Normal file
@@ -0,0 +1,677 @@
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Union, Dict, Any, Tuple
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from transformers.utils import TensorType, logging
|
||||
from .vibevoice_tokenizer_processor import AudioNormalizer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VibeVoiceProcessor:
|
||||
r"""
|
||||
Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
|
||||
|
||||
[`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
|
||||
See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
|
||||
The tokenizer for text processing.
|
||||
audio_processor (`VibeVoiceTokenizerProcessor`):
|
||||
The audio processor for speech processing.
|
||||
speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
|
||||
The compression ratio for speech tokenization.
|
||||
db_normalize (`bool`, *optional*, defaults to True):
|
||||
Whether to apply decibel normalization to audio inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
|
||||
self.tokenizer = tokenizer
|
||||
self.audio_processor = audio_processor
|
||||
self.speech_tok_compress_ratio = speech_tok_compress_ratio
|
||||
self.db_normalize = db_normalize
|
||||
self.audio_normalizer = AudioNormalizer() if db_normalize else None
|
||||
self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
"""
|
||||
Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
- a string, the *model id* of a pretrained model
|
||||
- a path to a *directory* containing processor config
|
||||
|
||||
Returns:
|
||||
[`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
||||
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
|
||||
VibeVoiceTextTokenizer,
|
||||
VibeVoiceTextTokenizerFast
|
||||
)
|
||||
|
||||
# Load processor configuration
|
||||
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults")
|
||||
config = {
|
||||
"speech_tok_compress_ratio": 3200,
|
||||
"db_normalize": True,
|
||||
}
|
||||
|
||||
# Extract main processor parameters
|
||||
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
|
||||
db_normalize = config.get("db_normalize", True)
|
||||
|
||||
# Load tokenizer - try from model path first, then fallback to Qwen
|
||||
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
|
||||
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
|
||||
if 'qwen' in language_model_pretrained_name.lower():
|
||||
tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
|
||||
language_model_pretrained_name,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
|
||||
|
||||
# Load audio processor
|
||||
if "audio_processor" in config:
|
||||
# Create audio processor from config
|
||||
audio_config = config["audio_processor"]
|
||||
audio_processor = VibeVoiceTokenizerProcessor(
|
||||
sampling_rate=audio_config.get("sampling_rate", 24000),
|
||||
normalize_audio=audio_config.get("normalize_audio", True),
|
||||
target_dB_FS=audio_config.get("target_dB_FS", -25),
|
||||
eps=audio_config.get("eps", 1e-6),
|
||||
)
|
||||
else:
|
||||
# Create default audio processor
|
||||
audio_processor = VibeVoiceTokenizerProcessor()
|
||||
|
||||
# Create and return the processor
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
speech_tok_compress_ratio=speech_tok_compress_ratio,
|
||||
db_normalize=db_normalize,
|
||||
)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
||||
"""
|
||||
Save a processor to a directory, so that it can be re-loaded using the
|
||||
[`~VibeVoiceProcessor.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the processor will be saved.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# Save processor configuration
|
||||
processor_config = {
|
||||
"processor_class": "VibeVoiceProcessor",
|
||||
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
|
||||
"db_normalize": self.db_normalize,
|
||||
"audio_processor": {
|
||||
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
|
||||
"sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
|
||||
"normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
|
||||
"target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
|
||||
"eps": getattr(self.audio_processor, 'eps', 1e-6),
|
||||
}
|
||||
}
|
||||
|
||||
config_path = os.path.join(save_directory, "preprocessor_config.json")
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(processor_config, f, indent=2)
|
||||
|
||||
logger.info(f"Processor configuration saved in {config_path}")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||
voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
|
||||
padding: Union[bool, str, PaddingStrategy] = True,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_attention_mask: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Main method to process one or more podcast scripts with optional voice samples.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`):
|
||||
The input text(s) to process. Can be:
|
||||
- A single script string
|
||||
- A list of script strings for batch processing
|
||||
- A path to a .json or .txt file
|
||||
- A list of paths
|
||||
voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
|
||||
Voice samples for each script. Can be:
|
||||
- A list of samples for a single script
|
||||
- A list of lists for batch processing
|
||||
padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
|
||||
Whether to pad sequences to the same length
|
||||
truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
|
||||
Whether to truncate sequences
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned sequences
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
If set, will return tensors of a particular framework
|
||||
return_attention_mask (`bool`, defaults to `True`):
|
||||
Whether to return the attention mask
|
||||
|
||||
Returns:
|
||||
`BatchEncoding`: A BatchEncoding with the following fields:
|
||||
- **input_ids** -- List of token id sequences or tensor
|
||||
- **attention_mask** -- List of attention masks or tensor
|
||||
- **speech_tensors** -- Padded speech inputs (if voice_samples provided)
|
||||
- **speech_masks** -- Speech masks (if voice_samples provided)
|
||||
- **speech_input_mask** -- Boolean masks indicating speech token positions
|
||||
"""
|
||||
# Handle single vs batch input
|
||||
if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
|
||||
# Single input
|
||||
texts = [text]
|
||||
is_batched = False
|
||||
else:
|
||||
# Batch input
|
||||
texts = text
|
||||
is_batched = True
|
||||
|
||||
# Handle voice samples
|
||||
if voice_samples is not None:
|
||||
if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
|
||||
# Single set of voice samples
|
||||
voice_samples_list = [voice_samples]
|
||||
else:
|
||||
# Batch of voice samples
|
||||
voice_samples_list = voice_samples
|
||||
else:
|
||||
voice_samples_list = [None] * len(texts)
|
||||
|
||||
# Process each input
|
||||
all_encodings = []
|
||||
for text_input, voice_input in zip(texts, voice_samples_list):
|
||||
encoding = self._process_single(text_input, voice_input)
|
||||
all_encodings.append(encoding)
|
||||
|
||||
# Combine batch
|
||||
batch_encoding = self._batch_encode(
|
||||
all_encodings,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
return_tensors=return_tensors,
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
|
||||
return batch_encoding
|
||||
|
||||
def _process_single(
|
||||
self,
|
||||
text: Union[str, TextInput],
|
||||
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a single podcast script."""
|
||||
# Determine if text is a file path or direct script
|
||||
script = None
|
||||
if isinstance(text, str):
|
||||
# Check if it's a file path
|
||||
if text.endswith('.json') and os.path.exists(text):
|
||||
script = self._convert_json_to_script(text)
|
||||
elif text.endswith('.txt') and os.path.exists(text):
|
||||
script = self._convert_text_to_script(text)
|
||||
else:
|
||||
# Assume it's the script content directly
|
||||
script = text
|
||||
|
||||
if script is None:
|
||||
raise ValueError(f"Could not process input text: {text}")
|
||||
|
||||
# Parse the script
|
||||
parsed_lines = self._parse_script(script)
|
||||
all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
|
||||
|
||||
# Create system prompt
|
||||
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
|
||||
system_tokens = self.tokenizer.encode(self.system_prompt)
|
||||
|
||||
# Process voice samples if provided
|
||||
if voice_samples:
|
||||
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
|
||||
else:
|
||||
voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
|
||||
|
||||
# Build full token sequence
|
||||
full_tokens = system_tokens + voice_tokens
|
||||
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
|
||||
|
||||
# Add text input section
|
||||
full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
|
||||
speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
|
||||
|
||||
for speaker_id, speaker_text in parsed_lines:
|
||||
speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
|
||||
full_tokens += speaker_text_tokens
|
||||
speech_input_mask += [False] * len(speaker_text_tokens)
|
||||
|
||||
# Add speech output section
|
||||
full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
|
||||
speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
|
||||
|
||||
return {
|
||||
"input_ids": full_tokens,
|
||||
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
|
||||
"speech_input_mask": speech_input_mask,
|
||||
"parsed_script": parsed_lines,
|
||||
"all_speakers": all_speakers,
|
||||
}
|
||||
|
||||
def _batch_encode(
|
||||
self,
|
||||
encodings: List[Dict[str, Any]],
|
||||
padding: Union[bool, str, PaddingStrategy] = True,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_attention_mask: bool = True,
|
||||
) -> BatchEncoding:
|
||||
"""Combine multiple encodings into a batch with padding."""
|
||||
# Extract input_ids and create attention_mask
|
||||
input_ids_list = [enc["input_ids"] for enc in encodings]
|
||||
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
|
||||
|
||||
# Determine padding strategy
|
||||
if isinstance(padding, bool):
|
||||
padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
|
||||
elif isinstance(padding, str):
|
||||
padding_strategy = PaddingStrategy(padding)
|
||||
else:
|
||||
padding_strategy = padding
|
||||
|
||||
# Apply padding to input_ids
|
||||
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_len = max(len(ids) for ids in input_ids_list)
|
||||
elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
|
||||
max_len = max_length
|
||||
else:
|
||||
max_len = max(len(ids) for ids in input_ids_list)
|
||||
|
||||
# Pad sequences
|
||||
padded_input_ids = []
|
||||
attention_masks = []
|
||||
padded_speech_input_masks = []
|
||||
|
||||
for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
|
||||
# Truncate if needed
|
||||
if truncation and len(input_ids) > max_len:
|
||||
input_ids = input_ids[:max_len]
|
||||
speech_mask = speech_mask[:max_len]
|
||||
|
||||
# Pad
|
||||
padding_length = max_len - len(input_ids)
|
||||
# padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
|
||||
padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
|
||||
attention_mask = [0] * padding_length + [1] * len(input_ids)
|
||||
padded_speech_mask = [False] * padding_length + speech_mask
|
||||
|
||||
padded_input_ids.append(padded_ids)
|
||||
attention_masks.append(attention_mask)
|
||||
padded_speech_input_masks.append(padded_speech_mask)
|
||||
|
||||
input_ids_list = padded_input_ids
|
||||
speech_input_masks_list = padded_speech_input_masks
|
||||
else:
|
||||
# No padding, just create attention masks
|
||||
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
|
||||
|
||||
# Process speech inputs
|
||||
all_speech_inputs = []
|
||||
has_speech = False
|
||||
for enc in encodings:
|
||||
if enc["speech_inputs"] is not None:
|
||||
all_speech_inputs.extend(enc["speech_inputs"])
|
||||
has_speech = True
|
||||
|
||||
# Prepare batch encoding
|
||||
batch_encoding = BatchEncoding()
|
||||
|
||||
# Handle tensor conversion
|
||||
if return_tensors is not None:
|
||||
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
|
||||
if return_attention_mask and attention_masks is not None:
|
||||
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
|
||||
batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
|
||||
else:
|
||||
batch_encoding["input_ids"] = input_ids_list
|
||||
if return_attention_mask and attention_masks is not None:
|
||||
batch_encoding["attention_mask"] = attention_masks
|
||||
batch_encoding["speech_input_mask"] = speech_input_masks_list
|
||||
|
||||
# Process speech tensors if present
|
||||
if has_speech:
|
||||
speech_dict = self.prepare_speech_inputs(
|
||||
all_speech_inputs,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
|
||||
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
|
||||
else:
|
||||
batch_encoding["speech_tensors"] = None
|
||||
batch_encoding["speech_masks"] = None
|
||||
|
||||
# Add metadata
|
||||
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
|
||||
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
|
||||
|
||||
return batch_encoding
|
||||
|
||||
def _create_voice_prompt(
|
||||
self,
|
||||
speaker_samples: List[Union[str, np.ndarray]]
|
||||
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
|
||||
"""
|
||||
Create voice prompt tokens and process audio samples.
|
||||
|
||||
Returns:
|
||||
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
|
||||
"""
|
||||
vae_token_id = self.tokenizer.speech_diffusion_id
|
||||
|
||||
voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
|
||||
voice_speech_inputs = []
|
||||
voice_speech_masks = [False] * len(voice_full_tokens)
|
||||
|
||||
for speaker_id, speaker_audio in enumerate(speaker_samples):
|
||||
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
|
||||
|
||||
# Process audio
|
||||
if isinstance(speaker_audio, str):
|
||||
# Load audio from file
|
||||
wav = self.audio_processor._load_audio_from_path(speaker_audio)
|
||||
else:
|
||||
wav = np.array(speaker_audio, dtype=np.float32)
|
||||
|
||||
# Apply normalization if needed
|
||||
if self.db_normalize and self.audio_normalizer:
|
||||
wav = self.audio_normalizer(wav)
|
||||
|
||||
# Calculate token length based on compression ratio
|
||||
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
|
||||
# vae_tok_len = wav.shape[0]
|
||||
# else:
|
||||
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
|
||||
|
||||
# Build tokens and masks
|
||||
speaker_tokens = (prefix_tokens +
|
||||
[self.tokenizer.speech_start_id] +
|
||||
[vae_token_id] * vae_tok_len +
|
||||
[self.tokenizer.speech_end_id] +
|
||||
self.tokenizer.encode('\n', add_special_tokens=False))
|
||||
|
||||
vae_input_mask = ([False] * len(prefix_tokens) +
|
||||
[False] +
|
||||
[True] * vae_tok_len +
|
||||
[False] +
|
||||
[False])
|
||||
|
||||
voice_full_tokens.extend(speaker_tokens)
|
||||
voice_speech_masks.extend(vae_input_mask)
|
||||
voice_speech_inputs.append(wav)
|
||||
|
||||
return voice_full_tokens, voice_speech_inputs, voice_speech_masks
|
||||
|
||||
def prepare_speech_inputs(
|
||||
self,
|
||||
speech_inputs: List[np.ndarray],
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare speech inputs for model consumption.
|
||||
|
||||
Args:
|
||||
speech_inputs: List of speech arrays
|
||||
return_tensors: Output tensor type
|
||||
device: Device to place tensors on
|
||||
dtype: Data type for tensors
|
||||
|
||||
Returns:
|
||||
Dictionary with padded_speeches and speech_masks
|
||||
"""
|
||||
if not speech_inputs:
|
||||
return {"padded_speeches": None, "speech_masks": None}
|
||||
|
||||
# Calculate sequence lengths
|
||||
vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
|
||||
# vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
|
||||
max_speech_length = max(s.shape[0] for s in speech_inputs)
|
||||
|
||||
# Pad speeches
|
||||
if speech_inputs[0].ndim == 1:
|
||||
padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
|
||||
else:
|
||||
padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
|
||||
speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
|
||||
|
||||
for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
|
||||
padded_speeches[i, :len(speech)] = speech
|
||||
speech_masks[i, :vae_tok_length] = True
|
||||
|
||||
result = {
|
||||
"padded_speeches": padded_speeches,
|
||||
"speech_masks": speech_masks,
|
||||
}
|
||||
|
||||
# Convert to tensors if requested
|
||||
if return_tensors == "pt":
|
||||
result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
|
||||
result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_json_to_script(self, json_file: str) -> str:
|
||||
"""
|
||||
Convert JSON format to script format.
|
||||
Expected JSON format:
|
||||
[
|
||||
{"speaker": "1", "text": "Hello everyone..."},
|
||||
{"speaker": "2", "text": "Great to be here..."}
|
||||
]
|
||||
"""
|
||||
import json
|
||||
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("JSON file must contain a list of speaker entries")
|
||||
|
||||
script_lines = []
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
logger.warning(f"Skipping non-dict entry: {item}")
|
||||
continue
|
||||
|
||||
speaker = item.get('speaker')
|
||||
text = item.get('text')
|
||||
|
||||
if speaker is None or text is None:
|
||||
logger.warning(f"Skipping entry missing speaker or text: {item}")
|
||||
continue
|
||||
|
||||
# Ensure speaker ID is valid
|
||||
try:
|
||||
speaker_id = int(speaker)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
|
||||
continue
|
||||
|
||||
# Clean up text
|
||||
text = text.strip()
|
||||
if text:
|
||||
script_lines.append(f"Speaker {speaker_id}: {text}")
|
||||
|
||||
if not script_lines:
|
||||
raise ValueError("No valid entries found in JSON file")
|
||||
|
||||
return "\n".join(script_lines)
|
||||
|
||||
def _convert_text_to_script(self, text_file: str) -> str:
|
||||
"""
|
||||
Convert text file to script format.
|
||||
Handles multiple formats:
|
||||
1. Already formatted as "Speaker X: text"
|
||||
2. Plain text (assigns to Speaker 1)
|
||||
|
||||
Handles edge cases like multiple colons in a line.
|
||||
"""
|
||||
with open(text_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
script_lines = []
|
||||
current_speaker = 1
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Try to parse as "Speaker X: text" format
|
||||
# Use regex to be more robust
|
||||
speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
|
||||
|
||||
if speaker_match:
|
||||
speaker_id = int(speaker_match.group(1))
|
||||
text = speaker_match.group(2).strip()
|
||||
if text:
|
||||
script_lines.append(f"Speaker {speaker_id}: {text}")
|
||||
else:
|
||||
# Treat as plain text - assign to current speaker
|
||||
script_lines.append(f"Speaker {current_speaker}: {line}")
|
||||
|
||||
if not script_lines:
|
||||
raise ValueError("No valid content found in text file")
|
||||
|
||||
return "\n".join(script_lines)
|
||||
|
||||
def _parse_script(self, script: str) -> List[Tuple[int, str]]:
|
||||
"""Parse script into list of (speaker_id, text) tuples."""
|
||||
lines = script.strip().split("\n")
|
||||
parsed_lines = []
|
||||
speaker_ids = []
|
||||
|
||||
# First pass: parse all lines and collect speaker IDs
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Use regex to handle edge cases like multiple colons
|
||||
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
|
||||
|
||||
if match:
|
||||
speaker_id = int(match.group(1))
|
||||
text = ' ' + match.group(2).strip()
|
||||
parsed_lines.append((speaker_id, text))
|
||||
speaker_ids.append(speaker_id)
|
||||
else:
|
||||
logger.warning(f"Could not parse line: '{line}'")
|
||||
|
||||
if not parsed_lines:
|
||||
raise ValueError("No valid speaker lines found in script")
|
||||
|
||||
# Check if we need to normalize speaker IDs (only if all are > 0)
|
||||
min_speaker_id = min(speaker_ids)
|
||||
if min_speaker_id > 0:
|
||||
# Normalize to start from 0
|
||||
normalized_lines = []
|
||||
for speaker_id, text in parsed_lines:
|
||||
normalized_lines.append((speaker_id - 1, text))
|
||||
return normalized_lines
|
||||
else:
|
||||
# Keep original IDs
|
||||
return parsed_lines
|
||||
|
||||
def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
|
||||
"""Merge text and audio inputs into a single BatchEncoding."""
|
||||
# Start with text inputs
|
||||
merged = BatchEncoding(text_inputs)
|
||||
|
||||
# Add audio-specific fields
|
||||
if "audio" in audio_inputs:
|
||||
merged["speech_inputs"] = audio_inputs["audio"]
|
||||
if "streaming" in audio_inputs:
|
||||
merged["streaming"] = audio_inputs["streaming"]
|
||||
|
||||
return merged
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
|
||||
Please refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
|
||||
Please refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
"""
|
||||
Return the list of inputs accepted by the model.
|
||||
"""
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
audio_processor_input_names = self.audio_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
|
||||
|
||||
def save_audio(self,
|
||||
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
||||
output_path: str = "output.wav",
|
||||
sampling_rate: Optional[int] = None,
|
||||
normalize: bool = False,
|
||||
batch_prefix: str = "audio_",
|
||||
) -> str:
|
||||
"""
|
||||
Save audio data to a file.
|
||||
Args:
|
||||
audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
|
||||
The audio data to save. Can be a single tensor/array or a list of them.
|
||||
output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
|
||||
sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
|
||||
normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
|
||||
batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
|
||||
Returns:
|
||||
str: The path to the saved audio file.
|
||||
"""
|
||||
return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceProcessor",
|
||||
]
|
||||
Reference in New Issue
Block a user