mouliraj56's picture
Update app.py
df66752
raw
history blame
4.83 kB
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.seasonal import seasonal_decompose
from sklearn.preprocessing import MinMaxScaler
from keras.preprocessing.sequence import TimeseriesGenerator
from keras.models import Sequential
from keras.layers import Dense, LSTM
import yfinance as yf
# Function to execute the entire code, handling potential errors
def forecast_stock(ticker, period, interval):
try:
# Download data
df = yf.download(ticker, period=period, interval=interval)
df = df.asfreq('D').fillna(method='ffill')
# Filter columns
ts = df[['Close']]
# Perform seasonal decomposition
decomposition = seasonal_decompose(ts, model='additive')
# Train/test split
ts_train = ts.iloc[:int(ts.size * .8)]
ts_test = ts.iloc[int(ts.size * .8):]
# Normalize the data
scaler = MinMaxScaler()
scaler.fit(ts_train.values)
scaled_ts_train_values = scaler.transform(ts_train.values)
scaled_ts_test_values = scaler.transform(ts_test.values)
# Create LSTM model
n_input = 24
n_features = 1
generator = TimeseriesGenerator(scaled_ts_train_values, scaled_ts_train_values, length=n_input, batch_size=1)
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(n_input, n_features)))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')
model.fit(generator, epochs=50, verbose=0)
# Make predictions
test_predictions = []
first_eval_batch = scaled_ts_train_values[-n_input:]
current_batch = first_eval_batch.reshape((1, n_input, n_features))
for i in range(len(ts_test)):
current_pred = model.predict(current_batch, verbose=0)[0]
test_predictions.append(current_pred)
current_batch = np.append(current_batch[:, 1:, :], [[current_pred]], axis=1)
true_predictions = scaler.inverse_transform(test_predictions)
# # Plot results
# plt.figure(figsize=(10, 6))
# plt.plot(ts.index, ts.values, label='Original Data')
# plt.plot(ts_test.index, true_predictions, label='Forecasted Data')
# plt.legend()
# plt.xlabel('Time')
# plt.ylabel('Value')
# plt.title('Stock Price Forecast')
# plt.switch_backend("TkAgg")
# return plt.show() # Indicate successful execution
fig = plt.figure(figsize=(10, 6))
plt.plot(ts.index, ts.values, label='Original Data')
plt.plot(ts_test.index, true_predictions, label='Forecasted Data')
plt.legend()
plt.xlabel('Time')
plt.ylabel('Value')
plt.title('Stock Price Forecast')
return fig
except Exception as e:
er=f"An error occurred: {e}"
return er # Indicate error
tickers_info="""
**Ticker Examples**
---
**Common Stocks**
- **AAPL:** Apple Inc.
- **MSFT:** Microsoft Corp.
- **GOOG:** Alphabet Inc. (Google)
- **AMZN:** Amazon.com Inc.
- **TSLA:** Tesla Inc.
- **FB:** Meta Platforms Inc.
**Indices**
- **^GSPC:** S&P 500 Index
- **^IXIC:** Nasdaq Composite Index
- **^DJI:** Dow Jones Industrial Average
**ETFs ️**
- **SPY:** SPDR S&P 500 ETF Trust
- **QQQ:** Invesco QQQ Trust
- **IWM:** iShares Russell 2000 ETF
"""
examples = """This table demonstrates examples of stock forecasts you can generate using the application:
| Ticker | Period | Interval |
|---|---|---|
| AAPL | 2mo | 1d |
| GOOG | 1y | 1d |
| MSFT | 5y | 1wk |
| TSLA | max | 1h |
| AMZN | 1y | 1h |
| NVDA | 3mo | 1d |
| FB | 1y | 1wk |
| JNJ | 2y | 1d |
| BAC | 6mo | 1d |
| XOM | 1y | 1wk |
To generate a forecast for a specific stock, simply enter the ticker symbol, desired period, and interval into the interface.
"""
with gr.Blocks() as demo:
gr.Interface(
forecast_stock, # Function to execute on submit
[
gr.Textbox(label="Ticker", placeholder="e.g., AAPL, MSFT, GOOG"),
gr.Textbox(label="Period", placeholder="e.g., 1mo, 5y, max"),
gr.Textbox(label="Interval", placeholder="e.g., 1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo"),
],
"plot", # Output type
live=False, # Disable live updates
title="Stock Price Forecast",
description="Enter a stock ticker, desired data period, and interval to generate a forecast.",
)
with gr.Accordion("Example Stock Forecasts"):
gr.Markdown(examples)
with gr.Accordion("Open for More info"):
gr.Markdown(tickers_info)
demo.launch()