import re import gradio as gr from dataclasses import dataclass from prettytable import PrettyTable import logging from pytorch_ie.annotations import LabeledSpan, BinaryRelation from pytorch_ie.auto import AutoPipeline from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument import transformer_re_text_classification2 from typing import List logger = logging.getLogger(__name__) @dataclass class ExampleDocument(TextDocument): entities: AnnotationList[LabeledSpan] = annotation_field(target="text") relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") ner_model_name_or_path = "pie/example-ner-spanclf-conll03" re_model_name_or_path = "DFKI-SLT/relation_classification_tacred_revisited" ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0) re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0) ner_tag_mapping = { 'ORG': 'ORGANIZATION', 'PER': 'PERSON', 'LOC': 'LOCATION' } def predict(text): document = ExampleDocument(text) ner_pipeline(document) while len(document.entities.predictions) > 0: entity = document.entities.predictions.pop(0) if entity.label in ner_tag_mapping: entity = LabeledSpan(start=entity.start, end=entity.end, label=ner_tag_mapping[entity.label], score=entity.score) if entity.label in re_pipeline.taskmodule.entity_labels: document.entities.append(entity) logger.warning(f"detected entity: {entity} (added)") else: logger.warning(f"detected entity: {entity} (NOT added)") re_pipeline(document) t = PrettyTable() t.field_names = ["head", "tail", "relation"] t.align = "l" for relation in document.relations.predictions: t.add_row([str(relation.head), str(relation.tail), relation.label]) html = t.get_html_string(format=True) html = ( "