Spaces:
Build error
Build error
# Copyright (c) 2024 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license. | |
# LICENSE is in incl_licenses directory. | |
import torch | |
from einops import rearrange | |
from torch import nn | |
from torch.distributed.fsdp.wrap import ( | |
enable_wrap, | |
wrap, | |
) | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
from torch.distributed.fsdp import ( | |
FullyShardedDataParallel as FSDP, | |
) | |
try: | |
from .helpers import TransformerEncoder | |
from .utils import apply_with_stopping_condition | |
except: | |
from helpers import TransformerEncoder | |
from utils import apply_with_stopping_condition | |
class Flamingo(nn.Module): | |
def __init__( | |
self, | |
clap: nn.Module, | |
unfreeze_clap: bool, | |
lang_encoder: nn.Module, | |
eoc_token_id: int, | |
media_token_id: int, | |
sep_token_id: int, | |
audio_embed_dim: int, | |
audio_transformer_kwargs: dict, | |
cross_attn_every_n_layers: int = 1, | |
gradient_checkpointing: bool = False, | |
): | |
super().__init__() | |
self.eoc_token_id = eoc_token_id | |
self.media_token_id = media_token_id | |
self.sep_token_id = sep_token_id | |
self.audio_embed_dim = audio_embed_dim | |
self.clap = clap # .to(torch.cuda.current_device()) | |
self.unfreeze_clap = unfreeze_clap | |
self.clap.requires_grad_(unfreeze_clap) | |
if hasattr(lang_encoder.config, "d_model"): | |
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model | |
else: | |
self.lang_dim = lang_encoder.config.hidden_size | |
n_head = audio_transformer_kwargs["n_head"] | |
n_layers = audio_transformer_kwargs["n_layers"] | |
d_inner = audio_transformer_kwargs["d_inner"] | |
max_num_media = audio_transformer_kwargs["max_num_media"] | |
max_window_per_audio = audio_transformer_kwargs["max_window_per_audio"] | |
assert audio_embed_dim % n_head == 0 | |
self.audio_transformer = TransformerEncoder( | |
d_word_vec=audio_embed_dim, | |
n_layers=n_layers, | |
n_head=n_head, | |
d_k=audio_embed_dim // n_head, | |
d_v=audio_embed_dim // n_head, | |
d_model=audio_embed_dim, | |
d_inner=d_inner, | |
dropout=0.0, | |
n_position=max_num_media, | |
scale_emb=True | |
) | |
self.lang_encoder = lang_encoder | |
self.lang_encoder.init_flamingo( | |
media_token_id=media_token_id, | |
lang_hidden_size=self.lang_dim, | |
audio_hidden_size=self.audio_embed_dim, | |
max_window_per_audio=max_window_per_audio, | |
cross_attn_every_n_layers=cross_attn_every_n_layers, | |
gradient_checkpointing=gradient_checkpointing, | |
) | |
self._use_gradient_checkpointing = gradient_checkpointing | |
self.audio_transformer._use_gradient_checkpointing = gradient_checkpointing | |
self.clap._use_gradient_checkpointing = gradient_checkpointing | |
def forward( | |
self, | |
audio_x: torch.Tensor, | |
audio_x_mask: torch.Tensor, | |
lang_x: torch.Tensor, | |
attention_mask: torch.Tensor = None, | |
labels: torch.Tensor = None, | |
clear_conditioned_layers: bool = True, | |
past_key_values=None, | |
use_cache: bool = False, | |
): | |
assert ( | |
self.lang_encoder.initialized_flamingo | |
), "Flamingo layers are not initialized. Please call `init_flamingo` first." | |
assert ( | |
self.lang_encoder._use_cached_audio_x or audio_x is not None | |
), "Must provide either audio_x or have precached media using cache_media()." | |
if self.lang_encoder._use_cached_audio_x: | |
assert ( | |
audio_x is None | |
), "Expect audio_x to be None when media has been cached using cache_media(). Try uncache_media() first." | |
assert self.lang_encoder.is_conditioned() | |
else: | |
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask) | |
self._condition_media_locations(input_ids=lang_x) | |
output = self.lang_encoder( | |
input_ids=lang_x, | |
attention_mask=attention_mask, | |
labels=labels, | |
past_key_values=past_key_values, | |
use_cache=use_cache, | |
) | |
if clear_conditioned_layers: | |
self.lang_encoder.clear_conditioned_layers() | |
return output | |
def generate( | |
self, | |
audio_x: torch.Tensor, | |
audio_x_mask: torch.Tensor, | |
lang_x: torch.Tensor, | |
attention_mask: torch.Tensor = None, | |
**kwargs, | |
): | |
num_beams = kwargs.pop("num_beams", 1) | |
if num_beams > 1: | |
audio_x = audio_x.repeat_interleave(num_beams, dim=0) | |
self.lang_encoder._use_cached_audio_x = True | |
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask) | |
eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id) | |
output = self.lang_encoder.generate( | |
input_ids=lang_x, | |
attention_mask=attention_mask, | |
eos_token_id=eos_token_id, | |
num_beams=num_beams, | |
**kwargs, | |
) | |
self.lang_encoder.clear_conditioned_layers() | |
self.lang_encoder._use_cached_audio_x = False | |
return output | |
def _encode_audio_x(self, audio_x: torch.Tensor, audio_x_mask: torch.Tensor): | |
""" | |
rearrange code based on https://github.com/dhansmair/flamingo-mini | |
""" | |
assert audio_x.ndim == 3, "audio_x should be of shape (B, num_window, window_length)" | |
with torch.no_grad(): | |
audio_embeds = self.clap(audio_x) | |
B, L, D = audio_embeds.shape # L is number of windows, D is feature dim | |
assert D == self.audio_embed_dim | |
assert audio_x_mask.ndim == 2, "audio_x_mask should be of shape (B, L)" | |
if B > 1 and audio_x_mask.shape[0] == 1: | |
audio_x_mask = audio_x_mask.repeat(B, 1) | |
assert audio_x_mask.shape[0] == B and audio_x_mask.shape[1] == L, "{} != ({},{})".format(audio_x_mask.shape, B, L) | |
audio_x_out = self.audio_transformer(audio_embeds) # B, L, D | |
audio_x_out = audio_x_out.unsqueeze(2) # B, L, n=1, D | |
audio_x_mask = audio_x_mask.unsqueeze(2) # B, L, n=1 | |
for layer in self.lang_encoder._get_decoder_layers(): | |
layer.condition_audio_x(audio_x_out, audio_x_mask) | |
def wrap_fsdp(self, wrapper_kwargs, device_id): | |
# unfreeze the decoder layers | |
for block in self.lang_encoder.old_decoder_blocks: | |
block.requires_grad_(True) | |
# wrap in FSDP | |
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): | |
self.audio_transformer = wrap(wrap(self.audio_transformer)) | |
self.lang_encoder.old_decoder_blocks = nn.ModuleList( | |
wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks | |
) | |
self.lang_encoder.gated_cross_attn_layers = nn.ModuleList( | |
wrap(wrap(layer)) if layer is not None else None | |
for layer in self.lang_encoder.gated_cross_attn_layers | |
) | |
self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) | |
self.lang_encoder.set_input_embeddings( | |
wrap(wrap(self.lang_encoder.get_input_embeddings())) | |
) | |
if hasattr(self.lang_encoder, 'set_output_embeddings'): | |
self.lang_encoder.set_output_embeddings( | |
wrap(wrap(self.lang_encoder.get_output_embeddings())) | |
) | |
else: | |
print('skip wrapping output embeddings') | |
# manually move non-FSDP managed parameters to device_id | |
# these are all in lang_encoder | |
apply_with_stopping_condition( | |
module=self.lang_encoder, | |
apply_fn=lambda m: m.to(device_id), | |
apply_condition=lambda m: len(list(m.children())) == 0, | |
stopping_condition=lambda m: isinstance(m, FSDP), | |
) | |
# clap shouldn't be wrapped; should be on each gpu | |
if self.unfreeze_clap: | |
apply_with_stopping_condition( | |
module=self.clap, | |
apply_fn=lambda m: m.to(device_id), | |
apply_condition=lambda m: len(list(m.children())) == 0, | |
stopping_condition=lambda m: isinstance(m, FSDP), | |
) | |
# exclude the original decoder layers from the optimizer | |
for block in self.lang_encoder.old_decoder_blocks: | |
for p in block.parameters(): | |
p.exclude_from_optimizer = True | |
# set up clip_grad_norm_ function | |
def clip_grad_norm_(max_norm): | |
self.audio_transformer.clip_grad_norm_(max_norm) | |
for layer in self.lang_encoder.gated_cross_attn_layers: | |
if layer is not None: | |
layer.clip_grad_norm_(max_norm) | |
self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm) | |
self.clip_grad_norm_ = clip_grad_norm_ | |
def _condition_media_locations(self, input_ids: torch.Tensor): | |
media_locations = (input_ids == self.media_token_id) | |
for layer in self.lang_encoder._get_decoder_layers(): | |
layer.condition_media_locations(media_locations) | |
def cache_media(self, input_ids: torch.Tensor, audio_x: torch.Tensor, audio_x_mask: torch.Tensor): | |
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask) | |
self._condition_media_locations(input_ids=input_ids) | |
self.lang_encoder._use_cached_audio_x = True | |
def uncache_media(self): | |
self.lang_encoder.clear_conditioned_layers() | |
self.lang_encoder._use_cached_audio_x = False | |