jsebdev commited on
Commit
6475c61
1 Parent(s): 81e5f30

Implement model

Browse files
Files changed (7) hide show
  1. .env +2 -0
  2. .gitattributes +1 -0
  3. .gitignore +3 -1
  4. app.py +45 -4
  5. requirements.txt +79 -0
  6. train_data.csv +0 -0
  7. utils.py +97 -0
.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ SCALER_MAX_VALUE=10.589996337890625
2
+ SCALER_MAX_VOLUME=1460852400.0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.data* filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1 +1,3 @@
1
- env
 
 
 
1
+ env
2
+ .DS_Store
3
+ test.py
app.py CHANGED
@@ -1,9 +1,50 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def greet(name):
5
- return "Hello negrito " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
 
 
 
 
7
 
8
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- iface.launch()
 
1
+ from dotenv import load_dotenv
2
+ load_dotenv('.env')
3
+
4
  import gradio as gr
5
+ from utils import *
6
+ import os
7
+
8
+
9
+ days_to_plot = 50
10
+
11
+ data = get_data().iloc[-500:]
12
+
13
+
14
+ data_to_plot = data.iloc[-days_to_plot:][["Close"]]
15
+ data_to_plot['date'] = data_to_plot.index
16
 
17
+ with gr.Blocks() as demo:
18
+ gr.Markdown("# Apple Predictor")
19
+ predict_button = gr.Button("Predict")
20
+ with gr.Row() as row0:
21
+ with gr.Column() as col0:
22
+ gr.Markdown("## Last candle info")
23
+ last_open = gr.Textbox(get_last_candle_value(data, 'Open') ,label="Last Open")
24
+ last_max = gr.Textbox( get_last_candle_value(data, 'High') ,label="Last Max")
25
+ last_min = gr.Textbox( get_last_candle_value(data, 'Low') ,label="Last Min")
26
+ last_close = gr.Textbox( get_last_candle_value(data, 'Close') ,label="Last Close")
27
 
28
+ with gr.Column() as col1:
29
+ gr.Markdown("## Next Candle Prediction")
30
+ jump_text = gr.Textbox(label="Jump")
31
+ open_text = gr.Textbox(label="Open")
32
+ max_text = gr.Textbox(label="Max")
33
+ min_text = gr.Textbox(label="Min")
34
+ next_close_text = gr.Textbox(label="Close")
35
+ with gr.Row() as row1:
36
+ value_plot = gr.LinePlot(data_to_plot,
37
+ x="date",
38
+ y="Close",
39
+ label=f'Last {days_to_plot} days',
40
+ y_lim=[float(data_to_plot['Close'].min())-5, float(data_to_plot['Close'].max())+5])
41
 
42
+ outputs = [jump_text,
43
+ open_text,
44
+ max_text,
45
+ min_text,
46
+ next_close_text
47
+ ]
48
+ predict_button.click(lambda: predict(data), outputs=outputs)
49
 
50
+ demo.launch(debug=True)
 
