import torch import numpy as np import sys import os from .utils import freeze class BaseEmbedder: def __init__(self, conf): self.checkpoint_path = conf.text_embedder.params.checkpoint_path self.tokenizer_path = conf.text_embedder.params.tokenizer_path self.max_length = conf.text_embedder.tokens_lenght self.llm = None def to(self, device='cpu', dtype=torch.float32): self.llm = self.llm.to(device=device, dtype=dtype) return self def freeze(self): self.llm = freeze(self.llm) return self def compile(self): self.llm = torch.compile(self.llm) return self class EmbedderWithTokenizer(BaseEmbedder): def __init__(self, conf): super().__init__(conf) self.tokenizer = None def tokenize(self, text): model_input = self.tokenizer( text, max_length=self.max_length, truncation=True, add_special_tokens=True, padding='max_length', return_tensors='pt' ) return model_input.input_ids.to(self.llm.device) def __call__(self, text): return self.llm(self.tokenize(text), output_hidden_states=True)[0] class T5TextEmbedder(EmbedderWithTokenizer): def __init__(self, conf): from transformers import T5EncoderModel, T5Tokenizer super().__init__(conf) self.llm = T5EncoderModel.from_pretrained(self.checkpoint_path) self.tokenizer = T5Tokenizer.from_pretrained(self.tokenizer_path, clean_up_tokenization_spaces=False) def get_text_embedder(conf): return T5TextEmbedder(conf)