|
import streamlit as st |
|
import json |
|
import torch |
|
from transformers import AutoTokenizer |
|
from modelling_cnn import CNNForNER, SentimentCNNModel |
|
import pandas as pd |
|
import altair as alt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ner_model = AutoModelForTokenClassification.from_pretrained("masakhane/afroxlmr-large-ner-masakhaner-1.0_2.0") |
|
ner_tokenizers = AutoTokenizer.from_pretrained("masakhane/afroxlmr-large-ner-masakhaner-1.0_2.0") |
|
ner_config = ner_model.config |
|
|
|
ner_model.eval() |
|
|
|
|
|
|
|
sentiment_model_name = "./sent_model/sent_pytorch_model.bin" |
|
model_sent = "Testys/cnn_sent_yor" |
|
sentiment_tokenizer = AutoTokenizer.from_pretrained(model_sent) |
|
|
|
with open("./sent_model/config.json", "r") as f: |
|
sentiment_config = json.load(f) |
|
|
|
sentiment_model = SentimentCNNModel( |
|
transformer_model_name=sentiment_config["pretrained_model_name"], |
|
num_classes=sentiment_config["num_classes"] |
|
) |
|
|
|
sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu'))) |
|
sentiment_model.eval() |
|
|
|
|
|
def analyze_text(text): |
|
|
|
ner_inputs = ner_tokenizer(text, return_tensors="pt") |
|
|
|
|
|
tokens = ner_tokenizer.convert_ids_to_tokens(ner_inputs.input_ids[0]) |
|
with torch.no_grad(): |
|
ner_outputs = ner_model(**ner_inputs) |
|
|
|
print(ner_outputs) |
|
|
|
ner_predictions = torch.argmax(ner_outputs.logits, dim=-1)[0] |
|
ner_labels = ner_predictions.tolist() |
|
print(ner_labels) |
|
ner_labels = [ner_config.id2label[label] for label in ner_labels] |
|
|
|
|
|
ner_labels = [f"{token}: {label}" for token, label in zip(tokens, ner_labels)] |
|
|
|
|
|
sentiment_inputs = sentiment_tokenizer(text, max_length= 514, truncation= True, padding= "max_length", return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
sentiment_outputs = sentiment_model(**sentiment_inputs) |
|
sentiment_probabilities = torch.argmax(sentiment_outputs, dim=1) |
|
sentiment_scores = sentiment_probabilities.tolist() |
|
sentiment_id = sentiment_scores[0] |
|
sentiment = sentiment_config["id2label"][str(sentiment_id)] |
|
|
|
return ner_labels, sentiment |
|
|
|
def main(): |
|
st.set_page_config(page_title="YorubaCNN for NER and Sentiment Analysis", layout="wide") |
|
|
|
st.title("YorubaCNN Models for NER and Sentiment Analysis") |
|
|
|
|
|
text = st.text_area("Enter Yoruba text", "") |
|
|
|
if st.button("Analyze"): |
|
if text: |
|
ner_labels, sentiment = analyze_text(text) |
|
|
|
|
|
st.header("Named Entities") |
|
|
|
|
|
ner_df = pd.DataFrame([label.split(': ') for label in ner_labels], columns=['Token', 'Entity']) |
|
|
|
|
|
st.dataframe(ner_df.style.highlight_max(axis=0, color='lightblue')) |
|
|
|
|
|
st.header("Sentiment Analysis") |
|
|
|
|
|
sentiment_score = 0.8 if sentiment == "positive" else -0.8 if sentiment == "negative" else 0 |
|
|
|
|
|
sentiment_df = pd.DataFrame({'sentiment': [sentiment_score]}) |
|
chart = alt.Chart(sentiment_df).mark_bar().encode( |
|
x=alt.X('sentiment', scale=alt.Scale(domain=(-1, 1))), |
|
color=alt.condition( |
|
alt.datum.sentiment > 0, |
|
alt.value("green"), |
|
alt.value("red") |
|
) |
|
).properties(width=600, height=100) |
|
|
|
st.altair_chart(chart) |
|
st.write(f"Sentiment: {sentiment.capitalize()}") |
|
|
|
|
|
with st.expander("About this analysis"): |
|
st.write(""" |
|
This tool uses YorubaCNN models to perform two types of analysis on Yoruba text: |
|
|
|
1. **Named Entity Recognition (NER)**: Identifies and classifies named entities (e.g., person names, organizations) in the text. |
|
2. **Sentiment Analysis**: Determines the overall emotional tone of the text (positive, negative, or neutral). |
|
|
|
The models used are based on Convolutional Neural Networks (CNN) and are specifically trained for the Yoruba language. |
|
""") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.stAlert > div { |
|
padding-top: 20px; |
|
padding-bottom: 20px; |
|
} |
|
.stDataFrame { |
|
padding: 10px; |
|
border-radius: 5px; |
|
background-color: #f0f2f6; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
if __name__ == "__main__": |
|
main() |