MusicGen2 / audiocraft /metrics /clap_consistency.py
reach-vb's picture
reach-vb HF staff
Stereo demo update (#60)
5325fcc
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import typing as tp
import torch
import torchmetrics
from transformers import RobertaTokenizer # type: ignore
from ..data.audio_utils import convert_audio
from ..environment import AudioCraftEnvironment
from ..utils.utils import load_clap_state_dict
try:
import laion_clap # type: ignore
except ImportError:
laion_clap = None
class TextConsistencyMetric(torchmetrics.Metric):
"""Text consistency metric measuring consistency between audio and text pairs."""
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
def compute(self):
raise NotImplementedError("implement how to compute the final metric score.")
class CLAPTextConsistencyMetric(TextConsistencyMetric):
"""Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
well as the generated audio based on them, and define the MCC metric as the average cosine similarity
between these embeddings.
Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
"""
def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
super().__init__()
if laion_clap is None:
raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
self._initialize_model(model_path, model_arch, enable_fusion)
def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
self.model_sample_rate = 48_000
load_clap_state_dict(self.model, model_path)
self.model.eval()
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
# we use the default params from CLAP module here as well
return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
"""Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
assert audio.size(0) == len(text), "Number of audio and text samples should match"
assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
sample_rate = int(sample_rates[0].item())
# convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
# cosine similarity between the text and the audio embedding
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
self.cosine_sum += cosine_sim.sum(dim=0)
self.weight += torch.tensor(cosine_sim.size(0))
def compute(self):
"""Computes the average cosine similarty across all audio/text pairs."""
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
return (self.cosine_sum / self.weight).item() # type: ignore