Spaces:
Runtime error
Runtime error
mvy
commited on
Commit
•
8e19b14
1
Parent(s):
6f3f044
add validations checks
Browse files
app.py
CHANGED
@@ -24,7 +24,7 @@ examples = [
|
|
24 |
ner = NER('knowledgator/UTC-DeBERTa-small')
|
25 |
|
26 |
gradio_app = gr.Interface(
|
27 |
-
ner,
|
28 |
inputs = [
|
29 |
'text',
|
30 |
gr.Textbox(placeholder="Enter sentence here..."),
|
|
|
24 |
ner = NER('knowledgator/UTC-DeBERTa-small')
|
25 |
|
26 |
gradio_app = gr.Interface(
|
27 |
+
ner.process,
|
28 |
inputs = [
|
29 |
'text',
|
30 |
gr.Textbox(placeholder="Enter sentence here..."),
|
ner.py
CHANGED
@@ -4,6 +4,8 @@ import string
|
|
4 |
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
5 |
import spacy
|
6 |
import torch
|
|
|
|
|
7 |
|
8 |
class NER:
|
9 |
prompt: str = """
|
@@ -13,8 +15,14 @@ Identify entities in the text having the following classes:
|
|
13 |
Text:
|
14 |
"""
|
15 |
|
16 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
17 |
self.sents_batch = sents_batch
|
|
|
18 |
|
19 |
self.nlp: spacy.Language = spacy.load(
|
20 |
'en_core_web_sm',
|
@@ -23,13 +31,13 @@ Text:
|
|
23 |
self.nlp.add_pipe('sentencizer')
|
24 |
|
25 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
26 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
27 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
28 |
|
29 |
self.pipeline = pipeline(
|
30 |
"ner",
|
31 |
model=model,
|
32 |
-
tokenizer=tokenizer,
|
33 |
aggregation_strategy='first',
|
34 |
batch_size=12,
|
35 |
device=device
|
@@ -115,14 +123,47 @@ Text:
|
|
115 |
return outputs
|
116 |
|
117 |
|
118 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
self, labels: str, text: str, threshold: float=0.
|
120 |
) -> dict[str, any]:
|
121 |
-
labels_list =
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
chunks, chunks_starts = self.chunkanize(text)
|
124 |
inputs, prompts_lens = self.get_inputs(chunks, labels_list)
|
125 |
|
|
|
|
|
126 |
outputs = self.predict(
|
127 |
text, inputs, labels_list, chunks_starts, prompts_lens, threshold
|
128 |
)
|
|
|
4 |
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
5 |
import spacy
|
6 |
import torch
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
|
10 |
class NER:
|
11 |
prompt: str = """
|
|
|
15 |
Text:
|
16 |
"""
|
17 |
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
model_name: str,
|
21 |
+
sents_batch: int=10,
|
22 |
+
tokens_limit: int=2048
|
23 |
+
):
|
24 |
self.sents_batch = sents_batch
|
25 |
+
self.tokens_limit = tokens_limit
|
26 |
|
27 |
self.nlp: spacy.Language = spacy.load(
|
28 |
'en_core_web_sm',
|
|
|
31 |
self.nlp.add_pipe('sentencizer')
|
32 |
|
33 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
34 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
35 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
36 |
|
37 |
self.pipeline = pipeline(
|
38 |
"ner",
|
39 |
model=model,
|
40 |
+
tokenizer=self.tokenizer,
|
41 |
aggregation_strategy='first',
|
42 |
batch_size=12,
|
43 |
device=device
|
|
|
123 |
return outputs
|
124 |
|
125 |
|
126 |
+
def check_text(self, text: str) -> None:
|
127 |
+
if not text:
|
128 |
+
raise gr.Error('No text provided. Please provide text.')
|
129 |
+
|
130 |
+
|
131 |
+
def check_labels(self, labels: list[str]) -> None:
|
132 |
+
if not labels:
|
133 |
+
raise gr.Error(
|
134 |
+
'No labels provided. Please provide labels.'
|
135 |
+
' Multiple labels should be divided by commas.'
|
136 |
+
' See examples below.'
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
def check_tokens_limit(self, inputs: list[str]) -> None:
|
141 |
+
tokens = 0
|
142 |
+
for input_ in inputs:
|
143 |
+
tokens += len(self.tokenizer.encode(input_))
|
144 |
+
if tokens > self.tokens_limit:
|
145 |
+
raise gr.Error(
|
146 |
+
'Too many tokens! Please reduce size of text or amount of labels.'
|
147 |
+
f' Max tokens count is: {self.tokens_limit}.'
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
def process(
|
152 |
self, labels: str, text: str, threshold: float=0.
|
153 |
) -> dict[str, any]:
|
154 |
+
labels_list = list({
|
155 |
+
l for label in labels.split(',')
|
156 |
+
if (l:=label.strip())
|
157 |
+
})
|
158 |
+
|
159 |
+
self.check_labels(labels_list)
|
160 |
+
self.check_text(text)
|
161 |
|
162 |
chunks, chunks_starts = self.chunkanize(text)
|
163 |
inputs, prompts_lens = self.get_inputs(chunks, labels_list)
|
164 |
|
165 |
+
self.check_tokens_limit(inputs)
|
166 |
+
|
167 |
outputs = self.predict(
|
168 |
text, inputs, labels_list, chunks_starts, prompts_lens, threshold
|
169 |
)
|