poc / price_forecasting_ml /train_predict.py
ryanrahmadifa
Added files
79e1719
raw
history blame
1.21 kB
from modules.transform import transformData, prepareData
from train import trainModel
from pathlib import Path
import pickle
import uuid
import os
def main():
import logging
directory = Path(__file__).parent.absolute()
logging.basicConfig(level=logging.INFO)
run_id = str(uuid.uuid4())
artifacts_path = os.path.join(directory, 'artifacts', run_id)
logging.info(f'Created forecasting pipeline with id {run_id}')
os.mkdir(artifacts_path)
prepared_data = prepareData(dir=directory, id=run_id)
train_data, transformations = transformData(prepared_data, dir=directory, id=run_id)
train_data.to_csv(os.path.join(artifacts_path, 'transformed_dataset.csv'))
# Save transformations including StandardScaler objects
with open(os.path.join(artifacts_path, 'transformations.pkl'), 'wb') as fp:
pickle.dump(transformations, fp)
nf, results = trainModel(dataset=train_data, artifacts_path=artifacts_path)
results.to_csv(os.path.join(artifacts_path, 'training_results.csv'))
nf.save(path=os.path.join(artifacts_path, 'model'),
model_index=None,
overwrite=True,
save_dataset=True)
if __name__ == "__main__":
main()