#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from allennlp.models.archival import archive_model, load_archive from allennlp.predictors.text_classifier import TextClassifierPredictor import gradio as gr import platform from project_settings import project_path from toolbox.allennlp.data.dataset_readers.text_classification_json import TextClassificationJsonReader from toolbox.os.command import Command def get_args(): parser = argparse.ArgumentParser() args = parser.parse_args() return args model_names = { "allennlp_text_classification": { "qgyd2021/language_identification": "https://huggingface.co/qgyd2021/language_identification" } } trained_model_dir = project_path / "trained_models/huggingface" trained_model_dir.mkdir(parents=True, exist_ok=True) def click_button_allennlp_text_classification(text: str, model_name: str): model_path = trained_model_dir / model_name if not model_path.exists(): model_path.parent.mkdir(exist_ok=True) Command.cd(model_path.parent.as_posix()) Command.popen("git clone https://huggingface.co/{}".format(model_name)) archive = load_archive(archive_file=model_path.as_posix()) predictor = TextClassifierPredictor( model=archive.model, dataset_reader=archive.dataset_reader, ) json_dict = { "sentence": text } outputs = predictor.predict_json( json_dict ) label = outputs["label"] probs = outputs["probs"] return label, round(max(probs), 4) def main(): args = get_args() brief_description = """ ## NLP Tools NLP Tools Demo """ # ui with gr.Blocks() as blocks: gr.Markdown(value=brief_description) with gr.Tabs(): with gr.TabItem("AllenNLP Text Classification"): with gr.Row(): with gr.Column(scale=3): text = gr.Text(label="text") ground_true = gr.Text(label="ground_true") model_name = gr.Dropdown( choices=list(model_names["allennlp_text_classification"].keys()) ) button = gr.Button("infer", variant="primary") with gr.Column(scale=3): label = gr.Text(label="label") prob = gr.Text(label="prob") gr.Examples( examples=[ ["你好", "zh", "qgyd2021/language_identification"] ], inputs=[text, ground_true, model_name], outputs=[label, prob], ) button.click( click_button_allennlp_text_classification, inputs=[text, model_name], outputs=[label, prob] ) blocks.queue().launch( share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=7860 ) return if __name__ == '__main__': main()