Spaces:
Runtime error
Runtime error
Trent
commited on
Commit
•
6ae27e8
1
Parent(s):
6c6e636
Port demo
Browse files- app.py +73 -0
- backend/config.py +3 -0
- backend/inference.py +41 -0
- backend/main.py +19 -0
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import base64
|
4 |
+
import requests
|
5 |
+
|
6 |
+
st.title('Demo using Flax-Sentence-Tranformers')
|
7 |
+
|
8 |
+
st.sidebar.title('')
|
9 |
+
|
10 |
+
st.markdown('''
|
11 |
+
|
12 |
+
Hi! This is the demo for the [flax sentence embeddings](https://huggingface.co/flax-sentence-embeddings) created for the **Flax/JAX community week 🤗**. We are going to use three flax-sentence-embeddings models: a **distilroberta base**, a **mpnet base** and a **minilm-l6**. All were trained on all the dataset of the 1B+ train corpus with the v3 setup.
|
13 |
+
|
14 |
+
---
|
15 |
+
|
16 |
+
**Instructions**: You can compare the similarity of a main text with other texts of your choice (in the sidebar). In the background, we'll create an embedding for each text, and then we'll use the cosine similarity function to calculate a similarity metric between our main sentence and the others.
|
17 |
+
|
18 |
+
For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
|
19 |
+
|
20 |
+
Please enjoy!!
|
21 |
+
''')
|
22 |
+
|
23 |
+
|
24 |
+
anchor = st.text_input(
|
25 |
+
'Please enter here the main text you want to compare:'
|
26 |
+
)
|
27 |
+
|
28 |
+
if anchor:
|
29 |
+
n_texts = st.sidebar.number_input(
|
30 |
+
f'''How many texts you want to compare with: '{anchor}'?''',
|
31 |
+
value=2,
|
32 |
+
min_value=2)
|
33 |
+
|
34 |
+
inputs = []
|
35 |
+
|
36 |
+
for i in range(n_texts):
|
37 |
+
|
38 |
+
input = st.sidebar.text_input(f'Text {i+1}:')
|
39 |
+
|
40 |
+
inputs.append(input)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
api_base_url = 'http://127.0.0.1:8000/similarity'
|
45 |
+
|
46 |
+
if anchor:
|
47 |
+
if st.sidebar.button('Tell me the similarity.'):
|
48 |
+
res_distilroberta = requests.get(url = api_base_url, params = dict(anchor = anchor,
|
49 |
+
inputs = inputs,
|
50 |
+
model = 'distilroberta'))
|
51 |
+
res_mpnet = requests.get(url = api_base_url, params = dict(anchor = anchor,
|
52 |
+
inputs = inputs,
|
53 |
+
model = 'mpnet'))
|
54 |
+
res_minilm_l6 = requests.get(url = api_base_url, params = dict(anchor = anchor,
|
55 |
+
inputs = inputs,
|
56 |
+
model = 'minilm_l6'))
|
57 |
+
|
58 |
+
d_distilroberta = res_distilroberta.json()['dataframe']
|
59 |
+
d_mpnet = res_mpnet.json()['dataframe']
|
60 |
+
d_minilm_l6 = res_minilm_l6.json()['dataframe']
|
61 |
+
|
62 |
+
index = list(d_distilroberta['inputs'].values())
|
63 |
+
df_total = pd.DataFrame(index=index)
|
64 |
+
df_total['distilroberta'] = list(d_distilroberta['score'].values())
|
65 |
+
df_total['mpnet'] = list(d_mpnet['score'].values())
|
66 |
+
df_total['minilm_l6'] = list(d_minilm_l6['score'].values())
|
67 |
+
|
68 |
+
st.write('Here are the results for our three models:')
|
69 |
+
st.write(df_total)
|
70 |
+
st.write('Visualize the results of each model:')
|
71 |
+
st.area_chart(df_total)
|
72 |
+
|
73 |
+
|
backend/config.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
|
2 |
+
mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
|
3 |
+
minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
|
backend/inference.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import pandas as pd
|
3 |
+
import jax.numpy as jnp
|
4 |
+
|
5 |
+
from typing import List
|
6 |
+
import config
|
7 |
+
|
8 |
+
# We download the models we will be using.
|
9 |
+
# If you do not want to use all, you can comment the unused ones.
|
10 |
+
distilroberta_model = SentenceTransformer(config.MODELS_ID['distilroberta'])
|
11 |
+
mpnet_model = SentenceTransformer(config.MODELS_ID['mpnet'])
|
12 |
+
minilm_l6_model = SentenceTransformer(config.MODELS_ID['minilm_l6'])
|
13 |
+
|
14 |
+
# Defining cosine similarity using flax.
|
15 |
+
def cos_sim(a, b):
|
16 |
+
return jnp.matmul(a, jnp.transpose(b))/(jnp.linalg.norm(a)*jnp.linalg.norm(b))
|
17 |
+
|
18 |
+
|
19 |
+
# We get similarity between embeddings.
|
20 |
+
def text_similarity(anchor: str, inputs: List[str], model: str = 'distilroberta'):
|
21 |
+
|
22 |
+
# Creating embeddings
|
23 |
+
if model == 'distilroberta':
|
24 |
+
anchor_emb = distilroberta_model.encode(anchor)[None, :]
|
25 |
+
inputs_emb = distilroberta_model.encode([input for input in inputs])
|
26 |
+
elif model == 'mpnet':
|
27 |
+
anchor_emb = mpnet_model.encode(anchor)[None, :]
|
28 |
+
inputs_emb = mpnet_model.encode([input for input in inputs])
|
29 |
+
elif model == 'minilm_l6':
|
30 |
+
anchor_emb = minilm_l6_model.encode(anchor)[None, :]
|
31 |
+
inputs_emb = minilm_l6_model.encode([input for input in inputs])
|
32 |
+
|
33 |
+
# Obtaining similarity
|
34 |
+
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
|
35 |
+
|
36 |
+
# Returning a Pandas' dataframe
|
37 |
+
d = {'inputs': [input for input in inputs],
|
38 |
+
'score': [round(similarity[i],3) for i in range(len(similarity))]}
|
39 |
+
df = pd.DataFrame(d, columns=['inputs', 'score'])
|
40 |
+
|
41 |
+
return df.sort_values('score', ascending=False)
|
backend/main.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Query, FastAPI
|
2 |
+
|
3 |
+
import config
|
4 |
+
import inference
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
app = FastAPI()
|
8 |
+
|
9 |
+
@app.get("/")
|
10 |
+
def read_root():
|
11 |
+
return {"message": "Welcome to the API of flax-sentence-embeddings."}
|
12 |
+
|
13 |
+
@app.get('/similarity')
|
14 |
+
def get_similarity(anchor: str, inputs: List[str] = Query([]), model: str = 'distilroberta'):
|
15 |
+
return {'dataframe': inference.text_similarity(anchor, inputs, model)}
|
16 |
+
|
17 |
+
|
18 |
+
#if __name__ == "__main__":
|
19 |
+
# uvicorn.run("main:app", host="0.0.0.0", port=8080)
|