|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import sys |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import numpy as np |
|
|
|
from dataclasses import dataclass, field |
|
from fairseq.data import Dictionary, HubertDataset |
|
from fairseq.dataclass.configs import FairseqDataclass |
|
from fairseq.tasks import register_task |
|
from fairseq.tasks.fairseq_task import FairseqTask |
|
from omegaconf import MISSING |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LabelEncoder(object): |
|
def __init__(self, dictionary: Dictionary) -> None: |
|
self.dictionary = dictionary |
|
|
|
def __call__(self, label: str) -> List[str]: |
|
return self.dictionary.encode_line( |
|
label, |
|
append_eos=False, |
|
add_if_not_exist=False, |
|
) |
|
|
|
|
|
@dataclass |
|
class HubertPretrainingConfig(FairseqDataclass): |
|
data: str = field(default=MISSING, metadata={"help": "path to data directory"}) |
|
fine_tuning: bool = field( |
|
default=False, metadata={"help": "set to true if fine-tuning Hubert"} |
|
) |
|
labels: List[str] = field( |
|
default_factory=lambda: ["ltr"], |
|
metadata={ |
|
"help": ( |
|
"extension of the label files to load, frame-level labels for" |
|
" pre-training, and sequence-level label for fine-tuning" |
|
) |
|
}, |
|
) |
|
label_dir: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "if set, looks for labels in this directory instead", |
|
}, |
|
) |
|
label_rate: float = field( |
|
default=-1.0, |
|
metadata={"help": "label frame rate. -1.0 for sequence label"}, |
|
) |
|
sample_rate: int = field( |
|
default=16_000, |
|
metadata={ |
|
"help": "target sample rate. audio files will be up/down " |
|
"sampled to this rate" |
|
}, |
|
) |
|
normalize: bool = field( |
|
default=False, |
|
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, |
|
) |
|
enable_padding: bool = field( |
|
default=False, |
|
metadata={"help": "pad shorter samples instead of cropping"}, |
|
) |
|
max_keep_size: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "exclude sample longer than this"}, |
|
) |
|
max_sample_size: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "max sample size to crop to for batching"}, |
|
) |
|
min_sample_size: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "min sample size to crop to for batching"}, |
|
) |
|
single_target: Optional[bool] = field( |
|
default=False, |
|
metadata={ |
|
"help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset" |
|
}, |
|
) |
|
random_crop: Optional[bool] = field( |
|
default=True, |
|
metadata={"help": "always crop from the beginning if false"}, |
|
) |
|
pad_audio: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "pad audio to the longest one in the batch if true"}, |
|
) |
|
|
|
|
|
@register_task("hubert_pretraining", dataclass=HubertPretrainingConfig) |
|
class HubertPretrainingTask(FairseqTask): |
|
|
|
cfg: HubertPretrainingConfig |
|
|
|
def __init__( |
|
self, |
|
cfg: HubertPretrainingConfig, |
|
) -> None: |
|
super().__init__(cfg) |
|
|
|
logger.info(f"current directory is {os.getcwd()}") |
|
logger.info(f"HubertPretrainingTask Config {cfg}") |
|
|
|
self.cfg = cfg |
|
self.fine_tuning = cfg.fine_tuning |
|
|
|
if cfg.fine_tuning: |
|
self.state.add_factory("target_dictionary", self.load_dictionaries) |
|
else: |
|
self.state.add_factory("dictionaries", self.load_dictionaries) |
|
|
|
self.blank_symbol = "<s>" |
|
|
|
@property |
|
def source_dictionary(self) -> Optional[Dictionary]: |
|
return None |
|
|
|
@property |
|
def target_dictionary(self) -> Optional[Dictionary]: |
|
return self.state.target_dictionary |
|
|
|
@property |
|
def dictionaries(self) -> List[Dictionary]: |
|
return self.state.dictionaries |
|
|
|
@classmethod |
|
def setup_task( |
|
cls, cfg: HubertPretrainingConfig, **kwargs |
|
) -> "HubertPretrainingTask": |
|
return cls(cfg) |
|
|
|
def load_dictionaries(self): |
|
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir |
|
dictionaries = [ |
|
Dictionary.load(f"{label_dir}/dict.{label}.txt") |
|
for label in self.cfg.labels |
|
] |
|
return dictionaries[0] if self.cfg.fine_tuning else dictionaries |
|
|
|
def get_label_dir(self) -> str: |
|
if self.cfg.label_dir is None: |
|
return self.cfg.data |
|
return self.cfg.label_dir |
|
|
|
def load_dataset(self, split: str, **kwargs) -> None: |
|
manifest = f"{self.cfg.data}/{split}.tsv" |
|
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries |
|
pad_list = [dict.pad() for dict in dicts] |
|
eos_list = [dict.eos() for dict in dicts] |
|
procs = [LabelEncoder(dict) for dict in dicts] |
|
paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] |
|
|
|
|
|
self.datasets[split] = HubertDataset( |
|
manifest, |
|
sample_rate=self.cfg.sample_rate, |
|
label_paths=paths, |
|
label_rates=self.cfg.label_rate, |
|
pad_list=pad_list, |
|
eos_list=eos_list, |
|
label_processors=procs, |
|
max_keep_sample_size=self.cfg.max_keep_size, |
|
min_keep_sample_size=self.cfg.min_sample_size, |
|
max_sample_size=self.cfg.max_sample_size, |
|
pad_audio=self.cfg.pad_audio, |
|
normalize=self.cfg.normalize, |
|
store_labels=False, |
|
random_crop=self.cfg.random_crop, |
|
single_target=self.cfg.single_target, |
|
) |
|
|
|
def max_positions(self) -> Tuple[int, int]: |
|
return (sys.maxsize, sys.maxsize) |
|
|
|
def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: |
|
return indices |
|
|