File size: 20,969 Bytes
3290a51 9d8e26e 3290a51 9a496bb 3290a51 5e9a806 9d8e26e 3290a51 9d8e26e 5932caf 9d8e26e 9a496bb 9d8e26e 5e9a806 9a496bb 3290a51 9a496bb 9d8e26e 3290a51 9d8e26e 3290a51 9d8e26e 3290a51 9d8e26e 3290a51 9d8e26e 0f3913b 9d8e26e 3290a51 9d8e26e f22721a 9d8e26e 3290a51 f22721a 9d8e26e 4bc0b85 f22721a 9d8e26e f22721a 9d8e26e f22721a 9d8e26e f22721a 9d8e26e 3290a51 9d8e26e 9a496bb 21ad8e5 9d8e26e 21ad8e5 9d8e26e 21ad8e5 9d8e26e 9a496bb 9d8e26e 3290a51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import streamlit as st, random
st.set_page_config(layout="wide")
# Give some context
st.html("""
<h1 style="text-align: center; margin: 0px; text-wrap: balance;">🔀 Word-level alignment between two sentences</h1>
<div style="text-align: center; color: gray; text-wrap: balance;">Supports English, French, Dutch, and German.</div>
<style>
.stButton { text-align: center; }
</style>
""")
# Create a layout with a columns on each side for padding
_, mid_st, _ = st.columns([1, 2, 1])
# Allow the user to reroll the example sentences
reroll_button = mid_st.button("Try a new example!", key="reroll")
if reroll_button:
example_sentences = [
# translations
("The book, which was on the table, is now missing.", "Het boek, dat op de tafel lag, is nu verdwenen."),
("If I had known, I would have acted differently.", "Si j'avais su, j'aurais agi différemment."),
("She can speak three languages fluently.", "Sie kann drei Sprachen fließend sprechen."),
("I wish I had more time to learn.", "Ich wünschte, ich hätte mehr Zeit zum Lernen."),
("The children were playing while their parents were talking.", "De kinderen speelden terwijl hun ouders aan het praten waren."),
("He would go to the gym every day if he had more energy.", "Il irait à la salle de sport tous les jours s'il avait plus d'énergie."),
("By the time I arrived, she had already left.", "Als ich ankam, was zij al vertrokken."),
("Despite the rain, they went for a walk.", "Malgré la pluie, ils sont allés se promener."),
("If I were you, I wouldn't do that.", "Als ik jou was, zou ik dat niet doen."),
("The movie, which I watched yesterday, was fantastic.", "Der Film, den ich gestern gesehen habe, war fantastisch."),
# paraphrases
("She has a remarkable ability to solve problems quickly.", "Her problem-solving skills are impressive and rapid."),
("Despite the fact that the project was delayed, they managed to finish it on time.", "Even though the project was delayed, they were able to complete it by the deadline."),
("The teacher asked the students to submit their assignments by Friday.", "The students were required to hand in their assignments no later than Friday."),
("I haven't seen him in years, and I wonder how he's doing.", "It's been years since I last saw him, and I'm curious about his well-being."),
("He was hesitant to take the offer because it seemed too good to be true.", "He doubted the offer because it appeared to be too perfect to be genuine."),
("She didn't have the necessary qualifications, but she still managed to get the job.", "Even though she lacked the required qualifications, she succeeded in securing the position."),
("John said that he would be going to the meeting later.", "According to John, he planned to attend the meeting later."),
("The weather was terrible, so we decided to cancel the outdoor event.", "Due to the poor weather, we chose to call off the outdoor event."),
("They have lived in this city for a long time, and they're very familiar with it.", "Having resided in this city for many years, they know it quite well."),
("The book was so captivating that I couldn't put it down until I finished it.", "I found the book so engrossing that I read it all the way through without stopping.")
]
random_sentences = random.choice(example_sentences)
sent0 = mid_st.text_input("Type your first sentence here", value=random_sentences[0], key="sent0")
sent1 = mid_st.text_input("Type your second sentence here", value=random_sentences[1], key="sent1")
else:
# Allow the user to input two sentences
sent0 = mid_st.text_input("Type your first sentence here", value="De fait, mon mari ne parlait jamais de ses affaires avec moi." if "sent0" not in st.session_state else st.session_state.sent0, key="sent0")
sent1 = mid_st.text_input("Type your second sentence here", value="M'n man had het met mij nooit over z'n zaken, inderdaad." if "sent1" not in st.session_state else st.session_state.sent1, key="sent1")
# Display the mapping between the two sentences
DEBUG = False
if DEBUG:
# Use some dummy data
tokens = [
["[0]\u2581De", "[0]\u2581fait", "[0],", "[0]\u2581mon", "[0]\u2581mari", "[0]\u2581ne", "[0]\u2581parlait", "[0]\u2581jamais", "[0]\u2581de", "[0]\u2581ses", "[0]\u2581affaires", "[0]\u2581avec", "[0]\u2581moi", "[0]."],
["[1]\u2581M", "[1]'", "[1]n", "[1]\u2581man", "[1]\u2581had", "[1]\u2581het", "[1]\u2581met", "[1]\u2581mij", "[1]\u2581nooit", "[1]\u2581over", "[1]\u2581z", "[1]'", "[1]n", "[1]\u2581zaken", "[1],", "[1]\u2581inderdaad", "[1]."],
]
sentence_similarity = 0.7
token_probabilities_12 = [[0.0001786323555279523, 0.029554476961493492, 0.0005334240850061178, 9.950459934771061e-05, 0.0069955079816281796, 0.002180722774937749, 0.0011584730818867683, 0.00024079506692942232, 0.004674288909882307, 0.0028222708497196436, 0.0010019985493272543, 0.04594381898641586, 0.005427392199635506, 0.0003972635604441166, 0.16267944872379303, 0.7361112833023071, 7.352152806561207e-07], [8.627762326796073e-06, 0.00021197207388468087, 4.6174409362720326e-05, 9.608593245502561e-06, 0.0007827585795894265, 0.00014010778977535665, 0.00010562615352682769, 2.1056943296571262e-05, 0.0001730725634843111, 0.00014431933232117444, 2.873636913136579e-05, 0.00034569500712677836, 0.00012183879152871668, 5.0419181206962094e-05, 0.0021323736291378736, 0.9956766963005066, 9.747334388521267e-07], [0.000363957486115396, 0.058086857199668884, 0.0010622636182233691, 0.0001401746121700853, 0.03782657906413078, 0.010060806758701801, 0.003844393650069833, 0.00017795785970520228, 0.006151353009045124, 0.004802143666893244, 0.00045746073010377586, 0.13868938386440277, 0.004628518130630255, 0.0001305590703850612, 0.7329095005989075, 0.0006681070663034916, 2.742818416834325e-08], [0.9904406070709229, 0.0029372733552008867, 0.002740614116191864, 0.00016679026884958148, 7.625997386639938e-05, 0.00013673932699020952, 3.424803799134679e-05, 0.0018435337115079165, 2.9864917451050133e-05, 3.524687417666428e-05, 0.0011246443027630448, 0.00029081859975121915, 0.00010983269021380693, 1.0424302672618069e-05, 2.242830305476673e-05, 6.100983682699734e-07, 1.5642478956578998e-07], [0.0014130950439721346, 0.0006491155363619328, 0.0039030632469803095, 0.9928011894226074, 3.293847112217918e-05, 0.00023004157992545515, 4.8409721784992144e-05, 1.1099789844593033e-05, 5.912710548727773e-06, 2.8217813451192342e-05, 0.0006040701409801841, 0.0001110830926336348, 5.0306276534684e-05, 8.831957529764622e-05, 1.970089942915365e-05, 1.0986423149006441e-06, 2.326722096768208e-06], [8.804490062175319e-05, 0.020074667409062386, 0.0004638279788196087, 6.011720688547939e-05, 0.3746979236602783, 0.018972501158714294, 0.0019666561856865883, 0.0003945657517760992, 0.5129392743110657, 0.010681135579943657, 0.0003376945969648659, 0.04529218748211861, 0.0050969235599040985, 0.00018651205755304545, 0.008603758178651333, 0.00014415399346034974, 7.929369161274735e-08], [0.00021422369172796607, 0.005102124996483326, 0.0003292255278211087, 0.0009069664520211518, 0.08789367973804474, 0.15337994694709778, 0.0511351116001606, 0.0014941217377781868, 0.003102638525888324, 0.6176480650901794, 0.0024865365121513605, 0.01969054341316223, 0.025343414396047592, 0.02977576106786728, 0.0013427335070446134, 0.00012596377928275615, 2.9010154321440496e-05], [1.408028879268386e-06, 0.00013552005111705512, 6.736800969520118e-06, 1.0955036486848257e-06, 0.0011538203107193112, 0.00019907338719349355, 3.0282362786238082e-05, 1.2565109500428662e-05, 0.9977654218673706, 0.00013098378258291632, 1.2177627468190622e-05, 0.0003589813131839037, 0.00010541738447500393, 7.141510650399141e-06, 6.763384590158239e-05, 1.1701029507094063e-05, 1.9685167274019477e-08], [5.773245902673807e-06, 0.0014488865854218602, 2.6845693355426192e-05, 1.0805002602864988e-05, 0.002086219610646367, 0.01130380667746067, 0.001883843680843711, 2.9443286621244624e-05, 0.00014186602493282408, 0.8935705423355103, 0.0006889772485010326, 0.016468364745378494, 0.0685788094997406, 0.0029715588316321373, 0.0007683428702875972, 1.5869531125645153e-05, 9.061647432417885e-09], [0.012304342351853848, 0.03311317041516304, 0.007781223859637976, 0.004408247768878937, 0.002244020812213421, 0.04515384882688522, 0.0010716691613197327, 0.0008402583771385252, 0.00038856116589158773, 0.003114827908575535, 0.7942622303962708, 0.026046255603432655, 0.0659388080239296, 0.002598112216219306, 0.0007164151757024229, 1.7107675375882536e-05, 8.561248137084476e-07], [1.8507088270780514e-06, 8.229719242081046e-05, 9.083240001928061e-06, 2.4296445189975202e-05, 6.340398158499738e-06, 0.0001343603798886761, 5.143981979927048e-06, 1.8609456446938566e-06, 1.2062999985573697e-06, 0.0006211695144884288, 0.0004705676983576268, 0.0002221767936134711, 0.008940674364566803, 0.9894717335700989, 6.331008989945985e-06, 6.910536853865779e-07, 1.8326457507100713e-07], [9.844085070653819e-06, 0.00017242366448044777, 1.5286375855794176e-05, 9.348634193884209e-06, 0.0001390141696901992, 0.0014548851177096367, 0.9944896697998047, 2.612253047118429e-05, 1.680908098933287e-05, 0.001790934125892818, 0.00016232591588050127, 0.0006340526742860675, 0.0008499748073518276, 0.0001129394950112328, 0.00011405935219954699, 2.3224697542900685e-06, 2.073006299951885e-08], [0.0001119379885494709, 4.149791129748337e-05, 4.742472356156213e-06, 2.8589572593773482e-06, 7.517372523579979e-06, 0.00013416734873317182, 2.4442895664833486e-05, 0.9989697933197021, 1.3636500625580084e-05, 2.8603359169210307e-05, 0.000470715545816347, 6.383401341736317e-05, 0.00010340050357626751, 1.8945609554066323e-05, 3.2612724680802785e-06, 4.755758311603131e-07, 2.1897558610817214e-07], [1.182226538887221e-09, 2.612263738654974e-10, 4.270485631785448e-10, 3.19147619443072e-09, 1.5840342926232154e-10, 1.1831424728825368e-09, 5.845964268225146e-10, 5.307323469594394e-09, 1.2535458226992091e-09, 7.667128598676243e-10, 6.178164646541973e-09, 8.621070524128527e-11, 6.898879245653688e-10, 1.4480447951825681e-08, 4.186111873805132e-11, 1.1763671148301569e-09, 1.0]]
token_probabilities_21 = [[6.291573299677111e-06, 1.2251668977114605e-06, 2.580553882580716e-05, 0.9942325949668884, 0.0006945801433175802, 2.1142666810192168e-05, 7.554778676421847e-06, 1.1657297363854013e-05, 1.2746155334752984e-05, 0.00413579260930419, 1.4904621821187902e-05, 3.4593394957482815e-05, 0.0008008066797628999, 2.480075806943205e-07], [0.03415770083665848, 0.0009877384873107076, 0.13514696061611176, 0.09675426036119461, 0.01046981941908598, 0.15818676352500916, 0.005904339253902435, 0.036817632615566254, 0.10496868193149567, 0.3652307093143463, 0.02174874022603035, 0.01988295465707779, 0.009741867892444134, 1.7982374629355036e-06], [0.0024134248960763216, 0.0008422881946898997, 0.009675124660134315, 0.35340237617492676, 0.24644418060779572, 0.014307855628430843, 0.001491453149355948, 0.007164766546338797, 0.007613700814545155, 0.33597758412361145, 0.009396923705935478, 0.006900560110807419, 0.004358294885605574, 1.150809748651227e-05], [7.152517355279997e-06, 2.7846685952681582e-06, 2.028374728979543e-05, 0.0003417014959268272, 0.995932400226593, 2.9462620659614913e-05, 6.527722143800929e-05, 1.8510456357034855e-05, 4.868561154580675e-05, 0.003024008125066757, 0.0003993394784629345, 6.704749830532819e-05, 4.174213609076105e-05, 1.3663803883900982e-06], [0.0022054731380194426, 0.0009949662489816546, 0.024007266387343407, 0.0006852351943962276, 0.0001449233095627278, 0.8054160475730896, 0.02774566411972046, 0.08550822734832764, 0.04122897982597351, 0.006751661188900471, 0.00045707033132202923, 0.004372807219624519, 0.00048139269347302616, 2.974482242734666e-07], [0.0012809201143682003, 0.0003318041271995753, 0.011896451003849506, 0.00228915479965508, 0.0018857300747185946, 0.07598057389259338, 0.09020800143480301, 0.027486657723784447, 0.416204035282135, 0.25311478972435, 0.0180458165705204, 0.08526463806629181, 0.01600734144449234, 4.139269094594056e-06], [1.1649769476207439e-05, 4.282531790522626e-06, 7.782561442581937e-05, 9.815836165216751e-06, 6.7938508436782286e-06, 0.00013483928341884166, 0.0005148796481080353, 7.158283551689237e-05, 0.0011875078780576587, 0.00010284745076205581, 1.1828081369458232e-05, 0.9978162050247192, 4.9926951760426164e-05, 3.501482837009462e-08], [1.1862812243634835e-06, 4.1824714003269037e-07, 1.7648992525209906e-06, 0.0002588517963886261, 7.631428502463677e-07, 1.3253040378913283e-05, 7.370218554569874e-06, 1.4551008462149184e-05, 9.092555956158321e-06, 3.950515019823797e-05, 2.0963175302313175e-06, 1.2840252566093113e-05, 0.9996381998062134, 1.557328772605615e-07], [1.9633629563031718e-05, 2.9309690035006497e-06, 5.201384556130506e-05, 3.575253003873513e-06, 3.46595953715223e-07, 0.01468951441347599, 1.304881698160898e-05, 0.9851461052894592, 3.735282734851353e-05, 1.557562245579902e-05, 1.1585748325160239e-06, 7.0444598350150045e-06, 1.1634260772552807e-05, 3.1361004459995456e-08], [4.94217783852946e-05, 1.0189252861891873e-05, 0.0001692848454695195, 1.7591419236850925e-05, 6.895964816067135e-06, 0.0012752452166751027, 0.010829685255885124, 0.0005391684826463461, 0.9808638095855713, 0.0005205412162467837, 0.0024872170761227608, 0.0031290908809751272, 0.00010173906048294157, 7.99681956209497e-08], [0.00012695191253442317, 1.4679194464406464e-05, 0.00011667808576021343, 0.0040611401200294495, 0.0010680991690605879, 0.0002917109231930226, 0.0003154439036734402, 0.0003626788384281099, 0.005471888929605484, 0.9603675603866577, 0.013632584363222122, 0.0020520074758678675, 0.012113837525248528, 4.662265382648911e-06], [0.02129807136952877, 0.0006461066077463329, 0.12942540645599365, 0.003842339850962162, 0.000718642957508564, 0.14315049350261688, 0.009139570407569408, 0.03911759331822395, 0.47854599356651306, 0.11522843688726425, 0.023550251498818398, 0.029326293617486954, 0.006010579876601696, 2.3803401916211442e-07], [0.0007556688506156206, 6.83949256199412e-05, 0.0012973148841410875, 0.0004358450823929161, 9.774936916073784e-05, 0.0048384349793195724, 0.00353313609957695, 0.00345016410574317, 0.5985360145568848, 0.08761582523584366, 0.2846389412879944, 0.011807692237198353, 0.0029242453165352345, 5.721153115700872e-07], [1.753839683260594e-06, 8.974393495009281e-07, 1.1603299299167702e-06, 1.311652454205614e-06, 5.441519078885904e-06, 5.614032033918193e-06, 0.00013162224786356091, 7.411200840579113e-06, 0.000822348112706095, 0.00010946377733489498, 0.9988458156585693, 4.974802868673578e-05, 1.698911546554882e-05, 3.8076589703450736e-07], [0.09078110009431839, 0.004797599744051695, 0.823334276676178, 0.00035671371733769774, 0.00015342659025918692, 0.03273462504148483, 0.0007502525695599616, 0.008871844969689846, 0.026876816526055336, 0.0038153002969920635, 0.0008078285027295351, 0.006350564770400524, 0.00036965846084058285, 1.3913560792389035e-07], [0.15473105013370514, 0.8438209295272827, 0.00028271129122003913, 3.655057753348956e-06, 3.2228654163191095e-06, 0.00020659371512010694, 2.6511515898164362e-05, 0.0005781562067568302, 0.00020910197054035962, 3.4318334655836225e-05, 3.321463373140432e-05, 4.870831980952062e-05, 2.0305074940552004e-05, 1.472792405365908e-06], [1.2343838451389644e-10, 6.598104285160389e-10, 9.270336234767917e-12, 7.485163600051692e-10, 5.4517017566979575e-09, 9.076759482917751e-11, 4.8768549198996425e-09, 7.768937160257394e-10, 9.536814393751314e-11, 1.3717442737259944e-09, 7.035541482736107e-09, 3.472595544451451e-10, 7.467613194478417e-09, 1.0]]
else:
## Load the model
import transformers
from functools import lru_cache
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
@st.cache_resource
def load_model_and_tokenizer():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained('Parallia/Fairly-Multilingual-ModernBERT-Embed-BE', trust_remote_code=True).to(device)
tokenizer = AutoTokenizer.from_pretrained('Parallia/Fairly-Multilingual-ModernBERT-Embed-BE', trust_remote_code=True)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
@st.cache_data
def encode_sentences(sent0, sent1):
sentences = [sent0, sent1]
tokens = []
embeddings = []
for sentence in sentences:
with torch.no_grad():
encoded_sentence = tokenizer(sentence, padding=False, truncation=True, return_tensors="pt").to(model.device)
embedded_sentence = model(**encoded_sentence).last_hidden_state[0].detach().cpu().clone()
tokens.append(tokenizer.tokenize(sentence))
embeddings.append(embedded_sentence)
return tokens, embeddings
# Encode the sentences
tokens, embeddings = encode_sentences(sent0, sent1)
# Calculate the cross-token similarity
token_similarities = F.normalize(embeddings[0], dim=1) @ F.normalize(embeddings[1], dim=1).T
# Calculate the overall sentence similarity
sentence_similarity = F.normalize(torch.mean(embeddings[0], dim=0), dim=-1) @ F.normalize(torch.mean(embeddings[1], dim=0), dim=-1)
# Map sentence1 to sentence2
token_probabilities_12 = F.softmax(20*token_similarities, dim=1)
# Map sentence2 to sentence1
token_probabilities_21 = F.softmax(20*token_similarities.T, dim=1)
# Convert to naive python objects
sentence_similarity = max(0, round(sentence_similarity.item(), 2))
token_probabilities_12 = token_probabilities_12.numpy().tolist()
token_probabilities_21 = token_probabilities_21.numpy().tolist()
# Simplify the tokens for display
tokens = [[token[3:].replace("\u2581", " ").replace("Ġ", " ") for token in sentence] for sentence in tokens]
html = ''
html += """
<article>
<div style="color: gray">"""
html += f"""{("✅ Congrats!" if sentence_similarity >= 0.65 else "❌ Sorry!")} These sentences have {100*sentence_similarity}% similarity."""
html += """
</div>
<div><small style="color: silver; text-wrap: balance;">Hover over a word from either sentence to see which other tokens are mapped to it.</small></div>
<p id="sent0">"""
for token in tokens[0]:
html += f"""{' ' if token[0] == ' ' else ''}<span>{token[1:] if token[0] == ' ' else token}</span>"""
html += """</div>
</p>
<p id="sent1">"""
for token in tokens[1]:
html += f"""{' ' if token[0] == ' ' else ''}<span>{token[1:] if token[0] == ' ' else token}</span>"""
html += """</div>
</p>
</article>"""
html += """
<style>
article {
font-family: sans-serif;
text-align: center;
margin-top: 2em;
}
p {
margin: 0.5em;
font-size: 2em;
text-wrap: balance;
}
span {
animation-name: rotate_bg;
animation-duration: 15s;
animation-timing-function: steps(14, start);
animation-iteration-count: infinite;
text-decoration: underline;
text-decoration-thickness: 0.3em;
text-decoration-skip: none;
text-decoration-skip-ink: none;
color: rgba(0, 0, 0, calc((50% + 50% * var(--p))));
text-decoration-color: hsla(161, 100%, 43%, var(--p));
background-color: hsla(161, 100%, 43%, calc(var(--p) * 0.2));
--p: var(--p0); """
for i in range(len(tokens[0])):
html += f"""--p{i}: 0; """
for j in range(len(tokens[1])):
html += f"""--ip{j}: 0; """
html += """
}
"""
for i in range(len(tokens[0])):
html += f"""
#sent0 span:nth-child({i+1}) {{ --p{i}: 1; }}"""
for j in range(len(tokens[1])):
html += f"""
#sent1 span:nth-child({j+1}) {{ --ip{j}: 1; }}"""
for i in range(len(tokens[0])):
html += f"""
#sent0 span:nth-child({i+1}) {{"""
for j in range(len(tokens[1])):
if token_probabilities_21[j][i] < 0.01: continue
html += f"""--ip{j}: {round(token_probabilities_21[j][i],2)}; """
html += """}"""
for j in range(len(tokens[1])):
html += f"""
#sent1 span:nth-child({j+1}) {{"""
for i in range(len(tokens[0])):
if token_probabilities_12[i][j] < 0.01: continue
html += f"""--p{i}: {round(token_probabilities_12[i][j],2)}; """
html += """}"""
html += """
body:has(#sent0:hover,#sent1:hover) span { --p: 0 !important; animation-play-state: paused; }"""
for i in range(len(tokens[0])):
html += f"""
body:has(#sent0 span:nth-child({i+1}):hover) span {{ --p: var(--p{i}) !important; }}"""
for j in range(len(tokens[1])):
html += f"""
body:has(#sent1 span:nth-child({j+1}):hover) span {{ --p: var(--ip{j}) !important; }}"""
html += """
@keyframes rotate_bg {"""
for i in range(len(tokens[0])):
html += f"""
{100*i/len(tokens[0])}% {{ --p: var(--p{i}); }}"""
html += """
}
</style>
"""
st.html(html)
|