File size: 2,584 Bytes
fdbd697 11f10ff fdbd697 2b38b94 11f10ff 2b38b94 11f10ff 2b38b94 11f10ff 2b38b94 11f10ff 2b38b94 11f10ff 2b38b94 11f10ff 2b38b94 11f10ff 2b38b94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
import transformers
from typing import Any, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
# class EndpointHandler():
# def __init__(self, path=""):
# model = AutoModelForCausalLM.from_pretrained(path,
# torch_dtype=torch.bfloat16,
# trust_remote_code=True,
# device_map="auto")
# print(model.hf_device_map)
# tokenizer = AutoTokenizer.from_pretrained(path)
# #device = "cuda:0" if torch.cuda.is_available() else "cpu"
# self.pipeline = transformers.pipeline('text-generation',
# model=model,
# tokenizer=tokenizer)
# def __call__(self, data: Dict[str, Any]):
# inputs = data.pop("inputs", data)
# parameters = data.pop("parameters", {})
# with torch.autocast(self.pipeline.device.type, dtype=torch.bfloat16):
# outputs = self.pipeline(inputs,
# **parameters)
# return outputs
class EndpointHandler:
def __init__(self, path=""):
# load model and tokenizer from path
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AutoModelForCausalLM.from_pretrained(path,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
return_full_text = parameters.pop("return_full_text", True)
# preprocess
inputs = self.tokenizer(inputs,
return_tensors="pt",
return_token_type_ids=False)
inputs = inputs.to(self.device)
input_len = len(inputs[0])
outputs = self.model.generate(**inputs, **parameters)[0]
if not return_full_text:
outputs = outputs[input_len:]
# postprocess the prediction
prediction = self.tokenizer.decode(outputs,
skip_special_tokens=True)
return [{"generated_text": prediction}]
|