Spaces:
Running
Running
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
class SemanticResponseGenerator: | |
def __init__(self, model_name="google/flan-t5-small", max_input_length=512, max_new_tokens=50): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
self.max_input_length = max_input_length | |
self.max_new_tokens = max_new_tokens | |
def generate_response(self, retrieved_docs): | |
combined_docs = " ".join(retrieved_docs[:2]) | |
truncated_docs = combined_docs[:self.max_input_length - 50] | |
input_text = f"Based on the following information: {truncated_docs}" | |
inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=self.max_input_length) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=self.max_new_tokens, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |