Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| class SemanticResponseGenerator: | |
| def __init__(self, model_name="google/flan-t5-small", max_input_length=512, max_new_tokens=50): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| self.max_input_length = max_input_length | |
| self.max_new_tokens = max_new_tokens | |
| def generate_response(self, retrieved_docs): | |
| combined_docs = " ".join(retrieved_docs[:2]) | |
| truncated_docs = combined_docs[:self.max_input_length - 50] | |
| input_text = f"Based on the following information: {truncated_docs}" | |
| inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=self.max_input_length) | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=self.max_new_tokens, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |