|
import streamlit as st |
|
import pandas as pd |
|
import re |
|
import json |
|
import transformers |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer |
|
|
|
st.set_page_config( |
|
page_title="Named Entity Recognition Wolof", |
|
page_icon="📘" |
|
) |
|
|
|
def convert_df(df: pd.DataFrame): |
|
return df.to_csv(index=False).encode('utf-8') |
|
|
|
def convert_json(df: pd.DataFrame): |
|
result = df.to_json(orient="index") |
|
parsed = json.loads(result) |
|
json_string = json.dumps(parsed) |
|
return json_string |
|
|
|
def load_model(): |
|
model = AutoModelForTokenClassification.from_pretrained("vonewman/wolof-finetuned-ner") |
|
trainer = Trainer(model=model) |
|
tokenizer = AutoTokenizer.from_pretrained("vonewman/wolof-finetuned-ner") |
|
return trainer, model, tokenizer |
|
|
|
def align_word_ids(texts): |
|
|
|
trainer, model, tokenizer = load_model() |
|
|
|
tokenized_inputs = tokenizer(texts, padding='max_length', max_length=218, truncation=True) |
|
|
|
word_ids = tokenized_inputs.word_ids() |
|
|
|
previous_word_idx = None |
|
label_ids = [] |
|
|
|
for word_idx in word_ids: |
|
|
|
if word_idx is None: |
|
label_ids.append(-100) |
|
|
|
elif word_idx != previous_word_idx: |
|
try: |
|
label_ids.append(1) |
|
except: |
|
label_ids.append(-100) |
|
else: |
|
try: |
|
label_ids.append(1 if label_all_tokens else -100) |
|
except: |
|
label_ids.append(-100) |
|
previous_word_idx = word_idx |
|
|
|
return label_ids |
|
|
|
|
|
def predict_ner_labels(model, tokenizer, sentence): |
|
use_cuda = torch.cuda.is_available() |
|
device = torch.device("cuda" if use_cuda else "cpu") |
|
|
|
if use_cuda: |
|
model = model.cuda() |
|
|
|
text = tokenizer(sentence, padding='max_length', max_length=218, truncation=True, return_tensors="pt") |
|
mask = text['attention_mask'].to(device) |
|
input_id = text['input_ids'].to(device) |
|
label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device) |
|
|
|
logits = model(input_id, mask, None) |
|
logits_clean = logits[0][label_ids != -100] |
|
|
|
predictions = logits_clean.argmax(dim=1).tolist() |
|
prediction_label = [id2tag[i] for i in predictions] |
|
|
|
return prediction_label |
|
|
|
id2tag = {0: 'O', 1: 'B-LOC', 2: 'B-PER', 3: 'I-PER', 4: 'B-ORG', 5: 'I-DATE', 6: 'B-DATE', 7: 'I-ORG', 8: 'I-LOC'} |
|
|
|
|
|
def tag_sentence(text): |
|
trainer, model, tokenizer = load_model() |
|
|
|
|
|
predictions = predict_ner_labels(model, tokenizer, text) |
|
|
|
|
|
df = pd.DataFrame({'word': text.split(), 'tag': predictions}) |
|
|
|
|
|
df['tag'] = df['tag'].map(id2tag) |
|
|
|
|
|
def color_tags(tag): |
|
if tag == 'O': |
|
return '' |
|
else: |
|
return 'color: blue' |
|
|
|
df['word'] = df.apply(lambda row: f'<span style="{color_tags(row["tag"])}">{row["word"]}</span>', axis=1) |
|
|
|
return df |
|
|
|
st.title("📘 Named Entity Recognition Wolof") |
|
|
|
with st.form(key='my_form'): |
|
x1 = st.text_input(label='Enter a sentence:', max_chars=250) |
|
submit_button = st.form_submit_button(label='🏷️ Create tags') |
|
|
|
if submit_button: |
|
if re.sub('\s+', '', x1) == '': |
|
st.error('Please enter a non-empty sentence.') |
|
elif re.match(r'\A\s*\w+\s*\Z', x1): |
|
st.error("Please enter a sentence with at least one word") |
|
else: |
|
st.markdown("### Tagged Sentence") |
|
st.header("") |
|
|
|
results = tag_sentence(x1) |
|
|
|
cs, c1, c2, c3, cLast = st.columns([0.75, 1.5, 1.5, 1.5, 0.75]) |
|
|
|
with c1: |
|
csvbutton = st.download_button(label="📥 Download .csv", data=convert_df(results), |
|
file_name="results.csv", mime='text/csv', key='csv') |
|
with c2: |
|
textbutton = st.download_button(label="📥 Download .txt", data=convert_df(results), |
|
file_name="results.text", mime='text/plain', key='text') |
|
with c3: |
|
jsonbutton = st.download_button(label="📥 Download .json", data=convert_json(results), |
|
file_name="results.json", mime='application/json', key='json') |
|
|
|
st.header("") |
|
|
|
c1, c2, c3 = st.columns([1, 3, 1]) |
|
|
|
with c2: |
|
st.write(results.to_html(escape=False), unsafe_allow_html=True) |
|
|
|
st.header("") |
|
st.header("") |
|
st.header("") |
|
with st.expander("ℹ️ - About this app", expanded=True): |
|
st.write( |
|
""" |
|
- The **Named Entity Recognition Wolof** app is a tool that performs named entity recognition in Wolof. |
|
- The available entities are: *corporation*, *location*, *person*, and *date*. |
|
- The app uses the [XLMRoberta model](https://huggingface.co/xlm-roberta-base), fine-tuned on the [masakhaNER](https://huggingface.co/datasets/masakhane/masakhaner2) dataset. |
|
- The model uses the **byte-level BPE tokenizer**. Each sentence is first tokenized. |
|
""" |
|
) |