import re import gradio as gr from dataclasses import dataclass from prettytable import PrettyTable from pytorch_ie import AnnotationList, BinaryRelation, Span, LabeledSpan, Pipeline, TextDocument, annotation_field from pytorch_ie.models import TransformerSpanClassificationModel, TransformerTextClassificationModel from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule, TransformerRETextClassificationTaskModule from typing import List @dataclass class ExampleDocument(TextDocument): entities: AnnotationList[LabeledSpan] = annotation_field(target="text") relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") model_name_or_path = "pie/example-ner-spanclf-conll03" ner_taskmodule = TransformerSpanClassificationTaskModule.from_pretrained(model_name_or_path) ner_model = TransformerSpanClassificationModel.from_pretrained(model_name_or_path) ner_pipeline = Pipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1, num_workers=0) model_name_or_path = "pie/example-re-textclf-tacred" re_taskmodule = TransformerRETextClassificationTaskModule.from_pretrained(model_name_or_path) re_model = TransformerTextClassificationModel.from_pretrained(model_name_or_path) re_pipeline = Pipeline(model=re_model, taskmodule=re_taskmodule, device=-1, num_workers=0) def predict(text): document = ExampleDocument(text) ner_pipeline(document, predict_field="entities") for entity in document.entities.predictions: document.entities.append(entity) re_pipeline(document, predict_field="relations") t = PrettyTable() t.field_names = ["head", "tail", "relation"] t.align = "l" for relation in document.relations.predictions: t.add_row([str(relation.head), str(relation.tail), relation.label]) html = t.get_html_string(format=True) html = ( "
" + html + "
" ) return html iface = gr.Interface( fn=predict, inputs="textbox", outputs="html", ) iface.launch()