NazmusAshrafi's picture
updated
04aa744
import streamlit as st
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
# @st.cache(allow_output_mutation=True)
def get_model():
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
pulled_model = AutoModelForSequenceClassification.from_pretrained(
"NazmusAshrafi/large_dataset_stock_twitter_topic_Bert")
return tokenizer, pulled_model
tokenizer, model = get_model()
classifier = pipeline("sentiment-analysis",
model=model, tokenizer=tokenizer)
# print(classifier("AAP Up Theory = Will go up in future, buy now"))
st.title("Find the topic of a stock related tweets")
st.subheader(
'This model can predict 5 topics - :blue[Tesla Investment News], :green[Technology stock watchlist], :orange[Esports News], :blue[Apple Market Updates], :green[Amazon Updates] - Entering a topic related tweet will yeild the best results')
# elif topic == 0:
# return 'Tesla Investment News'
# elif topic == 1:
# return 'Technology stock watchlist'
# elif topic == 2:
# return 'Esports News'
# elif topic == 3:
# return 'Apple Market Updates'
# elif topic == 4:
# return 'Amazon Updates'
st.subheader(
'Examples:')
st.markdown(
':blue[Tesla Investment News: "Electric cars stock going to go up in future, buy now"]')
st.markdown(
':green[Technology stock watchlist: "Keep a look out for that stock, its really good!"]')
st.markdown(
':orange[Esports News: "Ninja is going to play in the next tournament, the best is back playing Fortnite."]')
st.markdown(
':blue[Apple Market Updates: "$AAPL released a new phone but the looks never change much. Buy anyway!"]')
st.markdown(
':green[Amazon Updates: "Black Fridays at Amazon, hail Jeff Bezos"]')
st.subheader("", divider='rainbow')
user_input = st.text_area("Enter a tweet about a stock")
button = st.button("Analyze")
# print(classifier("AAP Down Theory = Will go down in future,do not buy now"))
if user_input and button:
# output
st.write("Topic Prediction: ", classifier(user_input)[0]['label'])
st.write("Confidence Score: ", classifier(user_input)[0]['score'])
# st.markdown(
# 'Note: Model may be predicting "Investment decision" way too often, this is because of the data imbalance. The model was trained on 3 topics, but the data was not balanced. The model was trained on a dataset where a large portion of the data was on "Investment decision" topic. This is why the model is biased towards "Investment decision" topic. This is a flaw of this model, and it can be fixed by training the model on a balanced dataset.')