jxm commited on
Commit
7d508f1
1 Parent(s): 3792612

Upload DatasetTransformer

Browse files
Files changed (5) hide show
  1. README.md +369 -369
  2. config.json +7 -1
  3. misc.py +514 -0
  4. model.py +607 -0
  5. model.safetensors +2 -2
README.md CHANGED
@@ -4,12 +4,14 @@ tags:
4
  model-index:
5
  - name: cde-small-v1
6
  results:
7
- - dataset:
8
- config: en
 
9
  name: MTEB AmazonCounterfactualClassification (en)
10
- revision: e8379541af4e31359cca9fbcf4b00f2671dba205
11
- split: test
12
  type: mteb/amazon_counterfactual
 
 
 
13
  metrics:
14
  - type: accuracy
15
  value: 87.01492537313433
@@ -23,14 +25,14 @@ model-index:
23
  value: 87.74802754480477
24
  - type: main_score
25
  value: 87.01492537313433
26
- task:
27
  type: Classification
28
- - dataset:
29
- config: default
30
  name: MTEB AmazonPolarityClassification (default)
31
- revision: e2d317d38cd51312af73b3d32a06d1a08b442046
32
- split: test
33
  type: mteb/amazon_polarity
 
 
 
34
  metrics:
35
  - type: accuracy
36
  value: 94.652275
@@ -44,14 +46,14 @@ model-index:
44
  value: 94.64655930708355
45
  - type: main_score
46
  value: 94.652275
47
- task:
48
  type: Classification
49
- - dataset:
50
- config: en
51
  name: MTEB AmazonReviewsClassification (en)
52
- revision: 1399c76144fd37290681b995c656ef9b2e06e26d
53
- split: test
54
  type: mteb/amazon_reviews_multi
 
 
 
55
  metrics:
56
  - type: accuracy
57
  value: 55.75599999999999
@@ -61,14 +63,14 @@ model-index:
61
  value: 55.07058630829347
62
  - type: main_score
63
  value: 55.75599999999999
64
- task:
65
- type: Classification
66
- - dataset:
67
- config: default
68
  name: MTEB ArguAna (default)
69
- revision: c22ab2a51041ffd869aaddef7af8d8215647e41a
70
- split: test
71
  type: mteb/arguana
 
 
 
72
  metrics:
73
  - type: main_score
74
  value: 69.959
@@ -352,14 +354,14 @@ model-index:
352
  value: 74.182
353
  - type: recall_at_5
354
  value: 84.495
355
- task:
356
- type: Retrieval
357
- - dataset:
358
- config: default
359
  name: MTEB ArxivClusteringP2P (default)
360
- revision: a122ad7f3f0291bf49cc6f4d32aa80929df69d5d
361
- split: test
362
  type: mteb/arxiv-clustering-p2p
 
 
 
363
  metrics:
364
  - type: main_score
365
  value: 48.54672141116669
@@ -367,14 +369,14 @@ model-index:
367
  value: 48.54672141116669
368
  - type: v_measure_std
369
  value: 14.037498386768362
370
- task:
371
  type: Clustering
372
- - dataset:
373
- config: default
374
  name: MTEB ArxivClusteringS2S (default)
375
- revision: f910caf1a6075f7329cdf8c1a6135696f37dbd53
376
- split: test
377
  type: mteb/arxiv-clustering-s2s
 
 
 
378
  metrics:
379
  - type: main_score
380
  value: 40.5914039166466
@@ -382,14 +384,14 @@ model-index:
382
  value: 40.5914039166466
383
  - type: v_measure_std
384
  value: 14.385069818910331
385
- task:
386
- type: Clustering
387
- - dataset:
388
- config: default
389
  name: MTEB AskUbuntuDupQuestions (default)
390
- revision: 2000358ca161889fa9c082cb41daa8dcfb161a54
391
- split: test
392
  type: mteb/askubuntudupquestions-reranking
 
 
 
393
  metrics:
394
  - type: main_score
395
  value: 61.13621260261507
@@ -409,14 +411,14 @@ model-index:
409
  value: 31.484257486448364
410
  - type: nAUC_mrr_std
411
  value: 21.252659250011632
412
- task:
413
- type: Reranking
414
- - dataset:
415
- config: default
416
  name: MTEB BIOSSES (default)
417
- revision: d3fb88f8f02e40887cd149695127462bbcf29b4a
418
- split: test
419
  type: mteb/biosses-sts
 
 
 
420
  metrics:
421
  - type: cosine_pearson
422
  value: 89.07028016646942
@@ -436,14 +438,14 @@ model-index:
436
  value: 89.07028016646942
437
  - type: spearman
438
  value: 86.69595132967805
439
- task:
440
- type: STS
441
- - dataset:
442
- config: default
443
  name: MTEB Banking77Classification (default)
444
- revision: 0fd18e25b25c072e09e0d92ab615fda904d66300
445
- split: test
446
  type: mteb/banking77
 
 
 
447
  metrics:
448
  - type: accuracy
449
  value: 88.6038961038961
@@ -453,14 +455,14 @@ model-index:
453
  value: 88.56824205739822
454
  - type: main_score
455
  value: 88.6038961038961
456
- task:
457
- type: Classification
458
- - dataset:
459
- config: default
460
  name: MTEB BiorxivClusteringP2P (default)
461
- revision: 65b79d1d13f80053f67aca9498d9402c2d9f1f40
462
- split: test
463
  type: mteb/biorxiv-clustering-p2p
 
 
 
464
  metrics:
465
  - type: main_score
466
  value: 44.77800814327256
@@ -468,14 +470,14 @@ model-index:
468
  value: 44.77800814327256
469
  - type: v_measure_std
470
  value: 0.6462535527471919
471
- task:
472
  type: Clustering
473
- - dataset:
474
- config: default
475
  name: MTEB BiorxivClusteringS2S (default)
476
- revision: 258694dd0231531bc1fd9de6ceb52a0853c6d908
477
- split: test
478
  type: mteb/biorxiv-clustering-s2s
 
 
 
479
  metrics:
480
  - type: main_score
481
  value: 38.16110272459102
@@ -483,14 +485,14 @@ model-index:
483
  value: 38.16110272459102
484
  - type: v_measure_std
485
  value: 0.7456916212435019
486
- task:
487
- type: Clustering
488
- - dataset:
489
- config: default
490
  name: MTEB CQADupstackAndroidRetrieval (default)
491
- revision: f46a197baaae43b4f621051089b82a364682dfeb
492
- split: test
493
  type: mteb/cqadupstack-android
 
 
 
494
  metrics:
495
  - type: main_score
496
  value: 49.376
@@ -774,14 +776,14 @@ model-index:
774
  value: 47.591
775
  - type: recall_at_5
776
  value: 54.245
777
- task:
778
  type: Retrieval
779
- - dataset:
780
- config: default
781
  name: MTEB CQADupstackEnglishRetrieval (default)
782
- revision: ad9991cb51e31e31e430383c75ffb2885547b5f0
783
- split: test
784
  type: mteb/cqadupstack-english
 
 
 
785
  metrics:
786
  - type: main_score
787
  value: 44.727
@@ -1065,14 +1067,14 @@ model-index:
1065
  value: 42.085
1066
  - type: recall_at_5
1067
  value: 47.5
1068
- task:
1069
  type: Retrieval
1070
- - dataset:
1071
- config: default
1072
  name: MTEB CQADupstackGamingRetrieval (default)
1073
- revision: 4885aa143210c98657558c04aaf3dc47cfb54340
1074
- split: test
1075
  type: mteb/cqadupstack-gaming
 
 
 
1076
  metrics:
1077
  - type: main_score
1078
  value: 59.001999999999995
@@ -1356,14 +1358,14 @@ model-index:
1356
  value: 57.916000000000004
1357
  - type: recall_at_5
1358
  value: 65.44
1359
- task:
1360
  type: Retrieval
1361
- - dataset:
1362
- config: default
1363
  name: MTEB CQADupstackGisRetrieval (default)
1364
- revision: 5003b3064772da1887988e05400cf3806fe491f2
1365
- split: test
1366
  type: mteb/cqadupstack-gis
 
 
 
1367
  metrics:
1368
  - type: main_score
1369
  value: 37.501
@@ -1647,14 +1649,14 @@ model-index:
1647
  value: 37.218
1648
  - type: recall_at_5
1649
  value: 42.559000000000005
1650
- task:
1651
  type: Retrieval
1652
- - dataset:
1653
- config: default
1654
  name: MTEB CQADupstackMathematicaRetrieval (default)
1655
- revision: 90fceea13679c63fe563ded68f3b6f06e50061de
1656
- split: test
1657
  type: mteb/cqadupstack-mathematica
 
 
 
1658
  metrics:
1659
  - type: main_score
1660
  value: 27.653
@@ -1938,14 +1940,14 @@ model-index:
1938
  value: 25.469
1939
  - type: recall_at_5
1940
  value: 31.316
1941
- task:
1942
  type: Retrieval
1943
- - dataset:
1944
- config: default
1945
  name: MTEB CQADupstackPhysicsRetrieval (default)
1946
- revision: 79531abbd1fb92d06c6d6315a0cbbbf5bb247ea4
1947
- split: test
1948
  type: mteb/cqadupstack-physics
 
 
 
1949
  metrics:
1950
  - type: main_score
1951
  value: 45.314
@@ -2229,14 +2231,14 @@ model-index:
2229
  value: 43.679
2230
  - type: recall_at_5
2231
  value: 49.735
2232
- task:
2233
  type: Retrieval
2234
- - dataset:
2235
- config: default
2236
  name: MTEB CQADupstackProgrammersRetrieval (default)
2237
- revision: 6184bc1440d2dbc7612be22b50686b8826d22b32
2238
- split: test
2239
  type: mteb/cqadupstack-programmers
 
 
 
2240
  metrics:
2241
  - type: main_score
2242
  value: 41.972
@@ -2520,27 +2522,27 @@ model-index:
2520
  value: 39.363
2521
  - type: recall_at_5
2522
  value: 44.665
2523
- task:
2524
  type: Retrieval
2525
- - dataset:
2526
- config: default
2527
  name: MTEB CQADupstackRetrieval (default)
2528
- revision: CQADupstackRetrieval_is_a_combined_dataset
2529
- split: test
2530
  type: CQADupstackRetrieval_is_a_combined_dataset
 
 
 
2531
  metrics:
2532
  - type: main_score
2533
  value: 39.823499999999996
2534
  - type: ndcg_at_10
2535
  value: 39.823499999999996
2536
- task:
2537
  type: Retrieval
2538
- - dataset:
2539
- config: default
2540
  name: MTEB CQADupstackStatsRetrieval (default)
2541
- revision: 65ac3a16b8e91f9cee4c9828cc7c335575432a2a
2542
- split: test
2543
  type: mteb/cqadupstack-stats
 
 
 
2544
  metrics:
2545
  - type: main_score
2546
  value: 34.943000000000005
@@ -2824,14 +2826,14 @@ model-index:
2824
  value: 33.427
2825
  - type: recall_at_5
2826
  value: 37.643
2827
- task:
2828
  type: Retrieval
2829
- - dataset:
2830
- config: default
2831
  name: MTEB CQADupstackTexRetrieval (default)
2832
- revision: 46989137a86843e03a6195de44b09deda022eec7
2833
- split: test
2834
  type: mteb/cqadupstack-tex
 
 
 
2835
  metrics:
2836
  - type: main_score
2837
  value: 27.271
@@ -3115,14 +3117,14 @@ model-index:
3115
  value: 25.592
3116
  - type: recall_at_5
3117
  value: 30.279
3118
- task:
3119
  type: Retrieval
3120
- - dataset:
3121
- config: default
3122
  name: MTEB CQADupstackUnixRetrieval (default)
3123
- revision: 6c6430d3a6d36f8d2a829195bc5dc94d7e063e53
3124
- split: test
3125
  type: mteb/cqadupstack-unix
 
 
 
3126
  metrics:
3127
  - type: main_score
3128
  value: 38.237
@@ -3406,14 +3408,14 @@ model-index:
3406
  value: 36.275
3407
  - type: recall_at_5
3408
  value: 42.199
3409
- task:
3410
  type: Retrieval
3411
- - dataset:
3412
- config: default
3413
  name: MTEB CQADupstackWebmastersRetrieval (default)
3414
- revision: 160c094312a0e1facb97e55eeddb698c0abe3571
3415
- split: test
3416
  type: mteb/cqadupstack-webmasters
 
 
 
3417
  metrics:
3418
  - type: main_score
3419
  value: 38.702
@@ -3697,14 +3699,14 @@ model-index:
3697
  value: 37.634
3698
  - type: recall_at_5
3699
  value: 42.021
3700
- task:
3701
  type: Retrieval
3702
- - dataset:
3703
- config: default
3704
  name: MTEB CQADupstackWordpressRetrieval (default)
3705
- revision: 4ffe81d471b1924886b33c7567bfb200e9eec5c4
3706
- split: test
3707
  type: mteb/cqadupstack-wordpress
 
 
 
3708
  metrics:
3709
  - type: main_score
3710
  value: 33.184000000000005
@@ -3988,14 +3990,14 @@ model-index:
3988
  value: 32.683
3989
  - type: recall_at_5
3990
  value: 36.756
3991
- task:
3992
  type: Retrieval
3993
- - dataset:
3994
- config: default
3995
  name: MTEB ClimateFEVER (default)
3996
- revision: 47f2ac6acb640fc46020b02a5b59fdda04d39380
3997
- split: test
3998
  type: mteb/climate-fever
 
 
 
3999
  metrics:
4000
  - type: main_score
4001
  value: 25.068
@@ -4279,14 +4281,14 @@ model-index:
4279
  value: 18.312
4280
  - type: recall_at_5
4281
  value: 22.776
4282
- task:
4283
  type: Retrieval
4284
- - dataset:
4285
- config: default
4286
  name: MTEB DBPedia (default)
4287
- revision: c0f706b76e590d620bd6618b3ca8efdd34e2d659
4288
- split: test
4289
  type: mteb/dbpedia
 
 
 
4290
  metrics:
4291
  - type: main_score
4292
  value: 40.128
@@ -4570,14 +4572,14 @@ model-index:
4570
  value: 14.562
4571
  - type: recall_at_5
4572
  value: 18.779
