FremyCompany's picture
Update app.py
f22721a verified
raw
history blame
16.7 kB
import streamlit as st
st.set_page_config(layout="wide")
DEBUG = True
if DEBUG:
# Use some dummy data
tokens = [
["[0]\u2581De", "[0]\u2581fait", "[0],", "[0]\u2581mon", "[0]\u2581mari", "[0]\u2581ne", "[0]\u2581parlait", "[0]\u2581jamais", "[0]\u2581de", "[0]\u2581ses", "[0]\u2581affaires", "[0]\u2581avec", "[0]\u2581moi", "[0]."],
["[1]\u2581M", "[1]'", "[1]n", "[1]\u2581man", "[1]\u2581had", "[1]\u2581het", "[1]\u2581met", "[1]\u2581mij", "[1]\u2581nooit", "[1]\u2581over", "[1]\u2581z", "[1]'", "[1]n", "[1]\u2581zaken", "[1],", "[1]\u2581inderdaad", "[1]."],
]
sentence_similarity = 0.7
token_probabilities_12 = [[0.0001786323555279523, 0.029554476961493492, 0.0005334240850061178, 9.950459934771061e-05, 0.0069955079816281796, 0.002180722774937749, 0.0011584730818867683, 0.00024079506692942232, 0.004674288909882307, 0.0028222708497196436, 0.0010019985493272543, 0.04594381898641586, 0.005427392199635506, 0.0003972635604441166, 0.16267944872379303, 0.7361112833023071, 7.352152806561207e-07], [8.627762326796073e-06, 0.00021197207388468087, 4.6174409362720326e-05, 9.608593245502561e-06, 0.0007827585795894265, 0.00014010778977535665, 0.00010562615352682769, 2.1056943296571262e-05, 0.0001730725634843111, 0.00014431933232117444, 2.873636913136579e-05, 0.00034569500712677836, 0.00012183879152871668, 5.0419181206962094e-05, 0.0021323736291378736, 0.9956766963005066, 9.747334388521267e-07], [0.000363957486115396, 0.058086857199668884, 0.0010622636182233691, 0.0001401746121700853, 0.03782657906413078, 0.010060806758701801, 0.003844393650069833, 0.00017795785970520228, 0.006151353009045124, 0.004802143666893244, 0.00045746073010377586, 0.13868938386440277, 0.004628518130630255, 0.0001305590703850612, 0.7329095005989075, 0.0006681070663034916, 2.742818416834325e-08], [0.9904406070709229, 0.0029372733552008867, 0.002740614116191864, 0.00016679026884958148, 7.625997386639938e-05, 0.00013673932699020952, 3.424803799134679e-05, 0.0018435337115079165, 2.9864917451050133e-05, 3.524687417666428e-05, 0.0011246443027630448, 0.00029081859975121915, 0.00010983269021380693, 1.0424302672618069e-05, 2.242830305476673e-05, 6.100983682699734e-07, 1.5642478956578998e-07], [0.0014130950439721346, 0.0006491155363619328, 0.0039030632469803095, 0.9928011894226074, 3.293847112217918e-05, 0.00023004157992545515, 4.8409721784992144e-05, 1.1099789844593033e-05, 5.912710548727773e-06, 2.8217813451192342e-05, 0.0006040701409801841, 0.0001110830926336348, 5.0306276534684e-05, 8.831957529764622e-05, 1.970089942915365e-05, 1.0986423149006441e-06, 2.326722096768208e-06], [8.804490062175319e-05, 0.020074667409062386, 0.0004638279788196087, 6.011720688547939e-05, 0.3746979236602783, 0.018972501158714294, 0.0019666561856865883, 0.0003945657517760992, 0.5129392743110657, 0.010681135579943657, 0.0003376945969648659, 0.04529218748211861, 0.0050969235599040985, 0.00018651205755304545, 0.008603758178651333, 0.00014415399346034974, 7.929369161274735e-08], [0.00021422369172796607, 0.005102124996483326, 0.0003292255278211087, 0.0009069664520211518, 0.08789367973804474, 0.15337994694709778, 0.0511351116001606, 0.0014941217377781868, 0.003102638525888324, 0.6176480650901794, 0.0024865365121513605, 0.01969054341316223, 0.025343414396047592, 0.02977576106786728, 0.0013427335070446134, 0.00012596377928275615, 2.9010154321440496e-05], [1.408028879268386e-06, 0.00013552005111705512, 6.736800969520118e-06, 1.0955036486848257e-06, 0.0011538203107193112, 0.00019907338719349355, 3.0282362786238082e-05, 1.2565109500428662e-05, 0.9977654218673706, 0.00013098378258291632, 1.2177627468190622e-05, 0.0003589813131839037, 0.00010541738447500393, 7.141510650399141e-06, 6.763384590158239e-05, 1.1701029507094063e-05, 1.9685167274019477e-08], [5.773245902673807e-06, 0.0014488865854218602, 2.6845693355426192e-05, 1.0805002602864988e-05, 0.002086219610646367, 0.01130380667746067, 0.001883843680843711, 2.9443286621244624e-05, 0.00014186602493282408, 0.8935705423355103, 0.0006889772485010326, 0.016468364745378494, 0.0685788094997406, 0.0029715588316321373, 0.0007683428702875972, 1.5869531125645153e-05, 9.061647432417885e-09], [0.012304342351853848, 0.03311317041516304, 0.007781223859637976, 0.004408247768878937, 0.002244020812213421, 0.04515384882688522, 0.0010716691613197327, 0.0008402583771385252, 0.00038856116589158773, 0.003114827908575535, 0.7942622303962708, 0.026046255603432655, 0.0659388080239296, 0.002598112216219306, 0.0007164151757024229, 1.7107675375882536e-05, 8.561248137084476e-07], [1.8507088270780514e-06, 8.229719242081046e-05, 9.083240001928061e-06, 2.4296445189975202e-05, 6.340398158499738e-06, 0.0001343603798886761, 5.143981979927048e-06, 1.8609456446938566e-06, 1.2062999985573697e-06, 0.0006211695144884288, 0.0004705676983576268, 0.0002221767936134711, 0.008940674364566803, 0.9894717335700989, 6.331008989945985e-06, 6.910536853865779e-07, 1.8326457507100713e-07], [9.844085070653819e-06, 0.00017242366448044777, 1.5286375855794176e-05, 9.348634193884209e-06, 0.0001390141696901992, 0.0014548851177096367, 0.9944896697998047, 2.612253047118429e-05, 1.680908098933287e-05, 0.001790934125892818, 0.00016232591588050127, 0.0006340526742860675, 0.0008499748073518276, 0.0001129394950112328, 0.00011405935219954699, 2.3224697542900685e-06, 2.073006299951885e-08], [0.0001119379885494709, 4.149791129748337e-05, 4.742472356156213e-06, 2.8589572593773482e-06, 7.517372523579979e-06, 0.00013416734873317182, 2.4442895664833486e-05, 0.9989697933197021, 1.3636500625580084e-05, 2.8603359169210307e-05, 0.000470715545816347, 6.383401341736317e-05, 0.00010340050357626751, 1.8945609554066323e-05, 3.2612724680802785e-06, 4.755758311603131e-07, 2.1897558610817214e-07], [1.182226538887221e-09, 2.612263738654974e-10, 4.270485631785448e-10, 3.19147619443072e-09, 1.5840342926232154e-10, 1.1831424728825368e-09, 5.845964268225146e-10, 5.307323469594394e-09, 1.2535458226992091e-09, 7.667128598676243e-10, 6.178164646541973e-09, 8.621070524128527e-11, 6.898879245653688e-10, 1.4480447951825681e-08, 4.186111873805132e-11, 1.1763671148301569e-09, 1.0]]
token_probabilities_21 = [[6.291573299677111e-06, 1.2251668977114605e-06, 2.580553882580716e-05, 0.9942325949668884, 0.0006945801433175802, 2.1142666810192168e-05, 7.554778676421847e-06, 1.1657297363854013e-05, 1.2746155334752984e-05, 0.00413579260930419, 1.4904621821187902e-05, 3.4593394957482815e-05, 0.0008008066797628999, 2.480075806943205e-07], [0.03415770083665848, 0.0009877384873107076, 0.13514696061611176, 0.09675426036119461, 0.01046981941908598, 0.15818676352500916, 0.005904339253902435, 0.036817632615566254, 0.10496868193149567, 0.3652307093143463, 0.02174874022603035, 0.01988295465707779, 0.009741867892444134, 1.7982374629355036e-06], [0.0024134248960763216, 0.0008422881946898997, 0.009675124660134315, 0.35340237617492676, 0.24644418060779572, 0.014307855628430843, 0.001491453149355948, 0.007164766546338797, 0.007613700814545155, 0.33597758412361145, 0.009396923705935478, 0.006900560110807419, 0.004358294885605574, 1.150809748651227e-05], [7.152517355279997e-06, 2.7846685952681582e-06, 2.028374728979543e-05, 0.0003417014959268272, 0.995932400226593, 2.9462620659614913e-05, 6.527722143800929e-05, 1.8510456357034855e-05, 4.868561154580675e-05, 0.003024008125066757, 0.0003993394784629345, 6.704749830532819e-05, 4.174213609076105e-05, 1.3663803883900982e-06], [0.0022054731380194426, 0.0009949662489816546, 0.024007266387343407, 0.0006852351943962276, 0.0001449233095627278, 0.8054160475730896, 0.02774566411972046, 0.08550822734832764, 0.04122897982597351, 0.006751661188900471, 0.00045707033132202923, 0.004372807219624519, 0.00048139269347302616, 2.974482242734666e-07], [0.0012809201143682003, 0.0003318041271995753, 0.011896451003849506, 0.00228915479965508, 0.0018857300747185946, 0.07598057389259338, 0.09020800143480301, 0.027486657723784447, 0.416204035282135, 0.25311478972435, 0.0180458165705204, 0.08526463806629181, 0.01600734144449234, 4.139269094594056e-06], [1.1649769476207439e-05, 4.282531790522626e-06, 7.782561442581937e-05, 9.815836165216751e-06, 6.7938508436782286e-06, 0.00013483928341884166, 0.0005148796481080353, 7.158283551689237e-05, 0.0011875078780576587, 0.00010284745076205581, 1.1828081369458232e-05, 0.9978162050247192, 4.9926951760426164e-05, 3.501482837009462e-08], [1.1862812243634835e-06, 4.1824714003269037e-07, 1.7648992525209906e-06, 0.0002588517963886261, 7.631428502463677e-07, 1.3253040378913283e-05, 7.370218554569874e-06, 1.4551008462149184e-05, 9.092555956158321e-06, 3.950515019823797e-05, 2.0963175302313175e-06, 1.2840252566093113e-05, 0.9996381998062134, 1.557328772605615e-07], [1.9633629563031718e-05, 2.9309690035006497e-06, 5.201384556130506e-05, 3.575253003873513e-06, 3.46595953715223e-07, 0.01468951441347599, 1.304881698160898e-05, 0.9851461052894592, 3.735282734851353e-05, 1.557562245579902e-05, 1.1585748325160239e-06, 7.0444598350150045e-06, 1.1634260772552807e-05, 3.1361004459995456e-08], [4.94217783852946e-05, 1.0189252861891873e-05, 0.0001692848454695195, 1.7591419236850925e-05, 6.895964816067135e-06, 0.0012752452166751027, 0.010829685255885124, 0.0005391684826463461, 0.9808638095855713, 0.0005205412162467837, 0.0024872170761227608, 0.0031290908809751272, 0.00010173906048294157, 7.99681956209497e-08], [0.00012695191253442317, 1.4679194464406464e-05, 0.00011667808576021343, 0.0040611401200294495, 0.0010680991690605879, 0.0002917109231930226, 0.0003154439036734402, 0.0003626788384281099, 0.005471888929605484, 0.9603675603866577, 0.013632584363222122, 0.0020520074758678675, 0.012113837525248528, 4.662265382648911e-06], [0.02129807136952877, 0.0006461066077463329, 0.12942540645599365, 0.003842339850962162, 0.000718642957508564, 0.14315049350261688, 0.009139570407569408, 0.03911759331822395, 0.47854599356651306, 0.11522843688726425, 0.023550251498818398, 0.029326293617486954, 0.006010579876601696, 2.3803401916211442e-07], [0.0007556688506156206, 6.83949256199412e-05, 0.0012973148841410875, 0.0004358450823929161, 9.774936916073784e-05, 0.0048384349793195724, 0.00353313609957695, 0.00345016410574317, 0.5985360145568848, 0.08761582523584366, 0.2846389412879944, 0.011807692237198353, 0.0029242453165352345, 5.721153115700872e-07], [1.753839683260594e-06, 8.974393495009281e-07, 1.1603299299167702e-06, 1.311652454205614e-06, 5.441519078885904e-06, 5.614032033918193e-06, 0.00013162224786356091, 7.411200840579113e-06, 0.000822348112706095, 0.00010946377733489498, 0.9988458156585693, 4.974802868673578e-05, 1.698911546554882e-05, 3.8076589703450736e-07], [0.09078110009431839, 0.004797599744051695, 0.823334276676178, 0.00035671371733769774, 0.00015342659025918692, 0.03273462504148483, 0.0007502525695599616, 0.008871844969689846, 0.026876816526055336, 0.0038153002969920635, 0.0008078285027295351, 0.006350564770400524, 0.00036965846084058285, 1.3913560792389035e-07], [0.15473105013370514, 0.8438209295272827, 0.00028271129122003913, 3.655057753348956e-06, 3.2228654163191095e-06, 0.00020659371512010694, 2.6511515898164362e-05, 0.0005781562067568302, 0.00020910197054035962, 3.4318334655836225e-05, 3.321463373140432e-05, 4.870831980952062e-05, 2.0305074940552004e-05, 1.472792405365908e-06], [1.2343838451389644e-10, 6.598104285160389e-10, 9.270336234767917e-12, 7.485163600051692e-10, 5.4517017566979575e-09, 9.076759482917751e-11, 4.8768549198996425e-09, 7.768937160257394e-10, 9.536814393751314e-11, 1.3717442737259944e-09, 7.035541482736107e-09, 3.472595544451451e-10, 7.467613194478417e-09, 1.0]]
else:
## Load the model
import transformers
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
@st.cache_data
def load_model_and_tokenizer():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained('Parallia/Fairly-Multilingual-ModernBERT-Embed-BE', trust_remote_code=True).to(device)
tokenizer = AutoTokenizer.from_pretrained('Parallia/Fairly-Multilingual-ModernBERT-Embed-BE', trust_remote_code=True)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
_, mid_st, _ = st.columns([1, 2, 1])
sent0 = mid_st.text_input("Type your first sentence here", value="De fait, mon mari ne parlait jamais de ses affaires avec moi.", key="sent0")
sent1 = mid_st.text_input("Type your second sentence here", value="M'n man had het met mij nooit over z'n zaken, inderdaad.", key="sent1")
sentences = [
sent0, sent1
]
tokens = []
embeddings = []
for sentence in sentences:
with torch.no_grad():
encoded_sentence = tokenizer(sentence, padding=False, truncation=True, return_tensors="pt").to(model.device)
embedded_sentence = model(**encoded_sentence).last_hidden_state[0].detach().cpu().clone()
tokens.append(tokenizer.tokenize(sentence))
embeddings.append(embedded_sentence)
token_similarities = F.normalize(embeddings[0], dim=1) @ F.normalize(embeddings[1], dim=1).T
sentence_similarity = F.normalize(torch.mean(embeddings[0], dim=0), dim=-1) @ F.normalize(torch.mean(embeddings[1], dim=0), dim=-1)
sentence_similarity = max(0, round(sentence_similarity.item(), 2))
#print("="*60)
#print("Mapping sentence1 to sentence2...")
#print("="*60)
token_probabilities_12 = F.softmax(20*token_similarities, dim=1)
for i in range(len(tokens[0])):
j = torch.argmax(token_probabilities_12[i])
#print(tokens[0][i].ljust(15), tokens[1][j].ljust(15), round(token_probabilities_12[i][j].item(), 2))
#print("="*60)
#print("Mapping sentence2 to sentence1...")
#print("="*60)
token_probabilities_21 = F.softmax(20*token_similarities.T, dim=1)
for j in range(len(tokens[1])):
i = torch.argmax(token_probabilities_21[j])
#print(tokens[1][j].ljust(15), tokens[0][i].ljust(15), round(token_probabilities_21[j][i].item(), 2))
# Simplify the tokens for display
tokens = [[token[3:].replace("\u2581", " ") for token in sentence] for sentence in tokens]
html = ''
html += """
<article>
<div>"""
html += f"""{("βœ… Congrats!" if sentence_similarity >= 0.65 else "❌ Sorry!")} These sentences have {100*sentence_similarity}% similarity."""
html += """
</div>
<p id="sent0">"""
for token in tokens[0]:
html += f"""{' ' if token[0] == ' ' else ''}<span>{token[1:] if token[0] == ' ' else token}</span>"""
html += """</div>
</p>
<p id="sent1">"""
for token in tokens[1]:
html += f"""{' ' if token[0] == ' ' else ''}<span>{token[1:] if token[0] == ' ' else token}</span>"""
html += """</div>
</p>
</article>"""
html += """
<style>
article {
font-family: sans-serif;
text-align: center;
}
button:hover {
background-color: #0056b3;
}
p {
margin: 0.5em;
font-size: 2em;
text-wrap: balance;
}
span {
animation-name: rotate_bg;
animation-duration: 15s;
animation-timing-function: steps(14, start);
animation-iteration-count: infinite;
text-decoration: underline 0.3em;
text-decoration-skip: none;
text-decoration-skip-ink: none;
color: rgba(0, 0, 0, calc((50% + 50% * var(--p))));
text-decoration-color: hsla(161, 100%, 43%, var(--p));
background-color: hsla(161, 100%, 43%, calc(var(--p) * 0.2));
--p: var(--p0); """
for i in range(len(tokens[0])):
html += f"""--p{i}: 0; """
for j in range(len(tokens[1])):
html += f"""--ip{j}: 0; """
html += """
}
"""
for i in range(len(tokens[0])):
html += f"""
#sent0 span:nth-child({i+1}) {{ --p{i}: 1; }}"""
for j in range(len(tokens[1])):
html += f"""
#sent1 span:nth-child({j+1}) {{ --ip{j}: 1; }}"""
for i in range(len(tokens[0])):
html += f"""
#sent0 span:nth-child({i+1}) {{"""
for j in range(len(tokens[1])):
html += f"""--ip{j}: {token_probabilities_21[j][i]}; """
html += """}"""
for j in range(len(tokens[1])):
html += f"""
#sent1 span:nth-child({j+1}) {{"""
for i in range(len(tokens[0])):
html += f"""--p{i}: {token_probabilities_12[i][j]}; """
html += """}"""
for i in range(len(tokens[0])):
html += f"""
body:has(#sent0 span:nth-child({i+1}):hover) span {{ --p: var(--p{i}) !important; }}"""
for j in range(len(tokens[1])):
html += f"""
body:has(#sent1 span:nth-child({j+1}):hover) span {{ --p: var(--ip{j}) !important; }}"""
html += """
@keyframes rotate_bg {"""
for i in range(len(tokens[0])):
html += f"""
{100*i/len(tokens[0])}% {{ --p: var(--p{i}); }}"""
html += """
}
</style>
"""
st.html(html)