use proper model output dim
Browse files- finetune.py +8 -8
- onnx_convert.py +2 -2
- pytorch_model.bin +1 -1
- pytorch_model.onnx +2 -2
- test-small.json.gz +0 -3
- train-small.json.gz +0 -3
finetune.py
CHANGED
@@ -12,8 +12,8 @@ import gzip
|
|
12 |
|
13 |
model_name = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
14 |
|
15 |
-
train_batch_size =
|
16 |
-
max_seq_length =
|
17 |
num_epochs = 1
|
18 |
warmup_steps = 1000
|
19 |
model_save_path = '.'
|
@@ -27,13 +27,13 @@ class ESCIDataset(Dataset):
|
|
27 |
for line in jsonfile.readlines():
|
28 |
query = json.loads(line)
|
29 |
for doc in query['e']:
|
30 |
-
self.queries.append(InputExample(texts=[query['query'], doc['title']], label=1.0))
|
31 |
for doc in query['s']:
|
32 |
-
self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.1))
|
33 |
for doc in query['c']:
|
34 |
-
self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.01))
|
35 |
for doc in query['i']:
|
36 |
-
self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.0))
|
37 |
|
38 |
def __getitem__(self, item):
|
39 |
return self.queries[item]
|
@@ -49,9 +49,9 @@ class ESCIEvalDataset(Dataset):
|
|
49 |
query = json.loads(line)
|
50 |
if len(query['e']) > 0 and len(query['i']) > 0:
|
51 |
for p in query['e']:
|
52 |
-
positive = p['title']
|
53 |
for n in query['i']:
|
54 |
-
negative = n['title']
|
55 |
self.queries.append(InputExample(texts=[query['query'], positive, negative]))
|
56 |
|
57 |
def __getitem__(self, item):
|
|
|
12 |
|
13 |
model_name = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
14 |
|
15 |
+
train_batch_size = 8
|
16 |
+
max_seq_length = 384
|
17 |
num_epochs = 1
|
18 |
warmup_steps = 1000
|
19 |
model_save_path = '.'
|
|
|
27 |
for line in jsonfile.readlines():
|
28 |
query = json.loads(line)
|
29 |
for doc in query['e']:
|
30 |
+
self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=1.0))
|
31 |
for doc in query['s']:
|
32 |
+
self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=0.1))
|
33 |
for doc in query['c']:
|
34 |
+
self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=0.01))
|
35 |
for doc in query['i']:
|
36 |
+
self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=0.0))
|
37 |
|
38 |
def __getitem__(self, item):
|
39 |
return self.queries[item]
|
|
|
49 |
query = json.loads(line)
|
50 |
if len(query['e']) > 0 and len(query['i']) > 0:
|
51 |
for p in query['e']:
|
52 |
+
positive = p['title'] + ' ' + p['title']
|
53 |
for n in query['i']:
|
54 |
+
negative = n['title'] + ' ' + n['title']
|
55 |
self.queries.append(InputExample(texts=[query['query'], positive, negative]))
|
56 |
|
57 |
def __getitem__(self, item):
|
onnx_convert.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-
from transformers import AutoTokenizer, AutoModel
|
2 |
import torch
|
3 |
|
4 |
max_seq_length=128
|
5 |
|
6 |
-
model =
|
7 |
model.eval()
|
8 |
|
9 |
inputs = {"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64),
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
|
2 |
import torch
|
3 |
|
4 |
max_seq_length=128
|
5 |
|
6 |
+
model = AutoModelForSequenceClassification.from_pretrained(".")
|
7 |
model.eval()
|
8 |
|
9 |
inputs = {"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64),
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 133514357
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8eb5889a76cfd3d6beaaf62bf061723ebf7edd212329fc527ff36c5ed1b571a
|
3 |
size 133514357
|
pytorch_model.onnx
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5a0fe068eded0383c63e7e63e8d5fef4e6d30a5e4d3011b4e7d1602844fcd251
|
3 |
+
size 133717601
|
test-small.json.gz
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:fb557251b12addb55d94af30120d121dfa6391e58bcc4a9aee0f1d35cc2ea1c8
|
3 |
-
size 8522018
|
|
|
|
|
|
|
|
train-small.json.gz
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:9c7c14a8910a3a6c09421a08a84cfc0e74fd198d0aaf43ab2c39250a8ae4e4dd
|
3 |
-
size 19430577
|
|
|
|
|
|
|
|