4573
- task:
4574
- type: Retrieval
4575
- - dataset:
4576
- config: default
4577
  name: MTEB EmotionClassification (default)
4578
- revision: 4f58c6b202a23cf9a4da393831edf4f9183cad37
4579
- split: test
4580
  type: mteb/emotion
 
 
 
4581
  metrics:
4582
  - type: accuracy
4583
  value: 74.86
@@ -4587,14 +4589,14 @@ model-index:
4587
  value: 75.96499621761998
4588
  - type: main_score
4589
  value: 74.86
4590
- task:
4591
- type: Classification
4592
- - dataset:
4593
- config: default
4594
  name: MTEB FEVER (default)
4595
- revision: bea83ef9e8fb933d90a2f1d5515737465d613e12
4596
- split: test
4597
  type: mteb/fever
 
 
 
4598
  metrics:
4599
  - type: main_score
4600
  value: 86.029
@@ -4878,14 +4880,14 @@ model-index:
4878
  value: 88.382
4879
  - type: recall_at_5
4880
  value: 90.908
4881
- task:
4882
  type: Retrieval
4883
- - dataset:
4884
- config: default
4885
  name: MTEB FiQA2018 (default)
4886
- revision: 27a168819829fe9bcd655c2df245fb19452e8e06
4887
- split: test
4888
  type: mteb/fiqa
 
 
 
4889
  metrics:
4890
  - type: main_score
4891
  value: 45.238
@@ -5169,14 +5171,14 @@ model-index:
5169
  value: 37.656
5170
  - type: recall_at_5
5171
  value: 44.766
5172
- task:
5173
  type: Retrieval
5174
- - dataset:
5175
- config: default
5176
  name: MTEB HotpotQA (default)
5177
- revision: ab518f4d6fcca38d87c25209f94beba119d02014
5178
- split: test
5179
  type: mteb/hotpotqa
 
 
 
5180
  metrics:
5181
  - type: main_score
5182
  value: 66.672
@@ -5460,14 +5462,14 @@ model-index:
5460
  value: 57.522
5461
  - type: recall_at_5
5462
  value: 62.134
5463
- task:
5464
- type: Retrieval
5465
- - dataset:
5466
- config: default
5467
  name: MTEB ImdbClassification (default)
5468
- revision: 3d86128a09e091d6018b6d26cad27f2739fc2db7
5469
- split: test
5470
  type: mteb/imdb
 
 
 
5471
  metrics:
5472
  - type: accuracy
5473
  value: 93.5944
@@ -5481,14 +5483,14 @@ model-index:
5481
  value: 93.58945949328377
5482
  - type: main_score
5483
  value: 93.5944
5484
- task:
5485
- type: Classification
5486
- - dataset:
5487
- config: default
5488
  name: MTEB MSMARCO (default)
5489
- revision: c5a29a104738b98a9e76336939199e264163d4a0
5490
- split: dev
5491
  type: mteb/msmarco
 
 
 
5492
  metrics:
5493
  - type: main_score
5494
  value: 41.448
@@ -5772,14 +5774,14 @@ model-index:
5772
  value: 41.304
5773
  - type: recall_at_5
5774
  value: 51.076
5775
- task:
5776
- type: Retrieval
5777
- - dataset:
5778
- config: en
5779
  name: MTEB MTOPDomainClassification (en)
5780
- revision: d80d48c1eb48d3562165c59d59d0034df9fff0bf
5781
- split: test
5782
  type: mteb/mtop_domain
 
 
 
5783
  metrics:
5784
  - type: accuracy
5785
  value: 96.03967168262655
@@ -5789,14 +5791,14 @@ model-index:
5789
  value: 96.06623245823347
5790
  - type: main_score
5791
  value: 96.03967168262655
5792
- task:
5793
  type: Classification
5794
- - dataset:
5795
- config: en
5796
  name: MTEB MTOPIntentClassification (en)
5797
- revision: ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba
5798
- split: test
5799
  type: mteb/mtop_intent
 
 
 
5800
  metrics:
5801
  - type: accuracy
5802
  value: 89.12904696762428
@@ -5806,14 +5808,14 @@ model-index:
5806
  value: 90.41290566743324
5807
  - type: main_score
5808
  value: 89.12904696762428
5809
- task:
5810
  type: Classification
5811
- - dataset:
5812
- config: en
5813
  name: MTEB MassiveIntentClassification (en)
5814
- revision: 4672e20407010da34463acc759c162ca9734bca6
5815
- split: test
5816
  type: mteb/amazon_massive_intent
 
 
 
5817
  metrics:
5818
  - type: accuracy
5819
  value: 76.49630127774041
@@ -5823,14 +5825,14 @@ model-index:
5823
  value: 76.42436195016484
5824
  - type: main_score
5825
  value: 76.49630127774041
5826
- task:
5827
  type: Classification
5828
- - dataset:
5829
- config: en
5830
  name: MTEB MassiveScenarioClassification (en)
5831
- revision: fad2c6e8459f9e1c45d9315f4953d921437d70f8
5832
- split: test
5833
  type: mteb/amazon_massive_scenario
 
 
 
5834
  metrics:
5835
  - type: accuracy
5836
  value: 78.9340954942838
@@ -5840,14 +5842,14 @@ model-index:
5840
  value: 78.87787647838971
5841
  - type: main_score
5842
  value: 78.9340954942838
5843
- task:
5844
- type: Classification
5845
- - dataset:
5846
- config: default
5847
  name: MTEB MedrxivClusteringP2P (default)
5848
- revision: e7a26af6f3ae46b30dde8737f02c07b1505bcc73
5849
- split: test
5850
  type: mteb/medrxiv-clustering-p2p
 
 
 
5851
  metrics:
5852
  - type: main_score
5853
  value: 37.50182848656019
@@ -5855,14 +5857,14 @@ model-index:
5855
  value: 37.50182848656019
5856
  - type: v_measure_std
5857
  value: 1.1708518023877268
5858
- task:
5859
  type: Clustering
5860
- - dataset:
5861
- config: default
5862
  name: MTEB MedrxivClusteringS2S (default)
5863
- revision: 35191c8c0dca72d8ff3efcd72aa802307d469663
5864
- split: test
5865
  type: mteb/medrxiv-clustering-s2s
 
 
 
5866
  metrics:
5867
  - type: main_score
5868
  value: 35.72762609825363
@@ -5870,14 +5872,14 @@ model-index:
5870
  value: 35.72762609825363
5871
  - type: v_measure_std
5872
  value: 1.4555014772914985
5873
- task:
5874
- type: Clustering
5875
- - dataset:
5876
- config: default
5877
  name: MTEB MindSmallReranking (default)
5878
- revision: 59042f120c80e8afa9cdbb224f67076cec0fc9a7
5879
- split: test
5880
  type: mteb/mind_small
 
 
 
5881
  metrics:
5882
  - type: main_score
5883
  value: 30.47716416454022
@@ -5897,14 +5899,14 @@ model-index:
5897
  value: -15.78941850629242
5898
  - type: nAUC_mrr_std
5899
  value: -1.1330442292510805
5900
- task:
5901
- type: Reranking
5902
- - dataset:
5903
- config: default
5904
  name: MTEB NFCorpus (default)
5905
- revision: ec0fa4fe99da2ff19ca1214b7966684033a58814
5906
- split: test
5907
  type: mteb/nfcorpus
 
 
 
5908
  metrics:
5909
  - type: main_score
5910
  value: 34.648
@@ -6188,14 +6190,14 @@ model-index:
6188
  value: 10.037
6189
  - type: recall_at_5
6190
  value: 12.717999999999998
6191
- task:
6192
  type: Retrieval
6193
- - dataset:
6194
- config: default
6195
  name: MTEB NQ (default)
6196
- revision: b774495ed302d8c44a3a7ea25c90dbce03968f31
6197
- split: test
6198
  type: mteb/nq
 
 
 
6199
  metrics:
6200
  - type: main_score
6201
  value: 60.06
@@ -6479,14 +6481,14 @@ model-index:
6479
  value: 61.114000000000004
6480
  - type: recall_at_5
6481
  value: 69.812
6482
- task:
6483
  type: Retrieval
6484
- - dataset:
6485
- config: default
6486
  name: MTEB QuoraRetrieval (default)
6487
- revision: e4e08e0b7dbe3c8700f0daef558ff32256715259
6488
- split: test
6489
  type: mteb/quora
 
 
 
6490
  metrics:
6491
  - type: main_score
6492
  value: 89.821
@@ -6770,14 +6772,14 @@ model-index:
6770
  value: 88.714
6771
  - type: recall_at_5
6772
  value: 92.96799999999999
6773
- task:
6774
- type: Retrieval
6775
- - dataset:
6776
- config: default
6777
  name: MTEB RedditClustering (default)
6778
- revision: 24640382cdbf8abc73003fb0fa6d111a705499eb
6779
- split: test
6780
  type: mteb/reddit-clustering
 
 
 
6781
  metrics:
6782
  - type: main_score
6783
  value: 59.36038828851887
@@ -6785,14 +6787,14 @@ model-index:
6785
  value: 59.36038828851887
6786
  - type: v_measure_std
6787
  value: 4.1958765965154425
6788
- task:
6789
  type: Clustering
6790
- - dataset:
6791
- config: default
6792
  name: MTEB RedditClusteringP2P (default)
6793
- revision: 385e3cb46b4cfa89021f56c4380204149d0efe33
6794
- split: test
6795
  type: mteb/reddit-clustering-p2p
 
 
 
6796
  metrics:
6797
  - type: main_score
6798
  value: 64.67522832408089
@@ -6800,14 +6802,14 @@ model-index:
6800
  value: 64.67522832408089
6801
  - type: v_measure_std
6802
  value: 12.473765016158698
6803
- task:
6804
- type: Clustering
6805
- - dataset:
6806
- config: default
6807
  name: MTEB SCIDOCS (default)
6808
- revision: f8c2fcf00f625baaa80f62ec5bd9e1fff3b8ae88
6809
- split: test
6810
  type: mteb/scidocs
 
 
 
6811
  metrics:
6812
  - type: main_score
6813
  value: 21.751
@@ -7091,14 +7093,14 @@ model-index:
7091
  value: 11.648
7092
  - type: recall_at_5
7093
  value: 15.883
7094
- task:
7095
- type: Retrieval
7096
- - dataset:
7097
- config: default
7098
  name: MTEB SICK-R (default)
7099
- revision: 20a6d6f312dd54037fe07a32d58e5e168867909d
7100
- split: test
7101
  type: mteb/sickr-sts
 
 
 
7102
  metrics:
7103
  - type: cosine_pearson
7104
  value: 84.0161170579997
@@ -7118,14 +7120,14 @@ model-index:
7118
  value: 84.0161170579997
7119
  - type: spearman
7120
  value: 77.52025923874551
7121
- task:
7122
  type: STS
7123
- - dataset:
7124
- config: default
7125
  name: MTEB STS12 (default)
7126
- revision: a0d554a64d88156834ff5ae9920b964011b16384
7127
- split: test
7128
  type: mteb/sts12-sts
 
 
 
7129
  metrics:
7130
  - type: cosine_pearson
7131
  value: 81.32328780209225
@@ -7145,14 +7147,14 @@ model-index:
7145
  value: 81.32328780209225
7146
  - type: spearman
7147
  value: 74.17570679745272
7148
- task:
7149
  type: STS
7150
- - dataset:
7151
- config: default
7152
  name: MTEB STS13 (default)
7153
- revision: 7e90230a92c190f1bf69ae9002b8cea547a64cca
7154
- split: test
7155
  type: mteb/sts13-sts
 
 
 
7156
  metrics:
7157
  - type: cosine_pearson
7158
  value: 85.53224141249392
@@ -7172,14 +7174,14 @@ model-index:
7172
  value: 85.53224141249392
7173
  - type: spearman
7174
  value: 86.16981525069227
7175
- task:
7176
  type: STS
7177
- - dataset:
7178
- config: default
7179
  name: MTEB STS14 (default)
7180
- revision: 6031580fec1f6af667f0bd2da0a551cf4f0b2375
7181
- split: test
7182
  type: mteb/sts14-sts
 
 
 
7183
  metrics:
7184
  - type: cosine_pearson
7185
  value: 82.234064045301
@@ -7199,14 +7201,14 @@ model-index:
7199
  value: 82.234064045301
7200
  - type: spearman
7201
  value: 78.86920830792957
7202
- task:
7203
  type: STS
7204
- - dataset:
7205
- config: default
7206
  name: MTEB STS15 (default)
7207
- revision: ae752c7c21bf194d8b67fd573edf7ae58183cbe3
7208
- split: test
7209
  type: mteb/sts15-sts
 
 
 
7210
  metrics:
7211
  - type: cosine_pearson
7212
  value: 86.23114543080261
@@ -7226,14 +7228,14 @@ model-index:
7226
  value: 86.23114543080261
7227
  - type: spearman
7228
  value: 87.481042891123
7229
- task:
7230
  type: STS
7231
- - dataset:
7232
- config: default
7233
  name: MTEB STS16 (default)
7234
- revision: 4d8694f8f0e0100860b497b999b3dbed754a0513
7235
- split: test
7236
  type: mteb/sts16-sts
 
 
 
7237
  metrics:
7238
  - type: cosine_pearson
7239
  value: 82.9156629047782
@@ -7253,14 +7255,14 @@ model-index:
7253
  value: 82.9156629047782
7254
  - type: spearman
7255
  value: 84.28381329207937
7256
- task:
7257
  type: STS
7258
- - dataset:
7259
- config: en-en
7260
  name: MTEB STS17 (en-en)
7261
- revision: faeb762787bd10488a50c8b5be4a3b82e411949c
7262
- split: test
7263
  type: mteb/sts17-crosslingual-sts
 
 
 
7264
  metrics:
7265
  - type: cosine_pearson
7266
  value: 88.91985349746744
@@ -7280,14 +7282,14 @@ model-index:
7280
  value: 88.91985349746744
7281
  - type: spearman
7282
  value: 89.69151633966257
7283
- task:
7284
  type: STS
7285
- - dataset:
7286
- config: en
7287
  name: MTEB STS22 (en)
7288
- revision: de9d86b3b84231dc21f76c7b7af1f28e2f57f6e3
7289
- split: test
7290
  type: mteb/sts22-crosslingual-sts
 
 
 
7291
  metrics:
