Cloning from Github
This commit is contained in:
287
vibevoice/modular/modular_vibevoice_diffusion_head.py
Normal file
287
vibevoice/modular/modular_vibevoice_diffusion_head.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.models.auto import AutoModel
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
# from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.utils import logging
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if self.weight is not None:
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
"""Apply modulation to input tensor."""
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`): Size of the output embedding
|
||||
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
|
||||
# nn.SiLU(),
|
||||
ACT2FN['silu'],
|
||||
nn.Linear(hidden_size, hidden_size, bias=False),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
dim (`int`): The dimension of the output.
|
||||
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding.to(t.dtype)
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FeedForwardNetwork(nn.Module):
|
||||
"""
|
||||
Standard feed-forward network with SwiGLU activation.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): Input dimension
|
||||
ffn_dim (`int`): Hidden dimension
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
ffn_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
||||
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
|
||||
|
||||
def forward(self, x):
|
||||
gate = self.gate_proj(x)
|
||||
up = self.up_proj(x)
|
||||
|
||||
# SwiGLU activation
|
||||
# gate = F.silu(gate)
|
||||
gate = self.act_fn(gate)
|
||||
return self.down_proj(gate * up)
|
||||
|
||||
|
||||
class HeadLayer(nn.Module):
|
||||
"""
|
||||
A layer in the diffusion head.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): Input dimension
|
||||
ffn_dim (`int`): Hidden dimension
|
||||
cond_dim (`int`): Condition embedding dimension
|
||||
norm_eps (`float`, optional): Epsilon for normalization
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
ffn_dim,
|
||||
cond_dim,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.cond_dim = cond_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.ffn = FeedForwardNetwork(
|
||||
self.embed_dim,
|
||||
self.ffn_dim,
|
||||
)
|
||||
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
# nn.SiLU(),
|
||||
ACT2FN['silu'],
|
||||
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
|
||||
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
Final layer in the diffusion head.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`): Input dimension
|
||||
output_size (`int`): Output dimension
|
||||
cond_size (`int`): Condition embedding dimension
|
||||
norm_eps (`float`, optional): Epsilon for normalization
|
||||
"""
|
||||
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
|
||||
super().__init__()
|
||||
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
|
||||
self.linear = nn.Linear(hidden_size, output_size, bias=False)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
# nn.SiLU(),
|
||||
ACT2FN['silu'],
|
||||
nn.Linear(cond_size, 2 * hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class VibeVoiceDiffusionHead(PreTrainedModel):
|
||||
"""
|
||||
Diffusion head model for vibevoice.
|
||||
|
||||
Args:
|
||||
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
|
||||
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
|
||||
"""
|
||||
config_class = VibeVoiceDiffusionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.cond_dim = config.hidden_size
|
||||
latent_size = config.latent_size
|
||||
|
||||
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
|
||||
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
|
||||
self.t_embedder = TimestepEmbedder(self.cond_dim)
|
||||
|
||||
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
|
||||
|
||||
# Create the intermediate layers
|
||||
self.layers = nn.ModuleList([
|
||||
HeadLayer(
|
||||
embed_dim=config.hidden_size,
|
||||
ffn_dim=ffn_dim,
|
||||
cond_dim=self.cond_dim,
|
||||
norm_eps=config.rms_norm_eps
|
||||
)
|
||||
for _ in range(config.head_layers)
|
||||
])
|
||||
|
||||
# Final layer for output
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size=config.hidden_size,
|
||||
output_size=latent_size,
|
||||
cond_size=self.cond_dim,
|
||||
norm_eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
"""Initialize the weights of the model."""
|
||||
# Initialize timestep embedder
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out adaLN modulation layers
|
||||
for layer in self.layers:
|
||||
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
|
||||
|
||||
# Zero-out output layers
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
noisy_images,
|
||||
timesteps,
|
||||
condition,
|
||||
):
|
||||
"""
|
||||
Forward pass of the prediction head.
|
||||
|
||||
Args:
|
||||
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
|
||||
timesteps (`torch.Tensor`): Timesteps for diffusion
|
||||
condition (`torch.Tensor`): Conditioning information
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The predicted noise/velocity
|
||||
"""
|
||||
x = self.noisy_images_proj(noisy_images)
|
||||
t = self.t_embedder(timesteps)
|
||||
condition = self.cond_proj(condition)
|
||||
c = condition + t
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, c)
|
||||
|
||||
x = self.final_layer(x, c)
|
||||
return x
|
||||
|
||||
|
||||
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceDiffusionHead",
|
||||
]
|
||||
Reference in New Issue
Block a user