requirements.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ anyio==3.6.2
6
+ appdirs==1.4.4
7
+ async-timeout==4.0.2
8
+ attrs==22.2.0
9
+ beautifulsoup4==4.12.2
10
+ certifi==2022.12.7
11
+ cffi==1.15.1
12
+ charset-normalizer==3.1.0
13
+ click==8.1.3
14
+ contourpy==1.0.7
15
+ cryptography==40.0.1
16
+ cycler==0.11.0
17
+ entrypoints==0.4
18
+ fastapi==0.95.0
19
+ ffmpy==0.3.0
20
+ filelock==3.11.0
21
+ fonttools==4.39.3
22
+ frozendict==2.3.7
23
+ frozenlist==1.3.3
24
+ fsspec==2023.4.0
25
+ gradio==3.24.1
26
+ gradio_client==0.0.8
27
+ h11==0.14.0
28
+ html5lib==1.1
29
+ httpcore==0.16.3
30
+ httpx==0.23.3
31
+ huggingface-hub==0.13.4
32
+ idna==3.4
33
+ importlib-resources==5.12.0
34
+ Jinja2==3.1.2
35
+ jsonschema==4.17.3
36
+ kiwisolver==1.4.4
37
+ linkify-it-py==2.0.0
38
+ lxml==4.9.2
39
+ markdown-it-py==2.2.0
40
+ MarkupSafe==2.1.2
41
+ matplotlib==3.7.1
42
+ mdit-py-plugins==0.3.3
43
+ mdurl==0.1.2
44
+ multidict==6.0.4
45
+ multitasking==0.0.11
46
+ numpy==1.24.2
47
+ orjson==3.8.10
48
+ packaging==23.0
49
+ pandas==2.0.0
50
+ pandas-datareader==0.10.0
51
+ Pillow==9.5.0
52
+ pycparser==2.21
53
+ pydantic==1.10.7
54
+ pydub==0.25.1
55
+ pyparsing==3.0.9
56
+ pyrsistent==0.19.3
57
+ python-dateutil==2.8.2
58
+ python-multipart==0.0.6
59
+ pytz==2023.3
60
+ PyYAML==6.0
61
+ requests==2.28.2
62
+ rfc3986==1.5.0
63
+ semantic-version==2.10.0
64
+ six==1.16.0
65
+ sniffio==1.3.0
66
+ soupsieve==2.4
67
+ starlette==0.26.1
68
+ toolz==0.12.0
69
+ tqdm==4.65.0
70
+ typing_extensions==4.5.0
71
+ tzdata==2023.3
72
+ uc-micro-py==1.0.1
73
+ urllib3==1.26.15
74
+ uvicorn==0.21.1
75
+ webencodings==0.5.1
76
+ websockets==11.0.1
77
+ yarl==1.8.2
78
+ yfinance==0.2.16
79
+ zipp==3.15.0
train_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ # import matplotlib.pyplot as plt
3
+ import pandas as pd
4
+ import pandas_datareader as web
5
+ import datetime as dt
6
+ import yfinance as yfin
7
+ # import tensorflow as tf
8
+ import os
9
+ # import re
10
+
11
+ from huggingface_hub import from_pretrained_keras
12
+ # from sklearn.preprocessing import MinMaxScaler
13
+ # from tensorflow.keras.models import Sequential
14
+ # from tensorflow.keras.layers import Dense, Dropout, LSTM
15
+
16
+
17
+ def get_data(ticker='AAPL', start=None, end=None):
18
+ if end is None:
19
+ end = dt.date.today()
20
+ if start is None:
21
+ start = end - dt.timedelta(days=800)
22
+
23
+ yfin.pdr_override()
24
+ data = web.data.get_data_yahoo(ticker, start, end)
25
+ # data = pd.read_csv('train_data.csv', index_col='Date')
26
+ return data
27
+
28
+
29
+ def get_last_candle_value(data, column):
30
+ val = data.iloc[-1][column]
31
+ return "{:.2f}".format(val)
32
+
33
+
34
+ # Preprocessing functions copied from notebook where model was trained
35
+ def create_remove_columns(data):
36
+ # create jump column
37
+ data = pd.DataFrame.copy(data)
38
+ data['Jump'] = data['Open'] - data['Close'].shift(1)
39
+ data['Jump'].fillna(0, inplace=True)
40
+ data.insert(0,'Jump', data.pop('Jump'))
41
+ return data
42
+
43
+ def normalize_data(data):
44
+ # Returns a tuple with the normalized data, the scaler and the decoder
45
+ # The normalized data is a dataframe with the following columns:
46
+ # ['Jump', 'High', 'Low', 'Close', 'Adj Close', 'Volume']
47
+ the_data = pd.DataFrame.copy(data)
48
+ # substract the open value to all columns but the first one and the last one which are "Jump" and "Volume"
49
+ the_data.iloc[:, 1:-1] = the_data.iloc[:,1:-1] - the_data['Open'].values[:, np.newaxis]
50
+ # print('the_data')
51
+ # print(the_data)
52
+
53
+ the_data.pop('Open')
54
+ # Create the scaler
55
+ max_value = float(os.getenv('SCALER_MAX_VALUE'))
56
+ max_volume = float(os.getenv('SCALER_MAX_VOLUME'))
57
+ def scaler(d):
58
+ data = pd.DataFrame.copy(d)
59
+ print('max_value: ', max_value)
60
+ print('max_volume: ', max_volume)
61
+ data.iloc[:, :-1] = data.iloc[:,:-1].apply(lambda x: x/max_value)
62
+ data.iloc[:, -1] = data.iloc[:,-1].apply(lambda x: x/max_volume)
63
+ return data
64
+ def decoder(values):
65
+ decoded_values = values * max_value
66
+ return decoded_values
67
+
68
+ normalized_data = scaler(the_data)
69
+
70
+ return normalized_data, scaler, decoder
71
+
72
+ def preprocessing(data):
73
+ # print(data.head(3))
74
+ data_0 = create_remove_columns(data)
75
+ # print(data_0.head(3))
76
+ #todo: save the_scaler somehow to use in new runtimes
77
+ norm_data, scaler, decoder = normalize_data(data_0)
78
+ # print(norm_data.head(3))
79
+ # print(x_train.shape, y_train.shape)
80
+ norm_data_array = np.array(norm_data)
81
+ return np.expand_dims(norm_data_array, axis=0), decoder
82
+
83
+
84
+ # Model prediction
85
+ # model = from_pretrained_keras("jsebdev/apple_stock_predictor")
86
+ def predict(data):
87
+ input, decoder = preprocessing(data)
88
+ print("input")
89
+ print(input.shape)
90
+ result = decoder(model.predict(input))
91
+ last_close = data.iloc[-1]['Close']
92
+ next_candle = result[0, -1]
93
+ print('next_candle')
94
+ print(next_candle)
95
+ jump = next_candle[0]
96
+ next_candle = next_candle + last_close
97
+ return (jump, next_candle[0], next_candle[1], next_candle[2], next_candle[3])