7292
  - type: cosine_pearson
7293
  value: 65.0979772547511
@@ -7307,14 +7309,14 @@ model-index:
7307
  value: 65.0979772547511
7308
  - type: spearman
7309
  value: 65.78126527764236
7310
- task:
7311
  type: STS
7312
- - dataset:
7313
- config: default
7314
  name: MTEB STSBenchmark (default)
7315
- revision: b0fddb56ed78048fa8b90373c8a3cfc37b684831
7316
- split: test
7317
  type: mteb/stsbenchmark-sts
 
 
 
7318
  metrics:
7319
  - type: cosine_pearson
7320
  value: 85.6426635049971
@@ -7334,14 +7336,14 @@ model-index:
7334
  value: 85.6426635049971
7335
  - type: spearman
7336
  value: 85.609856578385
7337
- task:
7338
- type: STS
7339
- - dataset:
7340
- config: default
7341
  name: MTEB SciDocsRR (default)
7342
- revision: d3c5e1fc0b855ab6097bf1cda04dd73947d7caab
7343
- split: test
7344
  type: mteb/scidocs-reranking
 
 
 
7345
  metrics:
7346
  - type: main_score
7347
  value: 82.85163332499799
@@ -7361,14 +7363,14 @@ model-index:
7361
  value: 89.47202967481866
7362
  - type: nAUC_mrr_std
7363
  value: 85.40446996933892
7364
- task:
7365
- type: Reranking
7366
- - dataset:
7367
- config: default
7368
  name: MTEB SciFact (default)
7369
- revision: 0228b52cf27578f30900b9e5271d331663a030d7
7370
- split: test
7371
  type: mteb/scifact
 
 
 
7372
  metrics:
7373
  - type: main_score
7374
  value: 71.655
@@ -7652,14 +7654,14 @@ model-index:
7652
  value: 71.61699999999999
7653
  - type: recall_at_5
7654
  value: 78.361
7655
- task:
7656
- type: Retrieval
7657
- - dataset:
7658
- config: default
7659
  name: MTEB SprintDuplicateQuestions (default)
7660
- revision: d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46
7661
- split: test
7662
  type: mteb/sprintduplicatequestions-pairclassification
 
 
 
7663
  metrics:
7664
  - type: cosine_accuracy
7665
  value: 99.8019801980198
@@ -7743,14 +7745,14 @@ model-index:
7743
  value: 90.79754601226993
7744
  - type: similarity_recall
7745
  value: 88.8
7746
- task:
7747
- type: PairClassification
7748
- - dataset:
7749
- config: default
7750
  name: MTEB StackExchangeClustering (default)
7751
- revision: 6cbc1f7b2bc0622f2e39d2c77fa502909748c259
7752
- split: test
7753
  type: mteb/stackexchange-clustering
 
 
 
7754
  metrics:
7755
  - type: main_score
7756
  value: 66.63931197758824
@@ -7758,14 +7760,14 @@ model-index:
7758
  value: 66.63931197758824
7759
  - type: v_measure_std
7760
  value: 3.896206781511776
7761
- task:
7762
  type: Clustering
7763
- - dataset:
7764
- config: default
7765
  name: MTEB StackExchangeClusteringP2P (default)
7766
- revision: 815ca46b2622cec33ccafc3735d572c266efdb44
7767
- split: test
7768
  type: mteb/stackexchange-clustering-p2p
 
 
 
7769
  metrics:
7770
  - type: main_score
7771
  value: 38.984892653301884
@@ -7773,14 +7775,14 @@ model-index:
7773
  value: 38.984892653301884
7774
  - type: v_measure_std
7775
  value: 1.3308552162270453
7776
- task:
7777
- type: Clustering
7778
- - dataset:
7779
- config: default
7780
  name: MTEB StackOverflowDupQuestions (default)
7781
- revision: e185fbe320c72810689fc5848eb6114e1ef5ec69
7782
- split: test
7783
  type: mteb/stackoverflowdupquestions-reranking
 
 
 
7784
  metrics:
7785
  - type: main_score
7786
  value: 52.71499643455044
@@ -7800,14 +7802,14 @@ model-index:
7800
  value: 13.931448578334379
7801
  - type: nAUC_mrr_std
7802
  value: 10.441860004959661
7803
- task:
7804
- type: Reranking
7805
- - dataset:
7806
- config: default
7807
  name: MTEB SummEval (default)
7808
- revision: cda12ad7615edc362dbf25a00fdd61d3b1eaf93c
7809
- split: test
7810
  type: mteb/summeval
 
 
 
7811
  metrics:
7812
  - type: cosine_pearson
7813
  value: 31.5167525286909
@@ -7823,14 +7825,14 @@ model-index:
7823
  value: 31.5167525286909
7824
  - type: spearman
7825
  value: 31.218862970706496
7826
- task:
7827
- type: Summarization
7828
- - dataset:
7829
- config: default
7830
  name: MTEB TRECCOVID (default)
7831
- revision: bb9466bac8153a0349341eb1b22e06409e78ef4e
7832
- split: test
7833
  type: mteb/trec-covid
 
 
 
7834
  metrics:
7835
  - type: main_score
7836
  value: 78.996
@@ -8114,14 +8116,14 @@ model-index:
8114
  value: 0.705
8115
  - type: recall_at_5
8116
  value: 1.162
8117
- task:
8118
  type: Retrieval
8119
- - dataset:
8120
- config: default
8121
  name: MTEB Touche2020 (default)
8122
- revision: a34f9a33db75fa0cbb21bb5cfc3dae8dc8bec93f
8123
- split: test
8124
  type: mteb/touche2020
 
 
 
8125
  metrics:
8126
  - type: main_score
8127
  value: 24.234
@@ -8405,14 +8407,14 @@ model-index:
8405
  value: 6.625
8406
  - type: recall_at_5
8407
  value: 9.094
8408
- task:
8409
- type: Retrieval
8410
- - dataset:
8411
- config: default
8412
  name: MTEB ToxicConversationsClassification (default)
8413
- revision: edfaf9da55d3dd50d43143d90c1ac476895ae6de
8414
- split: test
8415
  type: mteb/toxic_conversations_50k
 
 
 
8416
  metrics:
8417
  - type: accuracy
8418
  value: 72.822265625
@@ -8426,14 +8428,14 @@ model-index:
8426
  value: 78.7454393727821
8427
  - type: main_score
8428
  value: 72.822265625
8429
- task:
8430
  type: Classification
8431
- - dataset:
8432
- config: default
8433
  name: MTEB TweetSentimentExtractionClassification (default)
8434
- revision: d604517c81ca91fe16a244d1248fc021f9ecee7a
8435
- split: test
8436
  type: mteb/tweet_sentiment_extraction
 
 
 
8437
  metrics:
8438
  - type: accuracy
8439
  value: 72.54385964912281
@@ -8443,14 +8445,14 @@ model-index:
8443
  value: 72.18022450339639
8444
  - type: main_score
8445
  value: 72.54385964912281
8446
- task:
8447
- type: Classification
8448
- - dataset:
8449
- config: default
8450
  name: MTEB TwentyNewsgroupsClustering (default)
8451
- revision: 6125ec4e24fa026cec8a478383ee943acfbd5449
8452
- split: test
8453
  type: mteb/twentynewsgroups-clustering
 
 
 
8454
  metrics:
8455
  - type: main_score
8456
  value: 57.41861450414374
@@ -8458,14 +8460,14 @@ model-index:
8458
  value: 57.41861450414374
8459
  - type: v_measure_std
8460
  value: 1.1732394227153524
8461
- task:
8462
- type: Clustering
8463
- - dataset:
8464
- config: default
8465
  name: MTEB TwitterSemEval2015 (default)
8466
- revision: 70970daeab8776df92f5ea462b6173c0b46fd2d1
8467
- split: test
8468
  type: mteb/twittersemeval2015-pairclassification
 
 
 
8469
  metrics:
8470
  - type: cosine_accuracy
8471
  value: 85.65893783155511
@@ -8549,14 +8551,14 @@ model-index:
8549
  value: 64.0855106888361
8550
  - type: similarity_recall
8551
  value: 71.18733509234828
8552
- task:
8553
  type: PairClassification
8554
- - dataset:
8555
- config: default
8556
  name: MTEB TwitterURLCorpus (default)
8557
- revision: 8b6510b0b1fa4e4c4f879467980e9be563ec1cdf
8558
- split: test
8559
  type: mteb/twitterurlcorpus-pairclassification
 
 
 
8560
  metrics:
8561
  - type: cosine_accuracy
8562
  value: 88.86754375751931
@@ -8640,8 +8642,6 @@ model-index:
8640
  value: 74.19310344827586
8641
  - type: similarity_recall
8642
  value: 82.83030489682784
8643
- task:
8644
- type: PairClassification
8645
  ---
8646
  # Contextual Document Embeddings (CDE)
8647
 
 
4
  model-index:
5
  - name: cde-small-v1
6
  results:
7
+ - task:
8
+ type: Classification
9
+ dataset:
10
  name: MTEB AmazonCounterfactualClassification (en)
 
 
11
  type: mteb/amazon_counterfactual
12
+ config: en
13
+ split: test
14
+ revision: e8379541af4e31359cca9fbcf4b00f2671dba205
15
  metrics:
16
  - type: accuracy
17
  value: 87.01492537313433
 
25
  value: 87.74802754480477
26
  - type: main_score
27
  value: 87.01492537313433
28
+ - task:
29
  type: Classification
30
+ dataset:
 
31
  name: MTEB AmazonPolarityClassification (default)
 
 
32
  type: mteb/amazon_polarity
33
+ config: default
34
+ split: test
35
+ revision: e2d317d38cd51312af73b3d32a06d1a08b442046
36
  metrics:
37
  - type: accuracy
38
  value: 94.652275
 
46
  value: 94.64655930708355
47
  - type: main_score
48
  value: 94.652275
49
+ - task:
50
  type: Classification
51
+ dataset:
 
52
  name: MTEB AmazonReviewsClassification (en)
 
 
53
  type: mteb/amazon_reviews_multi
54
+ config: en
55
+ split: test
56
+ revision: 1399c76144fd37290681b995c656ef9b2e06e26d
57
  metrics:
58
  - type: accuracy
59
  value: 55.75599999999999
 
63
  value: 55.07058630829347
64
  - type: main_score
65
  value: 55.75599999999999
66
+ - task:
67
+ type: Retrieval
68
+ dataset:
 
69
  name: MTEB ArguAna (default)
 
 
70
  type: mteb/arguana
71
+ config: default
72
+ split: test
73
+ revision: c22ab2a51041ffd869aaddef7af8d8215647e41a
74
  metrics:
75
  - type: main_score
76
  value: 69.959
 
354
  value: 74.182
355
  - type: recall_at_5
356
  value: 84.495
357
+ - task:
358
+ type: Clustering
359
+ dataset:
 
360
  name: MTEB ArxivClusteringP2P (default)
 
 
361
  type: mteb/arxiv-clustering-p2p
362
+ config: default
363
+ split: test
364
+ revision: a122ad7f3f0291bf49cc6f4d32aa80929df69d5d
365
  metrics:
366
  - type: main_score
367
  value: 48.54672141116669
 
369
  value: 48.54672141116669
370
  - type: v_measure_std
371
  value: 14.037498386768362
372
+ - task:
373
  type: Clustering
374
+ dataset:
 
375
  name: MTEB ArxivClusteringS2S (default)
 
 
376
  type: mteb/arxiv-clustering-s2s
377
+ config: default
378
+ split: test
379
+ revision: f910caf1a6075f7329cdf8c1a6135696f37dbd53
380
  metrics:
381
  - type: main_score
382
  value: 40.5914039166466
 
384
  value: 40.5914039166466
385
  - type: v_measure_std
386
  value: 14.385069818910331
387
+ - task:
388
+ type: Reranking
389
+ dataset:
 
390
  name: MTEB AskUbuntuDupQuestions (default)
 
 
391
  type: mteb/askubuntudupquestions-reranking
392
+ config: default
393
+ split: test
394
+ revision: 2000358ca161889fa9c082cb41daa8dcfb161a54
395
  metrics:
396
  - type: main_score
397
  value: 61.13621260261507
 
411
  value: 31.484257486448364
412
  - type: nAUC_mrr_std
413
  value: 21.252659250011632
414
+ - task:
415
+ type: STS
416
+ dataset:
 
417
  name: MTEB BIOSSES (default)
 
 
418
  type: mteb/biosses-sts
419
+ config: default
420
+ split: test
421
+ revision: d3fb88f8f02e40887cd149695127462bbcf29b4a
422
  metrics:
423
  - type: cosine_pearson
424
  value: 89.07028016646942
 
438
  value: 89.07028016646942
439
  - type: spearman
440
  value: 86.69595132967805
441
+ - task:
442
+ type: Classification
443
+ dataset:
 
444
  name: MTEB Banking77Classification (default)
 
 
445
  type: mteb/banking77
446
+ config: default
447
+ split: test
448
+ revision: 0fd18e25b25c072e09e0d92ab615fda904d66300
449
  metrics:
450
  - type: accuracy
451
  value: 88.6038961038961
 
455
  value: 88.56824205739822
456
  - type: main_score
457
  value: 88.6038961038961
458
+ - task:
459
+ type: Clustering
460
+ dataset:
 
461
  name: MTEB BiorxivClusteringP2P (default)
 
 
462
  type: mteb/biorxiv-clustering-p2p
463
+ config: default
464
+ split: test
465
+ revision: 65b79d1d13f80053f67aca9498d9402c2d9f1f40
466
  metrics:
467
  - type: main_score
468
  value: 44.77800814327256
 
470
  value: 44.77800814327256
471
  - type: v_measure_std
472
  value: 0.6462535527471919
473
+ - task:
474
  type: Clustering
475
+ dataset:
 
476
  name: MTEB BiorxivClusteringS2S (default)
 
 
477
  type: mteb/biorxiv-clustering-s2s
478
+ config: default
479
+ split: test
480
+ revision: 258694dd0231531bc1fd9de6ceb52a0853c6d908
481
  metrics:
482
  - type: main_score
483
  value: 38.16110272459102
 
485
  value: 38.16110272459102
486
  - type: v_measure_std
487
  value: 0.7456916212435019
488
+ - task:
489
+ type: Retrieval
490
+ dataset:
 
491
  name: MTEB CQADupstackAndroidRetrieval (default)
 
 
492
  type: mteb/cqadupstack-android
493
+ config: default
494
+ split: test
495
+ revision: f46a197baaae43b4f621051089b82a364682dfeb
496
  metrics:
497
  - type: main_score
498
  value: 49.376
 
776
  value: 47.591
777
  - type: recall_at_5
778
  value: 54.245
779
+ - task:
780
  type: Retrieval
781
+ dataset:
 
782
  name: MTEB CQADupstackEnglishRetrieval (default)
 
 
783
  type: mteb/cqadupstack-english
784
+ config: default
785
+ split: test
786
+ revision: ad9991cb51e31e31e430383c75ffb2885547b5f0
787
  metrics:
788
  - type: main_score
789
  value: 44.727
 
1067
  value: 42.085
1068
  - type: recall_at_5
1069
  value: 47.5
1070
+ - task:
1071
  type: Retrieval
1072
+ dataset:
 
1073
  name: MTEB CQADupstackGamingRetrieval (default)
 
 
1074
  type: mteb/cqadupstack-gaming
1075
+ config: default
1076
+ split: test
1077
+ revision: 4885aa143210c98657558c04aaf3dc47cfb54340
1078
  metrics:
1079
  - type: main_score
1080
  value: 59.001999999999995
 
1358
  value: 57.916000000000004
1359
  - type: recall_at_5
1360
  value: 65.44
1361
+ - task:
1362
  type: Retrieval
1363
+ dataset:
 
1364
  name: MTEB CQADupstackGisRetrieval (default)
 
 
1365
  type: mteb/cqadupstack-gis
1366
+ config: default
1367
+ split: test
1368
+ revision: 5003b3064772da1887988e05400cf3806fe491f2
1369
  metrics:
1370
  - type: main_score
1371
  value: 37.501
 
1649
  value: 37.218
1650
  - type: recall_at_5
1651
  value: 42.559000000000005
1652
+ - task:
1653
  type: Retrieval
1654
+ dataset:
 
1655
  name: MTEB CQADupstackMathematicaRetrieval (default)
 
 
1656
  type: mteb/cqadupstack-mathematica
1657
+ config: default
1658
+ split: test
1659
+ revision: 90fceea13679c63fe563ded68f3b6f06e50061de
1660
  metrics:
1661
  - type: main_score
1662
  value: 27.653
 
1940
  value: 25.469
1941
  - type: recall_at_5
1942
  value: 31.316
1943
+ - task:
1944
  type: Retrieval
1945
+ dataset:
 
1946
  name: MTEB CQADupstackPhysicsRetrieval (default)
 
 
1947
  type: mteb/cqadupstack-physics
1948
+ config: default
1949
+ split: test
1950
+ revision: 79531abbd1fb92d06c6d6315a0cbbbf5bb247ea4
1951
  metrics:
1952
  - type: main_score
1953
  value: 45.314
 
2231
  value: 43.679
2232
  - type: recall_at_5
2233
  value: 49.735
2234
+ - task:
2235
  type: Retrieval
2236
+ dataset:
 
2237
  name: MTEB CQADupstackProgrammersRetrieval (default)
 
 
2238
  type: mteb/cqadupstack-programmers
2239
+ config: default
2240
+ split: test
2241
+ revision: 6184bc1440d2dbc7612be22b50686b8826d22b32
2242
  metrics:
2243
  - type: main_score
2244
  value: 41.972
 
2522
  value: 39.363
2523
  - type: recall_at_5
2524
  value: 44.665
2525
+ - task:
2526
  type: Retrieval
2527
+ dataset:
 
2528
  name: MTEB CQADupstackRetrieval (default)
 
 
2529
  type: CQADupstackRetrieval_is_a_combined_dataset
2530
+ config: default
2531
+ split: test
2532
+ revision: CQADupstackRetrieval_is_a_combined_dataset
2533
  metrics:
2534
  - type: main_score
2535
  value: 39.823499999999996
2536
  - type: ndcg_at_10
2537
  value: 39.823499999999996
2538
+ - task:
2539
  type: Retrieval
2540
+ dataset:
 
2541
  name: MTEB CQADupstackStatsRetrieval (default)
 
 
2542
  type: mteb/cqadupstack-stats
2543
+ config: default
2544
+ split: test
2545
+ revision: 65ac3a16b8e91f9cee4c9828cc7c335575432a2a
2546
  metrics:
2547
  - type: main_score
2548
  value: 34.943000000000005
 
2826
  value: 33.427
2827
  - type: recall_at_5
2828
  value: 37.643
2829
+ - task:
2830
  type: Retrieval
2831
+ dataset:
 
2832
  name: MTEB CQADupstackTexRetrieval (default)
 
 
2833
  type: mteb/cqadupstack-tex
2834
+ config: default
2835
+ split: test
2836
+ revision: 46989137a86843e03a6195de44b09deda022eec7
2837
  metrics:
2838
  - type: main_score
2839
  value: 27.271
 
3117
  value: 25.592
3118
  - type: recall_at_5
3119
  value: 30.279
3120
+ - task:
3121
  type: Retrieval
3122
+ dataset:
 
3123
  name: MTEB CQADupstackUnixRetrieval (default)
 
 
3124
  type: mteb/cqadupstack-unix
3125
+ config: default
3126
+ split: test
3127
+ revision: 6c6430d3a6d36f8d2a829195bc5dc94d7e063e53
3128
  metrics:
3129
  - type: main_score
3130
  value: 38.237
 
3408
  value: 36.275
3409
  - type: recall_at_5
3410
  value: 42.199
3411
+ - task:
3412
  type: Retrieval
3413
+ dataset:
 
3414
  name: MTEB CQADupstackWebmastersRetrieval (default)
 
 
3415
  type: mteb/cqadupstack-webmasters
3416
+ config: default
3417
+ split: test
3418
+ revision: 160c094312a0e1facb97e55eeddb698c0abe3571
3419
  metrics:
3420
  - type: main_score
3421
  value: 38.702
 
3699
  value: 37.634
3700
  - type: recall_at_5
3701
  value: 42.021
3702
+ - task:
3703
  type: Retrieval
3704
+ dataset:
 
3705
  name: MTEB CQADupstackWordpressRetrieval (default)
 
 
3706
  type: mteb/cqadupstack-wordpress
3707
+ config: default
3708
+ split: test
3709
+ revision: 4ffe81d471b1924886b33c7567bfb200e9eec5c4
3710
  metrics:
3711
  - type: main_score
3712
  value: 33.184000000000005
 
3990
  value: 32.683
3991
  - type: recall_at_5
3992
  value: 36.756
3993
+ - task:
3994
  type: Retrieval
3995
+ dataset:
 
3996
  name: MTEB ClimateFEVER (default)
 
 
3997
  type: mteb/climate-fever
3998
+ config: default
3999
+ split: test
4000
+ revision: 47f2ac6acb640fc46020b02a5b59fdda04d39380
4001
  metrics:
4002
  - type: main_score
4003
  value: 25.068
 
4281
  value: 18.312
4282
  - type: recall_at_5
4283
  value: 22.776
4284
+ - task:
4285
  type: Retrieval
4286
+ dataset:
 
4287
  name: MTEB DBPedia (default)
 
 
4288
  type: mteb/dbpedia
4289
+ config: default
4290
+ split: test
4291
+ revision: c0f706b76e590d620bd6618b3ca8efdd34e2d659
4292
  metrics:
4293
  - type: main_score
4294
  value: 40.128
 
4572
  value: 14.562
4573
  - type: recall_at_5
4574
  value: 18.779
4575
+ - task:
4576
+ type: Classification
4577
+ dataset:
 
4578
  name: MTEB EmotionClassification (default)
 
 
4579
  type: mteb/emotion
4580
+ config: default
4581
+ split: test
4582
+ revision: 4f58c6b202a23cf9a4da393831edf4f9183cad37
4583
  metrics:
4584
  - type: accuracy
4585
  value: 74.86
 
4589
  value: 75.96499621761998
4590
  - type: main_score
4591
  value: 74.86
4592
+ - task:
4593
+ type: Retrieval
4594
+ dataset:
 
4595
  name: MTEB FEVER (default)
 
 
4596
  type: mteb/fever
4597
+ config: default
4598
+ split: test
4599
+ revision: bea83ef9e8fb933d90a2f1d5515737465d613e12
4600
  metrics:
4601
  - type: main_score
4602
  value: 86.029
 
4880
  value: 88.382
4881
  - type: recall_at_5
4882
  value: 90.908
4883
+ - task:
4884
  type: Retrieval
4885
+ dataset:
 
4886
  name: MTEB FiQA2018 (default)
 
 
4887
  type: mteb/fiqa
4888
+ config: default
4889
+ split: test
4890
+ revision: 27a168819829fe9bcd655c2df245fb19452e8e06
4891
  metrics:
4892
  - type: main_score
4893
  value: 45.238
 
5171
  value: 37.656
5172
  - type: recall_at_5
5173
  value: 44.766
5174
+ - task:
5175
  type: Retrieval
5176
+ dataset:
 
5177
  name: MTEB HotpotQA (default)
 
 
5178
  type: mteb/hotpotqa
5179
+ config: default
5180
+ split: test
5181
+ revision: ab518f4d6fcca38d87c25209f94beba119d02014
5182
  metrics:
5183
  - type: main_score
5184
  value: 66.672
 
5462
  value: 57.522
5463
  - type: recall_at_5
5464
  value: 62.134
5465
+ - task:
5466
+ type: Classification
5467
+ dataset:
 
5468
  name: MTEB ImdbClassification (default)
 
 
5469
  type: mteb/imdb
5470
+ config: default
5471
+ split: test
5472
+ revision: 3d86128a09e091d6018b6d26cad27f2739fc2db7
5473
  metrics:
5474
  - type: accuracy
5475
  value: 93.5944
 
5483
  value: 93.58945949328377
5484
  - type: main_score
5485
  value: 93.5944
5486
+ - task:
5487
+ type: Retrieval
5488
+ dataset:
 
5489
  name: MTEB MSMARCO (default)
 
 
5490
  type: mteb/msmarco
5491
+ config: default
5492
+ split: dev
5493
+ revision: c5a29a104738b98a9e76336939199e264163d4a0
5494
  metrics:
5495
  - type: main_score
5496
  value: 41.448
 
5774
  value: 41.304
5775
  - type: recall_at_5
5776
  value: 51.076
5777
+ - task:
5778
+ type: Classification
5779
+ dataset:
 
5780
  name: MTEB MTOPDomainClassification (en)
 
 
5781
  type: mteb/mtop_domain
5782
+ config: en
5783
+ split: test
5784
+ revision: d80d48c1eb48d3562165c59d59d0034df9fff0bf
5785
  metrics:
5786
  - type: accuracy
5787
  value: 96.03967168262655
 
5791
  value: 96.06623245823347
5792
  - type: main_score
5793
  value: 96.03967168262655
5794
+ - task:
5795
  type: Classification
5796
+ dataset:
 
5797
  name: MTEB MTOPIntentClassification (en)
 
 
5798
  type: mteb/mtop_intent
5799
+ config: en
5800
+ split: test
5801
+ revision: ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba
5802
  metrics:
5803
  - type: accuracy
5804
  value: 89.12904696762428
 
5808
  value: 90.41290566743324
5809
  - type: main_score
5810
  value: 89.12904696762428
5811
+ - task:
5812
  type: Classification
5813
+ dataset:
 
5814
  name: MTEB MassiveIntentClassification (en)
 
 
5815
  type: mteb/amazon_massive_intent
5816
+ config: en
5817
+ split: test
5818
+ revision: 4672e20407010da34463acc759c162ca9734bca6
5819
  metrics:
5820
  - type: accuracy
5821
  value: 76.49630127774041
 
5825
  value: 76.42436195016484
5826
  - type: main_score
5827
  value: 76.49630127774041
5828
+ - task:
5829
  type: Classification
5830
+ dataset:
 
5831
  name: MTEB MassiveScenarioClassification (en)
 
 
5832
  type: mteb/amazon_massive_scenario
5833
+ config: en
5834
+ split: test
5835
+ revision: fad2c6e8459f9e1c45d9315f4953d921437d70f8
5836
  metrics:
5837
  - type: accuracy
5838
  value: 78.9340954942838
 
5842
  value: 78.87787647838971
5843
  - type: main_score
5844
  value: 78.9340954942838
5845
+ - task:
5846
+ type: Clustering
5847
+ dataset:
 
5848
  name: MTEB MedrxivClusteringP2P (default)
 
 
5849
  type: mteb/medrxiv-clustering-p2p
5850
+ config: default
5851
+ split: test
5852
+ revision: e7a26af6f3ae46b30dde8737f02c07b1505bcc73
5853
  metrics:
5854
  - type: main_score
5855
  value: 37.50182848656019
 
5857
  value: 37.50182848656019
5858
  - type: v_measure_std
5859
  value: 1.1708518023877268
5860
+ - task:
5861
  type: Clustering
5862
+ dataset:
 
5863
  name: MTEB MedrxivClusteringS2S (default)
 
 
5864
  type: mteb/medrxiv-clustering-s2s
5865
+ config: default
5866
+ split: test
5867
+ revision: 35191c8c0dca72d8ff3efcd72aa802307d469663
5868
  metrics:
5869
  - type: main_score
5870
  value: 35.72762609825363
 
5872
  value: 35.72762609825363
5873
  - type: v_measure_std
5874
  value: 1.4555014772914985
5875
+ - task:
5876
+ type: Reranking
5877
+ dataset:
 
5878
  name: MTEB MindSmallReranking (default)
 
 
5879
  type: mteb/mind_small
5880
+ config: default
5881
+ split: test
5882
+ revision: 59042f120c80e8afa9cdbb224f67076cec0fc9a7
5883
  metrics:
5884
  - type: main_score
5885
  value: 30.47716416454022
 
5899
  value: -15.78941850629242
5900
  - type: nAUC_mrr_std
5901
  value: -1.1330442292510805
5902
+ - task:
5903
+ type: Retrieval
5904
+ dataset:
 
5905
  name: MTEB NFCorpus (default)
 
 
5906
  type: mteb/nfcorpus
5907
+ config: default
5908
+ split: test
5909
+ revision: ec0fa4fe99da2ff19ca1214b7966684033a58814
5910
  metrics:
5911
  - type: main_score
5912
  value: 34.648
 
6190
  value: 10.037
6191
  - type: recall_at_5
6192
  value: 12.717999999999998
6193
+ - task:
6194
  type: Retrieval
6195
+ dataset:
 
6196
  name: MTEB NQ (default)
 
 
6197
  type: mteb/nq
6198
+ config: default
6199
+ split: test
6200
+ revision: b774495ed302d8c44a3a7ea25c90dbce03968f31
6201
  metrics:
6202
  - type: main_score
6203
  value: 60.06
 
6481
  value: 61.114000000000004
6482
  - type: recall_at_5
6483
  value: 69.812
6484
+ - task:
6485
  type: Retrieval
6486
+ dataset:
 
6487
  name: MTEB QuoraRetrieval (default)
 
 
6488
  type: mteb/quora
6489
+ config: default
6490
+ split: test
6491
+ revision: e4e08e0b7dbe3c8700f0daef558ff32256715259
6492
  metrics:
6493
  - type: main_score
6494
  value: 89.821
 
6772
  value: 88.714
6773
  - type: recall_at_5
6774
  value: 92.96799999999999
6775
+ - task:
6776
+ type: Clustering
6777
+ dataset:
 
6778
  name: MTEB RedditClustering (default)
 
 
6779
  type: mteb/reddit-clustering
6780
+ config: default
6781
+ split: test
6782
+ revision: 24640382cdbf8abc73003fb0fa6d111a705499eb
6783
  metrics:
6784
  - type: main_score
6785
  value: 59.36038828851887
 
6787
  value: 59.36038828851887
6788
  - type: v_measure_std
6789
  value: 4.1958765965154425
6790
+ - task:
6791
  type: Clustering
6792
+ dataset:
 
6793
  name: MTEB RedditClusteringP2P (default)
 
 
6794
  type: mteb/reddit-clustering-p2p
6795
+ config: default
6796
+ split: test
6797
+ revision: 385e3cb46b4cfa89021f56c4380204149d0efe33
6798
  metrics:
6799
  - type: main_score
6800
  value: 64.67522832408089
 
6802
  value: 64.67522832408089
6803
  - type: v_measure_std
6804
  value: 12.473765016158698
6805
+ - task:
6806
+ type: Retrieval
6807
+ dataset:
 
6808
  name: MTEB SCIDOCS (default)
 
 
6809
  type: mteb/scidocs
6810
+ config: default
6811
+ split: test
6812
+ revision: f8c2fcf00f625baaa80f62ec5bd9e1fff3b8ae88
6813
  metrics:
6814
  - type: main_score
6815
  value: 21.751
 
7093
  value: 11.648
7094
  - type: recall_at_5
7095
  value: 15.883
7096
+ - task:
7097
+ type: STS
7098
+ dataset:
 
7099
  name: MTEB SICK-R (default)
 
 
7100
  type: mteb/sickr-sts
7101
+ config: default
7102
+ split: test
7103
+ revision: 20a6d6f312dd54037fe07a32d58e5e168867909d
7104
  metrics:
7105
  - type: cosine_pearson
7106
  value: 84.0161170579997
 
7120
  value: 84.0161170579997
7121
  - type: spearman
7122
  value: 77.52025923874551
7123
+ - task:
7124
  type: STS
7125
+ dataset:
 
7126
  name: MTEB STS12 (default)
 
 
7127
  type: mteb/sts12-sts
7128
+ config: default
7129
+ split: test
7130
+ revision: a0d554a64d88156834ff5ae9920b964011b16384
7131
  metrics:
7132
  - type: cosine_pearson
7133
  value: 81.32328780209225
 
7147
  value: 81.32328780209225
7148
  - type: spearman
7149
  value: 74.17570679745272
7150
+ - task:
7151
  type: STS
7152
+ dataset:
 
7153
  name: MTEB STS13 (default)
 
 
7154
  type: mteb/sts13-sts
7155
+ config: default
7156
+ split: test
7157
+ revision: 7e90230a92c190f1bf69ae9002b8cea547a64cca
7158
  metrics:
7159
  - type: cosine_pearson
7160
  value: 85.53224141249392
 
7174
  value: 85.53224141249392
7175
  - type: spearman
7176
  value: 86.16981525069227
7177
+ - task:
7178
  type: STS
7179
+ dataset:
 
7180
  name: MTEB STS14 (default)
 
 
7181
  type: mteb/sts14-sts
7182
+ config: default
7183
+ split: test
7184
+ revision: 6031580fec1f6af667f0bd2da0a551cf4f0b2375
7185
  metrics:
7186
  - type: cosine_pearson
7187
  value: 82.234064045301
 
7201
  value: 82.234064045301
7202
  - type: spearman
7203
  value: 78.86920830792957
7204
+ - task:
7205
  type: STS
7206
+ dataset:
 
7207
  name: MTEB STS15 (default)
 
 
7208
  type: mteb/sts15-sts
7209
+ config: default
7210
+ split: test
7211
+ revision: ae752c7c21bf194d8b67fd573edf7ae58183cbe3
7212
  metrics:
7213
  - type: cosine_pearson
7214
  value: 86.23114543080261
 
7228
  value: 86.23114543080261
7229
  - type: spearman
7230
  value: 87.481042891123
7231
+ - task:
7232
  type: STS
7233
+ dataset:
 
7234
  name: MTEB STS16 (default)
 
 
7235
  type: mteb/sts16-sts
7236
+ config: default
7237
+ split: test
7238
+ revision: 4d8694f8f0e0100860b497b999b3dbed754a0513
7239
  metrics:
7240
  - type: cosine_pearson
7241
  value: 82.9156629047782
 
7255
  value: 82.9156629047782
7256
  - type: spearman
7257
  value: 84.28381329207937
7258
+ - task:
7259
  type: STS
7260
+ dataset:
 
7261
  name: MTEB STS17 (en-en)
 
 
7262
  type: mteb/sts17-crosslingual-sts
7263
+ config: en-en
7264
+ split: test
7265
+ revision: faeb762787bd10488a50c8b5be4a3b82e411949c
7266
  metrics:
7267
  - type: cosine_pearson
7268
  value: 88.91985349746744
 
7282
  value: 88.91985349746744
7283
  - type: spearman
7284
  value: 89.69151633966257
7285
+ - task:
7286
  type: STS
7287
+ dataset:
 
7288
  name: MTEB STS22 (en)
 
 
7289
  type: mteb/sts22-crosslingual-sts
7290
+ config: en
7291
+ split: test
7292
+ revision: de9d86b3b84231dc21f76c7b7af1f28e2f57f6e3
7293
  metrics:
7294
  - type: cosine_pearson
7295
  value: 65.0979772547511
 
7309
  value: 65.0979772547511
7310
  - type: spearman
7311
  value: 65.78126527764236
7312
+ - task:
7313
  type: STS
7314
+ dataset:
 
7315
  name: MTEB STSBenchmark (default)
 
 
7316
  type: mteb/stsbenchmark-sts
7317
+ config: default
7318
+ split: test
7319
+ revision: b0fddb56ed78048fa8b90373c8a3cfc37b684831
7320
  metrics:
7321
  - type: cosine_pearson
7322
  value: 85.6426635049971
 
7336
  value: 85.6426635049971
7337
  - type: spearman
7338
  value: 85.609856578385
7339
+ - task:
7340
+ type: Reranking
7341
+ dataset:
 
7342
  name: MTEB SciDocsRR (default)
 
 
7343
  type: mteb/scidocs-reranking
7344
+ config: default
7345
+ split: test
7346
+ revision: d3c5e1fc0b855ab6097bf1cda04dd73947d7caab
7347
  metrics:
7348
  - type: main_score
7349
  value: 82.85163332499799
 
7363
  value: 89.47202967481866
7364
  - type: nAUC_mrr_std
7365
  value: 85.40446996933892
7366
+ - task:
7367
+ type: Retrieval
7368
+ dataset:
 
7369
  name: MTEB SciFact (default)
 
 
7370
  type: mteb/scifact
7371
+ config: default
7372
+ split: test
7373
+ revision: 0228b52cf27578f30900b9e5271d331663a030d7
7374
  metrics:
7375
  - type: main_score
7376
  value: 71.655
 
7654
  value: 71.61699999999999
7655
  - type: recall_at_5
7656
  value: 78.361
7657
+ - task:
7658
+ type: PairClassification
7659
+ dataset:
 
7660
  name: MTEB SprintDuplicateQuestions (default)
 
 
7661
  type: mteb/sprintduplicatequestions-pairclassification
7662
+ config: default
7663
+ split: test
7664
+ revision: d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46
7665
  metrics:
7666
  - type: cosine_accuracy
7667
  value: 99.8019801980198
 
7745
  value: 90.79754601226993
7746
  - type: similarity_recall
7747
  value: 88.8
7748
+ - task:
7749
+ type: Clustering
7750
+ dataset:
 
7751
  name: MTEB StackExchangeClustering (default)
 
 
7752
  type: mteb/stackexchange-clustering
7753
+ config: default
7754
+ split: test
7755
+ revision: 6cbc1f7b2bc0622f2e39d2c77fa502909748c259
7756
  metrics:
7757
  - type: main_score
7758
  value: 66.63931197758824
 
7760
  value: 66.63931197758824
7761
  - type: v_measure_std
7762
  value: 3.896206781511776
7763
+ - task:
7764
  type: Clustering
7765
+ dataset:
 
7766
  name: MTEB StackExchangeClusteringP2P (default)
 
 
7767
  type: mteb/stackexchange-clustering-p2p
7768
+ config: default
7769
+ split: test
7770
+ revision: 815ca46b2622cec33ccafc3735d572c266efdb44
7771
  metrics:
7772
  - type: main_score
7773
  value: 38.984892653301884
 
7775
  value: 38.984892653301884
7776
  - type: v_measure_std
7777
  value: 1.3308552162270453
7778
+ - task:
7779
+ type: Reranking
7780
+ dataset:
 
7781
  name: MTEB StackOverflowDupQuestions (default)
 
 
7782
  type: mteb/stackoverflowdupquestions-reranking
7783
+ config: default
7784
+ split: test
7785
+ revision: e185fbe320c72810689fc5848eb6114e1ef5ec69
7786
  metrics:
7787
  - type: main_score
7788
  value: 52.71499643455044
 
7802
  value: 13.931448578334379
7803
  - type: nAUC_mrr_std
7804
  value: 10.441860004959661
7805
+ - task:
7806
+ type: Summarization
7807
+ dataset:
 
7808
  name: MTEB SummEval (default)
 
 
7809
  type: mteb/summeval
7810
+ config: default
7811
+ split: test
7812
+ revision: cda12ad7615edc362dbf25a00fdd61d3b1eaf93c
7813
  metrics:
7814
  - type: cosine_pearson
7815
  value: 31.5167525286909
 
7825
  value: 31.5167525286909
7826
  - type: spearman
7827
  value: 31.218862970706496
7828
+ - task:
7829
+ type: Retrieval
7830
+ dataset:
 
7831
  name: MTEB TRECCOVID (default)
 
 
7832
  type: mteb/trec-covid
7833
+ config: default
7834
+ split: test
7835
+ revision: bb9466bac8153a0349341eb1b22e06409e78ef4e
7836
  metrics:
7837
  - type: main_score
7838
  value: 78.996
 
8116
  value: 0.705
8117
  - type: recall_at_5
8118
  value: 1.162
8119
+ - task:
8120
  type: Retrieval
8121
+ dataset:
 
8122
  name: MTEB Touche2020 (default)
 
 
8123
  type: mteb/touche2020
8124
+ config: default
8125
+ split: test
8126
+ revision: a34f9a33db75fa0cbb21bb5cfc3dae8dc8bec93f
8127
  metrics:
8128
  - type: main_score
8129
  value: 24.234
 
8407
  value: 6.625
8408
  - type: recall_at_5
8409
  value: 9.094
8410
+ - task:
8411
+ type: Classification
8412
+ dataset:
 
8413
  name: MTEB ToxicConversationsClassification (default)
 
 
8414
  type: mteb/toxic_conversations_50k
8415
+ config: default
8416
+ split: test
8417
+ revision: edfaf9da55d3dd50d43143d90c1ac476895ae6de
8418
  metrics:
8419
  - type: accuracy
8420
  value: 72.822265625
 
8428
  value: 78.7454393727821
8429
  - type: main_score
8430
  value: 72.822265625
8431
+ - task:
8432
  type: Classification
8433
+ dataset:
 
8434
  name: MTEB TweetSentimentExtractionClassification (default)
 
 
8435
  type: mteb/tweet_sentiment_extraction
8436
+ config: default
8437
+ split: test
8438
+ revision: d604517c81ca91fe16a244d1248fc021f9ecee7a
8439
  metrics:
8440
  - type: accuracy
8441
  value: 72.54385964912281
 
8445
  value: 72.18022450339639
8446
  - type: main_score
8447
  value: 72.54385964912281
8448
+ - task:
8449
+ type: Clustering
8450
+ dataset:
 
8451
  name: MTEB TwentyNewsgroupsClustering (default)
 
 
8452
  type: mteb/twentynewsgroups-clustering
8453
+ config: default
8454
+ split: test
8455
+ revision: 6125ec4e24fa026cec8a478383ee943acfbd5449
8456
  metrics:
8457
  - type: main_score
8458
  value: 57.41861450414374
 
8460
  value: 57.41861450414374
8461
  - type: v_measure_std
8462
  value: 1.1732394227153524
8463
+ - task:
8464
+ type: PairClassification
8465
+ dataset:
 
8466
  name: MTEB TwitterSemEval2015 (default)
 
 
8467
  type: mteb/twittersemeval2015-pairclassification
8468
+ config: default
8469
+ split: test
8470
+ revision: 70970daeab8776df92f5ea462b6173c0b46fd2d1
8471
  metrics:
8472
  - type: cosine_accuracy
8473
  value: 85.65893783155511
 
8551
  value: 64.0855106888361
8552
  - type: similarity_recall
8553
  value: 71.18733509234828
8554
+ - task:
8555
  type: PairClassification
8556
+ dataset:
 
8557
  name: MTEB TwitterURLCorpus (default)
 
 
8558
  type: mteb/twitterurlcorpus-pairclassification
