gptj-6b-8bits / handler.py
OssamaLafhel's picture
Update handler.py
5b941ac
raw
history blame
1.94 kB
import time
import json
from pydantic import BaseModel
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from torch import nn
import torch.nn.functional as F
from torch.cuda.amp import custom_fwd, custom_bwd
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
from loguru import logger
from typing import Dict, List, Any
# -----------------------------------------> API <---------------------------------------
name="Kanpredict/gptj-6b-8bits"
model = AutoModelForCausalLM.from_pretrained(name, device_map="auto", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(name)
class EndpointHandler:
def __init__(self, path=""):
# create inference pipeline
self.pipeline = pipeline(model=name, model_kwargs= {"device_map": "auto", "load_in_8bit": True}, max_new_tokens=max_new_tokens)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# run the model and get the output(generated text)
prompt = inputs
temperature = float(parameters.temperature)
length = int(parameters.length)
logger.info("message input: %s", prompt)
logger.info("tempereture: %s", parameters.temperature)
logger.info("length: %s", parameters.length)
start = time.time()
prompt = tokenizer(prompt, return_tensors='pt')
prompt = {key: value.to(device) for key, value in prompt.items()}
out = self.pipeline(**prompt, min_length=length, max_length=length, temperature=temperature, do_sample=True)
generated_text = tokenizer.decode(out[0])
logger.info("generated text: ", generated_text)
logger.info("time taken: %s", time.time() - start)
result = {"output": generated_text}
result = json.dumps(result)
return result