YorubaCNN / main.py
Testys's picture
Upload main.py
31cea2f verified
raw
history blame
5.28 kB
import streamlit as st
import json
import torch
from transformers import AutoTokenizer
from modelling_cnn import CNNForNER, SentimentCNNModel
import pandas as pd
import altair as alt
# Load the Yoruba NER model
ner_model_name = "./my_model/pytorch_model.bin"
model_ner = "Testys/cnn_yor_ner"
ner_tokenizer = AutoTokenizer.from_pretrained(model_ner)
with open("./my_model/config.json", "r") as f:
ner_config = json.load(f)
ner_model = CNNForNER(
pretrained_model_name=ner_config["pretrained_model_name"],
num_classes=ner_config["num_classes"]
)
ner_model.load_state_dict(torch.load(ner_model_name, map_location=torch.device('cpu')))
ner_model.eval()
# Load the Yoruba sentiment analysis model
sentiment_model_name = "./sent_model/sent_pytorch_model.bin"
model_sent = "Testys/cnn_sent_yor"
sentiment_tokenizer = AutoTokenizer.from_pretrained(model_sent)
with open("./sent_model/config.json", "r") as f:
sentiment_config = json.load(f)
sentiment_model = SentimentCNNModel(
transformer_model_name=sentiment_config["pretrained_model_name"],
num_classes=sentiment_config["num_classes"]
)
sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
sentiment_model.eval()
def analyze_text(text):
# Tokenize input text for NER
ner_inputs = ner_tokenizer(text, return_tensors="pt")
input_ids = ner_inputs['input_ids']
# Converting token IDs back to tokens
tokens = [ner_tokenizer.convert_ids_to_tokens(id) for id in input_ids.squeeze().tolist()]
# Perform Named Entity Recognition
with torch.no_grad():
ner_outputs = ner_model(**ner_inputs)
ner_predictions = torch.argmax(ner_outputs, dim=-1)[0]
ner_labels = ner_predictions.tolist()
ner_labels = [ner_config["id2labels"][str(label)] for label in ner_labels]
#matching the tokens with the labels
ner_labels = [f"{token}: {label}" for token, label in zip(tokens, ner_labels)]
# Tokenize input text for sentiment analysis
sentiment_inputs = sentiment_tokenizer(text, return_tensors="pt")
# Perform sentiment analysis
with torch.no_grad():
sentiment_outputs = sentiment_model(**sentiment_inputs)
sentiment_probabilities = torch.argmax(sentiment_outputs, dim=1)
sentiment_scores = sentiment_probabilities.tolist()
sentiment_id = sentiment_scores[0]
sentiment = sentiment_config["id2label"][str(sentiment_id)]
return ner_labels, sentiment
def main():
st.set_page_config(page_title="YorubaCNN for NER and Sentiment Analysis", layout="wide")
st.title("YorubaCNN Models for NER and Sentiment Analysis")
# Input text
text = st.text_area("Enter Yoruba text", "")
if st.button("Analyze"):
if text:
ner_labels, sentiment = analyze_text(text)
# Display Named Entities
st.header("Named Entities")
# Convert NER results to DataFrame
ner_df = pd.DataFrame([label.split(': ') for label in ner_labels], columns=['Token', 'Entity'])
# Display NER results in a styled table
st.dataframe(ner_df.style.highlight_max(axis=0, color='lightblue'))
# Display Sentiment Analysis
st.header("Sentiment Analysis")
# Create a sentiment score (you may need to adjust this based on your model's output)
sentiment_score = 0.8 if sentiment == "positive" else -0.8 if sentiment == "negative" else 0
# Create a chart for sentiment visualization
sentiment_df = pd.DataFrame({'sentiment': [sentiment_score]})
chart = alt.Chart(sentiment_df).mark_bar().encode(
x=alt.X('sentiment', scale=alt.Scale(domain=(-1, 1))),
color=alt.condition(
alt.datum.sentiment > 0,
alt.value("green"),
alt.value("red")
)
).properties(width=600, height=100)
st.altair_chart(chart)
st.write(f"Sentiment: {sentiment.capitalize()}")
# Explanatory section
with st.expander("About this analysis"):
st.write("""
This tool uses YorubaCNN models to perform two types of analysis on Yoruba text:
1. **Named Entity Recognition (NER)**: Identifies and classifies named entities (e.g., person names, organizations) in the text.
2. **Sentiment Analysis**: Determines the overall emotional tone of the text (positive, negative, or neutral).
The models used are based on Convolutional Neural Networks (CNN) and are specifically trained for the Yoruba language.
""")
# Styling
st.markdown("""
<style>
.stAlert > div {
padding-top: 20px;
padding-bottom: 20px;
}
.stDataFrame {
padding: 10px;
border-radius: 5px;
background-color: #f0f2f6;
}
</style>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()