Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,57 +1,154 @@
|
|
1 |
import random
|
2 |
-
from
|
3 |
-
from datasets import load_dataset
|
4 |
import gradio as gr
|
|
|
5 |
|
6 |
-
# Load the dataset
|
7 |
-
|
8 |
-
|
9 |
-
# Convert streaming dataset to an iterable
|
10 |
-
dataset_iter = iter(dataset)
|
11 |
-
|
12 |
-
# Load tokenizer and model
|
13 |
-
model_name = "max-long/textile_machines_3_oct" # Replace with your model's name
|
14 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
-
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
16 |
-
|
17 |
-
# Initialize NER pipeline
|
18 |
-
ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
|
19 |
-
|
20 |
-
def get_random_snippet(stream_iter, tokenizer, max_tokens=350, max_attempts=1000):
|
21 |
-
for _ in range(max_attempts):
|
22 |
-
try:
|
23 |
-
sample = next(stream_iter)['text']
|
24 |
-
tokens = tokenizer.tokenize(sample)
|
25 |
-
if len(tokens) <= max_tokens:
|
26 |
-
return sample
|
27 |
-
except StopIteration:
|
28 |
-
break
|
29 |
-
return "No suitable snippet found."
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
textile_entities = [ent for ent in ner_results if ent['entity_group'] == 'TEXTILE_MACHINERY']
|
34 |
-
return textile_entities
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
# Build Gradio interface
|
47 |
-
with gr.Blocks() as
|
48 |
-
gr.Markdown(
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
with gr.Row():
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
|
|
|
1 |
import random
|
2 |
+
from gliner import GLiNER
|
|
|
3 |
import gradio as gr
|
4 |
+
from datasets import load_dataset
|
5 |
|
6 |
+
# Load the subset dataset from Hugging Face Hub
|
7 |
+
subset_dataset = load_dataset("TheBritishLibrary/blbooks", split="train", streaming=True, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
# Load the GLiNER model
|
10 |
+
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True)
|
|
|
|
|
11 |
|
12 |
+
# Define the NER function
|
13 |
+
def ner(text: str, labels: str, threshold: float, nested_ner: bool):
|
14 |
+
labels = [label.strip() for label in labels.split(",")]
|
15 |
+
entities = model.predict_entities(text, labels, flat_ner=not nested_ner, threshold=threshold)
|
16 |
|
17 |
+
# Filter for "textile machinery" entities
|
18 |
+
textile_entities = [
|
19 |
+
{
|
20 |
+
"entity": ent["label"],
|
21 |
+
"word": ent["text"],
|
22 |
+
"start": ent["start"],
|
23 |
+
"end": ent["end"],
|
24 |
+
"score": ent.get("score", 0),
|
25 |
+
}
|
26 |
+
for ent in entities
|
27 |
+
if ent["label"].lower() == "textile machinery"
|
28 |
+
]
|
29 |
|
30 |
+
# Highlight entities with HTML
|
31 |
+
highlighted_text = text
|
32 |
+
for ent in sorted(textile_entities, key=lambda x: x['start'], reverse=True):
|
33 |
+
highlighted_text = (
|
34 |
+
highlighted_text[:ent['start']] +
|
35 |
+
f"<span style='background-color: yellow'>{highlighted_text[ent['start']:ent['end']]}</span>" +
|
36 |
+
highlighted_text[ent['end']:]
|
37 |
+
)
|
38 |
+
|
39 |
+
return gr.HTML(highlighted_text), textile_entities
|
40 |
|
41 |
# Build Gradio interface
|
42 |
+
with gr.Blocks(title="Textile Machinery NER Demo") as demo:
|
43 |
+
gr.Markdown(
|
44 |
+
"""
|
45 |
+
# Textile Machinery Entity Recognition Demo
|
46 |
+
This demo selects a random text snippet from the British Library's books dataset and identifies "textile machinery" entities using a fine-tuned GLiNER model.
|
47 |
+
"""
|
48 |
+
)
|
49 |
+
|
50 |
+
with gr.Accordion("How to run this model locally", open=False):
|
51 |
+
gr.Markdown(
|
52 |
+
"""
|
53 |
+
## Installation
|
54 |
+
To use this model, you must install the GLiNER Python library:
|
55 |
+
```
|
56 |
+
!pip install gliner
|
57 |
+
```
|
58 |
+
|
59 |
+
## Usage
|
60 |
+
Once you've downloaded the GLiNER library, you can import the GLiNER class. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
|
61 |
+
"""
|
62 |
+
)
|
63 |
+
gr.Code(
|
64 |
+
'''
|
65 |
+
from gliner import GLiNER
|
66 |
+
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct")
|
67 |
+
text = "Your sample text here."
|
68 |
+
labels = ["textile machinery"]
|
69 |
+
entities = model.predict_entities(text, labels)
|
70 |
+
for entity in entities:
|
71 |
+
print(entity["text"], "=>", entity["label"])
|
72 |
+
''',
|
73 |
+
language="python",
|
74 |
+
)
|
75 |
+
gr.Code(
|
76 |
+
"""
|
77 |
+
Textile Machine 1 => textile machinery
|
78 |
+
Textile Machine 2 => textile machinery
|
79 |
+
"""
|
80 |
+
)
|
81 |
+
|
82 |
+
input_text = gr.Textbox(
|
83 |
+
value="Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris.",
|
84 |
+
label="Text input",
|
85 |
+
placeholder="Enter your text here",
|
86 |
+
lines=5
|
87 |
+
)
|
88 |
+
|
89 |
with gr.Row():
|
90 |
+
labels = gr.Textbox(
|
91 |
+
value="textile machinery",
|
92 |
+
label="Labels",
|
93 |
+
placeholder="Enter your labels here (comma separated)",
|
94 |
+
scale=2,
|
95 |
+
)
|
96 |
+
threshold = gr.Slider(
|
97 |
+
0,
|
98 |
+
1,
|
99 |
+
value=0.3,
|
100 |
+
step=0.01,
|
101 |
+
label="Threshold",
|
102 |
+
info="Lower the threshold to increase how many entities get predicted.",
|
103 |
+
scale=1,
|
104 |
+
)
|
105 |
+
nested_ner = gr.Checkbox(
|
106 |
+
value=False,
|
107 |
+
label="Nested NER",
|
108 |
+
info="Allow for nested NER?",
|
109 |
+
scale=0,
|
110 |
+
)
|
111 |
+
|
112 |
+
output = gr.HighlightedText(label="Predicted Entities")
|
113 |
+
|
114 |
+
submit_btn = gr.Button("Analyze Random Snippet")
|
115 |
+
refresh_btn = gr.Button("Get New Snippet")
|
116 |
+
|
117 |
+
# Function to fetch a new random snippet
|
118 |
+
def get_new_snippet():
|
119 |
+
# WARNING: Streaming datasets may have performance implications
|
120 |
+
try:
|
121 |
+
sample = next(iter(subset_dataset))['text']
|
122 |
+
return sample
|
123 |
+
except StopIteration:
|
124 |
+
return "No more snippets available."
|
125 |
+
|
126 |
+
refresh_btn.click(fn=get_new_snippet, outputs=input_text)
|
127 |
+
|
128 |
+
submit_btn.click(
|
129 |
+
fn=ner,
|
130 |
+
inputs=[input_text, labels, threshold, nested_ner],
|
131 |
+
outputs=[output, gr.JSON(label="Entities")]
|
132 |
+
)
|
133 |
+
|
134 |
+
examples = [
|
135 |
+
[
|
136 |
+
"However, both models lack other frequent DM symptoms including the fibre-type dependent atrophy, myotonia, cataract and male-infertility.",
|
137 |
+
"textile machinery",
|
138 |
+
0.3,
|
139 |
+
False,
|
140 |
+
],
|
141 |
+
# Add more examples as needed
|
142 |
+
]
|
143 |
|
144 |
+
gr.Examples(
|
145 |
+
examples=examples,
|
146 |
+
inputs=[input_text, labels, threshold, nested_ner],
|
147 |
+
outputs=[output, gr.JSON(label="Entities")],
|
148 |
+
fn=ner,
|
149 |
+
label="Examples",
|
150 |
+
cache_examples=True,
|
151 |
+
)
|
152 |
|
153 |
+
demo.queue()
|
154 |
+
demo.launch(debug=True)
|