import streamlit as st, random
st.set_page_config(layout="wide")
# Give some context
st.html("""
🔀 Word-level alignment between two sentences
Supports English, French, Dutch, and German.
""")
# Create a layout with a columns on each side for padding
_, mid_st, _ = st.columns([1, 2, 1])
# Allow the user to reroll the example sentences
reroll_button = mid_st.button("Try a new example!", key="reroll")
if reroll_button:
example_sentences = [
# translations
("The book, which was on the table, is now missing.", "Het boek, dat op de tafel lag, is nu verdwenen."),
("If I had known, I would have acted differently.", "Si j'avais su, j'aurais agi différemment."),
("She can speak three languages fluently.", "Sie kann drei Sprachen fließend sprechen."),
("I wish I had more time to learn.", "Ich wünschte, ich hätte mehr Zeit zum Lernen."),
("The children were playing while their parents were talking.", "De kinderen speelden terwijl hun ouders aan het praten waren."),
("He would go to the gym every day if he had more energy.", "Il irait à la salle de sport tous les jours s'il avait plus d'énergie."),
("By the time I arrived, she had already left.", "Als ich ankam, was zij al vertrokken."),
("Despite the rain, they went for a walk.", "Malgré la pluie, ils sont allés se promener."),
("If I were you, I wouldn't do that.", "Als ik jou was, zou ik dat niet doen."),
("The movie, which I watched yesterday, was fantastic.", "Der Film, den ich gestern gesehen habe, war fantastisch."),
# paraphrases
("She has a remarkable ability to solve problems quickly.", "Her problem-solving skills are impressive and rapid."),
("Despite the fact that the project was delayed, they managed to finish it on time.", "Even though the project was delayed, they were able to complete it by the deadline."),
("The teacher asked the students to submit their assignments by Friday.", "The students were required to hand in their assignments no later than Friday."),
("I haven't seen him in years, and I wonder how he's doing.", "It's been years since I last saw him, and I'm curious about his well-being."),
("He was hesitant to take the offer because it seemed too good to be true.", "He doubted the offer because it appeared to be too perfect to be genuine."),
("She didn't have the necessary qualifications, but she still managed to get the job.", "Even though she lacked the required qualifications, she succeeded in securing the position."),
("John said that he would be going to the meeting later.", "According to John, he planned to attend the meeting later."),
("The weather was terrible, so we decided to cancel the outdoor event.", "Due to the poor weather, we chose to call off the outdoor event."),
("They have lived in this city for a long time, and they're very familiar with it.", "Having resided in this city for many years, they know it quite well."),
("The book was so captivating that I couldn't put it down until I finished it.", "I found the book so engrossing that I read it all the way through without stopping.")
]
random_sentences = random.choice(example_sentences)
sent0 = mid_st.text_input("Type your first sentence here", value=random_sentences[0], key="sent0")
sent1 = mid_st.text_input("Type your second sentence here", value=random_sentences[1], key="sent1")
else:
# Allow the user to input two sentences
sent0 = mid_st.text_input("Type your first sentence here", value="De fait, mon mari ne parlait jamais de ses affaires avec moi." if "sent0" not in st.session_state else st.session_state.sent0, key="sent0")
sent1 = mid_st.text_input("Type your second sentence here", value="M'n man had het met mij nooit over z'n zaken, inderdaad." if "sent1" not in st.session_state else st.session_state.sent1, key="sent1")
# Display the mapping between the two sentences
DEBUG = False
if DEBUG:
# Use some dummy data
tokens = [
["[0]\u2581De", "[0]\u2581fait", "[0],", "[0]\u2581mon", "[0]\u2581mari", "[0]\u2581ne", "[0]\u2581parlait", "[0]\u2581jamais", "[0]\u2581de", "[0]\u2581ses", "[0]\u2581affaires", "[0]\u2581avec", "[0]\u2581moi", "[0]."],
["[1]\u2581M", "[1]'", "[1]n", "[1]\u2581man", "[1]\u2581had", "[1]\u2581het", "[1]\u2581met", "[1]\u2581mij", "[1]\u2581nooit", "[1]\u2581over", "[1]\u2581z", "[1]'", "[1]n", "[1]\u2581zaken", "[1],", "[1]\u2581inderdaad", "[1]."],
]
sentence_similarity = 0.7
token_probabilities_12 = [[0.0001786323555279523, 0.029554476961493492, 0.0005334240850061178, 9.950459934771061e-05, 0.0069955079816281796, 0.002180722774937749, 0.0011584730818867683, 0.00024079506692942232, 0.004674288909882307, 0.0028222708497196436, 0.0010019985493272543, 0.04594381898641586, 0.005427392199635506, 0.0003972635604441166, 0.16267944872379303, 0.7361112833023071, 7.352152806561207e-07], [8.627762326796073e-06, 0.00021197207388468087, 4.6174409362720326e-05, 9.608593245502561e-06, 0.0007827585795894265, 0.00014010778977535665, 0.00010562615352682769, 2.1056943296571262e-05, 0.0001730725634843111, 0.00014431933232117444, 2.873636913136579e-05, 0.00034569500712677836, 0.00012183879152871668, 5.0419181206962094e-05, 0.0021323736291378736, 0.9956766963005066, 9.747334388521267e-07], [0.000363957486115396, 0.058086857199668884, 0.0010622636182233691, 0.0001401746121700853, 0.03782657906413078, 0.010060806758701801, 0.003844393650069833, 0.00017795785970520228, 0.006151353009045124, 0.004802143666893244, 0.00045746073010377586, 0.13868938386440277, 0.004628518130630255, 0.0001305590703850612, 0.7329095005989075, 0.0006681070663034916, 2.742818416834325e-08], [0.9904406070709229, 0.0029372733552008867, 0.002740614116191864, 0.00016679026884958148, 7.625997386639938e-05, 0.00013673932699020952, 3.424803799134679e-05, 0.0018435337115079165, 2.9864917451050133e-05, 3.524687417666428e-05, 0.0011246443027630448, 0.00029081859975121915, 0.00010983269021380693, 1.0424302672618069e-05, 2.242830305476673e-05, 6.100983682699734e-07, 1.5642478956578998e-07], [0.0014130950439721346, 0.0006491155363619328, 0.0039030632469803095, 0.9928011894226074, 3.293847112217918e-05, 0.00023004157992545515, 4.8409721784992144e-05, 1.1099789844593033e-05, 5.912710548727773e-06, 2.8217813451192342e-05, 0.0006040701409801841, 0.0001110830926336348, 5.0306276534684e-05, 8.831957529764622e-05, 1.970089942915365e-05, 1.0986423149006441e-06, 2.326722096768208e-06], [8.804490062175319e-05, 0.020074667409062386, 0.0004638279788196087, 6.011720688547939e-05, 0.3746979236602783, 0.018972501158714294, 0.0019666561856865883, 0.0003945657517760992, 0.5129392743110657, 0.010681135579943657, 0.0003376945969648659, 0.04529218748211861, 0.0050969235599040985, 0.00018651205755304545, 0.008603758178651333, 0.00014415399346034974, 7.929369161274735e-08], [0.00021422369172796607, 0.005102124996483326, 0.0003292255278211087, 0.0009069664520211518, 0.08789367973804474, 0.15337994694709778, 0.0511351116001606, 0.0014941217377781868, 0.003102638525888324, 0.6176480650901794, 0.0024865365121513605, 0.01969054341316223, 0.025343414396047592, 0.02977576106786728, 0.0013427335070446134, 0.00012596377928275615, 2.9010154321440496e-05], [1.408028879268386e-06, 0.00013552005111705512, 6.736800969520118e-06, 1.0955036486848257e-06, 0.0011538203107193112, 0.00019907338719349355, 3.0282362786238082e-05, 1.2565109500428662e-05, 0.9977654218673706, 0.00013098378258291632, 1.2177627468190622e-05, 0.0003589813131839037, 0.00010541738447500393, 7.141510650399141e-06, 6.763384590158239e-05, 1.1701029507094063e-05, 1.9685167274019477e-08], [5.773245902673807e-06, 0.0014488865854218602, 2.6845693355426192e-05, 1.0805002602864988e-05, 0.002086219610646367, 0.01130380667746067, 0.001883843680843711, 2.9443286621244624e-05, 0.00014186602493282408, 0.8935705423355103, 0.0006889772485010326, 0.016468364745378494, 0.0685788094997406, 0.0029715588316321373, 0.0007683428702875972, 1.5869531125645153e-05, 9.061647432417885e-09], [0.012304342351853848, 0.03311317041516304, 0.007781223859637976, 0.004408247768878937, 0.002244020812213421, 0.04515384882688522, 0.0010716691613197327, 0.0008402583771385252, 0.00038856116589158773, 0.003114827908575535, 0.7942622303962708, 0.026046255603432655, 0.0659388080239296, 0.002598112216219306, 0.0007164151757024229, 1.7107675375882536e-05, 8.561248137084476e-07], [1.8507088270780514e-06, 8.229719242081046e-05, 9.083240001928061e-06, 2.4296445189975202e-05, 6.340398158499738e-06, 0.0001343603798886761, 5.143981979927048e-06, 1.8609456446938566e-06, 1.2062999985573697e-06, 0.0006211695144884288, 0.0004705676983576268, 0.0002221767936134711, 0.008940674364566803, 0.9894717335700989, 6.331008989945985e-06, 6.910536853865779e-07, 1.8326457507100713e-07], [9.844085070653819e-06, 0.00017242366448044777, 1.5286375855794176e-05, 9.348634193884209e-06, 0.0001390141696901992, 0.0014548851177096367, 0.9944896697998047, 2.612253047118429e-05, 1.680908098933287e-05, 0.001790934125892818, 0.00016232591588050127, 0.0006340526742860675, 0.0008499748073518276, 0.0001129394950112328, 0.00011405935219954699, 2.3224697542900685e-06, 2.073006299951885e-08], [0.0001119379885494709, 4.149791129748337e-05, 4.742472356156213e-06, 2.8589572593773482e-06, 7.517372523579979e-06, 0.00013416734873317182, 2.4442895664833486e-05, 0.9989697933197021, 1.3636500625580084e-05, 2.8603359169210307e-05, 0.000470715545816347, 6.383401341736317e-05, 0.00010340050357626751, 1.8945609554066323e-05, 3.2612724680802785e-06, 4.755758311603131e-07, 2.1897558610817214e-07], [1.182226538887221e-09, 2.612263738654974e-10, 4.270485631785448e-10, 3.19147619443072e-09, 1.5840342926232154e-10, 1.1831424728825368e-09, 5.845964268225146e-10, 5.307323469594394e-09, 1.2535458226992091e-09, 7.667128598676243e-10, 6.178164646541973e-09, 8.621070524128527e-11, 6.898879245653688e-10, 1.4480447951825681e-08, 4.186111873805132e-11, 1.1763671148301569e-09, 1.0]]
token_probabilities_21 = [[6.291573299677111e-06, 1.2251668977114605e-06, 2.580553882580716e-05, 0.9942325949668884, 0.0006945801433175802, 2.1142666810192168e-05, 7.554778676421847e-06, 1.1657297363854013e-05, 1.2746155334752984e-05, 0.00413579260930419, 1.4904621821187902e-05, 3.4593394957482815e-05, 0.0008008066797628999, 2.480075806943205e-07], [0.03415770083665848, 0.0009877384873107076, 0.13514696061611176, 0.09675426036119461, 0.01046981941908598, 0.15818676352500916, 0.005904339253902435, 0.036817632615566254, 0.10496868193149567, 0.3652307093143463, 0.02174874022603035, 0.01988295465707779, 0.009741867892444134, 1.7982374629355036e-06], [0.0024134248960763216, 0.0008422881946898997, 0.009675124660134315, 0.35340237617492676, 0.24644418060779572, 0.014307855628430843, 0.001491453149355948, 0.007164766546338797, 0.007613700814545155, 0.33597758412361145, 0.009396923705935478, 0.006900560110807419, 0.004358294885605574, 1.150809748651227e-05], [7.152517355279997e-06, 2.7846685952681582e-06, 2.028374728979543e-05, 0.0003417014959268272, 0.995932400226593, 2.9462620659614913e-05, 6.527722143800929e-05, 1.8510456357034855e-05, 4.868561154580675e-05, 0.003024008125066757, 0.0003993394784629345, 6.704749830532819e-05, 4.174213609076105e-05, 1.3663803883900982e-06], [0.0022054731380194426, 0.0009949662489816546, 0.024007266387343407, 0.0006852351943962276, 0.0001449233095627278, 0.8054160475730896, 0.02774566411972046, 0.08550822734832764, 0.04122897982597351, 0.006751661188900471, 0.00045707033132202923, 0.004372807219624519, 0.00048139269347302616, 2.974482242734666e-07], [0.0012809201143682003, 0.0003318041271995753, 0.011896451003849506, 0.00228915479965508, 0.0018857300747185946, 0.07598057389259338, 0.09020800143480301, 0.027486657723784447, 0.416204035282135, 0.25311478972435, 0.0180458165705204, 0.08526463806629181, 0.01600734144449234, 4.139269094594056e-06], [1.1649769476207439e-05, 4.282531790522626e-06, 7.782561442581937e-05, 9.815836165216751e-06, 6.7938508436782286e-06, 0.00013483928341884166, 0.0005148796481080353, 7.158283551689237e-05, 0.0011875078780576587, 0.00010284745076205581, 1.1828081369458232e-05, 0.9978162050247192, 4.9926951760426164e-05, 3.501482837009462e-08], [1.1862812243634835e-06, 4.1824714003269037e-07, 1.7648992525209906e-06, 0.0002588517963886261, 7.631428502463677e-07, 1.3253040378913283e-05, 7.370218554569874e-06, 1.4551008462149184e-05, 9.092555956158321e-06, 3.950515019823797e-05, 2.0963175302313175e-06, 1.2840252566093113e-05, 0.9996381998062134, 1.557328772605615e-07], [1.9633629563031718e-05, 2.9309690035006497e-06, 5.201384556130506e-05, 3.575253003873513e-06, 3.46595953715223e-07, 0.01468951441347599, 1.304881698160898e-05, 0.9851461052894592, 3.735282734851353e-05, 1.557562245579902e-05, 1.1585748325160239e-06, 7.0444598350150045e-06, 1.1634260772552807e-05, 3.1361004459995456e-08], [4.94217783852946e-05, 1.0189252861891873e-05, 0.0001692848454695195, 1.7591419236850925e-05, 6.895964816067135e-06, 0.0012752452166751027, 0.010829685255885124, 0.0005391684826463461, 0.9808638095855713, 0.0005205412162467837, 0.0024872170761227608, 0.0031290908809751272, 0.00010173906048294157, 7.99681956209497e-08], [0.00012695191253442317, 1.4679194464406464e-05, 0.00011667808576021343, 0.0040611401200294495, 0.0010680991690605879, 0.0002917109231930226, 0.0003154439036734402, 0.0003626788384281099, 0.005471888929605484, 0.9603675603866577, 0.013632584363222122, 0.0020520074758678675, 0.012113837525248528, 4.662265382648911e-06], [0.02129807136952877, 0.0006461066077463329, 0.12942540645599365, 0.003842339850962162, 0.000718642957508564, 0.14315049350261688, 0.009139570407569408, 0.03911759331822395, 0.47854599356651306, 0.11522843688726425, 0.023550251498818398, 0.029326293617486954, 0.006010579876601696, 2.3803401916211442e-07], [0.0007556688506156206, 6.83949256199412e-05, 0.0012973148841410875, 0.0004358450823929161, 9.774936916073784e-05, 0.0048384349793195724, 0.00353313609957695, 0.00345016410574317, 0.5985360145568848, 0.08761582523584366, 0.2846389412879944, 0.011807692237198353, 0.0029242453165352345, 5.721153115700872e-07], [1.753839683260594e-06, 8.974393495009281e-07, 1.1603299299167702e-06, 1.311652454205614e-06, 5.441519078885904e-06, 5.614032033918193e-06, 0.00013162224786356091, 7.411200840579113e-06, 0.000822348112706095, 0.00010946377733489498, 0.9988458156585693, 4.974802868673578e-05, 1.698911546554882e-05, 3.8076589703450736e-07], [0.09078110009431839, 0.004797599744051695, 0.823334276676178, 0.00035671371733769774, 0.00015342659025918692, 0.03273462504148483, 0.0007502525695599616, 0.008871844969689846, 0.026876816526055336, 0.0038153002969920635, 0.0008078285027295351, 0.006350564770400524, 0.00036965846084058285, 1.3913560792389035e-07], [0.15473105013370514, 0.8438209295272827, 0.00028271129122003913, 3.655057753348956e-06, 3.2228654163191095e-06, 0.00020659371512010694, 2.6511515898164362e-05, 0.0005781562067568302, 0.00020910197054035962, 3.4318334655836225e-05, 3.321463373140432e-05, 4.870831980952062e-05, 2.0305074940552004e-05, 1.472792405365908e-06], [1.2343838451389644e-10, 6.598104285160389e-10, 9.270336234767917e-12, 7.485163600051692e-10, 5.4517017566979575e-09, 9.076759482917751e-11, 4.8768549198996425e-09, 7.768937160257394e-10, 9.536814393751314e-11, 1.3717442737259944e-09, 7.035541482736107e-09, 3.472595544451451e-10, 7.467613194478417e-09, 1.0]]
else:
## Load the model
import transformers
from functools import lru_cache
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
@st.cache_resource
def load_model_and_tokenizer():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained('Parallia/Fairly-Multilingual-ModernBERT-Embed-BE', trust_remote_code=True).to(device)
tokenizer = AutoTokenizer.from_pretrained('Parallia/Fairly-Multilingual-ModernBERT-Embed-BE', trust_remote_code=True)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
@st.cache_data
def encode_sentences(sent0, sent1):
sentences = [sent0, sent1]
tokens = []
embeddings = []
for sentence in sentences:
with torch.no_grad():
encoded_sentence = tokenizer(sentence, padding=False, truncation=True, return_tensors="pt").to(model.device)
embedded_sentence = model(**encoded_sentence).last_hidden_state[0].detach().cpu().clone()
tokens.append(tokenizer.tokenize(sentence))
embeddings.append(embedded_sentence)
return tokens, embeddings
# Encode the sentences
tokens, embeddings = encode_sentences(sent0, sent1)
# Calculate the cross-token similarity
token_similarities = F.normalize(embeddings[0], dim=1) @ F.normalize(embeddings[1], dim=1).T
# Calculate the overall sentence similarity
sentence_similarity = F.normalize(torch.mean(embeddings[0], dim=0), dim=-1) @ F.normalize(torch.mean(embeddings[1], dim=0), dim=-1)
# Map sentence1 to sentence2
token_probabilities_12 = F.softmax(20*token_similarities, dim=1)
# Map sentence2 to sentence1
token_probabilities_21 = F.softmax(20*token_similarities.T, dim=1)
# Convert to naive python objects
sentence_similarity = max(0, round(sentence_similarity.item(), 2))
token_probabilities_12 = token_probabilities_12.numpy().tolist()
token_probabilities_21 = token_probabilities_21.numpy().tolist()
# Simplify the tokens for display
tokens = [[token[3:].replace("\u2581", " ").replace("Ġ", " ") for token in sentence] for sentence in tokens]
html = ''
html += """
"""
html += f"""{("✅ Congrats!" if sentence_similarity >= 0.65 else "❌ Sorry!")} These sentences have {100*sentence_similarity}% similarity."""
html += """
Hover over a word from either sentence to see which other tokens are mapped to it.
"""
for token in tokens[0]:
html += f"""{' ' if token[0] == ' ' else ''}{token[1:] if token[0] == ' ' else token}"""
html += """
"""
for token in tokens[1]:
html += f"""{' ' if token[0] == ' ' else ''}{token[1:] if token[0] == ' ' else token}"""
html += """
"""
html += """
"""
st.html(html)