batalovme commited on
Commit
50bcd75
1 Parent(s): d857dec
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +34 -0
  3. data/quotes.csv +0 -0
  4. requirements.txt +3 -0
  5. utils.py +68 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .DS_Store
2
+ __pycache__/
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils import Recommender
3
+ import pandas as pd
4
+ from transformers import RobertaModel
5
+
6
+ st.title("Quotes RecSys")
7
+
8
+ @st.cache_resource
9
+ def get_recommender():
10
+ return Recommender(pd.read_csv('data/quotes.csv'),
11
+ "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest",
12
+ base_model=RobertaModel, ckpt="models/twitter.pt")
13
+
14
+
15
+ recommender = get_recommender()
16
+
17
+ if "messages" not in st.session_state:
18
+ st.session_state.messages = []
19
+
20
+ for message in st.session_state.messages:
21
+ with st.chat_message(message["role"]):
22
+ st.markdown(message["content"])
23
+
24
+ if prompt := st.chat_input("What is up?"):
25
+ st.session_state.messages.append({"role": "user", "content": prompt})
26
+ with st.chat_message("user"):
27
+ st.markdown(prompt)
28
+
29
+ with st.chat_message("assistant"):
30
+ message_placeholder = st.empty()
31
+ quote, author = recommender.recommend(prompt)
32
+ full_response = f"> {quote}\n\n _{author}_"
33
+ message_placeholder.markdown(full_response)
34
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
data/quotes.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch
3
+ transformers
utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from stqdm import stqdm
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from transformers import AutoTokenizer, AutoModel
9
+
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+
12
+
13
+ class DSSM(nn.Module):
14
+ def __init__(self, base_model_name, base_model=AutoModel):
15
+ super().__init__()
16
+ self.diary_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False)
17
+ self.quote_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False)
18
+
19
+ def forward(self, diary, quote):
20
+ return self.diary_emb(**diary), self.quote_emb(**quote)
21
+
22
+
23
+ def get_models_and_tokenizer(base_model_name, base_model=AutoModel, ckpt=None):
24
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
25
+ model = DSSM(base_model_name, base_model=base_model)
26
+ if ckpt:
27
+ print("use ckpt")
28
+ model.load_state_dict(torch.load(ckpt, map_location=device))
29
+ model.to(device)
30
+ return model.diary_emb, model.quote_emb, tokenizer
31
+
32
+
33
+ def model_inference(model, tokenizer, text):
34
+ tokenized_text = tokenizer(text, return_tensors="pt", truncation=True)
35
+ tokenized_text = tokenized_text.to(device)
36
+ output = model(**tokenized_text)
37
+ return output[0][:, 0, :]
38
+
39
+
40
+ class Recommender:
41
+ SIMILARITY_THRESHOLD = 0.8
42
+
43
+ def __init__(self, quotes_df, base_model_name, base_model=AutoModel, ckpt=None):
44
+ (self.diary_embedder,
45
+ self.quote_embedder,
46
+ self.tokenizer) = get_models_and_tokenizer(base_model_name, base_model, ckpt)
47
+
48
+ self.quotes = quotes_df['Quote'].to_list()
49
+ self.authors = quotes_df['Author'].to_list()
50
+
51
+ self.quote_embeddings = torch.tensor(np.array(
52
+ [model_inference(self.quote_embedder, self.tokenizer, q).cpu().detach().numpy() for q in stqdm(self.quotes[:50])]
53
+ )).squeeze(1)
54
+
55
+ def recommend(self, d):
56
+ d_emb = model_inference(self.diary_embedder, self.tokenizer, d).squeeze().cpu()
57
+ similarities = F.cosine_similarity(d_emb, self.quote_embeddings, dim=0)
58
+ above_threshold_indices = (similarities > self.SIMILARITY_THRESHOLD).nonzero().flatten().tolist()
59
+ if above_threshold_indices:
60
+ index = random.choice(above_threshold_indices)
61
+ else:
62
+ index = torch.argmax(similarities).item()
63
+ return self.quotes[index], self.authors[index]
64
+
65
+
66
+ def get_quote_embeddings(model, tokenizer):
67
+ quotes = pd.read_csv('quotes-recsys/data/quotes.csv')['Quote'].to_list()
68
+ return torch.tensor([model_inference(model, tokenizer, q) for q in quotes]).squeeze(1)