Spaces:
Runtime error
Runtime error
# Import baseline dependencies | |
import csv | |
import time | |
from datetime import date | |
import numpy as np | |
import pandas as pd | |
import pandas_datareader as data | |
import requests | |
import streamlit as st | |
from bs4 import BeautifulSoup | |
from plotly import graph_objs as go | |
from prophet import Prophet | |
from prophet.plot import plot_plotly | |
# summarisation (Pegasus) and sentiment analysis (BERT) models | |
from transformers import (BertForSequenceClassification, BertTokenizer, | |
PegasusTokenizer, TFPegasusForConditionalGeneration, | |
pipeline) | |
# Setting streamlit page config to wide | |
st.set_page_config(layout='wide') | |
# Setup summarisation model | |
def get_summarisation_model(): | |
sum_model_name = "human-centered-summarization/financial-summarization-pegasus" | |
sum_tokenizer = PegasusTokenizer.from_pretrained(sum_model_name) | |
sum_model = TFPegasusForConditionalGeneration.from_pretrained( | |
sum_model_name) | |
# returning model and tokenizer | |
return sum_model, sum_tokenizer | |
# Setup sentiment analysis model | |
def get_sentiment_pepeline(): | |
sen_model_name = "ahmedrachid/FinancialBERT-Sentiment-Analysis" | |
sen_tokenizer = BertTokenizer.from_pretrained(sen_model_name) | |
sen_model = BertForSequenceClassification.from_pretrained( | |
sen_model_name, num_labels=3) | |
sentiment_nlp = pipeline("sentiment-analysis", | |
model=sen_model, tokenizer=sen_tokenizer) | |
# returning sentiment pipeline | |
return sentiment_nlp | |
# Get all links from Google News | |
def search_urls(ticker, num, date): | |
# https://developers.google.com/custom-search/docs/xml_results_appendices#interfaceLanguages | |
# Request headers and parameters | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36", | |
} | |
params = { | |
"as_sitesearch": "finance.yahoo.com", # we only want results from Yahoo Finance | |
"hl": "en", # language of the interface | |
"gl": "us", # country of the search | |
"tbm": "nws", # news results | |
"lr": "lang_en" # language filter | |
} | |
# base URL | |
url = "https://www.google.com/search" | |
# search query | |
params["as_epq"] = ticker | |
params["as_occt"] = ticker | |
# number of search results per page | |
params["num"] = num | |
# articles timeframe | |
# d = past 24h, h = past hour, w = past week, m = pasth month | |
if date == "Past week": | |
params["as_qdr"] = "w" | |
elif date == "Past day": | |
params["as_qdr"] = "d" | |
r = requests.get(url, headers=headers, params=params, | |
cookies={'CONSENT': 'YES+'}) | |
time.sleep(5) | |
st.write("Searched URL:") | |
st.write(r.url) # debugging | |
soup = BeautifulSoup(r.text, "html.parser") | |
atags = soup.find_all("a", "WlydOe") | |
hrefs = [link["href"] for link in atags] | |
return hrefs | |
# Extract title, date, and content of the article from all given URLs | |
def search_scrape(urls): | |
articles = [] | |
titles = [] | |
post_dates = [] | |
for url in urls: | |
r = requests.get(url) | |
time.sleep(5) | |
soup = BeautifulSoup(r.text, "html.parser") | |
# title | |
title = soup.find("header", "caas-title-wrapper") | |
# handling missing titles | |
if title is not None: | |
titles.append(title.text) | |
else: | |
titles.append("N/A") | |
# posting date of the article | |
date = soup.find("time", "caas-attr-meta-time") | |
# handling missing dates | |
if date is not None: | |
post_dates.append(date.text) | |
else: | |
post_dates.append("N/A") | |
# article content | |
# all the paragraphs within the article | |
paragraphs = soup.find_all("div", "caas-body") | |
text = [paragraph.text for paragraph in paragraphs] | |
# extract only the first 300 words (needs to be done to avoid limit | |
# problems with the summarisation model) | |
words = " ".join(text).split(" ")[:350] | |
article = " ".join(words) | |
articles.append(article) | |
return titles, post_dates, articles | |
# Summarise all given articles using a fine-tuned Pegasus Transformers model | |
def summarise_articles(sum_model, sum_tokenizer, articles): | |
summaries = [] | |
for article in articles: | |
# source | |
# https://huggingface.co/human-centered-summarization/financial-summarization-pegasus | |
input_ids = sum_tokenizer( | |
article, return_tensors="tf").input_ids | |
output = sum_model.generate( | |
input_ids, max_length=55, num_beans=5, early_stopping=True) | |
summary = sum_tokenizer.decode( | |
output[0], skip_special_tokens=True) | |
summaries.append(summary) | |
return summaries | |
# Join all data into rows | |
def create_output_array(titles, post_dates, summarised_articles, sentiment_scores, raw_urls): | |
output_array = [] | |
for idx in range(len(summarised_articles)): | |
row = [ | |
titles[idx], | |
post_dates[idx], | |
summarised_articles[idx], | |
sentiment_scores[idx]["label"].capitalize(), | |
"{:.0%}".format(sentiment_scores[idx]["score"]), | |
raw_urls[idx] | |
] | |
output_array.append(row) | |
return output_array | |
# Convert dataframe to .csv file | |
def convert_df(df): | |
return df.to_csv().encode("utf-8") | |
# ------------------------------------------------------------------------------ | |
# Load data from Yahoo Finance | |
def load_data(ticker, start, end): | |
df = data.DataReader(ticker, "yahoo", start, end) | |
df.reset_index(inplace=True) | |
return df | |
# Predict stock trend for N years using Prophet | |
def predict(df, period): | |
df_train = df[["Date", "Close"]] | |
df_train = df_train.rename(columns={"Date": "ds", "Close": "y"}) | |
model = Prophet() | |
model.fit(df_train) | |
future = model.make_future_dataframe(periods=period) | |
forecast = model.predict(future) | |
return model, forecast | |
def main_page(): | |
# Financial News Analysis feature | |
# Streamlit text | |
st.sidebar.markdown("## Financial News Analysis") | |
st.sidebar.write( | |
"Scrape, auto summarise and calculate sentiment for stock and crypto news.") | |
# User input | |
ticker = st.text_input("Ticker:", "TSLA") | |
num = st.number_input("Number of articles:", 5, 15, 10) | |
date = st.selectbox( | |
"Timeline:", ["Past week", "Past day"]) | |
search = st.button("Search") | |
st.info("Please do not spam the search button") | |
st.markdown("---") | |
# If button is pressed | |
if search: | |
with st.spinner("Processing articles, please wait..."): | |
# Search query and return all articles' links | |
raw_urls = search_urls(ticker, num, date) | |
# If any problems happened (e.g., blocked by Google's server) stop app | |
if not raw_urls: | |
st.error("Please wait a few minutes before trying again") | |
else: | |
# Scrap title, posting date and article content from all the URLs | |
titles, post_dates, articles = search_scrape(raw_urls) | |
# Summarise all articles | |
summarised_articles = summarise_articles( | |
sum_model, sum_tokenizer, articles) | |
# Calculate sentiment for all articles | |
# source | |
# https://huggingface.co/ahmedrachid/FinancialBERT-Sentiment-Analysis | |
sentiment_scores = sentiment_pipeline(summarised_articles) | |
# Create dataframe | |
output_array = create_output_array( | |
titles, post_dates, summarised_articles, sentiment_scores, raw_urls) | |
cols = ["Title", "Date", "Summary", | |
"Label", "Confidence", "URL"] | |
df = pd.DataFrame(output_array, columns=cols) | |
# Visualise dataframe | |
st.dataframe(df) | |
# Convert dataframe to csv and let user download it | |
csv_file = convert_df(df) | |
# Download CSV | |
st.download_button( | |
"Save data to CSV", csv_file, "assetsummaries.csv", "text/csv", key="download-csv") | |
def page2(): | |
# Stock Trend Forecasting feature | |
# Streamlit text | |
st.sidebar.markdown("## Stock Trend Forecasting") | |
st.sidebar.write( | |
"A simple dashboard for stock trend forecasting and analysis.") | |
# Start and end date of data | |
start = "2010-01-01" | |
end = date.today().strftime("%Y-%m-%d") | |
# Ticker selection | |
ticker = st.text_input("Ticker:", "AAPL") | |
# Loading data from Yahoo Finance | |
df = load_data(ticker, start, end) | |
# Period selection | |
n_years = st.number_input("Years of prediction:", 1, 4, 1) | |
period = n_years * 365 | |
# Start prediction button | |
init = st.button("Predict") | |
st.markdown("---") | |
# Visualisation | |
# Dropping adj close column | |
df = df.drop(["Adj Close"], axis=1) | |
# Visualisation | |
# Exploratory analysis | |
st.subheader("Exploratory analysis") | |
st.write(df.describe()) | |
# Plot raw closing data with 100 and 200 days MA (for simple analysis) | |
st.subheader("Closing data, MA100 and MA200") | |
ma100 = df.Close.rolling(100).mean() | |
ma200 = df.Close.rolling(200).mean() | |
fig = go.Figure() | |
fig.update_layout( | |
margin=dict( | |
l=0, | |
r=0, | |
b=0, | |
t=50, | |
pad=4 | |
) | |
) | |
fig.add_trace(go.Scatter(x=df["Date"], | |
y=df['Close'], name="stock_close")) | |
fig.add_trace(go.Scatter(x=df["Date"], y=ma100, name="ma100")) | |
fig.add_trace(go.Scatter(x=df["Date"], y=ma200, name="ma200")) | |
fig.layout.update(xaxis_rangeslider_visible=True) | |
st.plotly_chart(fig, use_container_width=True) | |
# If button is pressed, start forecasting | |
if init: | |
with st.spinner("Please wait..."): | |
model, forecast = predict(df, period) | |
st.markdown("---") | |
st.subheader("Forecast data") | |
st.write(forecast.tail()) | |
st.subheader(f"Forecast plot for {n_years} years") | |
fig = plot_plotly(model, forecast) | |
fig.update_layout( | |
margin=dict( | |
l=0, | |
r=0, | |
b=0, | |
t=0, | |
pad=4 | |
) | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
st.subheader("Forecast components") | |
fig = model.plot_components(forecast) | |
st.write(fig) | |
if __name__ == "__main__": | |
with st.spinner("Loading all models..."): | |
# Creating summariser and sentiment models | |
sum_model, sum_tokenizer = get_summarisation_model() | |
sentiment_pipeline = get_sentiment_pepeline() | |
page_names_to_funcs = { | |
"Financial News Analysis": main_page, | |
"Stock Trend Forecasting": page2 | |
} | |
st.sidebar.markdown("# Financial Researcher") | |
selected_page = st.sidebar.selectbox( | |
"Select a page", page_names_to_funcs.keys()) | |
st.sidebar.markdown("---") | |
page_names_to_funcs[selected_page]() | |