clfegg's picture
Update handler.py
790c4c6 verified
raw
history blame
1.44 kB
from typing import Dict, List, Any
import pickle
import numpy as np
import pandas as pd
import os
import dill
class ContentBasedRecommender:
def __init__(self, train_data):
self.train_data = train_data
def predict(self, user_id, k=10):
user_books = set(self.train_data[self.train_data['user_id'] == user_id]['book_id'])
similar_books = set().union(*(self.train_data[self.train_data['book_id'] == book_id]['similar_books'].iloc[0] for book_id in user_books))
recommended_books = list(similar_books - user_books)
return np.random.choice(recommended_books, size=k, replace=False) if len(recommended_books) >= k else recommended_books
class EndpointHandler:
def __init__(self, path=""):
model_path = os.path.join(path, "model.pkl")
with open(model_path, 'rb') as f:
self.model = dill.load(f)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
user_id = data.pop("user_id", None)
k = data.pop("k", 10) # Default to 10 if not provided
if user_id is None:
return [{"error": "user_id is required"}]
try:
recommended_books = self.model.predict(user_id, k=k)
return [{"recommended_books": recommended_books.tolist()}]
except Exception as e:
return [{"error": str(e)}]
def load_model(model_path):
handler = EndpointHandler(model_path)
return handler