Cloning from Github
This commit is contained in:
264
vibevoice/modular/streamer.py
Normal file
264
vibevoice/modular/streamer.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import asyncio
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
|
||||
from transformers.generation import BaseStreamer
|
||||
|
||||
|
||||
class AudioStreamer(BaseStreamer):
|
||||
"""
|
||||
Audio streamer that stores audio chunks in queues for each sample in the batch.
|
||||
This allows streaming audio generation for multiple samples simultaneously.
|
||||
|
||||
Parameters:
|
||||
batch_size (`int`):
|
||||
The batch size for generation
|
||||
stop_signal (`any`, *optional*):
|
||||
The signal to put in the queue when generation ends. Defaults to None.
|
||||
timeout (`float`, *optional*):
|
||||
The timeout for the audio queue. If `None`, the queue will block indefinitely.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
stop_signal: Optional[any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.stop_signal = stop_signal
|
||||
self.timeout = timeout
|
||||
|
||||
# Create a queue for each sample in the batch
|
||||
self.audio_queues = [Queue() for _ in range(batch_size)]
|
||||
self.finished_flags = [False for _ in range(batch_size)]
|
||||
self.sample_indices_map = {} # Maps from sample index to queue index
|
||||
|
||||
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
||||
"""
|
||||
Receives audio chunks and puts them in the appropriate queues.
|
||||
|
||||
Args:
|
||||
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
|
||||
sample_indices: Tensor indicating which samples these chunks belong to
|
||||
"""
|
||||
for i, sample_idx in enumerate(sample_indices):
|
||||
idx = sample_idx.item()
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
# Convert to numpy or keep as tensor based on preference
|
||||
audio_chunk = audio_chunks[i].detach().cpu()
|
||||
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
|
||||
|
||||
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Signals the end of generation for specified samples or all samples.
|
||||
|
||||
Args:
|
||||
sample_indices: Optional tensor of sample indices to end. If None, ends all.
|
||||
"""
|
||||
if sample_indices is None:
|
||||
# End all samples
|
||||
for idx in range(self.batch_size):
|
||||
if not self.finished_flags[idx]:
|
||||
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
||||
self.finished_flags[idx] = True
|
||||
else:
|
||||
# End specific samples
|
||||
for sample_idx in sample_indices:
|
||||
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
||||
self.finished_flags[idx] = True
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator over the batch of audio streams."""
|
||||
return AudioBatchIterator(self)
|
||||
|
||||
def get_stream(self, sample_idx: int):
|
||||
"""Get the audio stream for a specific sample."""
|
||||
if sample_idx >= self.batch_size:
|
||||
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
||||
return AudioSampleIterator(self, sample_idx)
|
||||
|
||||
|
||||
class AudioSampleIterator:
|
||||
"""Iterator for a single audio stream from the batch."""
|
||||
|
||||
def __init__(self, streamer: AudioStreamer, sample_idx: int):
|
||||
self.streamer = streamer
|
||||
self.sample_idx = sample_idx
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
|
||||
if value == self.streamer.stop_signal:
|
||||
raise StopIteration()
|
||||
return value
|
||||
|
||||
|
||||
class AudioBatchIterator:
|
||||
"""Iterator that yields audio chunks for all samples in the batch."""
|
||||
|
||||
def __init__(self, streamer: AudioStreamer):
|
||||
self.streamer = streamer
|
||||
self.active_samples = set(range(streamer.batch_size))
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not self.active_samples:
|
||||
raise StopIteration()
|
||||
|
||||
batch_chunks = {}
|
||||
samples_to_remove = set()
|
||||
|
||||
# Try to get chunks from all active samples
|
||||
for idx in self.active_samples:
|
||||
try:
|
||||
value = self.streamer.audio_queues[idx].get(block=False)
|
||||
if value == self.streamer.stop_signal:
|
||||
samples_to_remove.add(idx)
|
||||
else:
|
||||
batch_chunks[idx] = value
|
||||
except:
|
||||
# Queue is empty for this sample, skip it this iteration
|
||||
pass
|
||||
|
||||
# Remove finished samples
|
||||
self.active_samples -= samples_to_remove
|
||||
|
||||
if batch_chunks:
|
||||
return batch_chunks
|
||||
elif self.active_samples:
|
||||
# If no chunks were ready but we still have active samples,
|
||||
# wait a bit and try again
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
return self.__next__()
|
||||
else:
|
||||
raise StopIteration()
|
||||
|
||||
|
||||
class AsyncAudioStreamer(AudioStreamer):
|
||||
"""
|
||||
Async version of AudioStreamer for use in async contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
stop_signal: Optional[any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
super().__init__(batch_size, stop_signal, timeout)
|
||||
# Replace regular queues with async queues
|
||||
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
|
||||
self.loop = asyncio.get_running_loop()
|
||||
|
||||
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
||||
"""Put audio chunks in the appropriate async queues."""
|
||||
for i, sample_idx in enumerate(sample_indices):
|
||||
idx = sample_idx.item()
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
audio_chunk = audio_chunks[i].detach().cpu()
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.audio_queues[idx].put_nowait, audio_chunk
|
||||
)
|
||||
|
||||
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
||||
"""Signal the end of generation for specified samples."""
|
||||
if sample_indices is None:
|
||||
indices_to_end = range(self.batch_size)
|
||||
else:
|
||||
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
||||
|
||||
for idx in indices_to_end:
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.audio_queues[idx].put_nowait, self.stop_signal
|
||||
)
|
||||
self.finished_flags[idx] = True
|
||||
|
||||
async def get_stream(self, sample_idx: int):
|
||||
"""Get async iterator for a specific sample's audio stream."""
|
||||
if sample_idx >= self.batch_size:
|
||||
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
||||
|
||||
while True:
|
||||
value = await self.audio_queues[sample_idx].get()
|
||||
if value == self.stop_signal:
|
||||
break
|
||||
yield value
|
||||
|
||||
def __aiter__(self):
|
||||
"""Returns an async iterator over all audio streams."""
|
||||
return AsyncAudioBatchIterator(self)
|
||||
|
||||
|
||||
class AsyncAudioBatchIterator:
|
||||
"""Async iterator for batch audio streaming."""
|
||||
|
||||
def __init__(self, streamer: AsyncAudioStreamer):
|
||||
self.streamer = streamer
|
||||
self.active_samples = set(range(streamer.batch_size))
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self.active_samples:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
batch_chunks = {}
|
||||
samples_to_remove = set()
|
||||
|
||||
# Create tasks for all active samples
|
||||
tasks = {
|
||||
idx: asyncio.create_task(self._get_chunk(idx))
|
||||
for idx in self.active_samples
|
||||
}
|
||||
|
||||
# Wait for at least one chunk to be ready
|
||||
done, pending = await asyncio.wait(
|
||||
tasks.values(),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
timeout=self.streamer.timeout
|
||||
)
|
||||
|
||||
# Cancel pending tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Process completed tasks
|
||||
for idx, task in tasks.items():
|
||||
if task in done:
|
||||
try:
|
||||
value = await task
|
||||
if value == self.streamer.stop_signal:
|
||||
samples_to_remove.add(idx)
|
||||
else:
|
||||
batch_chunks[idx] = value
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.active_samples -= samples_to_remove
|
||||
|
||||
if batch_chunks:
|
||||
return batch_chunks
|
||||
elif self.active_samples:
|
||||
# Try again if we still have active samples
|
||||
return await self.__anext__()
|
||||
else:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
async def _get_chunk(self, idx):
|
||||
"""Helper to get a chunk from a specific queue."""
|
||||
return await self.streamer.audio_queues[idx].get()
|
||||
Reference in New Issue
Block a user