max-long commited on
Commit
999a2cb
·
verified ·
1 Parent(s): dc3b839

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -45
app.py CHANGED
@@ -1,57 +1,154 @@
1
  import random
2
- from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
3
- from datasets import load_dataset
4
  import gradio as gr
 
5
 
6
- # Load the dataset with streaming
7
- dataset = load_dataset("TheBritishLibrary/blbooks", split="train", trust_remote_code=True, streaming=True)
8
-
9
- # Convert streaming dataset to an iterable
10
- dataset_iter = iter(dataset)
11
-
12
- # Load tokenizer and model
13
- model_name = "max-long/textile_machines_3_oct" # Replace with your model's name
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModelForTokenClassification.from_pretrained(model_name)
16
-
17
- # Initialize NER pipeline
18
- ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
19
-
20
- def get_random_snippet(stream_iter, tokenizer, max_tokens=350, max_attempts=1000):
21
- for _ in range(max_attempts):
22
- try:
23
- sample = next(stream_iter)['text']
24
- tokens = tokenizer.tokenize(sample)
25
- if len(tokens) <= max_tokens:
26
- return sample
27
- except StopIteration:
28
- break
29
- return "No suitable snippet found."
30
 
31
- def extract_textile_machinery_entities(text):
32
- ner_results = ner_pipeline(text)
33
- textile_entities = [ent for ent in ner_results if ent['entity_group'] == 'TEXTILE_MACHINERY']
34
- return textile_entities
35
 
36
- def analyze_text():
37
- snippet = get_random_snippet(dataset_iter, tokenizer)
38
- entities = extract_textile_machinery_entities(snippet)
 
39
 
40
- # Highlight entities in the text
41
- for ent in sorted(entities, key=lambda x: x['start'], reverse=True):
42
- snippet = snippet[:ent['start']] + f"**{snippet['start']:ent['end']}**" + snippet[ent['end']:]
 
 
 
 
 
 
 
 
 
43
 
44
- return snippet, entities
 
 
 
 
 
 
 
 
 
45
 
46
  # Build Gradio interface
