Spaces:
Sleeping
Sleeping
import random | |
import torch | |
import numpy as np | |
import pandas as pd | |
from stqdm import stqdm | |
from torch import nn | |
from torch.nn import functional as F | |
from transformers import AutoTokenizer, AutoModel | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
class DSSM(nn.Module): | |
def __init__(self, base_model_name, base_model=AutoModel): | |
super().__init__() | |
self.diary_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False) | |
self.quote_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False) | |
def forward(self, diary, quote): | |
return self.diary_emb(**diary), self.quote_emb(**quote) | |
def get_models_and_tokenizer(base_model_name, base_model=AutoModel, ckpt=None): | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
model = DSSM(base_model_name, base_model=base_model) | |
if ckpt: | |
print("use ckpt") | |
model.load_state_dict(torch.load(ckpt, map_location=device)) | |
model.to(device) | |
return model.diary_emb, model.quote_emb, tokenizer | |
def model_inference(model, tokenizer, text): | |
tokenized_text = tokenizer(text, return_tensors="pt", truncation=True) | |
tokenized_text = tokenized_text.to(device) | |
output = model(**tokenized_text) | |
return output[0][:, 0, :] | |
class Recommender: | |
SIMILARITY_THRESHOLD = 0.8 | |
def __init__(self, quotes_df, base_model_name, base_model=AutoModel, ckpt=None): | |
(self.diary_embedder, | |
self.quote_embedder, | |
self.tokenizer) = get_models_and_tokenizer(base_model_name, base_model, ckpt) | |
self.quotes = quotes_df['Quote'].to_list() | |
self.authors = quotes_df['Author'].to_list() | |
self.quote_embeddings = torch.tensor(np.array( | |
[model_inference(self.quote_embedder, self.tokenizer, q).cpu().detach().numpy() for q in stqdm(self.quotes[:50])] | |
)).squeeze(1) | |
def recommend(self, d): | |
d_emb = model_inference(self.diary_embedder, self.tokenizer, d).squeeze().cpu() | |
similarities = F.cosine_similarity(d_emb, self.quote_embeddings, dim=0) | |
above_threshold_indices = (similarities > self.SIMILARITY_THRESHOLD).nonzero().flatten().tolist() | |
if above_threshold_indices: | |
index = random.choice(above_threshold_indices) | |
else: | |
index = torch.argmax(similarities).item() | |
return self.quotes[index], self.authors[index] | |
def get_quote_embeddings(model, tokenizer): | |
quotes = pd.read_csv('quotes-recsys/data/quotes.csv')['Quote'].to_list() | |
return torch.tensor([model_inference(model, tokenizer, q) for q in quotes]).squeeze(1) | |