import re def batch_as_list(a, batch_size = int(100000)): req = [] for ele in a: if not req: req.append([]) if len(req[-1]) < batch_size: req[-1].append(ele) else: req.append([]) req[-1].append(ele) return req class Obj: def __init__(self, model, tokenizer, device = "cpu"): self.model = model self.tokenizer = tokenizer self.device = "cpu" def predict( self, source_text: str, max_length: int = 512, num_return_sequences: int = 1, num_beams: int = 2, top_k: int = 50, top_p: float = 0.95, do_sample: bool = True, repetition_penalty: float = 2.5, length_penalty: float = 1.0, early_stopping: bool = True, skip_special_tokens: bool = True, clean_up_tokenization_spaces: bool = True, ): input_ids = self.tokenizer.encode( source_text, return_tensors="pt", add_special_tokens=True ) input_ids = input_ids.to(self.device) generated_ids = self.model.generate( input_ids=input_ids, num_beams=num_beams, max_length=max_length, repetition_penalty=repetition_penalty, length_penalty=length_penalty, early_stopping=early_stopping, top_p=top_p, top_k=top_k, num_return_sequences=num_return_sequences, ) preds = [ self.tokenizer.decode( g, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) for g in generated_ids ] return preds