# import torch
# import torchaudio
# from fairseq2.assets import InProcAssetMetadataProvider, asset_store
# from fairseq2.data import Collater, SequenceData
# from fairseq2.data.audio import (
#     AudioDecoder,
#     WaveformToFbankConverter,
#     WaveformToFbankOutput,
# )
# from fairseq2.generation import SequenceGeneratorOptions
# from fairseq2.memory import MemoryBlock
# from fairseq2.typing import DataType, Device
# from huggingface_hub import snapshot_download
# from seamless_communication.inference import BatchedSpeechOutput, Translator
# from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
# from seamless_communication.models.unity import (
#     UnitTokenizer,
#     load_gcmvn_stats,
#     load_unity_text_tokenizer,
#     load_unity_unit_tokenizer,
# )
# from torch.nn import Module

# class PretsselGenerator(Module):
#     def __init__(
#         self,
#         pretssel_name_or_card: str,
#         unit_tokenizer: UnitTokenizer,
#         device: Device,
#         dtype: DataType = torch.float16,
#     ):
#         super().__init__()
#         # Load the model.
#         if device == torch.device("cpu"):
#             dtype = torch.float32


#         self.device = device
#         self.dtype = dtype

#         self.pretssel_model = load_pretssel_vocoder_model(
#             pretssel_name_or_card,
#             device=device,
#             dtype=dtype,
#         )
#         self.pretssel_model.eval()

#         vocoder_model_card = asset_store.retrieve_card(pretssel_name_or_card)
#         self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)

#         self.unit_tokenizer = unit_tokenizer
#         self.unit_collate = Collater(pad_value=unit_tokenizer.vocab_info.pad_idx)
#         self.duration_collate = Collater(pad_value=0)

#     @torch.inference_mode()
#     def predict(
#         self,
#         units: list[list[int]],
#         tgt_lang: str,
#         prosody_encoder_input: SequenceData,
#     ) -> BatchedSpeechOutput:
#         audio_wavs = []
#         unit_eos_token = torch.tensor(
#             [self.unit_tokenizer.vocab_info.eos_idx],
#             device=self.device,
#         )

#         prosody_input_seqs = prosody_encoder_input["seqs"]
#         prosody_input_lens = prosody_encoder_input["seq_lens"]

#         for i, u in enumerate(units):
#             unit = torch.tensor(u).to(unit_eos_token)

#             # adjust the control symbols for the embedding
#             unit += 4
#             unit = torch.cat([unit, unit_eos_token], dim=0)

#             unit, duration = torch.unique_consecutive(unit, return_counts=True)

#             # adjust for the last eos token
#             duration[-1] = 0

#             duration *= 2

#             prosody_input_seq = prosody_input_seqs[i][: prosody_input_lens[i]]

#             audio_wav = self.pretssel_model(
#                 unit,
#                 tgt_lang,
#                 prosody_input_seq,
#                 durations=duration.unsqueeze(0),
#             )

#             audio_wavs.append(audio_wav)

#         return BatchedSpeechOutput(
#             units=units,
#             audio_wavs=audio_wavs,
#             sample_rate=self.output_sample_rate,
#         )


LANGUAGE_CODE_TO_NAME = {
    "afr": "Afrikaans",
    "amh": "Amharic",
    "arb": "Modern Standard Arabic",
    "ary": "Moroccan Arabic",
    "arz": "Egyptian Arabic",
    "asm": "Assamese",
    "ast": "Asturian",
    "azj": "North Azerbaijani",
    "bel": "Belarusian",
    "ben": "Bengali",
    "bos": "Bosnian",
    "bul": "Bulgarian",
    "cat": "Catalan",
    "ceb": "Cebuano",
    "ces": "Czech",
    "ckb": "Central Kurdish",
    "cmn": "Mandarin Chinese",
    "cym": "Welsh",
    "dan": "Danish",
    "deu": "German",
    "ell": "Greek",
    "eng": "English",
    "est": "Estonian",
    "eus": "Basque",
    "fin": "Finnish",
    "fra": "French",
    "gaz": "West Central Oromo",
    "gle": "Irish",
    "glg": "Galician",
    "guj": "Gujarati",
    "heb": "Hebrew",
    "hin": "Hindi",
    "hrv": "Croatian",
    "hun": "Hungarian",
    "hye": "Armenian",
    "ibo": "Igbo",
    "ind": "Indonesian",
    "isl": "Icelandic",
    "ita": "Italian",
    "jav": "Javanese",
    "jpn": "Japanese",
    "kam": "Kamba",
    "kan": "Kannada",
    "kat": "Georgian",
    "kaz": "Kazakh",
    "kea": "Kabuverdianu",
    "khk": "Halh Mongolian",
    "khm": "Khmer",
    "kir": "Kyrgyz",
    "kor": "Korean",
    "lao": "Lao",
    "lit": "Lithuanian",
    "ltz": "Luxembourgish",
    "lug": "Ganda",
    "luo": "Luo",
    "lvs": "Standard Latvian",
    "mai": "Maithili",
    "mal": "Malayalam",
    "mar": "Marathi",
    "mkd": "Macedonian",
    "mlt": "Maltese",
    "mni": "Meitei",
    "mya": "Burmese",
    "nld": "Dutch",
    "nno": "Norwegian Nynorsk",
    "nob": "Norwegian Bokm\u00e5l",
    "npi": "Nepali",
    "nya": "Nyanja",
    "oci": "Occitan",
    "ory": "Odia",
    "pan": "Punjabi",
    "pbt": "Southern Pashto",
    "pes": "Western Persian",
    "pol": "Polish",
    "por": "Portuguese",
    "ron": "Romanian",
    "rus": "Russian",
    "slk": "Slovak",
    "slv": "Slovenian",
    "sna": "Shona",
    "snd": "Sindhi",
    "som": "Somali",
    "spa": "Spanish",
    "srp": "Serbian",
    "swe": "Swedish",
    "swh": "Swahili",
    "tam": "Tamil",
    "tel": "Telugu",
    "tgk": "Tajik",
    "tgl": "Tagalog",
    "tha": "Thai",
    "tur": "Turkish",
    "ukr": "Ukrainian",
    "urd": "Urdu",
    "uzn": "Northern Uzbek",
    "vie": "Vietnamese",
    "xho": "Xhosa",
    "yor": "Yoruba",
    "yue": "Cantonese",
    "zlm": "Colloquial Malay",
    "zsm": "Standard Malay",
    "zul": "Zulu",
}