|
import os
|
|
import ujson
|
|
|
|
from functools import partial
|
|
from colbert.utils.utils import print_message
|
|
from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples
|
|
|
|
from colbert.utils.runs import Run
|
|
|
|
|
|
class EagerBatcher():
|
|
def __init__(self, args, rank=0, nranks=1):
|
|
self.rank, self.nranks = rank, nranks
|
|
self.bsize, self.accumsteps = args.bsize, args.accumsteps
|
|
|
|
self.query_tokenizer = QueryTokenizer(args.query_maxlen)
|
|
self.doc_tokenizer = DocTokenizer(args.doc_maxlen)
|
|
self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
|
|
|
|
self.triples_path = args.triples
|
|
self._reset_triples()
|
|
|
|
def _reset_triples(self):
|
|
self.reader = open(self.triples_path, mode='r', encoding="utf-8")
|
|
self.position = 0
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
queries, positives, negatives = [], [], []
|
|
|
|
for line_idx, line in zip(range(self.bsize * self.nranks), self.reader):
|
|
if (self.position + line_idx) % self.nranks != self.rank:
|
|
continue
|
|
|
|
query, pos, neg = line.strip().split('\t')
|
|
|
|
queries.append(query)
|
|
positives.append(pos)
|
|
negatives.append(neg)
|
|
|
|
self.position += line_idx + 1
|
|
|
|
if len(queries) < self.bsize:
|
|
raise StopIteration
|
|
|
|
return self.collate(queries, positives, negatives)
|
|
|
|
def collate(self, queries, positives, negatives):
|
|
assert len(queries) == len(positives) == len(negatives) == self.bsize
|
|
|
|
return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps)
|
|
|
|
def skip_to_batch(self, batch_idx, intended_batch_size):
|
|
self._reset_triples()
|
|
|
|
Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.')
|
|
|
|
_ = [self.reader.readline() for _ in range(batch_idx * intended_batch_size)]
|
|
|
|
return None
|
|
|