audio-flamingo-demo / src /flamingo.py
ZhifengKong's picture
upload
92740f3
raw
history blame
9.8 kB
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from einops import rearrange
from torch import nn
from torch.distributed.fsdp.wrap import (
enable_wrap,
wrap,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
try:
from .helpers import TransformerEncoder
from .utils import apply_with_stopping_condition
except:
from helpers import TransformerEncoder
from utils import apply_with_stopping_condition
class Flamingo(nn.Module):
def __init__(
self,
clap: nn.Module,
unfreeze_clap: bool,
lang_encoder: nn.Module,
eoc_token_id: int,
media_token_id: int,
sep_token_id: int,
audio_embed_dim: int,
audio_transformer_kwargs: dict,
cross_attn_every_n_layers: int = 1,
gradient_checkpointing: bool = False,
):
super().__init__()
self.eoc_token_id = eoc_token_id
self.media_token_id = media_token_id
self.sep_token_id = sep_token_id
self.audio_embed_dim = audio_embed_dim
self.clap = clap # .to(torch.cuda.current_device())
self.unfreeze_clap = unfreeze_clap
self.clap.requires_grad_(unfreeze_clap)
if hasattr(lang_encoder.config, "d_model"):
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
else:
self.lang_dim = lang_encoder.config.hidden_size
n_head = audio_transformer_kwargs["n_head"]
n_layers = audio_transformer_kwargs["n_layers"]
d_inner = audio_transformer_kwargs["d_inner"]
max_num_media = audio_transformer_kwargs["max_num_media"]
max_window_per_audio = audio_transformer_kwargs["max_window_per_audio"]
assert audio_embed_dim % n_head == 0
self.audio_transformer = TransformerEncoder(
d_word_vec=audio_embed_dim,
n_layers=n_layers,
n_head=n_head,
d_k=audio_embed_dim // n_head,
d_v=audio_embed_dim // n_head,
d_model=audio_embed_dim,
d_inner=d_inner,
dropout=0.0,
n_position=max_num_media,
scale_emb=True
)
self.lang_encoder = lang_encoder
self.lang_encoder.init_flamingo(
media_token_id=media_token_id,
lang_hidden_size=self.lang_dim,
audio_hidden_size=self.audio_embed_dim,
max_window_per_audio=max_window_per_audio,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
)
self._use_gradient_checkpointing = gradient_checkpointing
self.audio_transformer._use_gradient_checkpointing = gradient_checkpointing
self.clap._use_gradient_checkpointing = gradient_checkpointing
def forward(
self,
audio_x: torch.Tensor,
audio_x_mask: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
clear_conditioned_layers: bool = True,
past_key_values=None,
use_cache: bool = False,
):
assert (
self.lang_encoder.initialized_flamingo
), "Flamingo layers are not initialized. Please call `init_flamingo` first."
assert (
self.lang_encoder._use_cached_audio_x or audio_x is not None
), "Must provide either audio_x or have precached media using cache_media()."
if self.lang_encoder._use_cached_audio_x:
assert (
audio_x is None
), "Expect audio_x to be None when media has been cached using cache_media(). Try uncache_media() first."
assert self.lang_encoder.is_conditioned()
else:
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
self._condition_media_locations(input_ids=lang_x)
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
)
if clear_conditioned_layers:
self.lang_encoder.clear_conditioned_layers()
return output
def generate(
self,
audio_x: torch.Tensor,
audio_x_mask: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
**kwargs,
):
num_beams = kwargs.pop("num_beams", 1)
if num_beams > 1:
audio_x = audio_x.repeat_interleave(num_beams, dim=0)
self.lang_encoder._use_cached_audio_x = True
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
output = self.lang_encoder.generate(
input_ids=lang_x,
attention_mask=attention_mask,
eos_token_id=eos_token_id,
num_beams=num_beams,
**kwargs,
)
self.lang_encoder.clear_conditioned_layers()
self.lang_encoder._use_cached_audio_x = False
return output
def _encode_audio_x(self, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
"""
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert audio_x.ndim == 3, "audio_x should be of shape (B, num_window, window_length)"
with torch.no_grad():
audio_embeds = self.clap(audio_x)
B, L, D = audio_embeds.shape # L is number of windows, D is feature dim
assert D == self.audio_embed_dim
assert audio_x_mask.ndim == 2, "audio_x_mask should be of shape (B, L)"
if B > 1 and audio_x_mask.shape[0] == 1:
audio_x_mask = audio_x_mask.repeat(B, 1)
assert audio_x_mask.shape[0] == B and audio_x_mask.shape[1] == L, "{} != ({},{})".format(audio_x_mask.shape, B, L)
audio_x_out = self.audio_transformer(audio_embeds) # B, L, D
audio_x_out = audio_x_out.unsqueeze(2) # B, L, n=1, D
audio_x_mask = audio_x_mask.unsqueeze(2) # B, L, n=1
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_audio_x(audio_x_out, audio_x_mask)
def wrap_fsdp(self, wrapper_kwargs, device_id):
# unfreeze the decoder layers
for block in self.lang_encoder.old_decoder_blocks:
block.requires_grad_(True)
# wrap in FSDP
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
self.audio_transformer = wrap(wrap(self.audio_transformer))
self.lang_encoder.old_decoder_blocks = nn.ModuleList(
wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
)
self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
wrap(wrap(layer)) if layer is not None else None
for layer in self.lang_encoder.gated_cross_attn_layers
)
self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
self.lang_encoder.set_input_embeddings(
wrap(wrap(self.lang_encoder.get_input_embeddings()))
)
if hasattr(self.lang_encoder, 'set_output_embeddings'):
self.lang_encoder.set_output_embeddings(
wrap(wrap(self.lang_encoder.get_output_embeddings()))
)
else:
print('skip wrapping output embeddings')
# manually move non-FSDP managed parameters to device_id
# these are all in lang_encoder
apply_with_stopping_condition(
module=self.lang_encoder,
apply_fn=lambda m: m.to(device_id),
apply_condition=lambda m: len(list(m.children())) == 0,
stopping_condition=lambda m: isinstance(m, FSDP),
)
# clap shouldn't be wrapped; should be on each gpu
if self.unfreeze_clap:
apply_with_stopping_condition(
module=self.clap,
apply_fn=lambda m: m.to(device_id),
apply_condition=lambda m: len(list(m.children())) == 0,
stopping_condition=lambda m: isinstance(m, FSDP),
)
# exclude the original decoder layers from the optimizer
for block in self.lang_encoder.old_decoder_blocks:
for p in block.parameters():
p.exclude_from_optimizer = True
# set up clip_grad_norm_ function
def clip_grad_norm_(max_norm):
self.audio_transformer.clip_grad_norm_(max_norm)
for layer in self.lang_encoder.gated_cross_attn_layers:
if layer is not None:
layer.clip_grad_norm_(max_norm)
self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
self.clip_grad_norm_ = clip_grad_norm_
def _condition_media_locations(self, input_ids: torch.Tensor):
media_locations = (input_ids == self.media_token_id)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_media_locations(media_locations)
def cache_media(self, input_ids: torch.Tensor, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
self._condition_media_locations(input_ids=input_ids)
self.lang_encoder._use_cached_audio_x = True
def uncache_media(self):
self.lang_encoder.clear_conditioned_layers()
self.lang_encoder._use_cached_audio_x = False