47
- with gr.Blocks() as demo_interface:
48
- gr.Markdown("# Textile Machinery Entity Recognition Demo")
49
- gr.Markdown("Click the button below to analyze a random text snippet.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with gr.Row():
51
- analyze_button = gr.Button("Analyze Random Snippet")
52
- output_text = gr.Markdown()
53
- output_entities = gr.JSON()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- analyze_button.click(fn=analyze_text, outputs=[output_text, output_entities])
 
 
 
 
 
 
 
56
 
57
- demo_interface.launch()
 
 
1
  import random
2
+ from gliner import GLiNER
 
3
  import gradio as gr
4
+ from datasets import load_dataset
5
 
6
+ # Load the subset dataset from Hugging Face Hub
7
+ subset_dataset = load_dataset("TheBritishLibrary/blbooks", split="train", streaming=True, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Load the GLiNER model
10
+ model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True)
 
 
11
 
12
+ # Define the NER function
13
+ def ner(text: str, labels: str, threshold: float, nested_ner: bool):
14
+ labels = [label.strip() for label in labels.split(",")]
15
+ entities = model.predict_entities(text, labels, flat_ner=not nested_ner, threshold=threshold)
16
 
17
+ # Filter for "textile machinery" entities
18
+ textile_entities = [
19
+ {
20
+ "entity": ent["label"],
21
+ "word": ent["text"],
22
+ "start": ent["start"],
23
+ "end": ent["end"],
24
+ "score": ent.get("score", 0),
25
+ }
26
+ for ent in entities
27
+ if ent["label"].lower() == "textile machinery"
28
+ ]
29
 
30
+ # Highlight entities with HTML
31
+ highlighted_text = text
32
+ for ent in sorted(textile_entities, key=lambda x: x['start'], reverse=True):
33
+ highlighted_text = (
34
+ highlighted_text[:ent['start']] +
35
+ f"<span style='background-color: yellow'>{highlighted_text[ent['start']:ent['end']]}</span>" +
36
+ highlighted_text[ent['end']:]
37
+ )
38
+
39
+ return gr.HTML(highlighted_text), textile_entities
40
 
41
  # Build Gradio interface
42
+ with gr.Blocks(title="Textile Machinery NER Demo") as demo:
43
+ gr.Markdown(
44
+ """
45
+ # Textile Machinery Entity Recognition Demo
46
+ This demo selects a random text snippet from the British Library's books dataset and identifies "textile machinery" entities using a fine-tuned GLiNER model.
47
+ """
48
+ )
49
+
50
+ with gr.Accordion("How to run this model locally", open=False):
51
+ gr.Markdown(
52
+ """
53
+ ## Installation
54
+ To use this model, you must install the GLiNER Python library:
55
+ ```
56
+ !pip install gliner
57
+ ```
58
+
59
+ ## Usage
60
+ Once you've downloaded the GLiNER library, you can import the GLiNER class. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
61
+ """
62
+ )
63
+ gr.Code(
64
+ '''
65
+ from gliner import GLiNER
66
+ model = GLiNER.from_pretrained("max-long/textile_machines_3_oct")
67
+ text = "Your sample text here."
68
+ labels = ["textile machinery"]
69
+ entities = model.predict_entities(text, labels)
70
+ for entity in entities:
71
+ print(entity["text"], "=>", entity["label"])
72
+ ''',
73
+ language="python",
74
+ )
75
+ gr.Code(
76
+ """
77
+ Textile Machine 1 => textile machinery
78
+ Textile Machine 2 => textile machinery
79
+ """
80
+ )
81
+
82
+ input_text = gr.Textbox(
83
+ value="Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris.",
84
+ label="Text input",
85
+ placeholder="Enter your text here",
86
+ lines=5
87
+ )
88
+
89
  with gr.Row():
90
+ labels = gr.Textbox(
91
+ value="textile machinery",
92
+ label="Labels",
93
+ placeholder="Enter your labels here (comma separated)",
94
+ scale=2,
95
+ )
96
+ threshold = gr.Slider(
97
+ 0,
98
+ 1,
99
+ value=0.3,
100
+ step=0.01,
101
+ label="Threshold",
102
+ info="Lower the threshold to increase how many entities get predicted.",
103
+ scale=1,
104
+ )
105
+ nested_ner = gr.Checkbox(
106
+ value=False,
107
+ label="Nested NER",
108
+ info="Allow for nested NER?",
109
+ scale=0,
110
+ )
111
+
112
+ output = gr.HighlightedText(label="Predicted Entities")
113
+
114
+ submit_btn = gr.Button("Analyze Random Snippet")
115
+ refresh_btn = gr.Button("Get New Snippet")
116
+
117
+ # Function to fetch a new random snippet
118
+ def get_new_snippet():
119
+ # WARNING: Streaming datasets may have performance implications
120
+ try:
121
+ sample = next(iter(subset_dataset))['text']
122
+ return sample
123
+ except StopIteration:
124
+ return "No more snippets available."
125
+
126
+ refresh_btn.click(fn=get_new_snippet, outputs=input_text)
127
+
128
+ submit_btn.click(
129
+ fn=ner,
130
+ inputs=[input_text, labels, threshold, nested_ner],
131
+ outputs=[output, gr.JSON(label="Entities")]
132
+ )
133
+
134
+ examples = [
135
+ [
136
+ "However, both models lack other frequent DM symptoms including the fibre-type dependent atrophy, myotonia, cataract and male-infertility.",
137
+ "textile machinery",
138
+ 0.3,
139
+ False,
140
+ ],
141
+ # Add more examples as needed
142
+ ]
143
 
144
+ gr.Examples(
145
+ examples=examples,
146
+ inputs=[input_text, labels, threshold, nested_ner],
147
+ outputs=[output, gr.JSON(label="Entities")],
148
+ fn=ner,
149
+ label="Examples",
150
+ cache_examples=True,
151
+ )
152
 
153
+ demo.queue()
154
+ demo.launch(debug=True)