singhjagpreet's picture
prediction pipeline
d7e1d8f
raw
history blame contribute delete
No virus
1.63 kB
import os
import sys
import numpy as np
import pandas as pd
import pickle
from sklearn.metrics import r2_score
from sklearn.model_selection import GridSearchCV
from src.logger import logging
from src.exception import CustomException
def save_object(file_path, obj):
try:
dir_path = os.path.dirname(file_path)
os.makedirs(dir_path,exist_ok=True)
with open(file_path, 'wb') as file_obj:
pickle.dump(obj,file_obj)
except Exception as e:
raise CustomException(e,sys)
def evaluate_models(X_train, y_train, X_test, y_test, models,params):
try:
report = {}
for i in range(len(list(models))):
model = list(models.values())[i]
param = params[list(models.keys())[i]]
logging.info('training started')
gs = GridSearchCV(model,param_grid=param,cv=5,verbose=False)
gs.fit(X_train,y_train)
model.set_params(**gs.best_params_)
model.fit(X_train,y_train)
y_train_pred = model.predict(X_train)
y_test_pred = model.predict(X_test)
train_model_score = r2_score(y_train, y_train_pred)
test_model_score = r2_score(y_test,y_test_pred)
report[list(models.keys())[i]] = test_model_score
return report
except Exception as e:
raise ConnectionAbortedError(e,sys)
def load_object(file_path):
try:
with open(file_path,'rb') as file_obj:
return pickle.load(file_obj)
except Exception as e:
raise CustomException(e, sys)