File size: 5,277 Bytes
b8cf6ae
 
 
 
 
1549ba5
 
b8cf6ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31cea2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16ba103
31cea2f
 
b8cf6ae
31cea2f
 
b8cf6ae
31cea2f
 
 
 
 
 
 
db225d0
31cea2f
b8cf6ae
 
1549ba5
 
b8cf6ae
1549ba5
b8cf6ae
 
1549ba5
b8cf6ae
 
1549ba5
 
b8cf6ae
1549ba5
 
 
 
 
 
 
 
b8cf6ae
1549ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8cf6ae
 
31cea2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import streamlit as st
import json
import torch
from transformers import AutoTokenizer
from modelling_cnn import CNNForNER, SentimentCNNModel
import pandas as pd
import altair as alt

# Load the Yoruba NER model
ner_model_name = "./my_model/pytorch_model.bin"
model_ner = "Testys/cnn_yor_ner"
ner_tokenizer = AutoTokenizer.from_pretrained(model_ner)
with open("./my_model/config.json", "r") as f:
    ner_config = json.load(f)

ner_model = CNNForNER(
                      pretrained_model_name=ner_config["pretrained_model_name"],
                      num_classes=ner_config["num_classes"]
                      )
ner_model.load_state_dict(torch.load(ner_model_name, map_location=torch.device('cpu')))
ner_model.eval()

# Load the Yoruba sentiment analysis model
sentiment_model_name = "./sent_model/sent_pytorch_model.bin"
model_sent = "Testys/cnn_sent_yor"
sentiment_tokenizer = AutoTokenizer.from_pretrained(model_sent)

with open("./sent_model/config.json", "r") as f:
    sentiment_config = json.load(f)

sentiment_model = SentimentCNNModel(
                                    transformer_model_name=sentiment_config["pretrained_model_name"],
                                    num_classes=sentiment_config["num_classes"]
                                    )

sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
sentiment_model.eval()


def analyze_text(text):
    # Tokenize input text for NER
    ner_inputs = ner_tokenizer(text, return_tensors="pt")

    input_ids = ner_inputs['input_ids']

    # Converting token IDs back to tokens
    tokens = [ner_tokenizer.convert_ids_to_tokens(id) for id in input_ids.squeeze().tolist()]

    
    # Perform Named Entity Recognition
    with torch.no_grad():
        ner_outputs = ner_model(**ner_inputs)
    
    ner_predictions = torch.argmax(ner_outputs, dim=-1)[0]
    ner_labels = ner_predictions.tolist()
    ner_labels = [ner_config["id2labels"][str(label)] for label in ner_labels]

    #matching the tokens with the labels
    ner_labels = [f"{token}: {label}" for token, label in zip(tokens, ner_labels)]

    # Tokenize input text for sentiment analysis
    sentiment_inputs = sentiment_tokenizer(text, return_tensors="pt")

    # Perform sentiment analysis
    with torch.no_grad():
        sentiment_outputs = sentiment_model(**sentiment_inputs)
    sentiment_probabilities = torch.argmax(sentiment_outputs, dim=1)
    sentiment_scores = sentiment_probabilities.tolist()
    sentiment_id = sentiment_scores[0]
    sentiment = sentiment_config["id2label"][str(sentiment_id)]

    return ner_labels, sentiment

def main():
    st.set_page_config(page_title="YorubaCNN for NER and Sentiment Analysis", layout="wide")
    
    st.title("YorubaCNN Models for NER and Sentiment Analysis")
    
    # Input text
    text = st.text_area("Enter Yoruba text", "")
    
    if st.button("Analyze"):
        if text:
            ner_labels, sentiment = analyze_text(text)
            
            # Display Named Entities
            st.header("Named Entities")
            
            # Convert NER results to DataFrame
            ner_df = pd.DataFrame([label.split(': ') for label in ner_labels], columns=['Token', 'Entity'])
            
            # Display NER results in a styled table
            st.dataframe(ner_df.style.highlight_max(axis=0, color='lightblue'))
            
            # Display Sentiment Analysis
            st.header("Sentiment Analysis")
            
            # Create a sentiment score (you may need to adjust this based on your model's output)
            sentiment_score = 0.8 if sentiment == "positive" else -0.8 if sentiment == "negative" else 0
            
            # Create a chart for sentiment visualization
            sentiment_df = pd.DataFrame({'sentiment': [sentiment_score]})
            chart = alt.Chart(sentiment_df).mark_bar().encode(
                x=alt.X('sentiment', scale=alt.Scale(domain=(-1, 1))),
                color=alt.condition(
                    alt.datum.sentiment > 0,
                    alt.value("green"),
                    alt.value("red")
                )
            ).properties(width=600, height=100)
            
            st.altair_chart(chart)
            st.write(f"Sentiment: {sentiment.capitalize()}")
    
    # Explanatory section
    with st.expander("About this analysis"):
        st.write("""
        This tool uses YorubaCNN models to perform two types of analysis on Yoruba text:
        
        1. **Named Entity Recognition (NER)**: Identifies and classifies named entities (e.g., person names, organizations) in the text.
        2. **Sentiment Analysis**: Determines the overall emotional tone of the text (positive, negative, or neutral).
        
        The models used are based on Convolutional Neural Networks (CNN) and are specifically trained for the Yoruba language.
        """)

    # Styling
    st.markdown("""
        <style>
        .stAlert > div {
            padding-top: 20px;
            padding-bottom: 20px;
        }
        .stDataFrame {
            padding: 10px;
            border-radius: 5px;
            background-color: #f0f2f6;
        }
        </style>
        """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()