qiyas-falcon-7b / handler.py
luxmorocco's picture
Create handler.py
ce50fe0
raw
history blame
3.08 kB
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
import torch
import time
class EndpointHandler:
def __init__(self, path="luxmorocco/qiyas-falcon-7b"):
# load the model
config = PeftConfig.from_pretrained(path)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
self.model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
load_in_4bit=True,
device_map={"":0},
trust_remote_code=True,
quantization_config=bnb_config,
)
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = PeftModel.from_pretrained(self.model, path)
def __call__(self, data: Any) -> Dict[str, Any]:
"""
Args:
inputs :obj:`list`:. The object should be like {"context": "some word", "question": "some word"} containing:
- "context":
- "question":
Return:
A :obj:`list`:. The object returned should be like {"answer": "some word", time: "..."} containing:
- "answer": answer the question based on the context
- "time": the time run predict
"""
inputs = data.pop("inputs", data)
context = inputs.pop("context", inputs)
question = inputs.pop("question", inputs)
prompt = f"""Answer the question based on the context below. If the question cannot be answered using the information provided answer with 'No answer'. Stop response if end.
>>TITLE<<: Flawless answer.
>>CONTEXT<<: {context}
>>QUESTION<<: {question}
>>ANSWER<<:
""".strip()
batch = self.tokenizer(
prompt,
padding=True,
truncation=True,
return_tensors='pt'
)
batch = batch.to('cuda:0')
generation_config = self.model.generation_config
generation_config.top_p = 0.7
generation_config.temperature = 0.7
generation_config.max_new_tokens = 256
generation_config.num_return_sequences = 1
generation_config.pad_token_id = self.tokenizer.eos_token_id
generation_config.eos_token_id = self.tokenizer.eos_token_id
start = time.time()
with torch.cuda.amp.autocast():
output_tokens = self.model.generate(
input_ids = batch.input_ids,
generation_config=generation_config,
)
end = time.time()
generated_text = self.tokenizer.decode(output_tokens[0])
prediction = {'answer': generated_text.split('>>END<<')[0].split('>>ANSWER<<:')[1].strip(), 'time': f"{(end-start):.2f} s"}
return prediction