jrno's picture
refresh model
e2a6226
raw
history blame
1.09 kB
from fastai.learner import Learner
import pandas as pd
from tracks import get_unlistened_tracks_for_user, predictions_to_tracks
def get_recommendations_for_user(learn: Learner, user_id: str, limit: int = 5):
not_listened_tracks = get_unlistened_tracks_for_user(user_id)
print(len(not_listened_tracks))
# Get predictions for the tracks user hasn't listened yet
input_dataframe = pd.DataFrame({'user_id': [user_id] * len(not_listened_tracks), 'entry': not_listened_tracks})
test_dl = learn.dls.test_dl(input_dataframe)
predictions = learn.get_preds(dl=test_dl)
# Associate them with prediction score and sort
tracks_with_predictions = list(zip(not_listened_tracks, predictions[0].numpy()))
tracks_with_predictions.sort(key=lambda x: x[1], reverse=True)
print(tracks_with_predictions[:limit])
# Pick n and return as full tracks
recommendations = predictions_to_tracks(tracks_with_predictions[:limit])
return {
"user_id": user_id,
"limit": limit,
"recommendations": recommendations
}