Upload DatasetTransformer
Browse files- README.md +369 -369
- config.json +7 -1
- misc.py +514 -0
- model.py +607 -0
- model.safetensors +2 -2
README.md
CHANGED
@@ -4,12 +4,14 @@ tags:
|
|
4 |
model-index:
|
5 |
- name: cde-small-v1
|
6 |
results:
|
7 |
-
-
|
8 |
-
|
|
|
9 |
name: MTEB AmazonCounterfactualClassification (en)
|
10 |
-
revision: e8379541af4e31359cca9fbcf4b00f2671dba205
|
11 |
-
split: test
|
12 |
type: mteb/amazon_counterfactual
|
|
|
|
|
|
|
13 |
metrics:
|
14 |
- type: accuracy
|
15 |
value: 87.01492537313433
|
@@ -23,14 +25,14 @@ model-index:
|
|
23 |
value: 87.74802754480477
|
24 |
- type: main_score
|
25 |
value: 87.01492537313433
|
26 |
-
|
27 |
type: Classification
|
28 |
-
|
29 |
-
config: default
|
30 |
name: MTEB AmazonPolarityClassification (default)
|
31 |
-
revision: e2d317d38cd51312af73b3d32a06d1a08b442046
|
32 |
-
split: test
|
33 |
type: mteb/amazon_polarity
|
|
|
|
|
|
|
34 |
metrics:
|
35 |
- type: accuracy
|
36 |
value: 94.652275
|
@@ -44,14 +46,14 @@ model-index:
|
|
44 |
value: 94.64655930708355
|
45 |
- type: main_score
|
46 |
value: 94.652275
|
47 |
-
|
48 |
type: Classification
|
49 |
-
|
50 |
-
config: en
|
51 |
name: MTEB AmazonReviewsClassification (en)
|
52 |
-
revision: 1399c76144fd37290681b995c656ef9b2e06e26d
|
53 |
-
split: test
|
54 |
type: mteb/amazon_reviews_multi
|
|
|
|
|
|
|
55 |
metrics:
|
56 |
- type: accuracy
|
57 |
value: 55.75599999999999
|
@@ -61,14 +63,14 @@ model-index:
|
|
61 |
value: 55.07058630829347
|
62 |
- type: main_score
|
63 |
value: 55.75599999999999
|
64 |
-
|
65 |
-
type:
|
66 |
-
|
67 |
-
config: default
|
68 |
name: MTEB ArguAna (default)
|
69 |
-
revision: c22ab2a51041ffd869aaddef7af8d8215647e41a
|
70 |
-
split: test
|
71 |
type: mteb/arguana
|
|
|
|
|
|
|
72 |
metrics:
|
73 |
- type: main_score
|
74 |
value: 69.959
|
@@ -352,14 +354,14 @@ model-index:
|
|
352 |
value: 74.182
|
353 |
- type: recall_at_5
|
354 |
value: 84.495
|
355 |
-
|
356 |
-
type:
|
357 |
-
|
358 |
-
config: default
|
359 |
name: MTEB ArxivClusteringP2P (default)
|
360 |
-
revision: a122ad7f3f0291bf49cc6f4d32aa80929df69d5d
|
361 |
-
split: test
|
362 |
type: mteb/arxiv-clustering-p2p
|
|
|
|
|
|
|
363 |
metrics:
|
364 |
- type: main_score
|
365 |
value: 48.54672141116669
|
@@ -367,14 +369,14 @@ model-index:
|
|
367 |
value: 48.54672141116669
|
368 |
- type: v_measure_std
|
369 |
value: 14.037498386768362
|
370 |
-
|
371 |
type: Clustering
|
372 |
-
|
373 |
-
config: default
|
374 |
name: MTEB ArxivClusteringS2S (default)
|
375 |
-
revision: f910caf1a6075f7329cdf8c1a6135696f37dbd53
|
376 |
-
split: test
|
377 |
type: mteb/arxiv-clustering-s2s
|
|
|
|
|
|
|
378 |
metrics:
|
379 |
- type: main_score
|
380 |
value: 40.5914039166466
|
@@ -382,14 +384,14 @@ model-index:
|
|
382 |
value: 40.5914039166466
|
383 |
- type: v_measure_std
|
384 |
value: 14.385069818910331
|
385 |
-
|
386 |
-
type:
|
387 |
-
|
388 |
-
config: default
|
389 |
name: MTEB AskUbuntuDupQuestions (default)
|
390 |
-
revision: 2000358ca161889fa9c082cb41daa8dcfb161a54
|
391 |
-
split: test
|
392 |
type: mteb/askubuntudupquestions-reranking
|
|
|
|
|
|
|
393 |
metrics:
|
394 |
- type: main_score
|
395 |
value: 61.13621260261507
|
@@ -409,14 +411,14 @@ model-index:
|
|
409 |
value: 31.484257486448364
|
410 |
- type: nAUC_mrr_std
|
411 |
value: 21.252659250011632
|
412 |
-
|
413 |
-
type:
|
414 |
-
|
415 |
-
config: default
|
416 |
name: MTEB BIOSSES (default)
|
417 |
-
revision: d3fb88f8f02e40887cd149695127462bbcf29b4a
|
418 |
-
split: test
|
419 |
type: mteb/biosses-sts
|
|
|
|
|
|
|
420 |
metrics:
|
421 |
- type: cosine_pearson
|
422 |
value: 89.07028016646942
|
@@ -436,14 +438,14 @@ model-index:
|
|
436 |
value: 89.07028016646942
|
437 |
- type: spearman
|
438 |
value: 86.69595132967805
|
439 |
-
|
440 |
-
type:
|
441 |
-
|
442 |
-
config: default
|
443 |
name: MTEB Banking77Classification (default)
|
444 |
-
revision: 0fd18e25b25c072e09e0d92ab615fda904d66300
|
445 |
-
split: test
|
446 |
type: mteb/banking77
|
|
|
|
|
|
|
447 |
metrics:
|
448 |
- type: accuracy
|
449 |
value: 88.6038961038961
|
@@ -453,14 +455,14 @@ model-index:
|
|
453 |
value: 88.56824205739822
|
454 |
- type: main_score
|
455 |
value: 88.6038961038961
|
456 |
-
|
457 |
-
type:
|
458 |
-
|
459 |
-
config: default
|
460 |
name: MTEB BiorxivClusteringP2P (default)
|
461 |
-
revision: 65b79d1d13f80053f67aca9498d9402c2d9f1f40
|
462 |
-
split: test
|
463 |
type: mteb/biorxiv-clustering-p2p
|
|
|
|
|
|
|
464 |
metrics:
|
465 |
- type: main_score
|
466 |
value: 44.77800814327256
|
@@ -468,14 +470,14 @@ model-index:
|
|
468 |
value: 44.77800814327256
|
469 |
- type: v_measure_std
|
470 |
value: 0.6462535527471919
|
471 |
-
|
472 |
type: Clustering
|
473 |
-
|
474 |
-
config: default
|
475 |
name: MTEB BiorxivClusteringS2S (default)
|
476 |
-
revision: 258694dd0231531bc1fd9de6ceb52a0853c6d908
|
477 |
-
split: test
|
478 |
type: mteb/biorxiv-clustering-s2s
|
|
|
|
|
|
|
479 |
metrics:
|
480 |
- type: main_score
|
481 |
value: 38.16110272459102
|
@@ -483,14 +485,14 @@ model-index:
|
|
483 |
value: 38.16110272459102
|
484 |
- type: v_measure_std
|
485 |
value: 0.7456916212435019
|
486 |
-
|
487 |
-
type:
|
488 |
-
|
489 |
-
config: default
|
490 |
name: MTEB CQADupstackAndroidRetrieval (default)
|
491 |
-
revision: f46a197baaae43b4f621051089b82a364682dfeb
|
492 |
-
split: test
|
493 |
type: mteb/cqadupstack-android
|
|
|
|
|
|
|
494 |
metrics:
|
495 |
- type: main_score
|
496 |
value: 49.376
|
@@ -774,14 +776,14 @@ model-index:
|
|
774 |
value: 47.591
|
775 |
- type: recall_at_5
|
776 |
value: 54.245
|
777 |
-
|
778 |
type: Retrieval
|
779 |
-
|
780 |
-
config: default
|
781 |
name: MTEB CQADupstackEnglishRetrieval (default)
|
782 |
-
revision: ad9991cb51e31e31e430383c75ffb2885547b5f0
|
783 |
-
split: test
|
784 |
type: mteb/cqadupstack-english
|
|
|
|
|
|
|
785 |
metrics:
|
786 |
- type: main_score
|
787 |
value: 44.727
|
@@ -1065,14 +1067,14 @@ model-index:
|
|
1065 |
value: 42.085
|
1066 |
- type: recall_at_5
|
1067 |
value: 47.5
|
1068 |
-
|
1069 |
type: Retrieval
|
1070 |
-
|
1071 |
-
config: default
|
1072 |
name: MTEB CQADupstackGamingRetrieval (default)
|
1073 |
-
revision: 4885aa143210c98657558c04aaf3dc47cfb54340
|
1074 |
-
split: test
|
1075 |
type: mteb/cqadupstack-gaming
|
|
|
|
|
|
|
1076 |
metrics:
|
1077 |
- type: main_score
|
1078 |
value: 59.001999999999995
|
@@ -1356,14 +1358,14 @@ model-index:
|
|
1356 |
value: 57.916000000000004
|
1357 |
- type: recall_at_5
|
1358 |
value: 65.44
|
1359 |
-
|
1360 |
type: Retrieval
|
1361 |
-
|
1362 |
-
config: default
|
1363 |
name: MTEB CQADupstackGisRetrieval (default)
|
1364 |
-
revision: 5003b3064772da1887988e05400cf3806fe491f2
|
1365 |
-
split: test
|
1366 |
type: mteb/cqadupstack-gis
|
|
|
|
|
|
|
1367 |
metrics:
|
1368 |
- type: main_score
|
1369 |
value: 37.501
|
@@ -1647,14 +1649,14 @@ model-index:
|
|
1647 |
value: 37.218
|
1648 |
- type: recall_at_5
|
1649 |
value: 42.559000000000005
|
1650 |
-
|
1651 |
type: Retrieval
|
1652 |
-
|
1653 |
-
config: default
|
1654 |
name: MTEB CQADupstackMathematicaRetrieval (default)
|
1655 |
-
revision: 90fceea13679c63fe563ded68f3b6f06e50061de
|
1656 |
-
split: test
|
1657 |
type: mteb/cqadupstack-mathematica
|
|
|
|
|
|
|
1658 |
metrics:
|
1659 |
- type: main_score
|
1660 |
value: 27.653
|
@@ -1938,14 +1940,14 @@ model-index:
|
|
1938 |
value: 25.469
|
1939 |
- type: recall_at_5
|
1940 |
value: 31.316
|
1941 |
-
|
1942 |
type: Retrieval
|
1943 |
-
|
1944 |
-
config: default
|
1945 |
name: MTEB CQADupstackPhysicsRetrieval (default)
|
1946 |
-
revision: 79531abbd1fb92d06c6d6315a0cbbbf5bb247ea4
|
1947 |
-
split: test
|
1948 |
type: mteb/cqadupstack-physics
|
|
|
|
|
|
|
1949 |
metrics:
|
1950 |
- type: main_score
|
1951 |
value: 45.314
|
@@ -2229,14 +2231,14 @@ model-index:
|
|
2229 |
value: 43.679
|
2230 |
- type: recall_at_5
|
2231 |
value: 49.735
|
2232 |
-
|
2233 |
type: Retrieval
|
2234 |
-
|
2235 |
-
config: default
|
2236 |
name: MTEB CQADupstackProgrammersRetrieval (default)
|
2237 |
-
revision: 6184bc1440d2dbc7612be22b50686b8826d22b32
|
2238 |
-
split: test
|
2239 |
type: mteb/cqadupstack-programmers
|
|
|
|
|
|
|
2240 |
metrics:
|
2241 |
- type: main_score
|
2242 |
value: 41.972
|
@@ -2520,27 +2522,27 @@ model-index:
|
|
2520 |
value: 39.363
|
2521 |
- type: recall_at_5
|
2522 |
value: 44.665
|
2523 |
-
|
2524 |
type: Retrieval
|
2525 |
-
|
2526 |
-
config: default
|
2527 |
name: MTEB CQADupstackRetrieval (default)
|
2528 |
-
revision: CQADupstackRetrieval_is_a_combined_dataset
|
2529 |
-
split: test
|
2530 |
type: CQADupstackRetrieval_is_a_combined_dataset
|
|
|
|
|
|
|
2531 |
metrics:
|
2532 |
- type: main_score
|
2533 |
value: 39.823499999999996
|
2534 |
- type: ndcg_at_10
|
2535 |
value: 39.823499999999996
|
2536 |
-
|
2537 |
type: Retrieval
|
2538 |
-
|
2539 |
-
config: default
|
2540 |
name: MTEB CQADupstackStatsRetrieval (default)
|
2541 |
-
revision: 65ac3a16b8e91f9cee4c9828cc7c335575432a2a
|
2542 |
-
split: test
|
2543 |
type: mteb/cqadupstack-stats
|
|
|
|
|
|
|
2544 |
metrics:
|
2545 |
- type: main_score
|
2546 |
value: 34.943000000000005
|
@@ -2824,14 +2826,14 @@ model-index:
|
|
2824 |
value: 33.427
|
2825 |
- type: recall_at_5
|
2826 |
value: 37.643
|
2827 |
-
|
2828 |
type: Retrieval
|
2829 |
-
|
2830 |
-
config: default
|
2831 |
name: MTEB CQADupstackTexRetrieval (default)
|
2832 |
-
revision: 46989137a86843e03a6195de44b09deda022eec7
|
2833 |
-
split: test
|
2834 |
type: mteb/cqadupstack-tex
|
|
|
|
|
|
|
2835 |
metrics:
|
2836 |
- type: main_score
|
2837 |
value: 27.271
|
@@ -3115,14 +3117,14 @@ model-index:
|
|
3115 |
value: 25.592
|
3116 |
- type: recall_at_5
|
3117 |
value: 30.279
|
3118 |
-
|
3119 |
type: Retrieval
|
3120 |
-
|
3121 |
-
config: default
|
3122 |
name: MTEB CQADupstackUnixRetrieval (default)
|
3123 |
-
revision: 6c6430d3a6d36f8d2a829195bc5dc94d7e063e53
|
3124 |
-
split: test
|
3125 |
type: mteb/cqadupstack-unix
|
|
|
|
|
|
|
3126 |
metrics:
|
3127 |
- type: main_score
|
3128 |
value: 38.237
|
@@ -3406,14 +3408,14 @@ model-index:
|
|
3406 |
value: 36.275
|
3407 |
- type: recall_at_5
|
3408 |
value: 42.199
|
3409 |
-
|
3410 |
type: Retrieval
|
3411 |
-
|
3412 |
-
config: default
|
3413 |
name: MTEB CQADupstackWebmastersRetrieval (default)
|
3414 |
-
revision: 160c094312a0e1facb97e55eeddb698c0abe3571
|
3415 |
-
split: test
|
3416 |
type: mteb/cqadupstack-webmasters
|
|
|
|
|
|
|
3417 |
metrics:
|
3418 |
- type: main_score
|
3419 |
value: 38.702
|
@@ -3697,14 +3699,14 @@ model-index:
|
|
3697 |
value: 37.634
|
3698 |
- type: recall_at_5
|
3699 |
value: 42.021
|
3700 |
-
|
3701 |
type: Retrieval
|
3702 |
-
|
3703 |
-
config: default
|
3704 |
name: MTEB CQADupstackWordpressRetrieval (default)
|
3705 |
-
revision: 4ffe81d471b1924886b33c7567bfb200e9eec5c4
|
3706 |
-
split: test
|
3707 |
type: mteb/cqadupstack-wordpress
|
|
|
|
|
|
|
3708 |
metrics:
|
3709 |
- type: main_score
|
3710 |
value: 33.184000000000005
|
@@ -3988,14 +3990,14 @@ model-index:
|
|
3988 |
value: 32.683
|
3989 |
- type: recall_at_5
|
3990 |
value: 36.756
|
3991 |
-
|
3992 |
type: Retrieval
|
3993 |
-
|
3994 |
-
config: default
|
3995 |
name: MTEB ClimateFEVER (default)
|
3996 |
-
revision: 47f2ac6acb640fc46020b02a5b59fdda04d39380
|
3997 |
-
split: test
|
3998 |
type: mteb/climate-fever
|
|
|
|
|
|
|
3999 |
metrics:
|
4000 |
- type: main_score
|
4001 |
value: 25.068
|
@@ -4279,14 +4281,14 @@ model-index:
|
|
4279 |
value: 18.312
|
4280 |
- type: recall_at_5
|
4281 |
value: 22.776
|
4282 |
-
|
4283 |
type: Retrieval
|
4284 |
-
|
4285 |
-
config: default
|
4286 |
name: MTEB DBPedia (default)
|
4287 |
-
revision: c0f706b76e590d620bd6618b3ca8efdd34e2d659
|
4288 |
-
split: test
|
4289 |
type: mteb/dbpedia
|
|
|
|
|
|
|
4290 |
metrics:
|
4291 |
- type: main_score
|
4292 |
value: 40.128
|
@@ -4570,14 +4572,14 @@ model-index:
|
|
4570 |
value: 14.562
|
4571 |
- type: recall_at_5
|
4572 |
value: 18.779
|
4573 |
-
|
4574 |
-
type:
|
4575 |
-
|
4576 |
-
config: default
|
4577 |
name: MTEB EmotionClassification (default)
|
4578 |
-
revision: 4f58c6b202a23cf9a4da393831edf4f9183cad37
|
4579 |
-
split: test
|
4580 |
type: mteb/emotion
|
|
|
|
|
|
|
4581 |
metrics:
|
4582 |
- type: accuracy
|
4583 |
value: 74.86
|
@@ -4587,14 +4589,14 @@ model-index:
|
|
4587 |
value: 75.96499621761998
|
4588 |
- type: main_score
|
4589 |
value: 74.86
|
4590 |
-
|
4591 |
-
type:
|
4592 |
-
|
4593 |
-
config: default
|
4594 |
name: MTEB FEVER (default)
|
4595 |
-
revision: bea83ef9e8fb933d90a2f1d5515737465d613e12
|
4596 |
-
split: test
|
4597 |
type: mteb/fever
|
|
|
|
|
|
|
4598 |
metrics:
|
4599 |
- type: main_score
|
4600 |
value: 86.029
|
@@ -4878,14 +4880,14 @@ model-index:
|
|
4878 |
value: 88.382
|
4879 |
- type: recall_at_5
|
4880 |
value: 90.908
|
4881 |
-
|
4882 |
type: Retrieval
|
4883 |
-
|
4884 |
-
config: default
|
4885 |
name: MTEB FiQA2018 (default)
|
4886 |
-
revision: 27a168819829fe9bcd655c2df245fb19452e8e06
|
4887 |
-
split: test
|
4888 |
type: mteb/fiqa
|
|
|
|
|
|
|
4889 |
metrics:
|
4890 |
- type: main_score
|
4891 |
value: 45.238
|
@@ -5169,14 +5171,14 @@ model-index:
|
|
5169 |
value: 37.656
|
5170 |
- type: recall_at_5
|
5171 |
value: 44.766
|
5172 |
-
|
5173 |
type: Retrieval
|
5174 |
-
|
5175 |
-
config: default
|
5176 |
name: MTEB HotpotQA (default)
|
5177 |
-
revision: ab518f4d6fcca38d87c25209f94beba119d02014
|
5178 |
-
split: test
|
5179 |
type: mteb/hotpotqa
|
|
|
|
|
|
|
5180 |
metrics:
|
5181 |
- type: main_score
|
5182 |
value: 66.672
|
@@ -5460,14 +5462,14 @@ model-index:
|
|
5460 |
value: 57.522
|
5461 |
- type: recall_at_5
|
5462 |
value: 62.134
|
5463 |
-
|
5464 |
-
type:
|
5465 |
-
|
5466 |
-
config: default
|
5467 |
name: MTEB ImdbClassification (default)
|
5468 |
-
revision: 3d86128a09e091d6018b6d26cad27f2739fc2db7
|
5469 |
-
split: test
|
5470 |
type: mteb/imdb
|
|
|
|
|
|
|
5471 |
metrics:
|
5472 |
- type: accuracy
|
5473 |
value: 93.5944
|
@@ -5481,14 +5483,14 @@ model-index:
|
|
5481 |
value: 93.58945949328377
|
5482 |
- type: main_score
|
5483 |
value: 93.5944
|
5484 |
-
|
5485 |
-
type:
|
5486 |
-
|
5487 |
-
config: default
|
5488 |
name: MTEB MSMARCO (default)
|
5489 |
-
revision: c5a29a104738b98a9e76336939199e264163d4a0
|
5490 |
-
split: dev
|
5491 |
type: mteb/msmarco
|
|
|
|
|
|
|
5492 |
metrics:
|
5493 |
- type: main_score
|
5494 |
value: 41.448
|
@@ -5772,14 +5774,14 @@ model-index:
|
|
5772 |
value: 41.304
|
5773 |
- type: recall_at_5
|
5774 |
value: 51.076
|
5775 |
-
|
5776 |
-
type:
|
5777 |
-
|
5778 |
-
config: en
|
5779 |
name: MTEB MTOPDomainClassification (en)
|
5780 |
-
revision: d80d48c1eb48d3562165c59d59d0034df9fff0bf
|
5781 |
-
split: test
|
5782 |
type: mteb/mtop_domain
|
|
|
|
|
|
|
5783 |
metrics:
|
5784 |
- type: accuracy
|
5785 |
value: 96.03967168262655
|
@@ -5789,14 +5791,14 @@ model-index:
|
|
5789 |
value: 96.06623245823347
|
5790 |
- type: main_score
|
5791 |
value: 96.03967168262655
|
5792 |
-
|
5793 |
type: Classification
|
5794 |
-
|
5795 |
-
config: en
|
5796 |
name: MTEB MTOPIntentClassification (en)
|
5797 |
-
revision: ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba
|
5798 |
-
split: test
|
5799 |
type: mteb/mtop_intent
|
|
|
|
|
|
|
5800 |
metrics:
|
5801 |
- type: accuracy
|
5802 |
value: 89.12904696762428
|
@@ -5806,14 +5808,14 @@ model-index:
|
|
5806 |
value: 90.41290566743324
|
5807 |
- type: main_score
|
5808 |
value: 89.12904696762428
|
5809 |
-
|
5810 |
type: Classification
|
5811 |
-
|
5812 |
-
config: en
|
5813 |
name: MTEB MassiveIntentClassification (en)
|
5814 |
-
revision: 4672e20407010da34463acc759c162ca9734bca6
|
5815 |
-
split: test
|
5816 |
type: mteb/amazon_massive_intent
|
|
|
|
|
|
|
5817 |
metrics:
|
5818 |
- type: accuracy
|
5819 |
value: 76.49630127774041
|
@@ -5823,14 +5825,14 @@ model-index:
|
|
5823 |
value: 76.42436195016484
|
5824 |
- type: main_score
|
5825 |
value: 76.49630127774041
|
5826 |
-
|
5827 |
type: Classification
|
5828 |
-
|
5829 |
-
config: en
|
5830 |
name: MTEB MassiveScenarioClassification (en)
|
5831 |
-
revision: fad2c6e8459f9e1c45d9315f4953d921437d70f8
|
5832 |
-
split: test
|
5833 |
type: mteb/amazon_massive_scenario
|
|
|
|
|
|
|
5834 |
metrics:
|
5835 |
- type: accuracy
|
5836 |
value: 78.9340954942838
|
@@ -5840,14 +5842,14 @@ model-index:
|
|
5840 |
value: 78.87787647838971
|
5841 |
- type: main_score
|
5842 |
value: 78.9340954942838
|
5843 |
-
|
5844 |
-
type:
|
5845 |
-
|
5846 |
-
config: default
|
5847 |
name: MTEB MedrxivClusteringP2P (default)
|
5848 |
-
revision: e7a26af6f3ae46b30dde8737f02c07b1505bcc73
|
5849 |
-
split: test
|
5850 |
type: mteb/medrxiv-clustering-p2p
|
|
|
|
|
|
|
5851 |
metrics:
|
5852 |
- type: main_score
|
5853 |
value: 37.50182848656019
|
@@ -5855,14 +5857,14 @@ model-index:
|
|
5855 |
value: 37.50182848656019
|
5856 |
- type: v_measure_std
|
5857 |
value: 1.1708518023877268
|
5858 |
-
|
5859 |
type: Clustering
|
5860 |
-
|
5861 |
-
config: default
|
5862 |
name: MTEB MedrxivClusteringS2S (default)
|
5863 |
-
revision: 35191c8c0dca72d8ff3efcd72aa802307d469663
|
5864 |
-
split: test
|
5865 |
type: mteb/medrxiv-clustering-s2s
|
|
|
|
|
|
|
5866 |
metrics:
|
5867 |
- type: main_score
|
5868 |
value: 35.72762609825363
|
@@ -5870,14 +5872,14 @@ model-index:
|
|
5870 |
value: 35.72762609825363
|
5871 |
- type: v_measure_std
|
5872 |
value: 1.4555014772914985
|
5873 |
-
|
5874 |
-
type:
|
5875 |
-
|
5876 |
-
config: default
|
5877 |
name: MTEB MindSmallReranking (default)
|
5878 |
-
revision: 59042f120c80e8afa9cdbb224f67076cec0fc9a7
|
5879 |
-
split: test
|
5880 |
type: mteb/mind_small
|
|
|
|
|
|
|
5881 |
metrics:
|
5882 |
- type: main_score
|
5883 |
value: 30.47716416454022
|
@@ -5897,14 +5899,14 @@ model-index:
|
|
5897 |
value: -15.78941850629242
|
5898 |
- type: nAUC_mrr_std
|
5899 |
value: -1.1330442292510805
|
5900 |
-
|
5901 |
-
type:
|
5902 |
-
|
5903 |
-
config: default
|
5904 |
name: MTEB NFCorpus (default)
|
5905 |
-
revision: ec0fa4fe99da2ff19ca1214b7966684033a58814
|
5906 |
-
split: test
|
5907 |
type: mteb/nfcorpus
|
|
|
|
|
|
|
5908 |
metrics:
|
5909 |
- type: main_score
|
5910 |
value: 34.648
|
@@ -6188,14 +6190,14 @@ model-index:
|
|
6188 |
value: 10.037
|
6189 |
- type: recall_at_5
|
6190 |
value: 12.717999999999998
|
6191 |
-
|
6192 |
type: Retrieval
|
6193 |
-
|
6194 |
-
config: default
|
6195 |
name: MTEB NQ (default)
|
6196 |
-
revision: b774495ed302d8c44a3a7ea25c90dbce03968f31
|
6197 |
-
split: test
|
6198 |
type: mteb/nq
|
|
|
|
|
|
|
6199 |
metrics:
|
6200 |
- type: main_score
|
6201 |
value: 60.06
|
@@ -6479,14 +6481,14 @@ model-index:
|
|
6479 |
value: 61.114000000000004
|
6480 |
- type: recall_at_5
|
6481 |
value: 69.812
|
6482 |
-
|
6483 |
type: Retrieval
|
6484 |
-
|
6485 |
-
config: default
|
6486 |
name: MTEB QuoraRetrieval (default)
|
6487 |
-
revision: e4e08e0b7dbe3c8700f0daef558ff32256715259
|
6488 |
-
split: test
|
6489 |
type: mteb/quora
|
|
|
|
|
|
|
6490 |
metrics:
|
6491 |
- type: main_score
|
6492 |
value: 89.821
|
@@ -6770,14 +6772,14 @@ model-index:
|
|
6770 |
value: 88.714
|
6771 |
- type: recall_at_5
|
6772 |
value: 92.96799999999999
|
6773 |
-
|
6774 |
-
type:
|
6775 |
-
|
6776 |
-
config: default
|
6777 |
name: MTEB RedditClustering (default)
|
6778 |
-
revision: 24640382cdbf8abc73003fb0fa6d111a705499eb
|
6779 |
-
split: test
|
6780 |
type: mteb/reddit-clustering
|
|
|
|
|
|
|
6781 |
metrics:
|
6782 |
- type: main_score
|
6783 |
value: 59.36038828851887
|
@@ -6785,14 +6787,14 @@ model-index:
|
|
6785 |
value: 59.36038828851887
|
6786 |
- type: v_measure_std
|
6787 |
value: 4.1958765965154425
|
6788 |
-
|
6789 |
type: Clustering
|
6790 |
-
|
6791 |
-
config: default
|
6792 |
name: MTEB RedditClusteringP2P (default)
|
6793 |
-
revision: 385e3cb46b4cfa89021f56c4380204149d0efe33
|
6794 |
-
split: test
|
6795 |
type: mteb/reddit-clustering-p2p
|
|
|
|
|
|
|
6796 |
metrics:
|
6797 |
- type: main_score
|
6798 |
value: 64.67522832408089
|
@@ -6800,14 +6802,14 @@ model-index:
|
|
6800 |
value: 64.67522832408089
|
6801 |
- type: v_measure_std
|
6802 |
value: 12.473765016158698
|
6803 |
-
|
6804 |
-
type:
|
6805 |
-
|
6806 |
-
config: default
|
6807 |
name: MTEB SCIDOCS (default)
|
6808 |
-
revision: f8c2fcf00f625baaa80f62ec5bd9e1fff3b8ae88
|
6809 |
-
split: test
|
6810 |
type: mteb/scidocs
|
|
|
|
|
|
|
6811 |
metrics:
|
6812 |
- type: main_score
|
6813 |
value: 21.751
|
@@ -7091,14 +7093,14 @@ model-index:
|
|
7091 |
value: 11.648
|
7092 |
- type: recall_at_5
|
7093 |
value: 15.883
|
7094 |
-
|
7095 |
-
type:
|
7096 |
-
|
7097 |
-
config: default
|
7098 |
name: MTEB SICK-R (default)
|
7099 |
-
revision: 20a6d6f312dd54037fe07a32d58e5e168867909d
|
7100 |
-
split: test
|
7101 |
type: mteb/sickr-sts
|
|
|
|
|
|
|
7102 |
metrics:
|
7103 |
- type: cosine_pearson
|
7104 |
value: 84.0161170579997
|
@@ -7118,14 +7120,14 @@ model-index:
|
|
7118 |
value: 84.0161170579997
|
7119 |
- type: spearman
|
7120 |
value: 77.52025923874551
|
7121 |
-
|
7122 |
type: STS
|
7123 |
-
|
7124 |
-
config: default
|
7125 |
name: MTEB STS12 (default)
|
7126 |
-
revision: a0d554a64d88156834ff5ae9920b964011b16384
|
7127 |
-
split: test
|
7128 |
type: mteb/sts12-sts
|
|
|
|
|
|
|
7129 |
metrics:
|
7130 |
- type: cosine_pearson
|
7131 |
value: 81.32328780209225
|
@@ -7145,14 +7147,14 @@ model-index:
|
|
7145 |
value: 81.32328780209225
|
7146 |
- type: spearman
|
7147 |
value: 74.17570679745272
|
7148 |
-
|
7149 |
type: STS
|
7150 |
-
|
7151 |
-
config: default
|
7152 |
name: MTEB STS13 (default)
|
7153 |
-
revision: 7e90230a92c190f1bf69ae9002b8cea547a64cca
|
7154 |
-
split: test
|
7155 |
type: mteb/sts13-sts
|
|
|
|
|
|
|
7156 |
metrics:
|
7157 |
- type: cosine_pearson
|
7158 |
value: 85.53224141249392
|
@@ -7172,14 +7174,14 @@ model-index:
|
|
7172 |
value: 85.53224141249392
|
7173 |
- type: spearman
|
7174 |
value: 86.16981525069227
|
7175 |
-
|
7176 |
type: STS
|
7177 |
-
|
7178 |
-
config: default
|
7179 |
name: MTEB STS14 (default)
|
7180 |
-
revision: 6031580fec1f6af667f0bd2da0a551cf4f0b2375
|
7181 |
-
split: test
|
7182 |
type: mteb/sts14-sts
|
|
|
|
|
|
|
7183 |
metrics:
|
7184 |
- type: cosine_pearson
|
7185 |
value: 82.234064045301
|
@@ -7199,14 +7201,14 @@ model-index:
|
|
7199 |
value: 82.234064045301
|
7200 |
- type: spearman
|
7201 |
value: 78.86920830792957
|
7202 |
-
|
7203 |
type: STS
|
7204 |
-
|
7205 |
-
config: default
|
7206 |
name: MTEB STS15 (default)
|
7207 |
-
revision: ae752c7c21bf194d8b67fd573edf7ae58183cbe3
|
7208 |
-
split: test
|
7209 |
type: mteb/sts15-sts
|
|
|
|
|
|
|
7210 |
metrics:
|
7211 |
- type: cosine_pearson
|
7212 |
value: 86.23114543080261
|
@@ -7226,14 +7228,14 @@ model-index:
|
|
7226 |
value: 86.23114543080261
|
7227 |
- type: spearman
|
7228 |
value: 87.481042891123
|
7229 |
-
|
7230 |
type: STS
|
7231 |
-
|
7232 |
-
config: default
|
7233 |
name: MTEB STS16 (default)
|
7234 |
-
revision: 4d8694f8f0e0100860b497b999b3dbed754a0513
|
7235 |
-
split: test
|
7236 |
type: mteb/sts16-sts
|
|
|
|
|
|
|
7237 |
metrics:
|
7238 |
- type: cosine_pearson
|
7239 |
value: 82.9156629047782
|
@@ -7253,14 +7255,14 @@ model-index:
|
|
7253 |
value: 82.9156629047782
|
7254 |
- type: spearman
|
7255 |
value: 84.28381329207937
|
7256 |
-
|
7257 |
type: STS
|
7258 |
-
|
7259 |
-
config: en-en
|
7260 |
name: MTEB STS17 (en-en)
|
7261 |
-
revision: faeb762787bd10488a50c8b5be4a3b82e411949c
|
7262 |
-
split: test
|
7263 |
type: mteb/sts17-crosslingual-sts
|
|
|
|
|
|
|
7264 |
metrics:
|
7265 |
- type: cosine_pearson
|
7266 |
value: 88.91985349746744
|
@@ -7280,14 +7282,14 @@ model-index:
|
|
7280 |
value: 88.91985349746744
|
7281 |
- type: spearman
|
7282 |
value: 89.69151633966257
|
7283 |
-
|
7284 |
type: STS
|
7285 |
-
|
7286 |
-
config: en
|
7287 |
name: MTEB STS22 (en)
|
7288 |
-
revision: de9d86b3b84231dc21f76c7b7af1f28e2f57f6e3
|
7289 |
-
split: test
|
7290 |
type: mteb/sts22-crosslingual-sts
|
|
|
|
|
|
|
7291 |
metrics:
|
7292 |
- type: cosine_pearson
|
7293 |
value: 65.0979772547511
|
@@ -7307,14 +7309,14 @@ model-index:
|
|
7307 |
value: 65.0979772547511
|
7308 |
- type: spearman
|
7309 |
value: 65.78126527764236
|
7310 |
-
|
7311 |
type: STS
|
7312 |
-
|
7313 |
-
config: default
|
7314 |
name: MTEB STSBenchmark (default)
|
7315 |
-
revision: b0fddb56ed78048fa8b90373c8a3cfc37b684831
|
7316 |
-
split: test
|
7317 |
type: mteb/stsbenchmark-sts
|
|
|
|
|
|
|
7318 |
metrics:
|
7319 |
- type: cosine_pearson
|
7320 |
value: 85.6426635049971
|
@@ -7334,14 +7336,14 @@ model-index:
|
|
7334 |
value: 85.6426635049971
|
7335 |
- type: spearman
|
7336 |
value: 85.609856578385
|
7337 |
-
|
7338 |
-
type:
|
7339 |
-
|
7340 |
-
config: default
|
7341 |
name: MTEB SciDocsRR (default)
|
7342 |
-
revision: d3c5e1fc0b855ab6097bf1cda04dd73947d7caab
|
7343 |
-
split: test
|
7344 |
type: mteb/scidocs-reranking
|
|
|
|
|
|
|
7345 |
metrics:
|
7346 |
- type: main_score
|
7347 |
value: 82.85163332499799
|
@@ -7361,14 +7363,14 @@ model-index:
|
|
7361 |
value: 89.47202967481866
|
7362 |
- type: nAUC_mrr_std
|
7363 |
value: 85.40446996933892
|
7364 |
-
|
7365 |
-
type:
|
7366 |
-
|
7367 |
-
config: default
|
7368 |
name: MTEB SciFact (default)
|
7369 |
-
revision: 0228b52cf27578f30900b9e5271d331663a030d7
|
7370 |
-
split: test
|
7371 |
type: mteb/scifact
|
|
|
|
|
|
|
7372 |
metrics:
|
7373 |
- type: main_score
|
7374 |
value: 71.655
|
@@ -7652,14 +7654,14 @@ model-index:
|
|
7652 |
value: 71.61699999999999
|
7653 |
- type: recall_at_5
|
7654 |
value: 78.361
|
7655 |
-
|
7656 |
-
type:
|
7657 |
-
|
7658 |
-
config: default
|
7659 |
name: MTEB SprintDuplicateQuestions (default)
|
7660 |
-
revision: d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46
|
7661 |
-
split: test
|
7662 |
type: mteb/sprintduplicatequestions-pairclassification
|
|
|
|
|
|
|
7663 |
metrics:
|
7664 |
- type: cosine_accuracy
|
7665 |
value: 99.8019801980198
|
@@ -7743,14 +7745,14 @@ model-index:
|
|
7743 |
value: 90.79754601226993
|
7744 |
- type: similarity_recall
|
7745 |
value: 88.8
|
7746 |
-
|
7747 |
-
type:
|
7748 |
-
|
7749 |
-
config: default
|
7750 |
name: MTEB StackExchangeClustering (default)
|
7751 |
-
revision: 6cbc1f7b2bc0622f2e39d2c77fa502909748c259
|
7752 |
-
split: test
|
7753 |
type: mteb/stackexchange-clustering
|
|
|
|
|
|
|
7754 |
metrics:
|
7755 |
- type: main_score
|
7756 |
value: 66.63931197758824
|
@@ -7758,14 +7760,14 @@ model-index:
|
|
7758 |
value: 66.63931197758824
|
7759 |
- type: v_measure_std
|
7760 |
value: 3.896206781511776
|
7761 |
-
|
7762 |
type: Clustering
|
7763 |
-
|
7764 |
-
config: default
|
7765 |
name: MTEB StackExchangeClusteringP2P (default)
|
7766 |
-
revision: 815ca46b2622cec33ccafc3735d572c266efdb44
|
7767 |
-
split: test
|
7768 |
type: mteb/stackexchange-clustering-p2p
|
|
|
|
|
|
|
7769 |
metrics:
|
7770 |
- type: main_score
|
7771 |
value: 38.984892653301884
|
@@ -7773,14 +7775,14 @@ model-index:
|
|
7773 |
value: 38.984892653301884
|
7774 |
- type: v_measure_std
|
7775 |
value: 1.3308552162270453
|
7776 |
-
|
7777 |
-
type:
|
7778 |
-
|
7779 |
-
config: default
|
7780 |
name: MTEB StackOverflowDupQuestions (default)
|
7781 |
-
revision: e185fbe320c72810689fc5848eb6114e1ef5ec69
|
7782 |
-
split: test
|
7783 |
type: mteb/stackoverflowdupquestions-reranking
|
|
|
|
|
|
|
7784 |
metrics:
|
7785 |
- type: main_score
|
7786 |
value: 52.71499643455044
|
@@ -7800,14 +7802,14 @@ model-index:
|
|
7800 |
value: 13.931448578334379
|
7801 |
- type: nAUC_mrr_std
|
7802 |
value: 10.441860004959661
|
7803 |
-
|
7804 |
-
type:
|
7805 |
-
|
7806 |
-
config: default
|
7807 |
name: MTEB SummEval (default)
|
7808 |
-
revision: cda12ad7615edc362dbf25a00fdd61d3b1eaf93c
|
7809 |
-
split: test
|
7810 |
type: mteb/summeval
|
|
|
|
|
|
|
7811 |
metrics:
|
7812 |
- type: cosine_pearson
|
7813 |
value: 31.5167525286909
|
@@ -7823,14 +7825,14 @@ model-index:
|
|
7823 |
value: 31.5167525286909
|
7824 |
- type: spearman
|
7825 |
value: 31.218862970706496
|
7826 |
-
|
7827 |
-
type:
|
7828 |
-
|
7829 |
-
config: default
|
7830 |
name: MTEB TRECCOVID (default)
|
7831 |
-
revision: bb9466bac8153a0349341eb1b22e06409e78ef4e
|
7832 |
-
split: test
|
7833 |
type: mteb/trec-covid
|
|
|
|
|
|
|
7834 |
metrics:
|
7835 |
- type: main_score
|
7836 |
value: 78.996
|
@@ -8114,14 +8116,14 @@ model-index:
|
|
8114 |
value: 0.705
|
8115 |
- type: recall_at_5
|
8116 |
value: 1.162
|
8117 |
-
|
8118 |
type: Retrieval
|
8119 |
-
|
8120 |
-
config: default
|
8121 |
name: MTEB Touche2020 (default)
|
8122 |
-
revision: a34f9a33db75fa0cbb21bb5cfc3dae8dc8bec93f
|
8123 |
-
split: test
|
8124 |
type: mteb/touche2020
|
|
|
|
|
|
|
8125 |
metrics:
|
8126 |
- type: main_score
|
8127 |
value: 24.234
|
@@ -8405,14 +8407,14 @@ model-index:
|
|
8405 |
value: 6.625
|
8406 |
- type: recall_at_5
|
8407 |
value: 9.094
|
8408 |
-
|
8409 |
-
type:
|
8410 |
-
|
8411 |
-
config: default
|
8412 |
name: MTEB ToxicConversationsClassification (default)
|
8413 |
-
revision: edfaf9da55d3dd50d43143d90c1ac476895ae6de
|
8414 |
-
split: test
|
8415 |
type: mteb/toxic_conversations_50k
|
|
|
|
|
|
|
8416 |
metrics:
|
8417 |
- type: accuracy
|
8418 |
value: 72.822265625
|
@@ -8426,14 +8428,14 @@ model-index:
|
|
8426 |
value: 78.7454393727821
|
8427 |
- type: main_score
|
8428 |
value: 72.822265625
|
8429 |
-
|
8430 |
type: Classification
|
8431 |
-
|
8432 |
-
config: default
|
8433 |
name: MTEB TweetSentimentExtractionClassification (default)
|
8434 |
-
revision: d604517c81ca91fe16a244d1248fc021f9ecee7a
|
8435 |
-
split: test
|
8436 |
type: mteb/tweet_sentiment_extraction
|
|
|
|
|
|
|
8437 |
metrics:
|
8438 |
- type: accuracy
|
8439 |
value: 72.54385964912281
|
@@ -8443,14 +8445,14 @@ model-index:
|
|
8443 |
value: 72.18022450339639
|
8444 |
- type: main_score
|
8445 |
value: 72.54385964912281
|
8446 |
-
|
8447 |
-
type:
|
8448 |
-
|
8449 |
-
config: default
|
8450 |
name: MTEB TwentyNewsgroupsClustering (default)
|
8451 |
-
revision: 6125ec4e24fa026cec8a478383ee943acfbd5449
|
8452 |
-
split: test
|
8453 |
type: mteb/twentynewsgroups-clustering
|
|
|
|
|
|
|
8454 |
metrics:
|
8455 |
- type: main_score
|
8456 |
value: 57.41861450414374
|
@@ -8458,14 +8460,14 @@ model-index:
|
|
8458 |
value: 57.41861450414374
|
8459 |
- type: v_measure_std
|
8460 |
value: 1.1732394227153524
|
8461 |
-
|
8462 |
-
type:
|
8463 |
-
|
8464 |
-
config: default
|
8465 |
name: MTEB TwitterSemEval2015 (default)
|
8466 |
-
revision: 70970daeab8776df92f5ea462b6173c0b46fd2d1
|
8467 |
-
split: test
|
8468 |
type: mteb/twittersemeval2015-pairclassification
|
|
|
|
|
|
|
8469 |
metrics:
|
8470 |
- type: cosine_accuracy
|
8471 |
value: 85.65893783155511
|
@@ -8549,14 +8551,14 @@ model-index:
|
|
8549 |
value: 64.0855106888361
|
8550 |
- type: similarity_recall
|
8551 |
value: 71.18733509234828
|
8552 |
-
|
8553 |
type: PairClassification
|
8554 |
-
|
8555 |
-
config: default
|
8556 |
name: MTEB TwitterURLCorpus (default)
|
8557 |
-
revision: 8b6510b0b1fa4e4c4f879467980e9be563ec1cdf
|
8558 |
-
split: test
|
8559 |
type: mteb/twitterurlcorpus-pairclassification
|
|
|
|
|
|
|
8560 |
metrics:
|
8561 |
- type: cosine_accuracy
|
8562 |
value: 88.86754375751931
|
@@ -8640,8 +8642,6 @@ model-index:
|
|
8640 |
value: 74.19310344827586
|
8641 |
- type: similarity_recall
|
8642 |
value: 82.83030489682784
|
8643 |
-
task:
|
8644 |
-
type: PairClassification
|
8645 |
---
|
8646 |
# Contextual Document Embeddings (CDE)
|
8647 |
|
|
|
4 |
model-index:
|
5 |
- name: cde-small-v1
|
6 |
results:
|
7 |
+
- task:
|
8 |
+
type: Classification
|
9 |
+
dataset:
|
10 |
name: MTEB AmazonCounterfactualClassification (en)
|
|
|
|
|
11 |
type: mteb/amazon_counterfactual
|
12 |
+
config: en
|
13 |
+
split: test
|
14 |
+
revision: e8379541af4e31359cca9fbcf4b00f2671dba205
|
15 |
metrics:
|
16 |
- type: accuracy
|
17 |
value: 87.01492537313433
|
|
|
25 |
value: 87.74802754480477
|
26 |
- type: main_score
|
27 |
value: 87.01492537313433
|
28 |
+
- task:
|
29 |
type: Classification
|
30 |
+
dataset:
|
|
|
31 |
name: MTEB AmazonPolarityClassification (default)
|
|
|
|
|
32 |
type: mteb/amazon_polarity
|
33 |
+
config: default
|
34 |
+
split: test
|
35 |
+
revision: e2d317d38cd51312af73b3d32a06d1a08b442046
|
36 |
metrics:
|
37 |
- type: accuracy
|
38 |
value: 94.652275
|
|
|
46 |
value: 94.64655930708355
|
47 |
- type: main_score
|
48 |
value: 94.652275
|
49 |
+
- task:
|
50 |
type: Classification
|
51 |
+
dataset:
|
|
|
52 |
name: MTEB AmazonReviewsClassification (en)
|
|
|
|
|
53 |
type: mteb/amazon_reviews_multi
|
54 |
+
config: en
|
55 |
+
split: test
|
56 |
+
revision: 1399c76144fd37290681b995c656ef9b2e06e26d
|
57 |
metrics:
|
58 |
- type: accuracy
|
59 |
value: 55.75599999999999
|
|
|
63 |
value: 55.07058630829347
|
64 |
- type: main_score
|
65 |
value: 55.75599999999999
|
66 |
+
- task:
|
67 |
+
type: Retrieval
|
68 |
+
dataset:
|
|
|
69 |
name: MTEB ArguAna (default)
|
|
|
|
|
70 |
type: mteb/arguana
|
71 |
+
config: default
|
72 |
+
split: test
|
73 |
+
revision: c22ab2a51041ffd869aaddef7af8d8215647e41a
|
74 |
metrics:
|
75 |
- type: main_score
|
76 |
value: 69.959
|
|
|
354 |
value: 74.182
|
355 |
- type: recall_at_5
|
356 |
value: 84.495
|
357 |
+
- task:
|
358 |
+
type: Clustering
|
359 |
+
dataset:
|
|
|
360 |
name: MTEB ArxivClusteringP2P (default)
|
|
|
|
|
361 |
type: mteb/arxiv-clustering-p2p
|
362 |
+
config: default
|
363 |
+
split: test
|
364 |
+
revision: a122ad7f3f0291bf49cc6f4d32aa80929df69d5d
|
365 |
metrics:
|
366 |
- type: main_score
|
367 |
value: 48.54672141116669
|
|
|
369 |
value: 48.54672141116669
|
370 |
- type: v_measure_std
|
371 |
value: 14.037498386768362
|
372 |
+
- task:
|
373 |
type: Clustering
|
374 |
+
dataset:
|
|
|
375 |
name: MTEB ArxivClusteringS2S (default)
|
|
|
|
|
376 |
type: mteb/arxiv-clustering-s2s
|
377 |
+
config: default
|
378 |
+
split: test
|
379 |
+
revision: f910caf1a6075f7329cdf8c1a6135696f37dbd53
|
380 |
metrics:
|
381 |
- type: main_score
|
382 |
value: 40.5914039166466
|
|
|
384 |
value: 40.5914039166466
|
385 |
- type: v_measure_std
|
386 |
value: 14.385069818910331
|
387 |
+
- task:
|
388 |
+
type: Reranking
|
389 |
+
dataset:
|
|
|
390 |
name: MTEB AskUbuntuDupQuestions (default)
|
|
|
|
|
391 |
type: mteb/askubuntudupquestions-reranking
|
392 |
+
config: default
|
393 |
+
split: test
|
394 |
+
revision: 2000358ca161889fa9c082cb41daa8dcfb161a54
|
395 |
metrics:
|
396 |
- type: main_score
|
397 |
value: 61.13621260261507
|
|
|
411 |
value: 31.484257486448364
|
412 |
- type: nAUC_mrr_std
|
413 |
value: 21.252659250011632
|
414 |
+
- task:
|
415 |
+
type: STS
|
416 |
+
dataset:
|
|
|
417 |
name: MTEB BIOSSES (default)
|
|
|
|
|
418 |
type: mteb/biosses-sts
|
419 |
+
config: default
|
420 |
+
split: test
|
421 |
+
revision: d3fb88f8f02e40887cd149695127462bbcf29b4a
|
422 |
metrics:
|
423 |
- type: cosine_pearson
|
424 |
value: 89.07028016646942
|
|
|
438 |
value: 89.07028016646942
|
439 |
- type: spearman
|
440 |
value: 86.69595132967805
|
441 |
+
- task:
|
442 |
+
type: Classification
|
443 |
+
dataset:
|
|
|
444 |
name: MTEB Banking77Classification (default)
|
|
|
|
|
445 |
type: mteb/banking77
|
446 |
+
config: default
|
447 |
+
split: test
|
448 |
+
revision: 0fd18e25b25c072e09e0d92ab615fda904d66300
|
449 |
metrics:
|
450 |
- type: accuracy
|
451 |
value: 88.6038961038961
|
|
|
455 |
value: 88.56824205739822
|
456 |
- type: main_score
|
457 |
value: 88.6038961038961
|
458 |
+
- task:
|
459 |
+
type: Clustering
|
460 |
+
dataset:
|
|
|
461 |
name: MTEB BiorxivClusteringP2P (default)
|
|
|
|
|
462 |
type: mteb/biorxiv-clustering-p2p
|
463 |
+
config: default
|
464 |
+
split: test
|
465 |
+
revision: 65b79d1d13f80053f67aca9498d9402c2d9f1f40
|
466 |
metrics:
|
467 |
- type: main_score
|
468 |
value: 44.77800814327256
|
|
|
470 |
value: 44.77800814327256
|
471 |
- type: v_measure_std
|
472 |
value: 0.6462535527471919
|
473 |
+
- task:
|
474 |
type: Clustering
|
475 |
+
dataset:
|
|
|
476 |
name: MTEB BiorxivClusteringS2S (default)
|
|
|
|
|
477 |
type: mteb/biorxiv-clustering-s2s
|
478 |
+
config: default
|
479 |
+
split: test
|
480 |
+
revision: 258694dd0231531bc1fd9de6ceb52a0853c6d908
|
481 |
metrics:
|
482 |
- type: main_score
|
483 |
value: 38.16110272459102
|
|
|
485 |
value: 38.16110272459102
|
486 |
- type: v_measure_std
|
487 |
value: 0.7456916212435019
|
488 |
+
- task:
|
489 |
+
type: Retrieval
|
490 |
+
dataset:
|
|
|
491 |
name: MTEB CQADupstackAndroidRetrieval (default)
|
|
|
|
|
492 |
type: mteb/cqadupstack-android
|
493 |
+
config: default
|
494 |
+
split: test
|
495 |
+
revision: f46a197baaae43b4f621051089b82a364682dfeb
|
496 |
metrics:
|
497 |
- type: main_score
|
498 |
value: 49.376
|
|
|
776 |
value: 47.591
|
777 |
- type: recall_at_5
|
778 |
value: 54.245
|
779 |
+
- task:
|
780 |
type: Retrieval
|
781 |
+
dataset:
|
|
|
782 |
name: MTEB CQADupstackEnglishRetrieval (default)
|
|
|
|
|
783 |
type: mteb/cqadupstack-english
|
784 |
+
config: default
|
785 |
+
split: test
|
786 |
+
revision: ad9991cb51e31e31e430383c75ffb2885547b5f0
|
787 |
metrics:
|
788 |
- type: main_score
|
789 |
value: 44.727
|
|
|
1067 |
value: 42.085
|
1068 |
- type: recall_at_5
|
1069 |
value: 47.5
|
1070 |
+
- task:
|
1071 |
type: Retrieval
|
1072 |
+
dataset:
|
|
|
1073 |
name: MTEB CQADupstackGamingRetrieval (default)
|
|
|
|
|
1074 |
type: mteb/cqadupstack-gaming
|
1075 |
+
config: default
|
1076 |
+
split: test
|
1077 |
+
revision: 4885aa143210c98657558c04aaf3dc47cfb54340
|
1078 |
metrics:
|
1079 |
- type: main_score
|
1080 |
value: 59.001999999999995
|
|
|
1358 |
value: 57.916000000000004
|
1359 |
- type: recall_at_5
|
1360 |
value: 65.44
|
1361 |
+
- task:
|
1362 |
type: Retrieval
|
1363 |
+
dataset:
|
|
|
1364 |
name: MTEB CQADupstackGisRetrieval (default)
|
|
|
|
|
1365 |
type: mteb/cqadupstack-gis
|
1366 |
+
config: default
|
1367 |
+
split: test
|
1368 |
+
revision: 5003b3064772da1887988e05400cf3806fe491f2
|
1369 |
metrics:
|
1370 |
- type: main_score
|
1371 |
value: 37.501
|
|
|
1649 |
value: 37.218
|
1650 |
- type: recall_at_5
|
1651 |
value: 42.559000000000005
|
1652 |
+
- task:
|
1653 |
type: Retrieval
|
1654 |
+
dataset:
|
|
|
1655 |
name: MTEB CQADupstackMathematicaRetrieval (default)
|
|
|
|
|
1656 |
type: mteb/cqadupstack-mathematica
|
1657 |
+
config: default
|
1658 |
+
split: test
|
1659 |
+
revision: 90fceea13679c63fe563ded68f3b6f06e50061de
|
1660 |
metrics:
|
1661 |
- type: main_score
|
1662 |
value: 27.653
|
|
|
1940 |
value: 25.469
|
1941 |
- type: recall_at_5
|
1942 |
value: 31.316
|
1943 |
+
- task:
|
1944 |
type: Retrieval
|
1945 |
+
dataset:
|
|
|
1946 |
name: MTEB CQADupstackPhysicsRetrieval (default)
|
|
|
|
|
1947 |
type: mteb/cqadupstack-physics
|
1948 |
+
config: default
|
1949 |
+
split: test
|
1950 |
+
revision: 79531abbd1fb92d06c6d6315a0cbbbf5bb247ea4
|
1951 |
metrics:
|
1952 |
- type: main_score
|
1953 |
value: 45.314
|
|
|
2231 |
value: 43.679
|
2232 |
- type: recall_at_5
|
2233 |
value: 49.735
|
2234 |
+
- task:
|
2235 |
type: Retrieval
|
2236 |
+
dataset:
|
|
|
2237 |
name: MTEB CQADupstackProgrammersRetrieval (default)
|
|
|
|
|
2238 |
type: mteb/cqadupstack-programmers
|
2239 |
+
config: default
|
2240 |
+
split: test
|
2241 |
+
revision: 6184bc1440d2dbc7612be22b50686b8826d22b32
|
2242 |
metrics:
|
2243 |
- type: main_score
|
2244 |
value: 41.972
|
|
|
2522 |
value: 39.363
|
2523 |
- type: recall_at_5
|
2524 |
value: 44.665
|
2525 |
+
- task:
|
2526 |
type: Retrieval
|
2527 |
+
dataset:
|
|
|
2528 |
name: MTEB CQADupstackRetrieval (default)
|
|
|
|
|
2529 |
type: CQADupstackRetrieval_is_a_combined_dataset
|
2530 |
+
config: default
|
2531 |
+
split: test
|
2532 |
+
revision: CQADupstackRetrieval_is_a_combined_dataset
|
2533 |
metrics:
|
2534 |
- type: main_score
|
2535 |
value: 39.823499999999996
|
2536 |
- type: ndcg_at_10
|
2537 |
value: 39.823499999999996
|
2538 |
+
- task:
|
2539 |
type: Retrieval
|
2540 |
+
dataset:
|
|
|
2541 |
name: MTEB CQADupstackStatsRetrieval (default)
|
|
|
|
|
2542 |
type: mteb/cqadupstack-stats
|
2543 |
+
config: default
|
2544 |
+
split: test
|
2545 |
+
revision: 65ac3a16b8e91f9cee4c9828cc7c335575432a2a
|
2546 |
metrics:
|
2547 |
- type: main_score
|
2548 |
value: 34.943000000000005
|
|
|
2826 |
value: 33.427
|
2827 |
- type: recall_at_5
|
2828 |
value: 37.643
|
2829 |
+
- task:
|
2830 |
type: Retrieval
|
2831 |
+
dataset:
|
|
|
2832 |
name: MTEB CQADupstackTexRetrieval (default)
|
|
|
|
|
2833 |
type: mteb/cqadupstack-tex
|
2834 |
+
config: default
|
2835 |
+
split: test
|
2836 |
+
revision: 46989137a86843e03a6195de44b09deda022eec7
|
2837 |
metrics:
|
2838 |
- type: main_score
|
2839 |
value: 27.271
|
|
|
3117 |
value: 25.592
|
3118 |
- type: recall_at_5
|
3119 |
value: 30.279
|
3120 |
+
- task:
|
3121 |
type: Retrieval
|
3122 |
+
dataset:
|
|
|
3123 |
name: MTEB CQADupstackUnixRetrieval (default)
|
|
|
|
|
3124 |
type: mteb/cqadupstack-unix
|
3125 |
+
config: default
|
3126 |
+
split: test
|
3127 |
+
revision: 6c6430d3a6d36f8d2a829195bc5dc94d7e063e53
|
3128 |
metrics:
|
3129 |
- type: main_score
|
3130 |
value: 38.237
|
|
|
3408 |
value: 36.275
|
3409 |
- type: recall_at_5
|
3410 |
value: 42.199
|
3411 |
+
- task:
|
3412 |
type: Retrieval
|
3413 |
+
dataset:
|
|
|
3414 |
name: MTEB CQADupstackWebmastersRetrieval (default)
|
|
|
|
|
3415 |
type: mteb/cqadupstack-webmasters
|
3416 |
+
config: default
|
3417 |
+
split: test
|
3418 |
+
revision: 160c094312a0e1facb97e55eeddb698c0abe3571
|
3419 |
metrics:
|
3420 |
- type: main_score
|
3421 |
value: 38.702
|
|
|
3699 |
value: 37.634
|
3700 |
- type: recall_at_5
|
3701 |
value: 42.021
|
3702 |
+
- task:
|
3703 |
type: Retrieval
|
3704 |
+
dataset:
|
|
|
3705 |
name: MTEB CQADupstackWordpressRetrieval (default)
|
|
|
|
|
3706 |
type: mteb/cqadupstack-wordpress
|
3707 |
+
config: default
|
3708 |
+
split: test
|
3709 |
+
revision: 4ffe81d471b1924886b33c7567bfb200e9eec5c4
|
3710 |
metrics:
|
3711 |
- type: main_score
|
3712 |
value: 33.184000000000005
|
|
|
3990 |
value: 32.683
|
3991 |
- type: recall_at_5
|
3992 |
value: 36.756
|
3993 |
+
- task:
|
3994 |
type: Retrieval
|
3995 |
+
dataset:
|
|
|
3996 |
name: MTEB ClimateFEVER (default)
|
|
|
|
|
3997 |
type: mteb/climate-fever
|
3998 |
+
config: default
|
3999 |
+
split: test
|
4000 |
+
revision: 47f2ac6acb640fc46020b02a5b59fdda04d39380
|
4001 |
metrics:
|
4002 |
- type: main_score
|
4003 |
value: 25.068
|
|
|
4281 |
value: 18.312
|
4282 |
- type: recall_at_5
|
4283 |
value: 22.776
|
4284 |
+
- task:
|
4285 |
type: Retrieval
|
4286 |
+
dataset:
|
|
|
4287 |
name: MTEB DBPedia (default)
|
|
|
|
|
4288 |
type: mteb/dbpedia
|
4289 |
+
config: default
|
4290 |
+
split: test
|
4291 |
+
revision: c0f706b76e590d620bd6618b3ca8efdd34e2d659
|
4292 |
metrics:
|
4293 |
- type: main_score
|
4294 |
value: 40.128
|
|
|
4572 |
value: 14.562
|
4573 |
- type: recall_at_5
|
4574 |
value: 18.779
|
4575 |
+
- task:
|
4576 |
+
type: Classification
|
4577 |
+
dataset:
|
|
|
4578 |
name: MTEB EmotionClassification (default)
|
|
|
|
|
4579 |
type: mteb/emotion
|
4580 |
+
config: default
|
4581 |
+
split: test
|
4582 |
+
revision: 4f58c6b202a23cf9a4da393831edf4f9183cad37
|
4583 |
metrics:
|
4584 |
- type: accuracy
|
4585 |
value: 74.86
|
|
|
4589 |
value: 75.96499621761998
|
4590 |
- type: main_score
|
4591 |
value: 74.86
|
4592 |
+
- task:
|
4593 |
+
type: Retrieval
|
4594 |
+
dataset:
|
|
|
4595 |
name: MTEB FEVER (default)
|
|
|
|
|
4596 |
type: mteb/fever
|
4597 |
+
config: default
|
4598 |
+
split: test
|
4599 |
+
revision: bea83ef9e8fb933d90a2f1d5515737465d613e12
|
4600 |
metrics:
|
4601 |
- type: main_score
|
4602 |
value: 86.029
|
|
|
4880 |
value: 88.382
|
4881 |
- type: recall_at_5
|
4882 |
value: 90.908
|
4883 |
+
- task:
|
4884 |
type: Retrieval
|
4885 |
+
dataset:
|
|
|
4886 |
name: MTEB FiQA2018 (default)
|
|
|
|
|
4887 |
type: mteb/fiqa
|
4888 |
+
config: default
|
4889 |
+
split: test
|
4890 |
+
revision: 27a168819829fe9bcd655c2df245fb19452e8e06
|
4891 |
metrics:
|
4892 |
- type: main_score
|
4893 |
value: 45.238
|
|
|
5171 |
value: 37.656
|
5172 |
- type: recall_at_5
|
5173 |
value: 44.766
|
5174 |
+
- task:
|
5175 |
type: Retrieval
|
5176 |
+
dataset:
|
|
|
5177 |
name: MTEB HotpotQA (default)
|
|
|
|
|
5178 |
type: mteb/hotpotqa
|
5179 |
+
config: default
|
5180 |
+
split: test
|
5181 |
+
revision: ab518f4d6fcca38d87c25209f94beba119d02014
|
5182 |
metrics:
|
5183 |
- type: main_score
|
5184 |
value: 66.672
|
|
|
5462 |
value: 57.522
|
5463 |
- type: recall_at_5
|
5464 |
value: 62.134
|
5465 |
+
- task:
|
5466 |
+
type: Classification
|
5467 |
+
dataset:
|
|
|
5468 |
name: MTEB ImdbClassification (default)
|
|
|
|
|
5469 |
type: mteb/imdb
|
5470 |
+
config: default
|
5471 |
+
split: test
|
5472 |
+
revision: 3d86128a09e091d6018b6d26cad27f2739fc2db7
|
5473 |
metrics:
|
5474 |
- type: accuracy
|
5475 |
value: 93.5944
|
|
|
5483 |
value: 93.58945949328377
|
5484 |
- type: main_score
|
5485 |
value: 93.5944
|
5486 |
+
- task:
|
5487 |
+
type: Retrieval
|
5488 |
+
dataset:
|
|
|
5489 |
name: MTEB MSMARCO (default)
|
|
|
|
|
5490 |
type: mteb/msmarco
|
5491 |
+
config: default
|
5492 |
+
split: dev
|
5493 |
+
revision: c5a29a104738b98a9e76336939199e264163d4a0
|
5494 |
metrics:
|
5495 |
- type: main_score
|
5496 |
value: 41.448
|
|
|
5774 |
value: 41.304
|
5775 |
- type: recall_at_5
|
5776 |
value: 51.076
|
5777 |
+
- task:
|
5778 |
+
type: Classification
|
5779 |
+
dataset:
|
|
|
5780 |
name: MTEB MTOPDomainClassification (en)
|
|
|
|
|
5781 |
type: mteb/mtop_domain
|
5782 |
+
config: en
|
5783 |
+
split: test
|
5784 |
+
revision: d80d48c1eb48d3562165c59d59d0034df9fff0bf
|
5785 |
metrics:
|
5786 |
- type: accuracy
|
5787 |
value: 96.03967168262655
|
|
|
5791 |
value: 96.06623245823347
|
5792 |
- type: main_score
|
5793 |
value: 96.03967168262655
|
5794 |
+
- task:
|
5795 |
type: Classification
|
5796 |
+
dataset:
|
|
|
5797 |
name: MTEB MTOPIntentClassification (en)
|
|
|
|
|
5798 |
type: mteb/mtop_intent
|
5799 |
+
config: en
|
5800 |
+
split: test
|
5801 |
+
revision: ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba
|
5802 |
metrics:
|
5803 |
- type: accuracy
|
5804 |
value: 89.12904696762428
|
|
|
5808 |
value: 90.41290566743324
|
5809 |
- type: main_score
|
5810 |
value: 89.12904696762428
|
5811 |
+
- task:
|
5812 |
type: Classification
|
5813 |
+
dataset:
|
|
|
5814 |
name: MTEB MassiveIntentClassification (en)
|
|
|
|
|
5815 |
type: mteb/amazon_massive_intent
|
5816 |
+
config: en
|
5817 |
+
split: test
|
5818 |
+
revision: 4672e20407010da34463acc759c162ca9734bca6
|
5819 |
metrics:
|
5820 |
- type: accuracy
|
5821 |
value: 76.49630127774041
|
|
|
5825 |
value: 76.42436195016484
|
5826 |
- type: main_score
|
5827 |
value: 76.49630127774041
|
5828 |
+
- task:
|
5829 |
type: Classification
|
5830 |
+
dataset:
|
|
|
5831 |
name: MTEB MassiveScenarioClassification (en)
|
|
|
|
|
5832 |
type: mteb/amazon_massive_scenario
|
5833 |
+
config: en
|
5834 |
+
split: test
|
5835 |
+
revision: fad2c6e8459f9e1c45d9315f4953d921437d70f8
|
5836 |
metrics:
|
5837 |
- type: accuracy
|
5838 |
value: 78.9340954942838
|
|
|
5842 |
value: 78.87787647838971
|
5843 |
- type: main_score
|
5844 |
value: 78.9340954942838
|
5845 |
+
- task:
|
5846 |
+
type: Clustering
|
5847 |
+
dataset:
|
|
|
5848 |
name: MTEB MedrxivClusteringP2P (default)
|
|
|
|
|
5849 |
type: mteb/medrxiv-clustering-p2p
|
5850 |
+
config: default
|
5851 |
+
split: test
|
5852 |
+
revision: e7a26af6f3ae46b30dde8737f02c07b1505bcc73
|
5853 |
metrics:
|
5854 |
- type: main_score
|
5855 |
value: 37.50182848656019
|
|
|
5857 |
value: 37.50182848656019
|
5858 |
- type: v_measure_std
|
5859 |
value: 1.1708518023877268
|
5860 |
+
- task:
|
5861 |
type: Clustering
|
5862 |
+
dataset:
|
|
|
5863 |
name: MTEB MedrxivClusteringS2S (default)
|
|
|
|
|
5864 |
type: mteb/medrxiv-clustering-s2s
|
5865 |
+
config: default
|
5866 |
+
split: test
|
5867 |
+
revision: 35191c8c0dca72d8ff3efcd72aa802307d469663
|
5868 |
metrics:
|
5869 |
- type: main_score
|
5870 |
value: 35.72762609825363
|
|
|
5872 |
value: 35.72762609825363
|
5873 |
- type: v_measure_std
|
5874 |
value: 1.4555014772914985
|
5875 |
+
- task:
|
5876 |
+
type: Reranking
|
5877 |
+
dataset:
|
|
|
5878 |
name: MTEB MindSmallReranking (default)
|
|
|
|
|
5879 |
type: mteb/mind_small
|
5880 |
+
config: default
|
5881 |
+
split: test
|
5882 |
+
revision: 59042f120c80e8afa9cdbb224f67076cec0fc9a7
|
5883 |
metrics:
|
5884 |
- type: main_score
|
5885 |
value: 30.47716416454022
|
|
|
5899 |
value: -15.78941850629242
|
5900 |
- type: nAUC_mrr_std
|
5901 |
value: -1.1330442292510805
|
5902 |
+
- task:
|
5903 |
+
type: Retrieval
|
5904 |
+
dataset:
|
|
|
5905 |
name: MTEB NFCorpus (default)
|
|
|
|
|
5906 |
type: mteb/nfcorpus
|
5907 |
+
config: default
|
5908 |
+
split: test
|
5909 |
+
revision: ec0fa4fe99da2ff19ca1214b7966684033a58814
|
5910 |
metrics:
|
5911 |
- type: main_score
|
5912 |
value: 34.648
|
|
|
6190 |
value: 10.037
|
6191 |
- type: recall_at_5
|
6192 |
value: 12.717999999999998
|
6193 |
+
- task:
|
6194 |
type: Retrieval
|
6195 |
+
dataset:
|
|
|
6196 |
name: MTEB NQ (default)
|
|
|
|
|
6197 |
type: mteb/nq
|
6198 |
+
config: default
|
6199 |
+
split: test
|
6200 |
+
revision: b774495ed302d8c44a3a7ea25c90dbce03968f31
|
6201 |
metrics:
|
6202 |
- type: main_score
|
6203 |
value: 60.06
|
|
|
6481 |
value: 61.114000000000004
|
6482 |
- type: recall_at_5
|
6483 |
value: 69.812
|
6484 |
+
- task:
|
6485 |
type: Retrieval
|
6486 |
+
dataset:
|
|
|
6487 |
name: MTEB QuoraRetrieval (default)
|
|
|
|
|
6488 |
type: mteb/quora
|
6489 |
+
config: default
|
6490 |
+
split: test
|
6491 |
+
revision: e4e08e0b7dbe3c8700f0daef558ff32256715259
|
6492 |
metrics:
|
6493 |
- type: main_score
|
6494 |
value: 89.821
|
|
|
6772 |
value: 88.714
|
6773 |
- type: recall_at_5
|
6774 |
value: 92.96799999999999
|
6775 |
+
- task:
|
6776 |
+
type: Clustering
|
6777 |
+
dataset:
|
|
|
6778 |
name: MTEB RedditClustering (default)
|
|
|
|
|
6779 |
type: mteb/reddit-clustering
|
6780 |
+
config: default
|
6781 |
+
split: test
|
6782 |
+
revision: 24640382cdbf8abc73003fb0fa6d111a705499eb
|
6783 |
metrics:
|
6784 |
- type: main_score
|
6785 |
value: 59.36038828851887
|
|
|
6787 |
value: 59.36038828851887
|
6788 |
- type: v_measure_std
|
6789 |
value: 4.1958765965154425
|
6790 |
+
- task:
|
6791 |
type: Clustering
|
6792 |
+
dataset:
|
|
|
6793 |
name: MTEB RedditClusteringP2P (default)
|
|
|
|
|
6794 |
type: mteb/reddit-clustering-p2p
|
6795 |
+
config: default
|
6796 |
+
split: test
|
6797 |
+
revision: 385e3cb46b4cfa89021f56c4380204149d0efe33
|
6798 |
metrics:
|
6799 |
- type: main_score
|
6800 |
value: 64.67522832408089
|
|
|
6802 |
value: 64.67522832408089
|
6803 |
- type: v_measure_std
|
6804 |
value: 12.473765016158698
|
6805 |
+
- task:
|
6806 |
+
type: Retrieval
|
6807 |
+
dataset:
|
|
|
6808 |
name: MTEB SCIDOCS (default)
|
|
|
|
|
6809 |
type: mteb/scidocs
|
6810 |
+
config: default
|
6811 |
+
split: test
|
6812 |
+
revision: f8c2fcf00f625baaa80f62ec5bd9e1fff3b8ae88
|
6813 |
metrics:
|
6814 |
- type: main_score
|
6815 |
value: 21.751
|
|
|
7093 |
value: 11.648
|
7094 |
- type: recall_at_5
|
7095 |
value: 15.883
|
7096 |
+
- task:
|
7097 |
+
type: STS
|
7098 |
+
dataset:
|
|
|
7099 |
name: MTEB SICK-R (default)
|
|
|
|
|
7100 |
type: mteb/sickr-sts
|
7101 |
+
config: default
|
7102 |
+
split: test
|
7103 |
+
revision: 20a6d6f312dd54037fe07a32d58e5e168867909d
|
7104 |
metrics:
|
7105 |
- type: cosine_pearson
|
7106 |
value: 84.0161170579997
|
|
|
7120 |
value: 84.0161170579997
|
7121 |
- type: spearman
|
7122 |
value: 77.52025923874551
|
7123 |
+
- task:
|
7124 |
type: STS
|
7125 |
+
dataset:
|
|
|
7126 |
name: MTEB STS12 (default)
|
|
|
|
|
7127 |
type: mteb/sts12-sts
|
7128 |
+
config: default
|
7129 |
+
split: test
|
7130 |
+
revision: a0d554a64d88156834ff5ae9920b964011b16384
|
7131 |
metrics:
|
7132 |
- type: cosine_pearson
|
7133 |
value: 81.32328780209225
|
|
|
7147 |
value: 81.32328780209225
|
7148 |
- type: spearman
|
7149 |
value: 74.17570679745272
|
7150 |
+
- task:
|
7151 |
type: STS
|
7152 |
+
dataset:
|
|
|
7153 |
name: MTEB STS13 (default)
|
|
|
|
|
7154 |
type: mteb/sts13-sts
|
7155 |
+
config: default
|
7156 |
+
split: test
|
7157 |
+
revision: 7e90230a92c190f1bf69ae9002b8cea547a64cca
|
7158 |
metrics:
|
7159 |
- type: cosine_pearson
|
7160 |
value: 85.53224141249392
|
|
|
7174 |
value: 85.53224141249392
|
7175 |
- type: spearman
|
7176 |
value: 86.16981525069227
|
7177 |
+
- task:
|
7178 |
type: STS
|
7179 |
+
dataset:
|
|
|
7180 |
name: MTEB STS14 (default)
|
|
|
|
|
7181 |
type: mteb/sts14-sts
|
7182 |
+
config: default
|
7183 |
+
split: test
|
7184 |
+
revision: 6031580fec1f6af667f0bd2da0a551cf4f0b2375
|
7185 |
metrics:
|
7186 |
- type: cosine_pearson
|
7187 |
value: 82.234064045301
|
|
|
7201 |
value: 82.234064045301
|
7202 |
- type: spearman
|
7203 |
value: 78.86920830792957
|
7204 |
+
- task:
|
7205 |
type: STS
|
7206 |
+
dataset:
|
|
|
7207 |
name: MTEB STS15 (default)
|
|
|
|
|
7208 |
type: mteb/sts15-sts
|
7209 |
+
config: default
|
7210 |
+
split: test
|
7211 |
+
revision: ae752c7c21bf194d8b67fd573edf7ae58183cbe3
|
7212 |
metrics:
|
7213 |
- type: cosine_pearson
|
7214 |
value: 86.23114543080261
|
|
|
7228 |
value: 86.23114543080261
|
7229 |
- type: spearman
|
7230 |
value: 87.481042891123
|
7231 |
+
- task:
|
7232 |
type: STS
|
7233 |
+
dataset:
|
|
|
7234 |
name: MTEB STS16 (default)
|
|
|
|
|
7235 |
type: mteb/sts16-sts
|
7236 |
+
config: default
|
7237 |
+
split: test
|
7238 |
+
revision: 4d8694f8f0e0100860b497b999b3dbed754a0513
|
7239 |
metrics:
|
7240 |
- type: cosine_pearson
|
7241 |
value: 82.9156629047782
|
|
|
7255 |
value: 82.9156629047782
|
7256 |
- type: spearman
|
7257 |
value: 84.28381329207937
|
7258 |
+
- task:
|
7259 |
type: STS
|
7260 |
+
dataset:
|
|
|
7261 |
name: MTEB STS17 (en-en)
|
|
|
|
|
7262 |
type: mteb/sts17-crosslingual-sts
|
7263 |
+
config: en-en
|
7264 |
+
split: test
|
7265 |
+
revision: faeb762787bd10488a50c8b5be4a3b82e411949c
|
7266 |
metrics:
|
7267 |
- type: cosine_pearson
|
7268 |
value: 88.91985349746744
|
|
|
7282 |
value: 88.91985349746744
|
7283 |
- type: spearman
|
7284 |
value: 89.69151633966257
|
7285 |
+
- task:
|
7286 |
type: STS
|
7287 |
+
dataset:
|
|
|
7288 |
name: MTEB STS22 (en)
|
|
|
|
|
7289 |
type: mteb/sts22-crosslingual-sts
|
7290 |
+
config: en
|
7291 |
+
split: test
|
7292 |
+
revision: de9d86b3b84231dc21f76c7b7af1f28e2f57f6e3
|
7293 |
metrics:
|
7294 |
- type: cosine_pearson
|
7295 |
value: 65.0979772547511
|
|
|
7309 |
value: 65.0979772547511
|
7310 |
- type: spearman
|
7311 |
value: 65.78126527764236
|
7312 |
+
- task:
|
7313 |
type: STS
|
7314 |
+
dataset:
|
|
|
7315 |
name: MTEB STSBenchmark (default)
|
|
|
|
|
7316 |
type: mteb/stsbenchmark-sts
|
7317 |
+
config: default
|
7318 |
+
split: test
|
7319 |
+
revision: b0fddb56ed78048fa8b90373c8a3cfc37b684831
|
7320 |
metrics:
|
7321 |
- type: cosine_pearson
|
7322 |
value: 85.6426635049971
|
|
|
7336 |
value: 85.6426635049971
|
7337 |
- type: spearman
|
7338 |
value: 85.609856578385
|
7339 |
+
- task:
|
7340 |
+
type: Reranking
|
7341 |
+
dataset:
|
|
|
7342 |
name: MTEB SciDocsRR (default)
|
|
|
|
|
7343 |
type: mteb/scidocs-reranking
|
7344 |
+
config: default
|
7345 |
+
split: test
|
7346 |
+
revision: d3c5e1fc0b855ab6097bf1cda04dd73947d7caab
|
7347 |
metrics:
|
7348 |
- type: main_score
|
7349 |
value: 82.85163332499799
|
|
|
7363 |
value: 89.47202967481866
|
7364 |
- type: nAUC_mrr_std
|
7365 |
value: 85.40446996933892
|
7366 |
+
- task:
|
7367 |
+
type: Retrieval
|
7368 |
+
dataset:
|
|
|
7369 |
name: MTEB SciFact (default)
|
|
|
|
|
7370 |
type: mteb/scifact
|
7371 |
+
config: default
|
7372 |
+
split: test
|
7373 |
+
revision: 0228b52cf27578f30900b9e5271d331663a030d7
|
7374 |
metrics:
|
7375 |
- type: main_score
|
7376 |
value: 71.655
|
|
|
7654 |
value: 71.61699999999999
|
7655 |
- type: recall_at_5
|
7656 |
value: 78.361
|
7657 |
+
- task:
|
7658 |
+
type: PairClassification
|
7659 |
+
dataset:
|
|
|
7660 |
name: MTEB SprintDuplicateQuestions (default)
|
|
|
|
|
7661 |
type: mteb/sprintduplicatequestions-pairclassification
|
7662 |
+
config: default
|
7663 |
+
split: test
|
7664 |
+
revision: d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46
|
7665 |
metrics:
|
7666 |
- type: cosine_accuracy
|
7667 |
value: 99.8019801980198
|
|
|
7745 |
value: 90.79754601226993
|
7746 |
- type: similarity_recall
|
7747 |
value: 88.8
|
7748 |
+
- task:
|
7749 |
+
type: Clustering
|
7750 |
+
dataset:
|
|
|
7751 |
name: MTEB StackExchangeClustering (default)
|
|
|
|
|
7752 |
type: mteb/stackexchange-clustering
|
7753 |
+
config: default
|
7754 |
+
split: test
|
7755 |
+
revision: 6cbc1f7b2bc0622f2e39d2c77fa502909748c259
|
7756 |
metrics:
|
7757 |
- type: main_score
|
7758 |
value: 66.63931197758824
|
|
|
7760 |
value: 66.63931197758824
|
7761 |
- type: v_measure_std
|
7762 |
value: 3.896206781511776
|
7763 |
+
- task:
|
7764 |
type: Clustering
|
7765 |
+
dataset:
|
|
|
7766 |
name: MTEB StackExchangeClusteringP2P (default)
|
|
|
|
|
7767 |
type: mteb/stackexchange-clustering-p2p
|
7768 |
+
config: default
|
7769 |
+
split: test
|
7770 |
+
revision: 815ca46b2622cec33ccafc3735d572c266efdb44
|
7771 |
metrics:
|
7772 |
- type: main_score
|
7773 |
value: 38.984892653301884
|
|
|
7775 |
value: 38.984892653301884
|
7776 |
- type: v_measure_std
|
7777 |
value: 1.3308552162270453
|
7778 |
+
- task:
|
7779 |
+
type: Reranking
|
7780 |
+
dataset:
|
|
|
7781 |
name: MTEB StackOverflowDupQuestions (default)
|
|
|
|
|
7782 |
type: mteb/stackoverflowdupquestions-reranking
|
7783 |
+
config: default
|
7784 |
+
split: test
|
7785 |
+
revision: e185fbe320c72810689fc5848eb6114e1ef5ec69
|
7786 |
metrics:
|
7787 |
- type: main_score
|
7788 |
value: 52.71499643455044
|
|
|
7802 |
value: 13.931448578334379
|
7803 |
- type: nAUC_mrr_std
|
7804 |
value: 10.441860004959661
|
7805 |
+
- task:
|
7806 |
+
type: Summarization
|
7807 |
+
dataset:
|
|
|
7808 |
name: MTEB SummEval (default)
|
|
|
|
|
7809 |
type: mteb/summeval
|
7810 |
+
config: default
|
7811 |
+
split: test
|
7812 |
+
revision: cda12ad7615edc362dbf25a00fdd61d3b1eaf93c
|
7813 |
metrics:
|
7814 |
- type: cosine_pearson
|
7815 |
value: 31.5167525286909
|
|
|
7825 |
value: 31.5167525286909
|
7826 |
- type: spearman
|
7827 |
value: 31.218862970706496
|
7828 |
+
- task:
|
7829 |
+
type: Retrieval
|
7830 |
+
dataset:
|
|
|
7831 |
name: MTEB TRECCOVID (default)
|
|
|
|
|
7832 |
type: mteb/trec-covid
|
7833 |
+
config: default
|
7834 |
+
split: test
|
7835 |
+
revision: bb9466bac8153a0349341eb1b22e06409e78ef4e
|
7836 |
metrics:
|
7837 |
- type: main_score
|
7838 |
value: 78.996
|
|
|
8116 |
value: 0.705
|
8117 |
- type: recall_at_5
|
8118 |
value: 1.162
|
8119 |
+
- task:
|
8120 |
type: Retrieval
|
8121 |
+
dataset:
|
|
|
8122 |
name: MTEB Touche2020 (default)
|
|
|
|
|
8123 |
type: mteb/touche2020
|
8124 |
+
config: default
|
8125 |
+
split: test
|
8126 |
+
revision: a34f9a33db75fa0cbb21bb5cfc3dae8dc8bec93f
|
8127 |
metrics:
|
8128 |
- type: main_score
|
8129 |
value: 24.234
|
|
|
8407 |
value: 6.625
|
8408 |
- type: recall_at_5
|
8409 |
value: 9.094
|
8410 |
+
- task:
|
8411 |
+
type: Classification
|
8412 |
+
dataset:
|
|
|
8413 |
name: MTEB ToxicConversationsClassification (default)
|
|
|
|
|
8414 |
type: mteb/toxic_conversations_50k
|
8415 |
+
config: default
|
8416 |
+
split: test
|
8417 |
+
revision: edfaf9da55d3dd50d43143d90c1ac476895ae6de
|
8418 |
metrics:
|
8419 |
- type: accuracy
|
8420 |
value: 72.822265625
|
|
|
8428 |
value: 78.7454393727821
|
8429 |
- type: main_score
|
8430 |
value: 72.822265625
|
8431 |
+
- task:
|
8432 |
type: Classification
|
8433 |
+
dataset:
|
|
|
8434 |
name: MTEB TweetSentimentExtractionClassification (default)
|
|
|
|
|
8435 |
type: mteb/tweet_sentiment_extraction
|
8436 |
+
config: default
|
8437 |
+
split: test
|
8438 |
+
revision: d604517c81ca91fe16a244d1248fc021f9ecee7a
|
8439 |
metrics:
|
8440 |
- type: accuracy
|
8441 |
value: 72.54385964912281
|
|
|
8445 |
value: 72.18022450339639
|
8446 |
- type: main_score
|
8447 |
value: 72.54385964912281
|
8448 |
+
- task:
|
8449 |
+
type: Clustering
|
8450 |
+
dataset:
|
|
|
8451 |
name: MTEB TwentyNewsgroupsClustering (default)
|
|
|
|
|
8452 |
type: mteb/twentynewsgroups-clustering
|
8453 |
+
config: default
|
8454 |
+
split: test
|
8455 |
+
revision: 6125ec4e24fa026cec8a478383ee943acfbd5449
|
8456 |
metrics:
|
8457 |
- type: main_score
|
8458 |
value: 57.41861450414374
|
|
|
8460 |
value: 57.41861450414374
|
8461 |
- type: v_measure_std
|
8462 |
value: 1.1732394227153524
|
8463 |
+
- task:
|
8464 |
+
type: PairClassification
|
8465 |
+
dataset:
|
|
|
8466 |
name: MTEB TwitterSemEval2015 (default)
|
|
|
|
|
8467 |
type: mteb/twittersemeval2015-pairclassification
|
8468 |
+
config: default
|
8469 |
+
split: test
|
8470 |
+
revision: 70970daeab8776df92f5ea462b6173c0b46fd2d1
|
8471 |
metrics:
|
8472 |
- type: cosine_accuracy
|
8473 |
value: 85.65893783155511
|
|
|
8551 |
value: 64.0855106888361
|
8552 |
- type: similarity_recall
|
8553 |
value: 71.18733509234828
|
8554 |
+
- task:
|
8555 |
type: PairClassification
|
8556 |
+
dataset:
|
|
|
8557 |
name: MTEB TwitterURLCorpus (default)
|
|
|
|
|
8558 |
type: mteb/twitterurlcorpus-pairclassification
|
8559 |
+
config: default
|
8560 |
+
split: test
|
8561 |
+
revision: 8b6510b0b1fa4e4c4f879467980e9be563ec1cdf
|
8562 |
metrics:
|
8563 |
- type: cosine_accuracy
|
8564 |
value: 88.86754375751931
|
|
|
8642 |
value: 74.19310344827586
|
8643 |
- type: similarity_recall
|
8644 |
value: 82.83030489682784
|
|
|
|
|
8645 |
---
|
8646 |
# Contextual Document Embeddings (CDE)
|
8647 |
|
config.json
CHANGED
@@ -1,8 +1,14 @@
|
|
1 |
{
|
|
|
2 |
"architecture": "transductive",
|
3 |
"architectures": [
|
4 |
-
"
|
5 |
],
|
|
|
|
|
|
|
|
|
|
|
6 |
"biencoder_pooling_strategy": "mean",
|
7 |
"cache_dir": null,
|
8 |
"config_name": null,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "/fsx-checkpoints/jxm/cde/2024-09-18-supervised-final-bge--epoch-4/checkpoint-1820",
|
3 |
"architecture": "transductive",
|
4 |
"architectures": [
|
5 |
+
"DatasetTransformer"
|
6 |
],
|
7 |
+
"attn_implementation": null,
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "misc.ContextualModelConfig",
|
10 |
+
"AutoModel": "model.DatasetTransformer"
|
11 |
+
},
|
12 |
"biencoder_pooling_strategy": "mean",
|
13 |
"cache_dir": null,
|
14 |
"config_name": null,
|
misc.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Iterable, List, Tuple, Union
|
2 |
+
|
3 |
+
import collections
|
4 |
+
import functools
|
5 |
+
import glob
|
6 |
+
import json
|
7 |
+
import hashlib
|
8 |
+
import itertools
|
9 |
+
import logging
|
10 |
+
import multiprocessing
|
11 |
+
import os
|
12 |
+
import pickle
|
13 |
+
import random
|
14 |
+
import requests
|
15 |
+
import sys
|
16 |
+
import zipfile
|
17 |
+
|
18 |
+
import datasets
|
19 |
+
import numpy as np
|
20 |
+
import safetensors
|
21 |
+
import torch
|
22 |
+
import tqdm
|
23 |
+
import transformers
|
24 |
+
|
25 |
+
from cde.lib.dist import get_num_proc, get_rank
|
26 |
+
|
27 |
+
|
28 |
+
def get_cde_cache_dir() -> str:
|
29 |
+
script_directory = os.path.normpath(
|
30 |
+
os.path.join(
|
31 |
+
os.path.dirname(os.path.abspath(__file__)),
|
32 |
+
os.pardir, os.pardir,
|
33 |
+
)
|
34 |
+
)
|
35 |
+
return os.path.join(script_directory, "data")
|
36 |
+
|
37 |
+
|
38 |
+
def get_cache_location_from_kwargs(**kwargs):
|
39 |
+
cache_location = os.path.join(
|
40 |
+
get_cde_cache_dir(), "cluster"
|
41 |
+
)
|
42 |
+
os.makedirs(cache_location, exist_ok=True)
|
43 |
+
return os.path.join(cache_location, md5_hash_kwargs(**kwargs))
|
44 |
+
|
45 |
+
|
46 |
+
def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
47 |
+
qrels_idxs = collections.defaultdict(list)
|
48 |
+
qrels_scores = collections.defaultdict(list)
|
49 |
+
corpus_ids = np.array(corpus['_id'])
|
50 |
+
skipped_qrels = 0
|
51 |
+
|
52 |
+
for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False):
|
53 |
+
#
|
54 |
+
# example:
|
55 |
+
# {
|
56 |
+
# 'query-id': 1,
|
57 |
+
# 'corpus-id': 'b0680508-2019-04-18T13:48:51Z-00002-000',
|
58 |
+
# 'score': 2
|
59 |
+
# }
|
60 |
+
#
|
61 |
+
q_id = str(ex['query-id'])
|
62 |
+
c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0]
|
63 |
+
#
|
64 |
+
assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)"
|
65 |
+
#
|
66 |
+
if len(c_idxs):
|
67 |
+
qrels_idxs[q_id].append(c_idxs[0])
|
68 |
+
qrels_scores[q_id].append(ex['score'])
|
69 |
+
else:
|
70 |
+
skipped_qrels += 1
|
71 |
+
#
|
72 |
+
|
73 |
+
if skipped_qrels > 0:
|
74 |
+
logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.')
|
75 |
+
|
76 |
+
return qrels_idxs, qrels_scores
|
77 |
+
|
78 |
+
|
79 |
+
def process_qrels(
|
80 |
+
corpus: datasets.Dataset, qrels: datasets.Dataset,
|
81 |
+
use_cache: bool = True
|
82 |
+
) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
83 |
+
dataset_cache_file = '_'.join(
|
84 |
+
(corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename'])
|
85 |
+
)
|
86 |
+
cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p'
|
87 |
+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
88 |
+
|
89 |
+
if not (use_cache and os.path.exists(cache_file)):
|
90 |
+
qrels_idxs, qrels_scores = process_qrels_uncached(
|
91 |
+
corpus=corpus, qrels=qrels
|
92 |
+
)
|
93 |
+
if use_cache:
|
94 |
+
pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb'))
|
95 |
+
else:
|
96 |
+
qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb'))
|
97 |
+
|
98 |
+
return qrels_idxs, qrels_scores
|
99 |
+
|
100 |
+
|
101 |
+
def strip_extension(filename: str) -> str:
|
102 |
+
"""Strips file extension.
|
103 |
+
|
104 |
+
Ex:
|
105 |
+
>> strip_extension('/root/dir/sub/file.ext')
|
106 |
+
'/root/dir/sub/file'
|
107 |
+
"""
|
108 |
+
return os.path.splitext(filename)[0]
|
109 |
+
|
110 |
+
|
111 |
+
def md5_hash(t: Tuple[str]) -> str:
|
112 |
+
return hashlib.md5('__'.join(t).encode()).hexdigest()
|
113 |
+
|
114 |
+
|
115 |
+
def md5_hash_kwargs(**kwargs) -> str:
|
116 |
+
# We ignore special hf args that start with _ like '__cached__setup_devices'.
|
117 |
+
safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')}
|
118 |
+
s = json.dumps(safe_kwargs, sort_keys=True)
|
119 |
+
return hashlib.md5(s.encode()).hexdigest()
|
120 |
+
|
121 |
+
def download_url(url: str, save_path: str, chunk_size: int = 1024):
|
122 |
+
"""Download url with progress bar using tqdm
|
123 |
+
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
|
124 |
+
Args:
|
125 |
+
url (str): downloadable url
|
126 |
+
save_path (str): local path to save the downloaded file
|
127 |
+
chunk_size (int, optional): chunking of files. Defaults to 1024.
|
128 |
+
"""
|
129 |
+
r = requests.get(url, stream=True)
|
130 |
+
total = int(r.headers.get('Content-Length', 0))
|
131 |
+
with open(save_path, 'wb') as fd, tqdm.tqdm(
|
132 |
+
desc=save_path,
|
133 |
+
total=total,
|
134 |
+
unit='iB',
|
135 |
+
unit_scale=True,
|
136 |
+
unit_divisor=chunk_size,
|
137 |
+
) as bar:
|
138 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
139 |
+
size = fd.write(data)
|
140 |
+
bar.update(size)
|
141 |
+
|
142 |
+
|
143 |
+
def unzip(zip_file: str, out_dir: str):
|
144 |
+
print("unzipping =>", zip_file)
|
145 |
+
zip_ = zipfile.ZipFile(zip_file, "r")
|
146 |
+
zip_.extractall(path=out_dir)
|
147 |
+
zip_.close()
|
148 |
+
|
149 |
+
|
150 |
+
def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
|
151 |
+
os.makedirs(out_dir, exist_ok=True)
|
152 |
+
dataset = url.split("/")[-1]
|
153 |
+
zip_file = os.path.join(out_dir, dataset)
|
154 |
+
|
155 |
+
if not os.path.isfile(zip_file):
|
156 |
+
logging.info("Downloading {} ...".format(dataset))
|
157 |
+
download_url(url, zip_file, chunk_size)
|
158 |
+
|
159 |
+
if not os.path.isdir(zip_file.replace(".zip", "")):
|
160 |
+
logging.info("Unzipping {} ...".format(dataset))
|
161 |
+
unzip(zip_file, out_dir)
|
162 |
+
|
163 |
+
return os.path.join(out_dir, dataset.replace(".zip", ""))
|
164 |
+
|
165 |
+
|
166 |
+
def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable:
|
167 |
+
if get_rank() == 0:
|
168 |
+
return tqdm.tqdm(iterable, **kwargs)
|
169 |
+
else:
|
170 |
+
return iterable
|
171 |
+
|
172 |
+
|
173 |
+
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
174 |
+
"""We create a dummy configuration class that will just set properties
|
175 |
+
based on whatever kwargs we pass in.
|
176 |
+
|
177 |
+
When this class is initialized (see experiments.py) we pass in the
|
178 |
+
union of all data, model, and training args, all of which should
|
179 |
+
get saved to the config json.
|
180 |
+
"""
|
181 |
+
|
182 |
+
def __init__(self, **kwargs):
|
183 |
+
for key, value in kwargs.items():
|
184 |
+
try:
|
185 |
+
json.dumps(value)
|
186 |
+
setattr(self, key, value)
|
187 |
+
except TypeError:
|
188 |
+
# value was not JSON-serializable, skip
|
189 |
+
continue
|
190 |
+
super().__init__()
|
191 |
+
|
192 |
+
|
193 |
+
def independent_crop(
|
194 |
+
input_ids: torch.Tensor, pad_token_id: int,
|
195 |
+
l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]:
|
196 |
+
"""Returns two independent crops from input_ids.
|
197 |
+
|
198 |
+
Assumes input_ids has a beginning and end token, like
|
199 |
+
[101, ..., 102, 0, 0, 0].
|
200 |
+
|
201 |
+
Args:
|
202 |
+
input_ids: tensor of IDs
|
203 |
+
pad_token_id: ID of pad tokens in input_ids
|
204 |
+
l1: length of span 1, cropped
|
205 |
+
l2: length of span 2, cropped
|
206 |
+
Returns:
|
207 |
+
span1: first crop (of length l1)
|
208 |
+
span2: second crop (of length l2)
|
209 |
+
"""
|
210 |
+
# Count tokens until pad.
|
211 |
+
if (input_ids == pad_token_id).sum() == 0:
|
212 |
+
N = len(input_ids)
|
213 |
+
else:
|
214 |
+
N = (input_ids == pad_token_id).int().argmax().item()
|
215 |
+
|
216 |
+
####
|
217 |
+
###
|
218 |
+
##
|
219 |
+
## Contriever: We use the random cropping data
|
220 |
+
## augmentation, with documents of 256 tokens and span
|
221 |
+
## sizes sampled between 5% and 50% of the document
|
222 |
+
## length
|
223 |
+
##
|
224 |
+
###
|
225 |
+
#####
|
226 |
+
####### LaPraDor: The maximum lengths set for queries and
|
227 |
+
####### documents are 64 and 350...
|
228 |
+
#####
|
229 |
+
# TODO is this divide-by-two a good idea? (Don't want s1=s2 ever..)
|
230 |
+
nl1 = min(N//2, l1)
|
231 |
+
nl2 = min(N//2, l2)
|
232 |
+
|
233 |
+
s1_start = random.randint(1, N-nl1)
|
234 |
+
s2_start = random.randint(1, N-nl2)
|
235 |
+
|
236 |
+
s1_idxs = itertools.chain(
|
237 |
+
[0], range(s1_start, s1_start+nl1), [N-1]
|
238 |
+
)
|
239 |
+
s1 = input_ids[torch.tensor(list(s1_idxs))]
|
240 |
+
s2_idxs = itertools.chain(
|
241 |
+
[0], range(s2_start, s2_start+nl2), [N-1]
|
242 |
+
)
|
243 |
+
s2 = input_ids[torch.tensor(list(s2_idxs))]
|
244 |
+
return (s1, s2)
|
245 |
+
|
246 |
+
|
247 |
+
def load_dataset_tables(
|
248 |
+
files: Iterable[str], num_workers: int = 16
|
249 |
+
) -> Iterable[datasets.table.MemoryMappedTable]:
|
250 |
+
import concurrent
|
251 |
+
from multiprocessing import Pool
|
252 |
+
|
253 |
+
# num_workers = min(num_workers, len(files))
|
254 |
+
num_workers = min(32, len(files))
|
255 |
+
|
256 |
+
use_threads = True
|
257 |
+
if use_threads:
|
258 |
+
pool_cls = concurrent.futures.ThreadPoolExecutor
|
259 |
+
pool_kwargs = {"max_workers": num_workers}
|
260 |
+
else:
|
261 |
+
pool_cls = Pool
|
262 |
+
pool_kwargs = {"processes": num_workers}
|
263 |
+
|
264 |
+
with pool_cls(**pool_kwargs) as pool:
|
265 |
+
if len(files) > 10:
|
266 |
+
files = tqdm_if_main_worker(
|
267 |
+
files,
|
268 |
+
desc=f"Loading {len(files)} files with {num_workers} workers",
|
269 |
+
total=len(files),
|
270 |
+
colour="#ffbd88"
|
271 |
+
)
|
272 |
+
|
273 |
+
result = list(
|
274 |
+
pool.map(datasets.table.MemoryMappedTable.from_file, files)
|
275 |
+
)
|
276 |
+
return result
|
277 |
+
|
278 |
+
|
279 |
+
def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset:
|
280 |
+
logging.info(f"fast_load_from_disk called with path:", cache_path)
|
281 |
+
dataset_info_path = os.path.join(cache_path, "dataset_info.json")
|
282 |
+
with open(dataset_info_path, encoding="utf-8") as dataset_info_file:
|
283 |
+
dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file))
|
284 |
+
|
285 |
+
dataset_state_path = os.path.join(cache_path, "state.json")
|
286 |
+
with open(dataset_state_path, encoding="utf-8") as state_file:
|
287 |
+
state = json.load(state_file)
|
288 |
+
|
289 |
+
files = glob.glob(os.path.join(cache_path, "data-*.arrow"))
|
290 |
+
files = sorted(files)
|
291 |
+
num_workers = get_num_proc()
|
292 |
+
ds_tables = load_dataset_tables(
|
293 |
+
files=files,
|
294 |
+
num_workers=num_workers
|
295 |
+
)
|
296 |
+
arrow_table = datasets.table.concat_tables(ds_tables)
|
297 |
+
|
298 |
+
split = state["_split"]
|
299 |
+
split = datasets.splits.Split(split) if split is not None else split
|
300 |
+
|
301 |
+
# print("returning dataset")
|
302 |
+
return datasets.Dataset(
|
303 |
+
arrow_table=arrow_table,
|
304 |
+
info=dataset_info,
|
305 |
+
split=split,
|
306 |
+
fingerprint=state["_fingerprint"],
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def tokenize_dataset(
|
311 |
+
dataset: datasets.Dataset,
|
312 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
313 |
+
max_length: int,
|
314 |
+
text_key: str,
|
315 |
+
padding_strategy: str
|
316 |
+
) -> datasets.Dataset:
|
317 |
+
def tokenize_text(ex: Dict) -> Dict:
|
318 |
+
tt = tokenizer(
|
319 |
+
ex[text_key],
|
320 |
+
max_length=max_length,
|
321 |
+
truncation=True,
|
322 |
+
padding=padding_strategy,
|
323 |
+
)
|
324 |
+
for k,v in tt.items():
|
325 |
+
ex[f"{text_key}_{k}"] = v
|
326 |
+
ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]]
|
327 |
+
return ex
|
328 |
+
|
329 |
+
# generate unique hash for tokenizer
|
330 |
+
vocab = tokenizer.vocab
|
331 |
+
vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word]))
|
332 |
+
vocab_hash = md5_hash(vocab_words)
|
333 |
+
|
334 |
+
data_fingerprint = '__'.join((
|
335 |
+
dataset._fingerprint, str(vocab_hash), str(max_length),
|
336 |
+
text_key, padding_strategy
|
337 |
+
))
|
338 |
+
data_fingerprint = md5_hash(data_fingerprint)
|
339 |
+
dataset = dataset.map(
|
340 |
+
tokenize_text,
|
341 |
+
new_fingerprint=data_fingerprint,
|
342 |
+
batched=True,
|
343 |
+
load_from_cache_file=True,
|
344 |
+
)
|
345 |
+
return dataset
|
346 |
+
|
347 |
+
|
348 |
+
class TensorRunningAverages:
|
349 |
+
_store_sum: Dict[str, torch.Tensor]
|
350 |
+
_store_total: Dict[str, torch.Tensor]
|
351 |
+
|
352 |
+
def __init__(self):
|
353 |
+
self._store_sum = {}
|
354 |
+
self._store_total = {}
|
355 |
+
|
356 |
+
def __iter__(self) -> Iterable[str]:
|
357 |
+
return iter(self._store_sum.keys())
|
358 |
+
|
359 |
+
def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None:
|
360 |
+
if key not in self._store_sum:
|
361 |
+
self.clear(key)
|
362 |
+
if isinstance(val, torch.Tensor):
|
363 |
+
val = val.item() # tensor -> num
|
364 |
+
self._store_sum[key] += val
|
365 |
+
self._store_total[key] += 1
|
366 |
+
|
367 |
+
def get(self, key: str) -> float:
|
368 |
+
total = max(self._store_total.get(key).item(), 1.0)
|
369 |
+
return (self._store_sum[key] / float(total)).item() or 0.0
|
370 |
+
|
371 |
+
def clear(self, key: str) -> None:
|
372 |
+
self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32)
|
373 |
+
self._store_total[key] = torch.tensor(0, dtype=torch.int32)
|
374 |
+
|
375 |
+
def clear_all(self) -> None:
|
376 |
+
for key in self._store_sum:
|
377 |
+
self.clear(key)
|
378 |
+
|
379 |
+
def get_and_clear_all(self) -> Dict[str, float]:
|
380 |
+
metrics = {}
|
381 |
+
for key in self:
|
382 |
+
metrics[key] = self.get(key)
|
383 |
+
self.clear(key)
|
384 |
+
return metrics
|
385 |
+
|
386 |
+
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
387 |
+
transformers.PreTrainedModel,
|
388 |
+
transformers.PreTrainedTokenizer
|
389 |
+
]:
|
390 |
+
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
391 |
+
from cde.lib.nomic_bert import NomicBertModel
|
392 |
+
if name.endswith("--from-scratch"):
|
393 |
+
name = name.replace("--from-scratch", "")
|
394 |
+
config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
|
395 |
+
model = NomicBertModel._from_config(config)
|
396 |
+
else:
|
397 |
+
model = NomicBertModel.from_pretrained(
|
398 |
+
name, add_pooling_layer=False
|
399 |
+
)
|
400 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
401 |
+
elif name in ["gtr-base", "gtr_base"]:
|
402 |
+
model = transformers.AutoModel.from_pretrained(
|
403 |
+
"sentence-transformers/gtr-t5-base"
|
404 |
+
).encoder
|
405 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
406 |
+
"sentence-transformers/gtr-t5-base"
|
407 |
+
)
|
408 |
+
elif name == "pile-t5-base-encoder":
|
409 |
+
model = transformers.AutoModel.from_pretrained(
|
410 |
+
"EleutherAI/pile-t5-base"
|
411 |
+
).encoder
|
412 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
413 |
+
"EleutherAI/pile-t5-base"
|
414 |
+
)
|
415 |
+
tokenizer.pad_token = tokenizer.eos_token
|
416 |
+
elif name == "pile-t5-base-decoder":
|
417 |
+
model = transformers.AutoModel.from_pretrained(
|
418 |
+
"EleutherAI/pile-t5-base"
|
419 |
+
).decoder
|
420 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
421 |
+
"EleutherAI/pile-t5-base"
|
422 |
+
)
|
423 |
+
tokenizer.pad_token = tokenizer.eos_token
|
424 |
+
elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name):
|
425 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
426 |
+
name,
|
427 |
+
# torch_dtype=torch.bfloat16,
|
428 |
+
attn_implementation="flash_attention_2",
|
429 |
+
low_cpu_mem_usage=True,
|
430 |
+
# device_map="auto",
|
431 |
+
)
|
432 |
+
model.padding_side = "right"
|
433 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
434 |
+
tokenizer.pad_token = tokenizer.eos_token
|
435 |
+
tokenizer.add_eos_token = True
|
436 |
+
else:
|
437 |
+
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True)
|
438 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
439 |
+
|
440 |
+
# if use_bettertransformer:
|
441 |
+
# from optimum.bettertransformer import BetterTransformer
|
442 |
+
# model = BetterTransformer.transform(model)
|
443 |
+
return model, tokenizer
|
444 |
+
|
445 |
+
|
446 |
+
def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str):
|
447 |
+
key += "_"
|
448 |
+
return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)}
|
449 |
+
|
450 |
+
|
451 |
+
def load_model_state_dict_from_path(folder: str) -> Dict:
|
452 |
+
checkpoint_folder = transformers.trainer_utils.get_last_checkpoint(folder)
|
453 |
+
if checkpoint_folder is None:
|
454 |
+
raise FileNotFoundError(f"no checkpoint found in {folder}")
|
455 |
+
WEIGHTS_NAME = "model.safetensors"
|
456 |
+
weights_path = os.path.join(checkpoint_folder, WEIGHTS_NAME)
|
457 |
+
if not os.path.exists(weights_path):
|
458 |
+
raise FileNotFoundError(f"no model weights found at {weights_path}")
|
459 |
+
return safetensors.torch.load_file(weights_path, device="cpu")
|
460 |
+
|
461 |
+
def count_cpus() -> int:
|
462 |
+
try:
|
463 |
+
return len(os.sched_getaffinity(0))
|
464 |
+
except AttributeError:
|
465 |
+
return multiprocessing.cpu_count()
|
466 |
+
|
467 |
+
|
468 |
+
def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]:
|
469 |
+
all_indices = []
|
470 |
+
for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"):
|
471 |
+
rand_perm = torch.randperm(len(batch_tensor), generator=g)
|
472 |
+
batch_list = batch_tensor[rand_perm].tolist()
|
473 |
+
all_indices.extend(batch_list)
|
474 |
+
return all_indices
|
475 |
+
|
476 |
+
|
477 |
+
# def shuffle_batches_multiproc(g: torch.Generator, list_of_tensors: List[torch.Tensor], num_processes: int = 8) -> List[int]:
|
478 |
+
# all_indices = []
|
479 |
+
# print(f"Shuffling {len(list_of_tensors)} tensors with {num_processes} workers.")
|
480 |
+
# pbar = tqdm_if_main_worker(list_of_tensors, colour="orange", desc=f"Sampler shuffling per-batch (nproc={num_processes})")
|
481 |
+
# pool = multiprocessing.Pool(processes=num_processes)
|
482 |
+
# chunk_size = len(list_of_tensors) // num_processes
|
483 |
+
# chunks = [list_of_tensors[i:i + chunk_size] for i in range(0, len(list_of_tensors), chunk_size)]
|
484 |
+
# worker_func = functools.partial(shuffle_batches, g=g)
|
485 |
+
# results = pool.map(worker_func, chunks)
|
486 |
+
# all_indices = []
|
487 |
+
# for result in results:
|
488 |
+
# all_indices.extend(result)
|
489 |
+
# pbar.update()
|
490 |
+
# return all_indices
|
491 |
+
|
492 |
+
|
493 |
+
def exit_if_running_or_finished_wandb(
|
494 |
+
project_name: str,
|
495 |
+
exp_group: str, exp_name: str
|
496 |
+
) -> None:
|
497 |
+
print("Checking if experiment is already running...")
|
498 |
+
import wandb
|
499 |
+
|
500 |
+
api = wandb.Api()
|
501 |
+
running_runs = api.runs(
|
502 |
+
path="tti-nomic-7",
|
503 |
+
filters={
|
504 |
+
"display_name": exp_name,
|
505 |
+
"state": {"$regex": "Running|Finished"},
|
506 |
+
"config.exp_group": exp_group,
|
507 |
+
}
|
508 |
+
)
|
509 |
+
print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.")
|
510 |
+
|
511 |
+
if len(running_runs) > 0:
|
512 |
+
print("Exiting because experiment is already running or completed.")
|
513 |
+
sys.exit(0)
|
514 |
+
|
model.py
ADDED
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Union
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import transformers
|
7 |
+
|
8 |
+
from cde.lib.dist import print0
|
9 |
+
from cde.lib.tensor import mean_pool, mean_pool_3d, mean_pool_weighted, last_token_pool
|
10 |
+
|
11 |
+
from cde.lib import load_embedder_and_tokenizer, ContextualModelConfig
|
12 |
+
|
13 |
+
|
14 |
+
def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None:
|
15 |
+
if hasattr(model, 'transformer'):
|
16 |
+
if hasattr(model.transformer, 'h'):
|
17 |
+
# gpt2
|
18 |
+
model.transformer.h = model.transformer.h[:n_layers]
|
19 |
+
else:
|
20 |
+
model.transformer.layer = model.transformer.layer[:n_layers]
|
21 |
+
elif hasattr(model, 'encoder'):
|
22 |
+
if hasattr(model.encoder, 'layers'):
|
23 |
+
model.encoder.layers = model.encoder.layers[:n_layers]
|
24 |
+
else:
|
25 |
+
model.encoder.layer = model.encoder.layer[:n_layers]
|
26 |
+
else:
|
27 |
+
raise RuntimeError(f"unknown how to limit layers of model {type(model)}")
|
28 |
+
|
29 |
+
|
30 |
+
def disable_dropout(model: torch.nn.Module):
|
31 |
+
dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)]
|
32 |
+
for m in dropout_modules:
|
33 |
+
m.p = 0.0
|
34 |
+
print0(
|
35 |
+
f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}"
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def disable_causality(model: torch.nn.Module):
|
40 |
+
disabled_modules = 0
|
41 |
+
for m in model.modules():
|
42 |
+
if hasattr(m, "is_causal"):
|
43 |
+
m.is_causal = False
|
44 |
+
disabled_modules += 1
|
45 |
+
print0(
|
46 |
+
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
|
47 |
+
)
|
48 |
+
|
49 |
+
class ContextualModelMixin(nn.Module):
|
50 |
+
@property
|
51 |
+
def num_corpus_tokens(self) -> int:
|
52 |
+
return self.transductive_corpus_size * self.transductive_tokens_per_document
|
53 |
+
|
54 |
+
def contextual_init(self):
|
55 |
+
self.n_soft_prompt = 8
|
56 |
+
self.prompt_projection = torch.nn.Sequential(
|
57 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
58 |
+
torch.nn.ReLU(),
|
59 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt)
|
60 |
+
)
|
61 |
+
self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1)
|
62 |
+
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
|
63 |
+
self.randomize_dataset_sequence_order = True
|
64 |
+
self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0)
|
65 |
+
if self.sequence_dropout_prob > 0.0:
|
66 |
+
self.sequence_dropout_null_embedding = torch.nn.Parameter(
|
67 |
+
torch.randn(self.hidden_size) * 0.01,
|
68 |
+
requires_grad = True
|
69 |
+
)
|
70 |
+
self.output_projection = torch.nn.Sequential(
|
71 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
72 |
+
torch.nn.ReLU(),
|
73 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size)
|
74 |
+
)
|
75 |
+
|
76 |
+
def _prepare_dataset_embeddings(
|
77 |
+
self,
|
78 |
+
input_ids: torch.Tensor, dataset_embeddings: torch.Tensor,
|
79 |
+
null_dataset_embedding: bool = False,
|
80 |
+
) -> torch.Tensor:
|
81 |
+
if not isinstance(dataset_embeddings, torch.Tensor):
|
82 |
+
dataset_embeddings = torch.tensor(dataset_embeddings)
|
83 |
+
|
84 |
+
if len(dataset_embeddings.shape) == 2:
|
85 |
+
# Auto-expand for a batch.
|
86 |
+
dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
|
87 |
+
dataset_embeddings = dataset_embeddings.to(input_ids.device)
|
88 |
+
|
89 |
+
batch_size = input_ids.shape[0]
|
90 |
+
if (self.transductive_tokens_per_document > 1):
|
91 |
+
if self.training:
|
92 |
+
# Choose N random documents to fill our context window with.
|
93 |
+
# This logic is a little confusing but allows us to sample a
|
94 |
+
# different batch *per-document*
|
95 |
+
assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document
|
96 |
+
R = torch.randint(
|
97 |
+
low=0,
|
98 |
+
high=len(dataset_embeddings),
|
99 |
+
size=(batch_size, self.config.transductive_corpus_size),
|
100 |
+
device=dataset_embeddings.device
|
101 |
+
)
|
102 |
+
# TODO make this deterministic somehow for evaluation?
|
103 |
+
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
|
104 |
+
else:
|
105 |
+
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
|
106 |
+
# print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape)
|
107 |
+
|
108 |
+
if dataset_embeddings.shape[1] > self.num_corpus_tokens:
|
109 |
+
# If too many dataset embeddings are passed in, just take the first N until
|
110 |
+
# we have the proper number.
|
111 |
+
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
|
112 |
+
|
113 |
+
_, corpus_size, _hidden_size = dataset_embeddings.shape
|
114 |
+
if _ == 1:
|
115 |
+
# Auto-expand for a batch.
|
116 |
+
dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1))
|
117 |
+
|
118 |
+
if self.training and self.sequence_dropout_prob > 0.0:
|
119 |
+
sequence_dropout_mask = (
|
120 |
+
torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob
|
121 |
+
)
|
122 |
+
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
123 |
+
dataset_embeddings = torch.where(
|
124 |
+
sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings
|
125 |
+
)
|
126 |
+
elif null_dataset_embedding:
|
127 |
+
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
128 |
+
dataset_embeddings = null_embeddings
|
129 |
+
|
130 |
+
# print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}")
|
131 |
+
|
132 |
+
# backbone_max_seq_length = self.backbone.config.max_trained_positions
|
133 |
+
# assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
|
134 |
+
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
|
135 |
+
soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size))
|
136 |
+
soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) # -> (b, 4+b, d) # soft_prompt.repeat((len(input_ids), 1, 1))
|
137 |
+
soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1)
|
138 |
+
|
139 |
+
# print(f"[ContextualModelMixin] soft_prompt.shape = {soft_prompt.shape}")
|
140 |
+
|
141 |
+
if self.training and self.randomize_dataset_sequence_order:
|
142 |
+
randomized_order = torch.stack(
|
143 |
+
[
|
144 |
+
torch.cat(
|
145 |
+
(
|
146 |
+
torch.randperm(corpus_size, device=soft_prompt.device),
|
147 |
+
torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size
|
148 |
+
), dim=0)
|
149 |
+
for _ in range(batch_size)])
|
150 |
+
randomized_order = randomized_order.to(soft_prompt.device)
|
151 |
+
soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt))
|
152 |
+
|
153 |
+
return soft_prompt
|
154 |
+
|
155 |
+
class BiEncoder(transformers.PreTrainedModel):
|
156 |
+
embedder: transformers.PreTrainedModel
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
config, #: transformers.PreTrainedConfig,
|
160 |
+
):
|
161 |
+
super().__init__(config=config)
|
162 |
+
embedder, _ = load_embedder_and_tokenizer(
|
163 |
+
config.embedder,
|
164 |
+
)
|
165 |
+
|
166 |
+
if config.limit_layers:
|
167 |
+
print0(f"Limiting layers to {config.limit_layers}")
|
168 |
+
limit_layers(embedder, config.limit_layers)
|
169 |
+
|
170 |
+
self.embedder = embedder
|
171 |
+
# if ("t5" in embedder.config.model_type):
|
172 |
+
# print0(f"using torch.compile() on embedder of type `{embedder.config.model_type}`")
|
173 |
+
# self.embedder = torch.compile(self.embedder)
|
174 |
+
self.hidden_size = self.embedder.config.hidden_size
|
175 |
+
# Allow pooling to multiple tokens per document
|
176 |
+
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
|
177 |
+
self.mlp = torch.nn.Sequential(
|
178 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
179 |
+
torch.nn.GELU(),
|
180 |
+
torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size),
|
181 |
+
)
|
182 |
+
self.temp = config.logit_scale
|
183 |
+
|
184 |
+
if config.disable_dropout:
|
185 |
+
disable_dropout(self)
|
186 |
+
self.pooling_strategy = vars(config).get("pooling_strategy", "mean")
|
187 |
+
|
188 |
+
def forward(
|
189 |
+
self,
|
190 |
+
input_ids: torch.Tensor,
|
191 |
+
attention_mask: torch.Tensor,
|
192 |
+
dataset_input_ids: Optional[torch.Tensor] = None,
|
193 |
+
dataset_attention_mask: Optional[torch.Tensor] = None,
|
194 |
+
token_type_ids = None,
|
195 |
+
output_hidden_states: bool = False,
|
196 |
+
) -> torch.Tensor:
|
197 |
+
"""
|
198 |
+
query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim)
|
199 |
+
document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim)
|
200 |
+
where the corpus_size >= batch_size and is structured like this:
|
201 |
+
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
|
202 |
+
for a corpus with three documents and two hard negatives per document
|
203 |
+
"""
|
204 |
+
# del dataset_input_ids
|
205 |
+
# del dataset_attention_mask
|
206 |
+
del token_type_ids
|
207 |
+
|
208 |
+
# from cde.lib.dist import get_rank
|
209 |
+
# tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
|
210 |
+
# if get_rank() == 0:
|
211 |
+
# breakpoint()
|
212 |
+
# torch.distributed.barrier()
|
213 |
+
outputs = (
|
214 |
+
self.embedder(
|
215 |
+
input_ids=input_ids,
|
216 |
+
attention_mask=attention_mask,
|
217 |
+
).last_hidden_state
|
218 |
+
)
|
219 |
+
|
220 |
+
if self.transductive_tokens_per_document > 1:
|
221 |
+
document_embeddings = None
|
222 |
+
batch_size, seq_length, output_dim = outputs.shape
|
223 |
+
|
224 |
+
if seq_length % self.transductive_tokens_per_document != 0:
|
225 |
+
# Pad to nearest multiple
|
226 |
+
n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document)
|
227 |
+
outputs = torch.cat(
|
228 |
+
(outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)),
|
229 |
+
dim=1
|
230 |
+
)
|
231 |
+
attention_mask = torch.cat(
|
232 |
+
(attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)),
|
233 |
+
dim=1
|
234 |
+
)
|
235 |
+
seq_length += n_extra_embeds
|
236 |
+
print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask")
|
237 |
+
|
238 |
+
# print("ftransductive_tokens_per_document {self.transductive_tokens_per_document} outputs.shape =", outputs.shape)
|
239 |
+
|
240 |
+
outputs = outputs.reshape(
|
241 |
+
(batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim)
|
242 |
+
)
|
243 |
+
|
244 |
+
attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1))
|
245 |
+
document_embeddings = mean_pool_3d(outputs, attention_mask)
|
246 |
+
|
247 |
+
document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim))
|
248 |
+
else:
|
249 |
+
if self.pooling_strategy == "mean":
|
250 |
+
document_embeddings = mean_pool(outputs, attention_mask)
|
251 |
+
else:
|
252 |
+
document_embeddings = document_embeddings.max(dim=1)
|
253 |
+
output = self.mlp(document_embeddings)
|
254 |
+
|
255 |
+
if output_hidden_states:
|
256 |
+
return {
|
257 |
+
"hidden_states": outputs,
|
258 |
+
"pooled": output,
|
259 |
+
}
|
260 |
+
else:
|
261 |
+
return output
|
262 |
+
|
263 |
+
|
264 |
+
class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin):
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
config,
|
268 |
+
dataset_backbone: transformers.PreTrainedModel,
|
269 |
+
first_stage_hidden_size: int,
|
270 |
+
):
|
271 |
+
super().__init__(config=config)
|
272 |
+
self.backbone = dataset_backbone
|
273 |
+
self.backbone_hidden_size = self.backbone.config.hidden_size
|
274 |
+
self.hidden_size = first_stage_hidden_size # Input token size
|
275 |
+
self.contextual_init()
|
276 |
+
disable_causality(self.backbone)
|
277 |
+
|
278 |
+
self.input_ln = torch.nn.LayerNorm(
|
279 |
+
self.backbone_hidden_size,
|
280 |
+
eps=1e-5
|
281 |
+
)
|
282 |
+
|
283 |
+
# Override contextual init
|
284 |
+
self.output_projection = torch.nn.Sequential(
|
285 |
+
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
286 |
+
torch.nn.ReLU(),
|
287 |
+
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size)
|
288 |
+
)
|
289 |
+
self._shift_rotary_embedding()
|
290 |
+
|
291 |
+
@property
|
292 |
+
def num_corpus_tokens(self) -> int:
|
293 |
+
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
|
294 |
+
|
295 |
+
@property
|
296 |
+
def corpus_token_ratio(self) -> float:
|
297 |
+
# How many tokens from the first stage make one token in the second
|
298 |
+
# stage?
|
299 |
+
return self.backbone_hidden_size / self.hidden_size
|
300 |
+
|
301 |
+
def corpus_token_pad_size(self, n_tokens: int) -> int:
|
302 |
+
return self.hidden_size % self.backbone_hidden_size
|
303 |
+
|
304 |
+
def _shift_rotary_embedding(self) -> None:
|
305 |
+
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
|
306 |
+
# TODO: Can we do this for LLAMA?
|
307 |
+
print("Warning: Positional embedding disabling not implemented for LLAMA.")
|
308 |
+
|
309 |
+
def forward(
|
310 |
+
self,
|
311 |
+
input_ids: torch.Tensor,
|
312 |
+
attention_mask: torch.Tensor,
|
313 |
+
dataset_embeddings: torch.Tensor,
|
314 |
+
output_hidden_states: bool = False,
|
315 |
+
null_dataset_embedding: bool = False,
|
316 |
+
) -> torch.Tensor:
|
317 |
+
soft_prompt = self._prepare_dataset_embeddings(
|
318 |
+
input_ids=input_ids,
|
319 |
+
dataset_embeddings=dataset_embeddings,
|
320 |
+
null_dataset_embedding=null_dataset_embedding,
|
321 |
+
)
|
322 |
+
|
323 |
+
# Reshape for this model.
|
324 |
+
# print("[DatasetConditionedAutoregressive] 1 -> soft_prompt.shape =", soft_prompt.shape)
|
325 |
+
num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item()
|
326 |
+
soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements))
|
327 |
+
num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size)
|
328 |
+
padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device)
|
329 |
+
soft_prompt = torch.cat((soft_prompt, padding), dim=1)
|
330 |
+
soft_prompt = soft_prompt.reshape(
|
331 |
+
(soft_prompt.shape[0], -1, self.backbone_hidden_size)
|
332 |
+
)
|
333 |
+
soft_prompt = self.input_ln(soft_prompt)
|
334 |
+
# print("[DatasetConditionedAutoregressive] 2 -> soft_prompt.shape =", soft_prompt.shape)
|
335 |
+
|
336 |
+
backbone_attention_mask = torch.ones(
|
337 |
+
soft_prompt.shape[0:2],
|
338 |
+
dtype=torch.long,
|
339 |
+
device=soft_prompt.device,
|
340 |
+
)
|
341 |
+
token_embeddings = self.backbone.get_input_embeddings()
|
342 |
+
inputs_embeds = token_embeddings(input_ids) # (b, s) -> (b, s, d)
|
343 |
+
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
|
344 |
+
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
345 |
+
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
|
346 |
+
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
347 |
+
# print("[3.b] attention_mask.shape =", attention_mask.shape)
|
348 |
+
|
349 |
+
output = self.backbone(
|
350 |
+
inputs_embeds=inputs_embeds,
|
351 |
+
attention_mask=input_attention_mask,
|
352 |
+
output_hidden_states=True,
|
353 |
+
) # (1, 4 + b + s, d)
|
354 |
+
# trim soft prompt
|
355 |
+
last_hidden_state = output.hidden_states[-1]
|
356 |
+
n_soft_prompt_tokens = soft_prompt.shape[1]
|
357 |
+
|
358 |
+
output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :]
|
359 |
+
output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:]
|
360 |
+
|
361 |
+
# Take last token position
|
362 |
+
if vars(self.config).get("pooling_strategy") == "last_token":
|
363 |
+
output_pooled = last_token_pool(output_vectors, output_attention_mask)
|
364 |
+
elif vars(self.config).get("pooling_strategy") == "mean":
|
365 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
366 |
+
else:
|
367 |
+
output_pooled = mean_pool_weighted(output_vectors, output_attention_mask)
|
368 |
+
|
369 |
+
# average with original vectors
|
370 |
+
# TODO: Argparse for pooling strategy.
|
371 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
372 |
+
|
373 |
+
if output_hidden_states:
|
374 |
+
return {
|
375 |
+
"hidden_states": output_vectors,
|
376 |
+
"pooled": output,
|
377 |
+
}
|
378 |
+
else:
|
379 |
+
return output
|
380 |
+
|
381 |
+
|
382 |
+
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
383 |
+
def __init__(
|
384 |
+
self,
|
385 |
+
config,
|
386 |
+
dataset_backbone: transformers.PreTrainedModel,
|
387 |
+
):
|
388 |
+
super().__init__(config=config)
|
389 |
+
self.backbone = dataset_backbone
|
390 |
+
self.hidden_size = self.backbone.config.hidden_size
|
391 |
+
self.hidden_size = dataset_backbone.config.hidden_size
|
392 |
+
# self.input_ln = torch.nn.LayerNorm(
|
393 |
+
# self.hidden_size,
|
394 |
+
# eps=self.backbone.config.layer_norm_epsilon
|
395 |
+
# )
|
396 |
+
self.contextual_init()
|
397 |
+
self._shift_rotary_embedding()
|
398 |
+
|
399 |
+
@property
|
400 |
+
def num_corpus_tokens(self) -> int:
|
401 |
+
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
|
402 |
+
|
403 |
+
def _shift_rotary_embedding(self) -> None:
|
404 |
+
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
|
405 |
+
if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding:
|
406 |
+
# We only want to apply positional embeddings to the
|
407 |
+
# *text* portion of the backbone network.
|
408 |
+
self.backbone.config.rotary_start_pos = 0.0
|
409 |
+
rotary_disabled = 0
|
410 |
+
|
411 |
+
rotary_start_pos = self.num_corpus_tokens
|
412 |
+
for module in self.backbone.modules():
|
413 |
+
if hasattr(module, "rotary_emb_dim"):
|
414 |
+
module.rotary_start_pos = rotary_start_pos
|
415 |
+
rotary_disabled += 1
|
416 |
+
print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}")
|
417 |
+
|
418 |
+
def forward(
|
419 |
+
self,
|
420 |
+
input_ids: torch.Tensor,
|
421 |
+
attention_mask: torch.Tensor,
|
422 |
+
dataset_embeddings: torch.Tensor,
|
423 |
+
output_hidden_states: bool = False,
|
424 |
+
null_dataset_embedding: bool = False,
|
425 |
+
) -> torch.Tensor:
|
426 |
+
# print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape)
|
427 |
+
soft_prompt = self._prepare_dataset_embeddings(
|
428 |
+
input_ids=input_ids,
|
429 |
+
dataset_embeddings=dataset_embeddings,
|
430 |
+
null_dataset_embedding=null_dataset_embedding,
|
431 |
+
)
|
432 |
+
# print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}")
|
433 |
+
backbone_attention_mask = torch.ones(
|
434 |
+
soft_prompt.shape[0:2],
|
435 |
+
dtype=torch.long,
|
436 |
+
device=soft_prompt.device,
|
437 |
+
)
|
438 |
+
inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
|
439 |
+
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
|
440 |
+
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
441 |
+
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
|
442 |
+
attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
443 |
+
# print("[3.b] attention_mask.shape =", attention_mask.shape)
|
444 |
+
output = self.backbone(
|
445 |
+
inputs_embeds=inputs_embeds,
|
446 |
+
attention_mask=attention_mask,
|
447 |
+
) # (1, 4 + b + s, d)
|
448 |
+
# trim soft prompt
|
449 |
+
output_vectors = output.last_hidden_state
|
450 |
+
|
451 |
+
# use only these tokens
|
452 |
+
n_soft_prompt_tokens = soft_prompt.shape[1]
|
453 |
+
# print("n_soft_prompt_tokens =", n_soft_prompt_tokens)
|
454 |
+
|
455 |
+
output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :]
|
456 |
+
output_attention_mask = attention_mask[:, n_soft_prompt_tokens:]
|
457 |
+
|
458 |
+
# print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape)
|
459 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
460 |
+
|
461 |
+
# average with original vectors
|
462 |
+
# TODO: Argparse for pooling strategy.
|
463 |
+
# output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d)
|
464 |
+
# print("output_pooled.shape =", output_pooled.shape)
|
465 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
466 |
+
|
467 |
+
# print("returning output.shape =", output.shape)
|
468 |
+
|
469 |
+
if output_hidden_states:
|
470 |
+
return {
|
471 |
+
"hidden_states": output_vectors,
|
472 |
+
"pooled": output,
|
473 |
+
}
|
474 |
+
else:
|
475 |
+
return output
|
476 |
+
|
477 |
+
|
478 |
+
class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
479 |
+
def __init__(
|
480 |
+
self,
|
481 |
+
config, #: transformers.PreTrainedConfig,
|
482 |
+
embedder: transformers.PreTrainedModel,
|
483 |
+
):
|
484 |
+
super().__init__(config=config)
|
485 |
+
self.embedder = embedder
|
486 |
+
self.hidden_size = self.embedder.config.hidden_size
|
487 |
+
self.contextual_init()
|
488 |
+
|
489 |
+
def forward(
|
490 |
+
self,
|
491 |
+
input_ids: torch.Tensor,
|
492 |
+
attention_mask: torch.Tensor,
|
493 |
+
dataset_input_ids: torch.Tensor,
|
494 |
+
dataset_attention_mask: torch.Tensor,
|
495 |
+
output_hidden_states: bool = False,
|
496 |
+
) -> torch.Tensor:
|
497 |
+
R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device)
|
498 |
+
|
499 |
+
dataset_input_ids = dataset_input_ids[R]
|
500 |
+
input_ids = torch.cat((dataset_input_ids, input_ids), dim=1)
|
501 |
+
|
502 |
+
dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device)
|
503 |
+
input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1)
|
504 |
+
output_attention_mask = torch.cat(
|
505 |
+
(torch.zeros_like(dataset_input_ids), attention_mask), dim=1
|
506 |
+
)
|
507 |
+
|
508 |
+
output = self.embedder(
|
509 |
+
input_ids=input_ids,
|
510 |
+
attention_mask=input_attention_mask,
|
511 |
+
)
|
512 |
+
|
513 |
+
output_vectors = output.last_hidden_state
|
514 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
515 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
516 |
+
|
517 |
+
if output_hidden_states:
|
518 |
+
S_d = dataset_attention_mask.shape[1]
|
519 |
+
output_vectors = output_vectors[:, S_d:, :]
|
520 |
+
return {
|
521 |
+
"hidden_states": output_vectors,
|
522 |
+
"pooled": output,
|
523 |
+
}
|
524 |
+
else:
|
525 |
+
return output
|
526 |
+
|
527 |
+
|
528 |
+
class DatasetTransformer(transformers.PreTrainedModel):
|
529 |
+
config_class = ContextualModelConfig
|
530 |
+
embedder: transformers.PreTrainedModel
|
531 |
+
dataset_backbone: transformers.PreTrainedModel
|
532 |
+
def __init__(
|
533 |
+
self,
|
534 |
+
config,
|
535 |
+
):
|
536 |
+
super().__init__(config=config)
|
537 |
+
dataset_backbone, _ = load_embedder_and_tokenizer(
|
538 |
+
vars(config).get("dataset_backbone", config.embedder)
|
539 |
+
)
|
540 |
+
|
541 |
+
if config.limit_layers:
|
542 |
+
print0(f"Limiting layers to {config.limit_layers}")
|
543 |
+
limit_layers(dataset_backbone, config.limit_layers)
|
544 |
+
|
545 |
+
biencoder_config = copy.deepcopy(config)
|
546 |
+
biencoder_config.embedding_output_dim = None
|
547 |
+
biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None)
|
548 |
+
self.first_stage_model = BiEncoder(
|
549 |
+
config=biencoder_config,
|
550 |
+
)
|
551 |
+
|
552 |
+
if vars(config).get("autoregressive_backbone", False):
|
553 |
+
self.second_stage_model = DatasetConditionedAutoregressive(
|
554 |
+
config=config,
|
555 |
+
dataset_backbone=dataset_backbone,
|
556 |
+
first_stage_hidden_size=self.first_stage_model.hidden_size,
|
557 |
+
)
|
558 |
+
else:
|
559 |
+
self.second_stage_model = DatasetConditionedBiencoder(
|
560 |
+
config=config,
|
561 |
+
dataset_backbone=dataset_backbone
|
562 |
+
)
|
563 |
+
|
564 |
+
self.temp = config.logit_scale
|
565 |
+
if config.disable_dropout:
|
566 |
+
disable_dropout(self)
|
567 |
+
|
568 |
+
transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False)
|
569 |
+
if transductive_tie_token_embeddings:
|
570 |
+
self.second_stage_model.backbone.embeddings.word_embeddings.weight = (
|
571 |
+
self.first_stage_model.embedder.embeddings.word_embeddings.weight
|
572 |
+
)
|
573 |
+
|
574 |
+
def forward(
|
575 |
+
self,
|
576 |
+
input_ids: torch.Tensor,
|
577 |
+
attention_mask: torch.Tensor,
|
578 |
+
dataset_input_ids: Optional[torch.Tensor],
|
579 |
+
dataset_attention_mask: Optional[torch.Tensor],
|
580 |
+
output_hidden_states: bool = False,
|
581 |
+
) -> torch.Tensor:
|
582 |
+
"""
|
583 |
+
input_ids (long torch.Tensor) – ids of input tokens
|
584 |
+
attention_mask (bool torch.Tensor)
|
585 |
+
"""
|
586 |
+
dataset_embeddings = self.first_stage_model(
|
587 |
+
input_ids=dataset_input_ids,
|
588 |
+
attention_mask=dataset_attention_mask
|
589 |
+
)
|
590 |
+
return self.second_stage_model(
|
591 |
+
input_ids=input_ids,
|
592 |
+
attention_mask=attention_mask,
|
593 |
+
dataset_embeddings=dataset_embeddings,
|
594 |
+
output_hidden_states=output_hidden_states,
|
595 |
+
)
|
596 |
+
|
597 |
+
|
598 |
+
|
599 |
+
def get_model_class(name: str):
|
600 |
+
if name in 'transductive':
|
601 |
+
return DatasetTransformer
|
602 |
+
elif name == 'biencoder':
|
603 |
+
return BiEncoder
|
604 |
+
elif name == "dataset_prefix_biencoder":
|
605 |
+
return DatasetPrefixBiencoder
|
606 |
+
else:
|
607 |
+
raise ValueError(f'unknown model cls {name}')
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6ec79407ada665817aebe929bdabbe83eecd816b75f7f26e3bdd8b4c092efb2a
|
3 |
+
size 1124594680
|