YorubaCNN / main.py
Testys's picture
Push
6303acb
raw
history blame
2.77 kB
import streamlit as st
import json
import torch
from transformers import AutoTokenizer
from modelling_cnn import CNNForNER, SentimentCNNModel
# Load the Yoruba NER model
ner_model_name = "./my_model/pytorch_model.bin"
model_ner = "Testys/cnn_yor_ner"
ner_tokenizer = AutoTokenizer.from_pretrained(model_ner)
with open("./my_model/config.json", "r") as f:
ner_config = json.load(f)
ner_model = CNNForNER(
pretrained_model_name=ner_config["pretrained_model_name"],
num_classes=ner_config["num_classes"]
)
ner_model.load_state_dict(torch.load(ner_model_name, map_location=torch.device('cpu')))
ner_model.eval()
# Load the Yoruba sentiment analysis model
sentiment_model_name = "./sent_model/sent_pytorch_model.bin"
model_sent = "Testys/cnn_sent_yor"
sentiment_tokenizer = AutoTokenizer.from_pretrained(model_sent)
with open("./sent_model/config.json", "r") as f:
sentiment_config = json.load(f)
sentiment_model = SentimentCNNModel(
transformer_model_name=sentiment_config["pretrained_model_name"],
num_classes=sentiment_config["num_classes"]
)
sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
sentiment_model.eval()
def analyze_text(text):
# Tokenize input text for NER
ner_inputs = ner_tokenizer(text, return_tensors="pt")
# Perform Named Entity Recognition
with torch.no_grad():
ner_outputs = ner_model(**ner_inputs)
ner_predictions = torch.argmax(ner_outputs, dim=-1)
ner_labels = ner_predictions.tolist()
# Tokenize input text for sentiment analysis
sentiment_inputs = sentiment_tokenizer(text, return_tensors="pt")
# Perform sentiment analysis
with torch.no_grad():
sentiment_outputs = sentiment_model(**sentiment_inputs)
sentiment_probabilities = torch.softmax(sentiment_outputs, dim=1)
sentiment_scores = sentiment_probabilities.tolist()
sentiment = sentiment_config["id2label"][torch.argmax(sentiment_outputs).item()]
return ner_labels, sentiment
def main():
st.title("YorubaCNN Models for NER and Sentiment Analysis")
# Input text
text = st.text_area("Enter Yoruba text", "")
if st.button("Analyze"):
if text:
ner_labels, sentiment_scores = analyze_text(text)
# Display Named Entities
st.subheader("Named Entities")
for label in ner_labels:
st.write(f"- {label}")
# Display Sentiment Analysis
st.subheader("Sentiment Analysis")
st.write(f"Sentiment: {sentiment_scores}")
if __name__ == "__main__":
main()