Spaces:
Runtime error
Runtime error
Implement model
Browse files- .env +2 -0
- .gitattributes +1 -0
- .gitignore +3 -1
- app.py +45 -4
- requirements.txt +79 -0
- train_data.csv +0 -0
- utils.py +97 -0
.env
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
SCALER_MAX_VALUE=10.589996337890625
|
2 |
+
SCALER_MAX_VOLUME=1460852400.0
|
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.data* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
env
|
|
|
|
|
|
1 |
+
env
|
2 |
+
.DS_Store
|
3 |
+
test.py
|
app.py
CHANGED
@@ -1,9 +1,50 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
iface.launch()
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
load_dotenv('.env')
|
3 |
+
|
4 |
import gradio as gr
|
5 |
+
from utils import *
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
days_to_plot = 50
|
10 |
+
|
11 |
+
data = get_data().iloc[-500:]
|
12 |
+
|
13 |
+
|
14 |
+
data_to_plot = data.iloc[-days_to_plot:][["Close"]]
|
15 |
+
data_to_plot['date'] = data_to_plot.index
|
16 |
|
17 |
+
with gr.Blocks() as demo:
|
18 |
+
gr.Markdown("# Apple Predictor")
|
19 |
+
predict_button = gr.Button("Predict")
|
20 |
+
with gr.Row() as row0:
|
21 |
+
with gr.Column() as col0:
|
22 |
+
gr.Markdown("## Last candle info")
|
23 |
+
last_open = gr.Textbox(get_last_candle_value(data, 'Open') ,label="Last Open")
|
24 |
+
last_max = gr.Textbox( get_last_candle_value(data, 'High') ,label="Last Max")
|
25 |
+
last_min = gr.Textbox( get_last_candle_value(data, 'Low') ,label="Last Min")
|
26 |
+
last_close = gr.Textbox( get_last_candle_value(data, 'Close') ,label="Last Close")
|
27 |
|
28 |
+
with gr.Column() as col1:
|
29 |
+
gr.Markdown("## Next Candle Prediction")
|
30 |
+
jump_text = gr.Textbox(label="Jump")
|
31 |
+
open_text = gr.Textbox(label="Open")
|
32 |
+
max_text = gr.Textbox(label="Max")
|
33 |
+
min_text = gr.Textbox(label="Min")
|
34 |
+
next_close_text = gr.Textbox(label="Close")
|
35 |
+
with gr.Row() as row1:
|
36 |
+
value_plot = gr.LinePlot(data_to_plot,
|
37 |
+
x="date",
|
38 |
+
y="Close",
|
39 |
+
label=f'Last {days_to_plot} days',
|
40 |
+
y_lim=[float(data_to_plot['Close'].min())-5, float(data_to_plot['Close'].max())+5])
|
41 |
|
42 |
+
outputs = [jump_text,
|
43 |
+
open_text,
|
44 |
+
max_text,
|
45 |
+
min_text,
|
46 |
+
next_close_text
|
47 |
+
]
|
48 |
+
predict_button.click(lambda: predict(data), outputs=outputs)
|
49 |
|
50 |
+
demo.launch(debug=True)
|
|
requirements.txt
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.1.0
|
2 |
+
aiohttp==3.8.4
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==4.2.2
|
5 |
+
anyio==3.6.2
|
6 |
+
appdirs==1.4.4
|
7 |
+
async-timeout==4.0.2
|
8 |
+
attrs==22.2.0
|
9 |
+
beautifulsoup4==4.12.2
|
10 |
+
certifi==2022.12.7
|
11 |
+
cffi==1.15.1
|
12 |
+
charset-normalizer==3.1.0
|
13 |
+
click==8.1.3
|
14 |
+
contourpy==1.0.7
|
15 |
+
cryptography==40.0.1
|
16 |
+
cycler==0.11.0
|
17 |
+
entrypoints==0.4
|
18 |
+
fastapi==0.95.0
|
19 |
+
ffmpy==0.3.0
|
20 |
+
filelock==3.11.0
|
21 |
+
fonttools==4.39.3
|
22 |
+
frozendict==2.3.7
|
23 |
+
frozenlist==1.3.3
|
24 |
+
fsspec==2023.4.0
|
25 |
+
gradio==3.24.1
|
26 |
+
gradio_client==0.0.8
|
27 |
+
h11==0.14.0
|
28 |
+
html5lib==1.1
|
29 |
+
httpcore==0.16.3
|
30 |
+
httpx==0.23.3
|
31 |
+
huggingface-hub==0.13.4
|
32 |
+
idna==3.4
|
33 |
+
importlib-resources==5.12.0
|
34 |
+
Jinja2==3.1.2
|
35 |
+
jsonschema==4.17.3
|
36 |
+
kiwisolver==1.4.4
|
37 |
+
linkify-it-py==2.0.0
|
38 |
+
lxml==4.9.2
|
39 |
+
markdown-it-py==2.2.0
|
40 |
+
MarkupSafe==2.1.2
|
41 |
+
matplotlib==3.7.1
|
42 |
+
mdit-py-plugins==0.3.3
|
43 |
+
mdurl==0.1.2
|
44 |
+
multidict==6.0.4
|
45 |
+
multitasking==0.0.11
|
46 |
+
numpy==1.24.2
|
47 |
+
orjson==3.8.10
|
48 |
+
packaging==23.0
|
49 |
+
pandas==2.0.0
|
50 |
+
pandas-datareader==0.10.0
|
51 |
+
Pillow==9.5.0
|
52 |
+
pycparser==2.21
|
53 |
+
pydantic==1.10.7
|
54 |
+
pydub==0.25.1
|
55 |
+
pyparsing==3.0.9
|
56 |
+
pyrsistent==0.19.3
|
57 |
+
python-dateutil==2.8.2
|
58 |
+
python-multipart==0.0.6
|
59 |
+
pytz==2023.3
|
60 |
+
PyYAML==6.0
|
61 |
+
requests==2.28.2
|
62 |
+
rfc3986==1.5.0
|
63 |
+
semantic-version==2.10.0
|
64 |
+
six==1.16.0
|
65 |
+
sniffio==1.3.0
|
66 |
+
soupsieve==2.4
|
67 |
+
starlette==0.26.1
|
68 |
+
toolz==0.12.0
|
69 |
+
tqdm==4.65.0
|
70 |
+
typing_extensions==4.5.0
|
71 |
+
tzdata==2023.3
|
72 |
+
uc-micro-py==1.0.1
|
73 |
+
urllib3==1.26.15
|
74 |
+
uvicorn==0.21.1
|
75 |
+
webencodings==0.5.1
|
76 |
+
websockets==11.0.1
|
77 |
+
yarl==1.8.2
|
78 |
+
yfinance==0.2.16
|
79 |
+
zipp==3.15.0
|
train_data.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
# import matplotlib.pyplot as plt
|
3 |
+
import pandas as pd
|
4 |
+
import pandas_datareader as web
|
5 |
+
import datetime as dt
|
6 |
+
import yfinance as yfin
|
7 |
+
# import tensorflow as tf
|
8 |
+
import os
|
9 |
+
# import re
|
10 |
+
|
11 |
+
from huggingface_hub import from_pretrained_keras
|
12 |
+
# from sklearn.preprocessing import MinMaxScaler
|
13 |
+
# from tensorflow.keras.models import Sequential
|
14 |
+
# from tensorflow.keras.layers import Dense, Dropout, LSTM
|
15 |
+
|
16 |
+
|
17 |
+
def get_data(ticker='AAPL', start=None, end=None):
|
18 |
+
if end is None:
|
19 |
+
end = dt.date.today()
|
20 |
+
if start is None:
|
21 |
+
start = end - dt.timedelta(days=800)
|
22 |
+
|
23 |
+
yfin.pdr_override()
|
24 |
+
data = web.data.get_data_yahoo(ticker, start, end)
|
25 |
+
# data = pd.read_csv('train_data.csv', index_col='Date')
|
26 |
+
return data
|
27 |
+
|
28 |
+
|
29 |
+
def get_last_candle_value(data, column):
|
30 |
+
val = data.iloc[-1][column]
|
31 |
+
return "{:.2f}".format(val)
|
32 |
+
|
33 |
+
|
34 |
+
# Preprocessing functions copied from notebook where model was trained
|
35 |
+
def create_remove_columns(data):
|
36 |
+
# create jump column
|
37 |
+
data = pd.DataFrame.copy(data)
|
38 |
+
data['Jump'] = data['Open'] - data['Close'].shift(1)
|
39 |
+
data['Jump'].fillna(0, inplace=True)
|
40 |
+
data.insert(0,'Jump', data.pop('Jump'))
|
41 |
+
return data
|
42 |
+
|
43 |
+
def normalize_data(data):
|
44 |
+
# Returns a tuple with the normalized data, the scaler and the decoder
|
45 |
+
# The normalized data is a dataframe with the following columns:
|
46 |
+
# ['Jump', 'High', 'Low', 'Close', 'Adj Close', 'Volume']
|
47 |
+
the_data = pd.DataFrame.copy(data)
|
48 |
+
# substract the open value to all columns but the first one and the last one which are "Jump" and "Volume"
|
49 |
+
the_data.iloc[:, 1:-1] = the_data.iloc[:,1:-1] - the_data['Open'].values[:, np.newaxis]
|
50 |
+
# print('the_data')
|
51 |
+
# print(the_data)
|
52 |
+
|
53 |
+
the_data.pop('Open')
|
54 |
+
# Create the scaler
|
55 |
+
max_value = float(os.getenv('SCALER_MAX_VALUE'))
|
56 |
+
max_volume = float(os.getenv('SCALER_MAX_VOLUME'))
|
57 |
+
def scaler(d):
|
58 |
+
data = pd.DataFrame.copy(d)
|
59 |
+
print('max_value: ', max_value)
|
60 |
+
print('max_volume: ', max_volume)
|
61 |
+
data.iloc[:, :-1] = data.iloc[:,:-1].apply(lambda x: x/max_value)
|
62 |
+
data.iloc[:, -1] = data.iloc[:,-1].apply(lambda x: x/max_volume)
|
63 |
+
return data
|
64 |
+
def decoder(values):
|
65 |
+
decoded_values = values * max_value
|
66 |
+
return decoded_values
|
67 |
+
|
68 |
+
normalized_data = scaler(the_data)
|
69 |
+
|
70 |
+
return normalized_data, scaler, decoder
|
71 |
+
|
72 |
+
def preprocessing(data):
|
73 |
+
# print(data.head(3))
|
74 |
+
data_0 = create_remove_columns(data)
|
75 |
+
# print(data_0.head(3))
|
76 |
+
#todo: save the_scaler somehow to use in new runtimes
|
77 |
+
norm_data, scaler, decoder = normalize_data(data_0)
|
78 |
+
# print(norm_data.head(3))
|
79 |
+
# print(x_train.shape, y_train.shape)
|
80 |
+
norm_data_array = np.array(norm_data)
|
81 |
+
return np.expand_dims(norm_data_array, axis=0), decoder
|
82 |
+
|
83 |
+
|
84 |
+
# Model prediction
|
85 |
+
# model = from_pretrained_keras("jsebdev/apple_stock_predictor")
|
86 |
+
def predict(data):
|
87 |
+
input, decoder = preprocessing(data)
|
88 |
+
print("input")
|
89 |
+
print(input.shape)
|
90 |
+
result = decoder(model.predict(input))
|
91 |
+
last_close = data.iloc[-1]['Close']
|
92 |
+
next_candle = result[0, -1]
|
93 |
+
print('next_candle')
|
94 |
+
print(next_candle)
|
95 |
+
jump = next_candle[0]
|
96 |
+
next_candle = next_candle + last_close
|
97 |
+
return (jump, next_candle[0], next_candle[1], next_candle[2], next_candle[3])
|