8559
+ config: default
8560
+ split: test
8561
+ revision: 8b6510b0b1fa4e4c4f879467980e9be563ec1cdf
8562
  metrics:
8563
  - type: cosine_accuracy
8564
  value: 88.86754375751931
 
8642
  value: 74.19310344827586
8643
  - type: similarity_recall
8644
  value: 82.83030489682784
 
 
8645
  ---
8646
  # Contextual Document Embeddings (CDE)
8647
 
config.json CHANGED
@@ -1,8 +1,14 @@
1
  {
 
2
  "architecture": "transductive",
3
  "architectures": [
4
- "DatasetConditionedBiencoder"
5
  ],
 
 
 
 
 
6
  "biencoder_pooling_strategy": "mean",
7
  "cache_dir": null,
8
  "config_name": null,
 
1
  {
2
+ "_name_or_path": "/fsx-checkpoints/jxm/cde/2024-09-18-supervised-final-bge--epoch-4/checkpoint-1820",
3
  "architecture": "transductive",
4
  "architectures": [
5
+ "DatasetTransformer"
6
  ],
7
+ "attn_implementation": null,
8
+ "auto_map": {
9
+ "AutoConfig": "misc.ContextualModelConfig",
10
+ "AutoModel": "model.DatasetTransformer"
11
+ },
12
  "biencoder_pooling_strategy": "mean",
13
  "cache_dir": null,
14
  "config_name": null,
misc.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterable, List, Tuple, Union
2
+
3
+ import collections
4
+ import functools
5
+ import glob
6
+ import json
7
+ import hashlib
8
+ import itertools
9
+ import logging
10
+ import multiprocessing
11
+ import os
12
+ import pickle
13
+ import random
14
+ import requests
15
+ import sys
16
+ import zipfile
17
+
18
+ import datasets
19
+ import numpy as np
20
+ import safetensors
21
+ import torch
22
+ import tqdm
23
+ import transformers
24
+
25
+ from cde.lib.dist import get_num_proc, get_rank
26
+
27
+
28
+ def get_cde_cache_dir() -> str:
29
+ script_directory = os.path.normpath(
30
+ os.path.join(
31
+ os.path.dirname(os.path.abspath(__file__)),
32
+ os.pardir, os.pardir,
33
+ )
34
+ )
35
+ return os.path.join(script_directory, "data")
36
+
37
+
38
+ def get_cache_location_from_kwargs(**kwargs):
39
+ cache_location = os.path.join(
40
+ get_cde_cache_dir(), "cluster"
41
+ )
42
+ os.makedirs(cache_location, exist_ok=True)
43
+ return os.path.join(cache_location, md5_hash_kwargs(**kwargs))
44
+
45
+
46
+ def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
47
+ qrels_idxs = collections.defaultdict(list)
48
+ qrels_scores = collections.defaultdict(list)
49
+ corpus_ids = np.array(corpus['_id'])
50
+ skipped_qrels = 0
51
+
52
+ for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False):
53
+ #
54
+ # example:
55
+ # {
56
+ # 'query-id': 1,
57
+ # 'corpus-id': 'b0680508-2019-04-18T13:48:51Z-00002-000',
58
+ # 'score': 2
59
+ # }
60
+ #
61
+ q_id = str(ex['query-id'])
62
+ c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0]
63
+ #
64
+ assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)"
65
+ #
66
+ if len(c_idxs):
67
+ qrels_idxs[q_id].append(c_idxs[0])
68
+ qrels_scores[q_id].append(ex['score'])
69
+ else:
70
+ skipped_qrels += 1
71
+ #
72
+
73
+ if skipped_qrels > 0:
74
+ logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.')
75
+
76
+ return qrels_idxs, qrels_scores
77
+
78
+
79
+ def process_qrels(
80
+ corpus: datasets.Dataset, qrels: datasets.Dataset,
81
+ use_cache: bool = True
82
+ ) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
83
+ dataset_cache_file = '_'.join(
84
+ (corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename'])
85
+ )
86
+ cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p'
87
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
88
+
89
+ if not (use_cache and os.path.exists(cache_file)):
90
+ qrels_idxs, qrels_scores = process_qrels_uncached(
91
+ corpus=corpus, qrels=qrels
92
+ )
93
+ if use_cache:
94
+ pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb'))
95
+ else:
96
+ qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb'))
97
+
98
+ return qrels_idxs, qrels_scores
99
+
100
+
101
+ def strip_extension(filename: str) -> str:
102
+ """Strips file extension.
103
+
104
+ Ex:
105
+ >> strip_extension('/root/dir/sub/file.ext')
106
+ '/root/dir/sub/file'
107
+ """
108
+ return os.path.splitext(filename)[0]
109
+
110
+
111
+ def md5_hash(t: Tuple[str]) -> str:
112
+ return hashlib.md5('__'.join(t).encode()).hexdigest()
113
+
114
+
115
+ def md5_hash_kwargs(**kwargs) -> str:
116
+ # We ignore special hf args that start with _ like '__cached__setup_devices'.
117
+ safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')}
118
+ s = json.dumps(safe_kwargs, sort_keys=True)
119
+ return hashlib.md5(s.encode()).hexdigest()
120
+
121
+ def download_url(url: str, save_path: str, chunk_size: int = 1024):
122
+ """Download url with progress bar using tqdm
123
+ https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
124
+ Args:
125
+ url (str): downloadable url
126
+ save_path (str): local path to save the downloaded file
127
+ chunk_size (int, optional): chunking of files. Defaults to 1024.
128
+ """
129
+ r = requests.get(url, stream=True)
130
+ total = int(r.headers.get('Content-Length', 0))
131
+ with open(save_path, 'wb') as fd, tqdm.tqdm(
132
+ desc=save_path,
133
+ total=total,
134
+ unit='iB',
135
+ unit_scale=True,
136
+ unit_divisor=chunk_size,
137
+ ) as bar:
138
+ for data in r.iter_content(chunk_size=chunk_size):
139
+ size = fd.write(data)
140
+ bar.update(size)
141
+
142
+
143
+ def unzip(zip_file: str, out_dir: str):
144
+ print("unzipping =>", zip_file)
145
+ zip_ = zipfile.ZipFile(zip_file, "r")
146
+ zip_.extractall(path=out_dir)
147
+ zip_.close()
148
+
149
+
150
+ def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
151
+ os.makedirs(out_dir, exist_ok=True)
152
+ dataset = url.split("/")[-1]
153
+ zip_file = os.path.join(out_dir, dataset)
154
+
155
+ if not os.path.isfile(zip_file):
156
+ logging.info("Downloading {} ...".format(dataset))
157
+ download_url(url, zip_file, chunk_size)
158
+
159
+ if not os.path.isdir(zip_file.replace(".zip", "")):
160
+ logging.info("Unzipping {} ...".format(dataset))
161
+ unzip(zip_file, out_dir)
162
+
163
+ return os.path.join(out_dir, dataset.replace(".zip", ""))
164
+
165
+
166
+ def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable:
167
+ if get_rank() == 0:
168
+ return tqdm.tqdm(iterable, **kwargs)
169
+ else:
170
+ return iterable
171
+
172
+
173
+ class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
174
+ """We create a dummy configuration class that will just set properties
175
+ based on whatever kwargs we pass in.
176
+
177
+ When this class is initialized (see experiments.py) we pass in the
178
+ union of all data, model, and training args, all of which should
179
+ get saved to the config json.
180
+ """
181
+
182
+ def __init__(self, **kwargs):
183
+ for key, value in kwargs.items():
184
+ try:
185
+ json.dumps(value)
186
+ setattr(self, key, value)
187
+ except TypeError:
188
+ # value was not JSON-serializable, skip
189
+ continue
190
+ super().__init__()
191
+
192
+
193
+ def independent_crop(
194
+ input_ids: torch.Tensor, pad_token_id: int,
195
+ l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]:
196
+ """Returns two independent crops from input_ids.
197
+
198
+ Assumes input_ids has a beginning and end token, like
199
+ [101, ..., 102, 0, 0, 0].
200
+
201
+ Args:
202
+ input_ids: tensor of IDs
203
+ pad_token_id: ID of pad tokens in input_ids
204
+ l1: length of span 1, cropped
205
+ l2: length of span 2, cropped
206
+ Returns:
207
+ span1: first crop (of length l1)
208
+ span2: second crop (of length l2)
209
+ """
210
+ # Count tokens until pad.
211
+ if (input_ids == pad_token_id).sum() == 0:
212
+ N = len(input_ids)
213
+ else:
214
+ N = (input_ids == pad_token_id).int().argmax().item()
215
+
216
+ ####
217
+ ###
218
+ ##
219
+ ## Contriever: We use the random cropping data
220
+ ## augmentation, with documents of 256 tokens and span
221
+ ## sizes sampled between 5% and 50% of the document
222
+ ## length
223
+ ##
224
+ ###
225
+ #####
226
+ ####### LaPraDor: The maximum lengths set for queries and
227
+ ####### documents are 64 and 350...
228
+ #####
229
+ # TODO is this divide-by-two a good idea? (Don't want s1=s2 ever..)
230
+ nl1 = min(N//2, l1)
231
+ nl2 = min(N//2, l2)
232
+
233
+ s1_start = random.randint(1, N-nl1)
234
+ s2_start = random.randint(1, N-nl2)
235
+
236
+ s1_idxs = itertools.chain(
237
+ [0], range(s1_start, s1_start+nl1), [N-1]
238
+ )
239
+ s1 = input_ids[torch.tensor(list(s1_idxs))]
240
+ s2_idxs = itertools.chain(
241
+ [0], range(s2_start, s2_start+nl2), [N-1]
242
+ )
243
+ s2 = input_ids[torch.tensor(list(s2_idxs))]
244
+ return (s1, s2)
245
+
246
+
247
+ def load_dataset_tables(
248
+ files: Iterable[str], num_workers: int = 16
249
+ ) -> Iterable[datasets.table.MemoryMappedTable]:
250
+ import concurrent
251
+ from multiprocessing import Pool
252
+
253
+ # num_workers = min(num_workers, len(files))
254
+ num_workers = min(32, len(files))
255
+
256
+ use_threads = True
257
+ if use_threads:
258
+ pool_cls = concurrent.futures.ThreadPoolExecutor
259
+ pool_kwargs = {"max_workers": num_workers}
260
+ else:
261
+ pool_cls = Pool
262
+ pool_kwargs = {"processes": num_workers}
263
+
264
+ with pool_cls(**pool_kwargs) as pool:
265
+ if len(files) > 10:
266
+ files = tqdm_if_main_worker(
267
+ files,
268
+ desc=f"Loading {len(files)} files with {num_workers} workers",
269
+ total=len(files),
270
+ colour="#ffbd88"
271
+ )
272
+
273
+ result = list(
274
+ pool.map(datasets.table.MemoryMappedTable.from_file, files)
275
+ )
276
+ return result
277
+
278
+
279
+ def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset:
280
+ logging.info(f"fast_load_from_disk called with path:", cache_path)
281
+ dataset_info_path = os.path.join(cache_path, "dataset_info.json")
282
+ with open(dataset_info_path, encoding="utf-8") as dataset_info_file:
283
+ dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file))
284
+
285
+ dataset_state_path = os.path.join(cache_path, "state.json")
286
+ with open(dataset_state_path, encoding="utf-8") as state_file:
287
+ state = json.load(state_file)
288
+
289
+ files = glob.glob(os.path.join(cache_path, "data-*.arrow"))
290
+ files = sorted(files)
291
+ num_workers = get_num_proc()
292
+ ds_tables = load_dataset_tables(
293
+ files=files,
294
+ num_workers=num_workers
295
+ )
296
+ arrow_table = datasets.table.concat_tables(ds_tables)
297
+
298
+ split = state["_split"]
299
+ split = datasets.splits.Split(split) if split is not None else split
300
+
301
+ # print("returning dataset")
302
+ return datasets.Dataset(
303
+ arrow_table=arrow_table,
304
+ info=dataset_info,
305
+ split=split,
306
+ fingerprint=state["_fingerprint"],
307
+ )
308
+
309
+
310
+ def tokenize_dataset(
311
+ dataset: datasets.Dataset,
312
+ tokenizer: transformers.PreTrainedTokenizer,
313
+ max_length: int,
314
+ text_key: str,
315
+ padding_strategy: str
316
+ ) -> datasets.Dataset:
317
+ def tokenize_text(ex: Dict) -> Dict:
318
+ tt = tokenizer(
319
+ ex[text_key],
320
+ max_length=max_length,
321
+ truncation=True,
322
+ padding=padding_strategy,
323
+ )
324
+ for k,v in tt.items():
325
+ ex[f"{text_key}_{k}"] = v
326
+ ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]]
327
+ return ex
328
+
329
+ # generate unique hash for tokenizer
330
+ vocab = tokenizer.vocab
331
+ vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word]))
332
+ vocab_hash = md5_hash(vocab_words)
333
+
334
+ data_fingerprint = '__'.join((
335
+ dataset._fingerprint, str(vocab_hash), str(max_length),
336
+ text_key, padding_strategy
337
+ ))
338
+ data_fingerprint = md5_hash(data_fingerprint)
339
+ dataset = dataset.map(
340
+ tokenize_text,
341
+ new_fingerprint=data_fingerprint,
342
+ batched=True,
343
+ load_from_cache_file=True,
344
+ )
345
+ return dataset
346
+
347
+
348
+ class TensorRunningAverages:
349
+ _store_sum: Dict[str, torch.Tensor]
350
+ _store_total: Dict[str, torch.Tensor]
351
+
352
+ def __init__(self):
353
+ self._store_sum = {}
354
+ self._store_total = {}
355
+
356
+ def __iter__(self) -> Iterable[str]:
357
+ return iter(self._store_sum.keys())
358
+
359
+ def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None:
360
+ if key not in self._store_sum:
361
+ self.clear(key)
362
+ if isinstance(val, torch.Tensor):
363
+ val = val.item() # tensor -> num
364
+ self._store_sum[key] += val
365
+ self._store_total[key] += 1
366
+
367
+ def get(self, key: str) -> float:
368
+ total = max(self._store_total.get(key).item(), 1.0)
369
+ return (self._store_sum[key] / float(total)).item() or 0.0
370
+
371
+ def clear(self, key: str) -> None:
372
+ self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32)
373
+ self._store_total[key] = torch.tensor(0, dtype=torch.int32)
374
+
375
+ def clear_all(self) -> None:
376
+ for key in self._store_sum:
377
+ self.clear(key)
378
+
379
+ def get_and_clear_all(self) -> Dict[str, float]:
380
+ metrics = {}
381
+ for key in self:
382
+ metrics[key] = self.get(key)
383
+ self.clear(key)
384
+ return metrics
385
+
386
+ def load_embedder_and_tokenizer(name: str) -> Tuple[
387
+ transformers.PreTrainedModel,
388
+ transformers.PreTrainedTokenizer
389
+ ]:
390
+ if name.startswith("nomic") or (name == "bert-base-uncased"):
391
+ from cde.lib.nomic_bert import NomicBertModel
392
+ if name.endswith("--from-scratch"):
393
+ name = name.replace("--from-scratch", "")
394
+ config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
395
+ model = NomicBertModel._from_config(config)
396
+ else:
397
+ model = NomicBertModel.from_pretrained(
398
+ name, add_pooling_layer=False
399
+ )
400
+ tokenizer = transformers.AutoTokenizer.from_pretrained(name)
401
+ elif name in ["gtr-base", "gtr_base"]:
402
+ model = transformers.AutoModel.from_pretrained(
403
+ "sentence-transformers/gtr-t5-base"
404
+ ).encoder
405
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
406
+ "sentence-transformers/gtr-t5-base"
407
+ )
408
+ elif name == "pile-t5-base-encoder":
409
+ model = transformers.AutoModel.from_pretrained(
410
+ "EleutherAI/pile-t5-base"
411
+ ).encoder
412
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
413
+ "EleutherAI/pile-t5-base"
414
+ )
415
+ tokenizer.pad_token = tokenizer.eos_token
416
+ elif name == "pile-t5-base-decoder":
417
+ model = transformers.AutoModel.from_pretrained(
418
+ "EleutherAI/pile-t5-base"
419
+ ).decoder
420
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
421
+ "EleutherAI/pile-t5-base"
422
+ )
423
+ tokenizer.pad_token = tokenizer.eos_token
424
+ elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name):
425
+ model = transformers.AutoModelForCausalLM.from_pretrained(
426
+ name,
427
+ # torch_dtype=torch.bfloat16,
428
+ attn_implementation="flash_attention_2",
429
+ low_cpu_mem_usage=True,
430
+ # device_map="auto",
431
+ )
432
+ model.padding_side = "right"
433
+ tokenizer = transformers.AutoTokenizer.from_pretrained(name)
434
+ tokenizer.pad_token = tokenizer.eos_token
435
+ tokenizer.add_eos_token = True
436
+ else:
437
+ model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True)
438
+ tokenizer = transformers.AutoTokenizer.from_pretrained(name)
439
+
440
+ # if use_bettertransformer:
441
+ # from optimum.bettertransformer import BetterTransformer
442
+ # model = BetterTransformer.transform(model)
443
+ return model, tokenizer
444
+
445
+
446
+ def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str):
447
+ key += "_"
448
+ return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)}
449
+
450
+
451
+ def load_model_state_dict_from_path(folder: str) -> Dict:
452
+ checkpoint_folder = transformers.trainer_utils.get_last_checkpoint(folder)
453
+ if checkpoint_folder is None:
454
+ raise FileNotFoundError(f"no checkpoint found in {folder}")
455
+ WEIGHTS_NAME = "model.safetensors"
456
+ weights_path = os.path.join(checkpoint_folder, WEIGHTS_NAME)
457
+ if not os.path.exists(weights_path):
458
+ raise FileNotFoundError(f"no model weights found at {weights_path}")
459
+ return safetensors.torch.load_file(weights_path, device="cpu")
460
+
461
+ def count_cpus() -> int:
462
+ try:
463
+ return len(os.sched_getaffinity(0))
464
+ except AttributeError:
465
+ return multiprocessing.cpu_count()
466
+
467
+
468
+ def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]:
469
+ all_indices = []
470
+ for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"):
471
+ rand_perm = torch.randperm(len(batch_tensor), generator=g)
472
+ batch_list = batch_tensor[rand_perm].tolist()
473
+ all_indices.extend(batch_list)
474
+ return all_indices
475
+
476
+
477
+ # def shuffle_batches_multiproc(g: torch.Generator, list_of_tensors: List[torch.Tensor], num_processes: int = 8) -> List[int]:
478
+ # all_indices = []
479
+ # print(f"Shuffling {len(list_of_tensors)} tensors with {num_processes} workers.")
480
+ # pbar = tqdm_if_main_worker(list_of_tensors, colour="orange", desc=f"Sampler shuffling per-batch (nproc={num_processes})")
481
+ # pool = multiprocessing.Pool(processes=num_processes)
482
+ # chunk_size = len(list_of_tensors) // num_processes
483
+ # chunks = [list_of_tensors[i:i + chunk_size] for i in range(0, len(list_of_tensors), chunk_size)]
484
+ # worker_func = functools.partial(shuffle_batches, g=g)
485
+ # results = pool.map(worker_func, chunks)
486
+ # all_indices = []
487
+ # for result in results:
488
+ # all_indices.extend(result)
489
+ # pbar.update()
490
+ # return all_indices
491
+
492
+
493
+ def exit_if_running_or_finished_wandb(
494
+ project_name: str,
495
+ exp_group: str, exp_name: str
496
+ ) -> None:
497
+ print("Checking if experiment is already running...")
498
+ import wandb
499
+
500
+ api = wandb.Api()
501
+ running_runs = api.runs(
502
+ path="tti-nomic-7",
503
+ filters={
504
+ "display_name": exp_name,
505
+ "state": {"$regex": "Running|Finished"},
506
+ "config.exp_group": exp_group,
507
+ }
508
+ )
509
+ print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.")
510
+
511
+ if len(running_runs) > 0:
512
+ print("Exiting because experiment is already running or completed.")
513
+ sys.exit(0)
514
+
model.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Union
2
+
3
+ import copy
4
+ import torch
5
+ import torch.nn as nn
6
+ import transformers
7
+
8
+ from cde.lib.dist import print0
9
+ from cde.lib.tensor import mean_pool, mean_pool_3d, mean_pool_weighted, last_token_pool
10
+
11
+ from cde.lib import load_embedder_and_tokenizer, ContextualModelConfig
12
+
13
+
14
+ def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None:
15
+ if hasattr(model, 'transformer'):
16
+ if hasattr(model.transformer, 'h'):
17
+ # gpt2
18
+ model.transformer.h = model.transformer.h[:n_layers]
19
+ else:
20
+ model.transformer.layer = model.transformer.layer[:n_layers]
21
+ elif hasattr(model, 'encoder'):
22
+ if hasattr(model.encoder, 'layers'):
23
+ model.encoder.layers = model.encoder.layers[:n_layers]
24
+ else:
25
+ model.encoder.layer = model.encoder.layer[:n_layers]
26
+ else:
27
+ raise RuntimeError(f"unknown how to limit layers of model {type(model)}")
28
+
29
+
30
+ def disable_dropout(model: torch.nn.Module):
31
+ dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)]
32
+ for m in dropout_modules:
33
+ m.p = 0.0
34
+ print0(
35
+ f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}"
36
+ )
37
+
38
+
39
+ def disable_causality(model: torch.nn.Module):
40
+ disabled_modules = 0
41
+ for m in model.modules():
42
+ if hasattr(m, "is_causal"):
43
+ m.is_causal = False
44
+ disabled_modules += 1
45
+ print0(
46
+ f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
47
+ )
48
+
49
+ class ContextualModelMixin(nn.Module):
50
+ @property
51
+ def num_corpus_tokens(self) -> int:
52
+ return self.transductive_corpus_size * self.transductive_tokens_per_document
53
+
54
+ def contextual_init(self):
55
+ self.n_soft_prompt = 8
56
+ self.prompt_projection = torch.nn.Sequential(
57
+ torch.nn.Linear(self.hidden_size, self.hidden_size),
58
+ torch.nn.ReLU(),
59
+ torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt)
60
+ )
61
+ self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1)
62
+ self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
63
+ self.randomize_dataset_sequence_order = True
64
+ self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0)
65
+ if self.sequence_dropout_prob > 0.0:
66
+ self.sequence_dropout_null_embedding = torch.nn.Parameter(
67
+ torch.randn(self.hidden_size) * 0.01,
68
+ requires_grad = True
69
+ )
70
+ self.output_projection = torch.nn.Sequential(
71
+ torch.nn.Linear(self.hidden_size, self.hidden_size),
72
+ torch.nn.ReLU(),
73
+ torch.nn.Linear(self.hidden_size, self.hidden_size)
74
+ )
75
+
76
+ def _prepare_dataset_embeddings(
77
+ self,
78
+ input_ids: torch.Tensor, dataset_embeddings: torch.Tensor,
79
+ null_dataset_embedding: bool = False,
80
+ ) -> torch.Tensor:
81
+ if not isinstance(dataset_embeddings, torch.Tensor):
82
+ dataset_embeddings = torch.tensor(dataset_embeddings)
83
+
84
+ if len(dataset_embeddings.shape) == 2:
85
+ # Auto-expand for a batch.
86
+ dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
87
+ dataset_embeddings = dataset_embeddings.to(input_ids.device)
88
+
89
+ batch_size = input_ids.shape[0]
90
+ if (self.transductive_tokens_per_document > 1):
91
+ if self.training:
92
+ # Choose N random documents to fill our context window with.
93
+ # This logic is a little confusing but allows us to sample a
94
+ # different batch *per-document*
95
+ assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document
96
+ R = torch.randint(
97
+ low=0,
98
+ high=len(dataset_embeddings),
99
+ size=(batch_size, self.config.transductive_corpus_size),
100
+ device=dataset_embeddings.device
101
+ )
102
+ # TODO make this deterministic somehow for evaluation?
103
+ dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
104
+ else:
105
+ dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
106
+ # print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape)
107
+
108
+ if dataset_embeddings.shape[1] > self.num_corpus_tokens:
109
+ # If too many dataset embeddings are passed in, just take the first N until
110
+ # we have the proper number.
111
+ dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
112
+
113
+ _, corpus_size, _hidden_size = dataset_embeddings.shape
114
+ if _ == 1:
115
+ # Auto-expand for a batch.
116
+ dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1))
117
+
118
+ if self.training and self.sequence_dropout_prob > 0.0:
119
+ sequence_dropout_mask = (
120
+ torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob
121
+ )
122
+ null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
123
+ dataset_embeddings = torch.where(
124
+ sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings
125
+ )
126
+ elif null_dataset_embedding:
127
+ null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
128
+ dataset_embeddings = null_embeddings
129
+
130
+ # print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}")
131
+
132
+ # backbone_max_seq_length = self.backbone.config.max_trained_positions
133
+ # assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
134
+ soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
135
+ soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size))
136
+ soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) # -> (b, 4+b, d) # soft_prompt.repeat((len(input_ids), 1, 1))
137
+ soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1)
138
+
139
+ # print(f"[ContextualModelMixin] soft_prompt.shape = {soft_prompt.shape}")
140
+
141
+ if self.training and self.randomize_dataset_sequence_order:
142
+ randomized_order = torch.stack(
143
+ [
144
+ torch.cat(
145
+ (
146
+ torch.randperm(corpus_size, device=soft_prompt.device),
147
+ torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size
148
+ ), dim=0)
149
+ for _ in range(batch_size)])
150
+ randomized_order = randomized_order.to(soft_prompt.device)
151
+ soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt))
152
+
153
+ return soft_prompt
154
+
155
+ class BiEncoder(transformers.PreTrainedModel):
156
+ embedder: transformers.PreTrainedModel
157
+ def __init__(
158
+ self,
159
+ config, #: transformers.PreTrainedConfig,
160
+ ):
161
+ super().__init__(config=config)
162
+ embedder, _ = load_embedder_and_tokenizer(
163
+ config.embedder,
164
+ )
165
+
166
+ if config.limit_layers:
167
+ print0(f"Limiting layers to {config.limit_layers}")
168
+ limit_layers(embedder, config.limit_layers)
169
+
170
+ self.embedder = embedder
171
+ # if ("t5" in embedder.config.model_type):
172
+ # print0(f"using torch.compile() on embedder of type `{embedder.config.model_type}`")
173
+ # self.embedder = torch.compile(self.embedder)
174
+ self.hidden_size = self.embedder.config.hidden_size
175
+ # Allow pooling to multiple tokens per document
176
+ self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
177
+ self.mlp = torch.nn.Sequential(
178
+ torch.nn.Linear(self.hidden_size, self.hidden_size),
179
+ torch.nn.GELU(),
180
+ torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size),
181
+ )
182
+ self.temp = config.logit_scale
183
+
184
+ if config.disable_dropout:
185
+ disable_dropout(self)
186
+ self.pooling_strategy = vars(config).get("pooling_strategy", "mean")
187
+
188
+ def forward(
189
+ self,
190
+ input_ids: torch.Tensor,
191
+ attention_mask: torch.Tensor,
192
+ dataset_input_ids: Optional[torch.Tensor] = None,
193
+ dataset_attention_mask: Optional[torch.Tensor] = None,
194
+ token_type_ids = None,
195
+ output_hidden_states: bool = False,
196
+ ) -> torch.Tensor:
197
+ """
198
+ query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim)
199
+ document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim)
200
+ where the corpus_size >= batch_size and is structured like this:
201
+ [d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
202
+ for a corpus with three documents and two hard negatives per document
203
+ """
204
+ # del dataset_input_ids
205
+ # del dataset_attention_mask
206
+ del token_type_ids
207
+
208
+ # from cde.lib.dist import get_rank
209
+ # tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
210
+ # if get_rank() == 0:
211
+ # breakpoint()
212
+ # torch.distributed.barrier()
213
+ outputs = (
214
+ self.embedder(
215
+ input_ids=input_ids,
216
+ attention_mask=attention_mask,
217
+ ).last_hidden_state
218
+ )
219
+
220
+ if self.transductive_tokens_per_document > 1:
221
+ document_embeddings = None
222
+ batch_size, seq_length, output_dim = outputs.shape
223
+
224
+ if seq_length % self.transductive_tokens_per_document != 0:
225
+ # Pad to nearest multiple
226
+ n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document)
227
+ outputs = torch.cat(
228
+ (outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)),
229
+ dim=1
230
+ )
231
+ attention_mask = torch.cat(
232
+ (attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)),
233
+ dim=1
234
+ )
235
+ seq_length += n_extra_embeds
236
+ print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask")
237
+
238
+ # print("ftransductive_tokens_per_document {self.transductive_tokens_per_document} outputs.shape =", outputs.shape)
239
+
240
+ outputs = outputs.reshape(
241
+ (batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim)
242
+ )
243
+
244
+ attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1))
245
+ document_embeddings = mean_pool_3d(outputs, attention_mask)
246
+
247
+ document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim))
248
+ else:
249
+ if self.pooling_strategy == "mean":
250
+ document_embeddings = mean_pool(outputs, attention_mask)
251
+ else:
252
+ document_embeddings = document_embeddings.max(dim=1)
253
+ output = self.mlp(document_embeddings)
254
+
255
+ if output_hidden_states:
256
+ return {
257
+ "hidden_states": outputs,
258
+ "pooled": output,
259
+ }
260
+ else:
261
+ return output
262
+
263
+
264
+ class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin):
265
+ def __init__(
266
+ self,
267
+ config,
268
+ dataset_backbone: transformers.PreTrainedModel,
269
+ first_stage_hidden_size: int,
270
+ ):
271
+ super().__init__(config=config)
272
+ self.backbone = dataset_backbone
273
+ self.backbone_hidden_size = self.backbone.config.hidden_size
274
+ self.hidden_size = first_stage_hidden_size # Input token size
275
+ self.contextual_init()
276
+ disable_causality(self.backbone)
277
+
278
+ self.input_ln = torch.nn.LayerNorm(
279
+ self.backbone_hidden_size,
280
+ eps=1e-5
281
+ )
282
+
283
+ # Override contextual init
284
+ self.output_projection = torch.nn.Sequential(
285
+ torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
286
+ torch.nn.ReLU(),
287
+ torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size)
288
+ )
289
+ self._shift_rotary_embedding()
290
+
291
+ @property
292
+ def num_corpus_tokens(self) -> int:
293
+ return self.config.transductive_corpus_size * self.transductive_tokens_per_document
294
+
295
+ @property
296
+ def corpus_token_ratio(self) -> float:
297
+ # How many tokens from the first stage make one token in the second
298
+ # stage?
299
+ return self.backbone_hidden_size / self.hidden_size
300
+
301
+ def corpus_token_pad_size(self, n_tokens: int) -> int:
302
+ return self.hidden_size % self.backbone_hidden_size
303
+
304
+ def _shift_rotary_embedding(self) -> None:
305
+ disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
306
+ # TODO: Can we do this for LLAMA?
307
+ print("Warning: Positional embedding disabling not implemented for LLAMA.")
308
+
309
+ def forward(
310
+ self,
311
+ input_ids: torch.Tensor,
312
+ attention_mask: torch.Tensor,
313
+ dataset_embeddings: torch.Tensor,
314
+ output_hidden_states: bool = False,
315
+ null_dataset_embedding: bool = False,
316
+ ) -> torch.Tensor:
317
+ soft_prompt = self._prepare_dataset_embeddings(
318
+ input_ids=input_ids,
319
+ dataset_embeddings=dataset_embeddings,
320
+ null_dataset_embedding=null_dataset_embedding,
321
+ )
322
+
323
+ # Reshape for this model.
324
+ # print("[DatasetConditionedAutoregressive] 1 -> soft_prompt.shape =", soft_prompt.shape)
325
+ num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item()
326
+ soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements))
327
+ num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size)
328
+ padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device)
329
+ soft_prompt = torch.cat((soft_prompt, padding), dim=1)
330
+ soft_prompt = soft_prompt.reshape(
331
+ (soft_prompt.shape[0], -1, self.backbone_hidden_size)
332
+ )
333
+ soft_prompt = self.input_ln(soft_prompt)
334
+ # print("[DatasetConditionedAutoregressive] 2 -> soft_prompt.shape =", soft_prompt.shape)
335
+
336
+ backbone_attention_mask = torch.ones(
337
+ soft_prompt.shape[0:2],
338
+ dtype=torch.long,
339
+ device=soft_prompt.device,
340
+ )
341
+ token_embeddings = self.backbone.get_input_embeddings()
342
+ inputs_embeds = token_embeddings(input_ids) # (b, s) -> (b, s, d)
343
+ # print("[2] inputs_embeds.shape =", inputs_embeds.shape)
344
+ inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
345
+ # print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
346
+ input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
347
+ # print("[3.b] attention_mask.shape =", attention_mask.shape)
348
+
349
+ output = self.backbone(
350
+ inputs_embeds=inputs_embeds,
351
+ attention_mask=input_attention_mask,
352
+ output_hidden_states=True,
353
+ ) # (1, 4 + b + s, d)
354
+ # trim soft prompt
355
+ last_hidden_state = output.hidden_states[-1]
356
+ n_soft_prompt_tokens = soft_prompt.shape[1]
357
+
358
+ output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :]
359
+ output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:]
360
+
361
+ # Take last token position
362
+ if vars(self.config).get("pooling_strategy") == "last_token":
363
+ output_pooled = last_token_pool(output_vectors, output_attention_mask)
364
+ elif vars(self.config).get("pooling_strategy") == "mean":
365
+ output_pooled = mean_pool(output_vectors, output_attention_mask)
366
+ else:
367
+ output_pooled = mean_pool_weighted(output_vectors, output_attention_mask)
368
+
369
+ # average with original vectors
370
+ # TODO: Argparse for pooling strategy.
371
+ output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
372
+
373
+ if output_hidden_states:
374
+ return {
375
+ "hidden_states": output_vectors,
376
+ "pooled": output,
377
+ }
378
+ else:
379
+ return output
380
+
381
+
382
+ class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
383
+ def __init__(
384
+ self,
385
+ config,
386
+ dataset_backbone: transformers.PreTrainedModel,
387
+ ):
388
+ super().__init__(config=config)
389
+ self.backbone = dataset_backbone
390
+ self.hidden_size = self.backbone.config.hidden_size
391
+ self.hidden_size = dataset_backbone.config.hidden_size
392
+ # self.input_ln = torch.nn.LayerNorm(
393
+ # self.hidden_size,
394
+ # eps=self.backbone.config.layer_norm_epsilon
395
+ # )
396
+ self.contextual_init()
397
+ self._shift_rotary_embedding()
398
+
399
+ @property
400
+ def num_corpus_tokens(self) -> int:
401
+ return self.config.transductive_corpus_size * self.transductive_tokens_per_document
402
+
403
+ def _shift_rotary_embedding(self) -> None:
404
+ disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
405
+ if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding:
406
+ # We only want to apply positional embeddings to the
407
+ # *text* portion of the backbone network.
408
+ self.backbone.config.rotary_start_pos = 0.0
409
+ rotary_disabled = 0
410
+
411
+ rotary_start_pos = self.num_corpus_tokens
412
+ for module in self.backbone.modules():
413
+ if hasattr(module, "rotary_emb_dim"):
414
+ module.rotary_start_pos = rotary_start_pos
415
+ rotary_disabled += 1
416
+ print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}")
417
+
418
+ def forward(
419
+ self,
420
+ input_ids: torch.Tensor,
421
+ attention_mask: torch.Tensor,
422
+ dataset_embeddings: torch.Tensor,
423
+ output_hidden_states: bool = False,
424
+ null_dataset_embedding: bool = False,
425
+ ) -> torch.Tensor:
426
+ # print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape)
427
+ soft_prompt = self._prepare_dataset_embeddings(
428
+ input_ids=input_ids,
429
+ dataset_embeddings=dataset_embeddings,
430
+ null_dataset_embedding=null_dataset_embedding,
431
+ )
432
+ # print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}")
433
+ backbone_attention_mask = torch.ones(
434
+ soft_prompt.shape[0:2],
435
+ dtype=torch.long,
436
+ device=soft_prompt.device,
437
+ )
438
+ inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
439
+ # print("[2] inputs_embeds.shape =", inputs_embeds.shape)
440
+ inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
441
+ # print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
442
+ attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
443
+ # print("[3.b] attention_mask.shape =", attention_mask.shape)
444
+ output = self.backbone(
445
+ inputs_embeds=inputs_embeds,
446
+ attention_mask=attention_mask,
447
+ ) # (1, 4 + b + s, d)
448
+ # trim soft prompt
449
+ output_vectors = output.last_hidden_state
450
+
451
+ # use only these tokens
452
+ n_soft_prompt_tokens = soft_prompt.shape[1]
453
+ # print("n_soft_prompt_tokens =", n_soft_prompt_tokens)
454
+
455
+ output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :]
456
+ output_attention_mask = attention_mask[:, n_soft_prompt_tokens:]
457
+
458
+ # print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape)
459
+ output_pooled = mean_pool(output_vectors, output_attention_mask)
460
+
461
+ # average with original vectors
462
+ # TODO: Argparse for pooling strategy.
463
+ # output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d)
464
+ # print("output_pooled.shape =", output_pooled.shape)
465
+ output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
466
+
467
+ # print("returning output.shape =", output.shape)
468
+
469
+ if output_hidden_states:
470
+ return {
471
+ "hidden_states": output_vectors,
472
+ "pooled": output,
473
+ }
474
+ else:
475
+ return output
476
+
477
+
478
+ class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
479
+ def __init__(
480
+ self,
481
+ config, #: transformers.PreTrainedConfig,
482
+ embedder: transformers.PreTrainedModel,
483
+ ):
484
+ super().__init__(config=config)
485
+ self.embedder = embedder
486
+ self.hidden_size = self.embedder.config.hidden_size
487
+ self.contextual_init()
488
+
489
+ def forward(
490
+ self,
491
+ input_ids: torch.Tensor,
492
+ attention_mask: torch.Tensor,
493
+ dataset_input_ids: torch.Tensor,
494
+ dataset_attention_mask: torch.Tensor,
495
+ output_hidden_states: bool = False,
496
+ ) -> torch.Tensor:
497
+ R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device)
498
+
499
+ dataset_input_ids = dataset_input_ids[R]
500
+ input_ids = torch.cat((dataset_input_ids, input_ids), dim=1)
501
+
502
+ dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device)
503
+ input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1)
504
+ output_attention_mask = torch.cat(
505
+ (torch.zeros_like(dataset_input_ids), attention_mask), dim=1
506
+ )
507
+
508
+ output = self.embedder(
509
+ input_ids=input_ids,
510
+ attention_mask=input_attention_mask,
511
+ )
512
+
513
+ output_vectors = output.last_hidden_state
514
+ output_pooled = mean_pool(output_vectors, output_attention_mask)
515
+ output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
516
+
517
+ if output_hidden_states:
518
+ S_d = dataset_attention_mask.shape[1]
519
+ output_vectors = output_vectors[:, S_d:, :]
520
+ return {
521
+ "hidden_states": output_vectors,
522
+ "pooled": output,
523
+ }
524
+ else:
525
+ return output
526
+
527
+
528
+ class DatasetTransformer(transformers.PreTrainedModel):
529
+ config_class = ContextualModelConfig
530
+ embedder: transformers.PreTrainedModel
531
+ dataset_backbone: transformers.PreTrainedModel
532
+ def __init__(
533
+ self,
534
+ config,
535
+ ):
536
+ super().__init__(config=config)
537
+ dataset_backbone, _ = load_embedder_and_tokenizer(
538
+ vars(config).get("dataset_backbone", config.embedder)
539
+ )
540
+
541
+ if config.limit_layers:
542
+ print0(f"Limiting layers to {config.limit_layers}")
543
+ limit_layers(dataset_backbone, config.limit_layers)
544
+
545
+ biencoder_config = copy.deepcopy(config)
546
+ biencoder_config.embedding_output_dim = None
547
+ biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None)
548
+ self.first_stage_model = BiEncoder(
549
+ config=biencoder_config,
550
+ )
551
+
552
+ if vars(config).get("autoregressive_backbone", False):
553
+ self.second_stage_model = DatasetConditionedAutoregressive(
554
+ config=config,
555
+ dataset_backbone=dataset_backbone,
556
+ first_stage_hidden_size=self.first_stage_model.hidden_size,
557
+ )
558
+ else:
559
+ self.second_stage_model = DatasetConditionedBiencoder(
560
+ config=config,
561
+ dataset_backbone=dataset_backbone
562
+ )
563
+
564
+ self.temp = config.logit_scale
565
+ if config.disable_dropout:
566
+ disable_dropout(self)
567
+
568
+ transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False)
569
+ if transductive_tie_token_embeddings:
570
+ self.second_stage_model.backbone.embeddings.word_embeddings.weight = (
571
+ self.first_stage_model.embedder.embeddings.word_embeddings.weight
572
+ )
573
+
574
+ def forward(
575
+ self,
576
+ input_ids: torch.Tensor,
577
+ attention_mask: torch.Tensor,
578
+ dataset_input_ids: Optional[torch.Tensor],
579
+ dataset_attention_mask: Optional[torch.Tensor],
580
+ output_hidden_states: bool = False,
581
+ ) -> torch.Tensor:
582
+ """
583
+ input_ids (long torch.Tensor) – ids of input tokens
584
+ attention_mask (bool torch.Tensor)
585
+ """
586
+ dataset_embeddings = self.first_stage_model(
587
+ input_ids=dataset_input_ids,
588
+ attention_mask=dataset_attention_mask
589
+ )
590
+ return self.second_stage_model(
591
+ input_ids=input_ids,
592
+ attention_mask=attention_mask,
593
+ dataset_embeddings=dataset_embeddings,
594
+ output_hidden_states=output_hidden_states,
595
+ )
596
+
597
+
598
+
599
+ def get_model_class(name: str):
600
+ if name in 'transductive':
601
+ return DatasetTransformer
602
+ elif name == 'biencoder':
603
+ return BiEncoder
604
+ elif name == "dataset_prefix_biencoder":
605
+ return DatasetPrefixBiencoder
606
+ else:
607
+ raise ValueError(f'unknown model cls {name}')
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8232363a55e0327e0b9cc85762662f474f314a2975127f70c9ba0777857d19d7
3
- size 572926008
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ec79407ada665817aebe929bdabbe83eecd816b75f7f26e3bdd8b4c092efb2a
3
+ size 1124594680