#!/usr/bin/env python3

import re

"""
Extracts code from the file "./Libraries.ts".
(Note that "Libraries.ts", must be in the same directory as 
this script).
"""

file = None

def read_file(library: str, model_name: str) -> str:
    text = file

    match = re.search('const ' + library + '.*', text, re.DOTALL).group()
    if match:
        text = match[match.index('`') + 1:match.index('`;')].replace('${model.id}', model_name)

    return text

file = """
import type { ModelData } from "./Types";
/**
 * Add your new library here.
 */
export enum ModelLibrary {
	"adapter-transformers"   = "Adapter Transformers",
	"allennlp"               = "allenNLP",
	"asteroid"               = "Asteroid",
	"diffusers"              = "Diffusers",
	"espnet"                 = "ESPnet",
	"fairseq"                = "Fairseq",
	"flair"                  = "Flair",
	"keras"                  = "Keras",
	"nemo"                   = "NeMo",
	"pyannote-audio"         = "pyannote.audio",
	"sentence-transformers"  = "Sentence Transformers",
	"sklearn"                = "Scikit-learn",
	"spacy"                  = "spaCy",
	"speechbrain"            = "speechbrain",
	"tensorflowtts"          = "TensorFlowTTS",
	"timm"                   = "Timm",
	"fastai"                 = "fastai",
	"transformers"           = "Transformers",
	"stanza"                 = "Stanza",
	"fasttext"               = "fastText",
	"stable-baselines3"      = "Stable-Baselines3",
	"ml-agents"              = "ML-Agents",
}

export const ALL_MODEL_LIBRARY_KEYS = Object.keys(ModelLibrary) as (keyof typeof ModelLibrary)[];


/**
 * Elements configurable by a model library.
 */
export interface LibraryUiElement {
	/**
	 * Name displayed on the main
	 * call-to-action button on the model page.
	 */
	btnLabel:  string;
	/**
	 * Repo name
	 */
	repoName: string;
	/**
	 * URL to library's repo
	 */
	repoUrl:   string;
	/**
	 * Code snippet displayed on model page
	 */
	snippet:   (model: ModelData) => string;
}

function nameWithoutNamespace(modelId: string): string {
	const splitted = modelId.split("/");
	return splitted.length === 1 ? splitted[0] : splitted[1];
}

//#region snippets

const adapter_transformers = (model: ModelData) =>
	`from transformers import ${model.config?.adapter_transformers?.model_class}

model = ${model.config?.adapter_transformers?.model_class}.from_pretrained("${model.config?.adapter_transformers?.{model.id}}")
model.load_adapter("${model.id}", source="hf")`;

const allennlpUnknown = (model: ModelData) =>
	`import allennlp_models
from allennlp.predictors.predictor import Predictor

predictor = Predictor.from_path("hf://${model.id}")`;

const allennlpQuestionAnswering = (model: ModelData) =>
	`import allennlp_models
from allennlp.predictors.predictor import Predictor

predictor = Predictor.from_path("hf://${model.id}")
predictor_input = {"passage": "My name is Wolfgang and I live in Berlin", "question": "Where do I live?"}
predictions = predictor.predict_json(predictor_input)`;

const allennlp = (model: ModelData) => {
	if (model.tags?.includes("question-answering")) {
		return allennlpQuestionAnswering(model);
	}
	return allennlpUnknown(model);
};

const asteroid = (model: ModelData) =>
	`from asteroid.models import BaseModel

model = BaseModel.from_pretrained("${model.id}")`;

const diffusers = (model: ModelData) =>
	`from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`;

const espnetTTS = (model: ModelData) =>
	`from espnet2.bin.tts_inference import Text2Speech

model = Text2Speech.from_pretrained("${model.id}")

speech, *_ = model("text to generate speech from")`;

const espnetASR = (model: ModelData) =>
	`from espnet2.bin.asr_inference import Speech2Text

model = Speech2Text.from_pretrained(
  "${model.id}"
)

speech, rate = soundfile.read("speech.wav")
text, *_ = model(speech)`;

const espnetUnknown = () =>
	`unknown model type (must be text-to-speech or automatic-speech-recognition)`;

const espnet = (model: ModelData) => {
	if (model.tags?.includes("text-to-speech")) {
		return espnetTTS(model);
	} else if (model.tags?.includes("automatic-speech-recognition")) {
		return espnetASR(model);
	}
	return espnetUnknown();
};

const fairseq = (model: ModelData) =>
	`from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub

models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
    "${model.id}"
)`;


const flair = (model: ModelData) =>
	`from flair.models import SequenceTagger

tagger = SequenceTagger.load("${model.id}")`;

const keras = (model: ModelData) =>
	`from huggingface_hub import from_pretrained_keras

model = from_pretrained_keras("${model.id}")
`;

const pyannote_audio_pipeline = (model: ModelData) =>
	`from pyannote.audio import Pipeline
  
pipeline = Pipeline.from_pretrained("${model.id}")

# inference on the whole file
pipeline("file.wav")

# inference on an excerpt
from pyannote.core import Segment
excerpt = Segment(start=2.0, end=5.0)

from pyannote.audio import Audio
waveform, sample_rate = Audio().crop("file.wav", excerpt)
pipeline({"waveform": waveform, "sample_rate": sample_rate})`;

const pyannote_audio_model = (model: ModelData) =>
	`from pyannote.audio import Model, Inference

model = Model.from_pretrained("${model.id}")
inference = Inference(model)

# inference on the whole file
inference("file.wav")

# inference on an excerpt
from pyannote.core import Segment
excerpt = Segment(start=2.0, end=5.0)
inference.crop("file.wav", excerpt)`;

const pyannote_audio = (model: ModelData) => {
	if (model.tags?.includes("pyannote-audio-pipeline")) {
		return pyannote_audio_pipeline(model);
	}
	return pyannote_audio_model(model);
};

const tensorflowttsTextToMel = (model: ModelData) =>
	`from tensorflow_tts.inference import AutoProcessor, TFAutoModel

processor = AutoProcessor.from_pretrained("${model.id}")
model = TFAutoModel.from_pretrained("${model.id}")
`;

const tensorflowttsMelToWav = (model: ModelData) =>
	`from tensorflow_tts.inference import TFAutoModel

model = TFAutoModel.from_pretrained("${model.id}")
audios = model.inference(mels)
`;

const tensorflowttsUnknown = (model: ModelData) =>
	`from tensorflow_tts.inference import TFAutoModel

model = TFAutoModel.from_pretrained("${model.id}")
`;

const tensorflowtts = (model: ModelData) => {
	if (model.tags?.includes("text-to-mel")) {
		return tensorflowttsTextToMel(model);
	} else if (model.tags?.includes("mel-to-wav")) {
		return tensorflowttsMelToWav(model);
	}
	return tensorflowttsUnknown(model);
};

const timm = (model: ModelData) =>
	`import timm

model = timm.create_model("hf_hub:${model.id}", pretrained=True)`;

const sklearn = (model: ModelData) =>
	`from huggingface_hub import hf_hub_download
import joblib

model = joblib.load(
	hf_hub_download("${model.id}", "sklearn_model.joblib")
)`;

const fastai = (model: ModelData) =>
	`from huggingface_hub import from_pretrained_fastai

learn = from_pretrained_fastai("${model.id}")`;

const sentenceTransformers = (model: ModelData) =>
	`from sentence_transformers import SentenceTransformer

model = SentenceTransformer("${model.id}")`;

const spacy = (model: ModelData) =>
	`!pip install https://huggingface.co/${model.id}/resolve/main/${nameWithoutNamespace(model.id)}-any-py3-none-any.whl

# Using spacy.load().
import spacy
nlp = spacy.load("${nameWithoutNamespace(model.id)}")

# Importing as module.
import ${nameWithoutNamespace(model.id)}
nlp = ${nameWithoutNamespace(model.id)}.load()`;

const stanza = (model: ModelData) =>
	`import stanza

stanza.download("${nameWithoutNamespace(model.id).replace("stanza-", "")}")
nlp = stanza.Pipeline("${nameWithoutNamespace(model.id).replace("stanza-", "")}")`;


const speechBrainMethod = (speechbrainInterface: string) => {
	switch (speechbrainInterface) {
		case "EncoderClassifier":
		   return "classify_file";
		case "EncoderDecoderASR":
		case "EncoderASR":
			return "transcribe_file";
		case "SpectralMaskEnhancement":
			return "enhance_file";
		case "SepformerSeparation":
			return "separate_file";
		default:
			return undefined;
	}
};

const speechbrain = (model: ModelData) => {
	const speechbrainInterface = model.config?.speechbrain?.interface;
	if (speechbrainInterface === undefined) {
		return `# interface not specified in config.json`;
	}

	const speechbrainMethod = speechBrainMethod(speechbrainInterface);
	if (speechbrainMethod === undefined) {
		return `# interface in config.json invalid`;
	}

	return `from speechbrain.pretrained import ${speechbrainInterface}
model = ${speechbrainInterface}.from_hparams(
  "${model.id}"
)
model.${speechbrainMethod}("file.wav")`;
};

const transformers = (model: ModelData) => {
	const info = model.transformersInfo;
	if (!info) {
		return `# ⚠️ Type of model unknown`;
	}
	if (info.processor) {
		const varName = info.processor === "AutoTokenizer" ? "tokenizer"
			: info.processor === "AutoFeatureExtractor" ? "extractor"
				: "processor"
		;
		return [
			`from transformers import ${info.processor}, ${info.auto_model}`,
			"",
			`${varName} = ${info.processor}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`,
			"",
			`model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`,
		].join("\n");
	} else {
		return [
			`from transformers import ${info.auto_model}`,
			"",
			`model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`,
		].join("\n");
	}
};

const fasttext = (model: ModelData) =>
	`from huggingface_hub import hf_hub_download
import fasttext

model = fasttext.load_model(hf_hub_download("${model.id}", "model.bin"))`;

const stableBaselines3 = (model: ModelData) =>
	`from huggingface_sb3 import load_from_hub
checkpoint = load_from_hub(
	repo_id="${model.id}",
	filename="{MODEL FILENAME}.zip",
)`;

const nemoDomainResolver = (domain: string, model: ModelData): string | undefined => {
	const modelName = `${nameWithoutNamespace(model.id)}.nemo`;

	switch (domain) {
		case "ASR":
			return `import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.ASRModel.from_pretrained("${model.id}")

transcriptions = asr_model.transcribe(["file.wav"])`;
		default:
			return undefined;
	}
};

const mlAgents = (model: ModelData) =>
	`mlagents-load-from-hf --repo-id="${model.id}" --local-dir="./downloads"`;
	
const nemo = (model: ModelData) => {
	let command: string | undefined = undefined;
	// Resolve the tag to a nemo domain/sub-domain 
	if (model.tags?.includes("automatic-speech-recognition")) {
		command = nemoDomainResolver("ASR", model);
	}
	
	return command ?? `# tag did not correspond to a valid NeMo domain.`;
};

//#endregion



export const MODEL_LIBRARIES_UI_ELEMENTS: { [key in keyof typeof ModelLibrary]?: LibraryUiElement } = {
	// ^^ TODO(remove the optional ? marker when Stanza snippet is available)
	"adapter-transformers": {
		btnLabel: "Adapter Transformers",
		repoName: "adapter-transformers",
		repoUrl:  "https://github.com/Adapter-Hub/adapter-transformers",
		snippet:  adapter_transformers,
	},
	"allennlp": {
		btnLabel: "AllenNLP",
		repoName: "AllenNLP",
		repoUrl:  "https://github.com/allenai/allennlp",
		snippet:  allennlp,
	},
	"asteroid": {
		btnLabel: "Asteroid",
		repoName: "Asteroid",
		repoUrl:  "https://github.com/asteroid-team/asteroid",
		snippet:  asteroid,
	},
	"diffusers": {
		btnLabel: "Diffusers",
		repoName: "🤗/diffusers",
		repoUrl:  "https://github.com/huggingface/diffusers",
		snippet:  diffusers,
	},
	"espnet": {
		btnLabel: "ESPnet",
		repoName: "ESPnet",
		repoUrl:  "https://github.com/espnet/espnet",
		snippet:  espnet,
	},
	"fairseq": {
		btnLabel: "Fairseq",
		repoName: "fairseq",
		repoUrl:  "https://github.com/pytorch/fairseq",
		snippet:  fairseq,
	},
	"flair": {
		btnLabel: "Flair",
		repoName: "Flair",
		repoUrl:  "https://github.com/flairNLP/flair",
		snippet:  flair,
	},
	"keras": {
		btnLabel: "Keras",
		repoName: "Keras",
		repoUrl:  "https://github.com/keras-team/keras",
		snippet:  keras,
	},
	"nemo": {
		btnLabel: "NeMo",
		repoName: "NeMo",
		repoUrl:  "https://github.com/NVIDIA/NeMo",
		snippet:  nemo,
	},
	"pyannote-audio": {
		btnLabel: "pyannote.audio",
		repoName: "pyannote-audio",
		repoUrl:  "https://github.com/pyannote/pyannote-audio",
		snippet:  pyannote_audio,
	},
	"sentence-transformers": {
		btnLabel: "sentence-transformers",
		repoName: "sentence-transformers",
		repoUrl:  "https://github.com/UKPLab/sentence-transformers",
		snippet:  sentenceTransformers,
	},
	"sklearn": {
		btnLabel: "Scikit-learn",
		repoName: "Scikit-learn",
		repoUrl:  "https://github.com/scikit-learn/scikit-learn",
		snippet:  sklearn,
	},
	"fastai": {
		btnLabel: "fastai",
		repoName: "fastai",
		repoUrl:  "https://github.com/fastai/fastai",
		snippet:  fastai,
	},
	"spacy": {
		btnLabel: "spaCy",
		repoName: "spaCy",
		repoUrl:  "https://github.com/explosion/spaCy",
		snippet:  spacy,
	},
	"speechbrain": {
		btnLabel: "speechbrain",
		repoName: "speechbrain",
		repoUrl:  "https://github.com/speechbrain/speechbrain",
		snippet:  speechbrain,
	},
	"stanza": {
		btnLabel: "Stanza",
		repoName: "stanza",
		repoUrl: "https://github.com/stanfordnlp/stanza",
		snippet: stanza,
	},
	"tensorflowtts": {
		btnLabel: "TensorFlowTTS",
		repoName: "TensorFlowTTS",
		repoUrl:  "https://github.com/TensorSpeech/TensorFlowTTS",
		snippet:  tensorflowtts,
	},
	"timm": {
		btnLabel: "timm",
		repoName: "pytorch-image-models",
		repoUrl:  "https://github.com/rwightman/pytorch-image-models",
		snippet:  timm,
	},
	"transformers": {
		btnLabel: "Transformers",
		repoName: "🤗/transformers",
		repoUrl:  "https://github.com/huggingface/transformers",
		snippet:  transformers,
	},
	"fasttext": {
		btnLabel: "fastText",
		repoName: "fastText",
		repoUrl:  "https://fasttext.cc/",
		snippet:  fasttext,
	},
	"stable-baselines3": {
		btnLabel: "stable-baselines3",
		repoName: "stable-baselines3",
		repoUrl:  "https://github.com/huggingface/huggingface_sb3",
		snippet:  stableBaselines3,
	},
	"ml-agents": {
		btnLabel: "ml-agents",
		repoName: "ml-agents",
		repoUrl:  "https://github.com/huggingface/ml-agents",
		snippet:  mlAgents,
	},
} as const;
"""


if __name__ == '__main__':
    import sys
    library_name = "keras"
    model_name = "Distillgpt2"
    print(read_file(library_name, model_name))
    
    """"
    try:
        args = sys.argv[1:]
        if args:
            print(read_file(args[0], args[1]))
    except IndexError:
        pass
    """