quotes-recsys / utils.py
batalovme's picture
Add demo
50bcd75
import random
import torch
import numpy as np
import pandas as pd
from stqdm import stqdm
from torch import nn
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class DSSM(nn.Module):
def __init__(self, base_model_name, base_model=AutoModel):
super().__init__()
self.diary_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False)
self.quote_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False)
def forward(self, diary, quote):
return self.diary_emb(**diary), self.quote_emb(**quote)
def get_models_and_tokenizer(base_model_name, base_model=AutoModel, ckpt=None):
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = DSSM(base_model_name, base_model=base_model)
if ckpt:
print("use ckpt")
model.load_state_dict(torch.load(ckpt, map_location=device))
model.to(device)
return model.diary_emb, model.quote_emb, tokenizer
def model_inference(model, tokenizer, text):
tokenized_text = tokenizer(text, return_tensors="pt", truncation=True)
tokenized_text = tokenized_text.to(device)
output = model(**tokenized_text)
return output[0][:, 0, :]
class Recommender:
SIMILARITY_THRESHOLD = 0.8
def __init__(self, quotes_df, base_model_name, base_model=AutoModel, ckpt=None):
(self.diary_embedder,
self.quote_embedder,
self.tokenizer) = get_models_and_tokenizer(base_model_name, base_model, ckpt)
self.quotes = quotes_df['Quote'].to_list()
self.authors = quotes_df['Author'].to_list()
self.quote_embeddings = torch.tensor(np.array(
[model_inference(self.quote_embedder, self.tokenizer, q).cpu().detach().numpy() for q in stqdm(self.quotes[:50])]
)).squeeze(1)
def recommend(self, d):
d_emb = model_inference(self.diary_embedder, self.tokenizer, d).squeeze().cpu()
similarities = F.cosine_similarity(d_emb, self.quote_embeddings, dim=0)
above_threshold_indices = (similarities > self.SIMILARITY_THRESHOLD).nonzero().flatten().tolist()
if above_threshold_indices:
index = random.choice(above_threshold_indices)
else:
index = torch.argmax(similarities).item()
return self.quotes[index], self.authors[index]
def get_quote_embeddings(model, tokenizer):
quotes = pd.read_csv('quotes-recsys/data/quotes.csv')['Quote'].to_list()
return torch.tensor([model_inference(model, tokenizer, q) for q in quotes]).squeeze(1)