Spaces:
Sleeping
Sleeping
Add demo
Browse files- .gitignore +2 -0
- app.py +34 -0
- data/quotes.csv +0 -0
- requirements.txt +3 -0
- utils.py +68 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
__pycache__/
|
app.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from utils import Recommender
|
3 |
+
import pandas as pd
|
4 |
+
from transformers import RobertaModel
|
5 |
+
|
6 |
+
st.title("Quotes RecSys")
|
7 |
+
|
8 |
+
@st.cache_resource
|
9 |
+
def get_recommender():
|
10 |
+
return Recommender(pd.read_csv('data/quotes.csv'),
|
11 |
+
"cardiffnlp/twitter-roberta-base-emotion-multilabel-latest",
|
12 |
+
base_model=RobertaModel, ckpt="models/twitter.pt")
|
13 |
+
|
14 |
+
|
15 |
+
recommender = get_recommender()
|
16 |
+
|
17 |
+
if "messages" not in st.session_state:
|
18 |
+
st.session_state.messages = []
|
19 |
+
|
20 |
+
for message in st.session_state.messages:
|
21 |
+
with st.chat_message(message["role"]):
|
22 |
+
st.markdown(message["content"])
|
23 |
+
|
24 |
+
if prompt := st.chat_input("What is up?"):
|
25 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
26 |
+
with st.chat_message("user"):
|
27 |
+
st.markdown(prompt)
|
28 |
+
|
29 |
+
with st.chat_message("assistant"):
|
30 |
+
message_placeholder = st.empty()
|
31 |
+
quote, author = recommender.recommend(prompt)
|
32 |
+
full_response = f"> {quote}\n\n _{author}_"
|
33 |
+
message_placeholder.markdown(full_response)
|
34 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
data/quotes.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
torch
|
3 |
+
transformers
|
utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from stqdm import stqdm
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from transformers import AutoTokenizer, AutoModel
|
9 |
+
|
10 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
11 |
+
|
12 |
+
|
13 |
+
class DSSM(nn.Module):
|
14 |
+
def __init__(self, base_model_name, base_model=AutoModel):
|
15 |
+
super().__init__()
|
16 |
+
self.diary_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False)
|
17 |
+
self.quote_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False)
|
18 |
+
|
19 |
+
def forward(self, diary, quote):
|
20 |
+
return self.diary_emb(**diary), self.quote_emb(**quote)
|
21 |
+
|
22 |
+
|
23 |
+
def get_models_and_tokenizer(base_model_name, base_model=AutoModel, ckpt=None):
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
25 |
+
model = DSSM(base_model_name, base_model=base_model)
|
26 |
+
if ckpt:
|
27 |
+
print("use ckpt")
|
28 |
+
model.load_state_dict(torch.load(ckpt, map_location=device))
|
29 |
+
model.to(device)
|
30 |
+
return model.diary_emb, model.quote_emb, tokenizer
|
31 |
+
|
32 |
+
|
33 |
+
def model_inference(model, tokenizer, text):
|
34 |
+
tokenized_text = tokenizer(text, return_tensors="pt", truncation=True)
|
35 |
+
tokenized_text = tokenized_text.to(device)
|
36 |
+
output = model(**tokenized_text)
|
37 |
+
return output[0][:, 0, :]
|
38 |
+
|
39 |
+
|
40 |
+
class Recommender:
|
41 |
+
SIMILARITY_THRESHOLD = 0.8
|
42 |
+
|
43 |
+
def __init__(self, quotes_df, base_model_name, base_model=AutoModel, ckpt=None):
|
44 |
+
(self.diary_embedder,
|
45 |
+
self.quote_embedder,
|
46 |
+
self.tokenizer) = get_models_and_tokenizer(base_model_name, base_model, ckpt)
|
47 |
+
|
48 |
+
self.quotes = quotes_df['Quote'].to_list()
|
49 |
+
self.authors = quotes_df['Author'].to_list()
|
50 |
+
|
51 |
+
self.quote_embeddings = torch.tensor(np.array(
|
52 |
+
[model_inference(self.quote_embedder, self.tokenizer, q).cpu().detach().numpy() for q in stqdm(self.quotes[:50])]
|
53 |
+
)).squeeze(1)
|
54 |
+
|
55 |
+
def recommend(self, d):
|
56 |
+
d_emb = model_inference(self.diary_embedder, self.tokenizer, d).squeeze().cpu()
|
57 |
+
similarities = F.cosine_similarity(d_emb, self.quote_embeddings, dim=0)
|
58 |
+
above_threshold_indices = (similarities > self.SIMILARITY_THRESHOLD).nonzero().flatten().tolist()
|
59 |
+
if above_threshold_indices:
|
60 |
+
index = random.choice(above_threshold_indices)
|
61 |
+
else:
|
62 |
+
index = torch.argmax(similarities).item()
|
63 |
+
return self.quotes[index], self.authors[index]
|
64 |
+
|
65 |
+
|
66 |
+
def get_quote_embeddings(model, tokenizer):
|
67 |
+
quotes = pd.read_csv('quotes-recsys/data/quotes.csv')['Quote'].to_list()
|
68 |
+
return torch.tensor([model_inference(model, tokenizer, q) for q in quotes]).squeeze(1)
|