import random import torch import numpy as np import pandas as pd from stqdm import stqdm from torch import nn from torch.nn import functional as F from transformers import AutoTokenizer, AutoModel device = 'cuda' if torch.cuda.is_available() else 'cpu' class DSSM(nn.Module): def __init__(self, base_model_name, base_model=AutoModel): super().__init__() self.diary_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False) self.quote_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False) def forward(self, diary, quote): return self.diary_emb(**diary), self.quote_emb(**quote) def get_models_and_tokenizer(base_model_name, base_model=AutoModel, ckpt=None): tokenizer = AutoTokenizer.from_pretrained(base_model_name) model = DSSM(base_model_name, base_model=base_model) if ckpt: print("use ckpt") model.load_state_dict(torch.load(ckpt, map_location=device)) model.to(device) return model.diary_emb, model.quote_emb, tokenizer def model_inference(model, tokenizer, text): tokenized_text = tokenizer(text, return_tensors="pt", truncation=True) tokenized_text = tokenized_text.to(device) output = model(**tokenized_text) return output[0][:, 0, :] class Recommender: SIMILARITY_THRESHOLD = 0.8 def __init__(self, quotes_df, base_model_name, base_model=AutoModel, ckpt=None): (self.diary_embedder, self.quote_embedder, self.tokenizer) = get_models_and_tokenizer(base_model_name, base_model, ckpt) self.quotes = quotes_df['Quote'].to_list() self.authors = quotes_df['Author'].to_list() self.quote_embeddings = torch.tensor(np.array( [model_inference(self.quote_embedder, self.tokenizer, q).cpu().detach().numpy() for q in stqdm(self.quotes[:50])] )).squeeze(1) def recommend(self, d): d_emb = model_inference(self.diary_embedder, self.tokenizer, d).squeeze().cpu() similarities = F.cosine_similarity(d_emb, self.quote_embeddings, dim=0) above_threshold_indices = (similarities > self.SIMILARITY_THRESHOLD).nonzero().flatten().tolist() if above_threshold_indices: index = random.choice(above_threshold_indices) else: index = torch.argmax(similarities).item() return self.quotes[index], self.authors[index] def get_quote_embeddings(model, tokenizer): quotes = pd.read_csv('quotes-recsys/data/quotes.csv')['Quote'].to_list() return torch.tensor([model_inference(model, tokenizer, q) for q in quotes]).squeeze(1)