OssamaLafhel
commited on
Commit
•
35d9624
1
Parent(s):
fd28244
Update handler.py
Browse files- handler.py +52 -14
handler.py
CHANGED
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import transformers
|
2 |
from transformers import pipeline
|
3 |
import torch
|
@@ -5,6 +10,7 @@ from torch import nn
|
|
5 |
import torch.nn.functional as F
|
6 |
from torch.cuda.amp import custom_fwd, custom_bwd
|
7 |
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
|
|
|
8 |
from typing import Dict, List, Any
|
9 |
|
10 |
|
@@ -153,30 +159,62 @@ class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
|
|
153 |
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
# -----------------------------------------> API <---------------------------------------
|
|
|
|
|
|
|
157 |
|
158 |
|
159 |
class EndpointHandler:
|
160 |
def __init__(self, path=""):
|
161 |
# load the model
|
162 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
163 |
-
model = GPTJForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
|
164 |
-
|
165 |
-
# check for GPU
|
166 |
-
device = 0 if torch.cuda.is_available() else -1
|
167 |
model.to(device)
|
168 |
-
|
169 |
# create inference pipeline
|
170 |
-
self.pipeline = pipeline(
|
171 |
|
172 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
173 |
inputs = data.pop("inputs", data)
|
174 |
parameters = data.pop("parameters", None)
|
175 |
|
176 |
-
#
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import json
|
3 |
+
from fastapi import FastAPI
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
+
from pydantic import BaseModel
|
6 |
import transformers
|
7 |
from transformers import pipeline
|
8 |
import torch
|
|
|
10 |
import torch.nn.functional as F
|
11 |
from torch.cuda.amp import custom_fwd, custom_bwd
|
12 |
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
|
13 |
+
from loguru import logger
|
14 |
from typing import Dict, List, Any
|
15 |
|
16 |
|
|
|
159 |
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
|
160 |
|
161 |
|
162 |
+
class Message(BaseModel):
|
163 |
+
input: str = None
|
164 |
+
output: dict = None
|
165 |
+
length: str = None
|
166 |
+
temperature: str = None
|
167 |
+
|
168 |
+
|
169 |
+
app = FastAPI()
|
170 |
+
|
171 |
+
origins = [
|
172 |
+
"http://localhost:8000",
|
173 |
+
"http://localhost",
|
174 |
+
"http://localhost:3000",
|
175 |
+
"http://127.0.0.1:3000"
|
176 |
+
]
|
177 |
+
|
178 |
+
app.add_middleware(
|
179 |
+
CORSMiddleware,
|
180 |
+
allow_origins=origins,
|
181 |
+
allow_credentials=True,
|
182 |
+
allow_methods=["POST"],
|
183 |
+
allow_headers=["*"],
|
184 |
+
)
|
185 |
+
|
186 |
# -----------------------------------------> API <---------------------------------------
|
187 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
188 |
+
model = GPTJForCausalLM.from_pretrained("Kanpredict/gptj-6b-8bits", low_cpu_mem_usage=True)
|
189 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
190 |
|
191 |
|
192 |
class EndpointHandler:
|
193 |
def __init__(self, path=""):
|
194 |
# load the model
|
|
|
|
|
|
|
|
|
|
|
195 |
model.to(device)
|
|
|
196 |
# create inference pipeline
|
197 |
+
self.pipeline = pipeline(model=model, tokenizer=tokenizer, device=device)
|
198 |
|
199 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
200 |
inputs = data.pop("inputs", data)
|
201 |
parameters = data.pop("parameters", None)
|
202 |
|
203 |
+
# run the model and get the output(generated text)
|
204 |
+
prompt = inputs
|
205 |
+
temperature = float(parameters.temperature)
|
206 |
+
length = int(parameters.length)
|
207 |
+
logger.info("message input: %s", prompt)
|
208 |
+
logger.info("tempereture: %s", parameters.temperature)
|
209 |
+
logger.info("length: %s", parameters.length)
|
210 |
+
start = time.time()
|
211 |
+
prompt = tokenizer(prompt, return_tensors='pt')
|
212 |
+
prompt = {key: value.to(device) for key, value in prompt.items()}
|
213 |
+
out = model.generate(**prompt, min_length=length, max_length=length, temperature=temperature, do_sample=True)
|
214 |
+
generated_text = tokenizer.decode(out[0])
|
215 |
+
logger.info("generated text: ", generated_text)
|
216 |
+
logger.info("time taken: %s", time.time() - start)
|
217 |
+
result = {"output": generated_text}
|
218 |
+
result = json.dumps(result)
|
219 |
+
return result
|
220 |
+
|