File size: 1,040 Bytes
c3c1187
 
f861dee
 
c3c1187
 
 
 
 
f861dee
 
c3c1187
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


class SemanticResponseGenerator:
    def __init__(self, model_name="google/flan-t5-small", max_input_length=512, max_new_tokens=50):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.max_input_length = max_input_length
        self.max_new_tokens = max_new_tokens

    def generate_response(self, retrieved_docs):
        combined_docs = " ".join(retrieved_docs[:2])
        truncated_docs = combined_docs[:self.max_input_length - 50]
        input_text = f"Based on the following information: {truncated_docs}"
        inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=self.max_input_length)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=self.max_new_tokens,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)