Federico Galatolo
first commit
717aa8f
raw
history blame
2.4 kB
import os
import streamlit as st
from elasticsearch import Elasticsearch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import plotly.express as plx
def compare():
if len(multiselect) == 0: return
target_field = f"{model}_features"
ids = [documents[title] for title in multiselect]
results = []
for id in ids:
results.append(es.search(
index="sentences",
query={
"constant_score" : {
"filter" : {
"term" : {
"document": id
}
}
}
},
size=limit
))
features = []
classes = []
sentences = []
for result, title in zip(results, multiselect):
features.append(np.asarray([sent["_source"][target_field] for sent in result["hits"]["hits"]]))
classes.extend([title]*len(result["hits"]["hits"]))
sentences.extend([sent["_source"]["sentence"] for sent in result["hits"]["hits"]])
features = np.concatenate(features)
scaler = StandardScaler()
features = scaler.fit_transform(features)
tsne = TSNE(n_components=2, metric="cosine", init="pca")
features = tsne.fit_transform(features)
classes = [c[:10]+"..." for c in classes]
df = pd.DataFrame.from_dict(dict(
x=features[:, 0],
y=features[:, 1],
classes=classes,
sentences=sentences
))
st.plotly_chart(plx.scatter(
data_frame=df,
x="x",
y="y",
color="classes",
hover_name="sentences"
))
es = Elasticsearch(os.environ["ELASTIC_HOST"], basic_auth=os.environ["ELASTIC_AUTH"].split(":"))
results = es.search(index="documents", query={"match_all":{}})
results = [result["_source"] for result in results["hits"]["hits"]]
documents = {f"{result['title']} - {result['author']}": result['id'] for result in results}
st.sidebar.title("Semantic compare")
st.sidebar.write("Select 2 or more documents from the SERICA library to semantically compare them")
multiselect = st.sidebar.multiselect("Documents", list(documents.keys()))
model = st.sidebar.selectbox("Model", ["LaBSE"])
limit = st.sidebar.number_input("Sentences per document", 1000)
st.sidebar.button("Compare", on_click=compare)