OssamaLafhel commited on
Commit
35d9624
1 Parent(s): fd28244

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +52 -14
handler.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import transformers
2
  from transformers import pipeline
3
  import torch
@@ -5,6 +10,7 @@ from torch import nn
5
  import torch.nn.functional as F
6
  from torch.cuda.amp import custom_fwd, custom_bwd
7
  from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
 
8
  from typing import Dict, List, Any
9
 
10
 
@@ -153,30 +159,62 @@ class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
153
  transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # -----------------------------------------> API <---------------------------------------
 
 
 
157
 
158
 
159
  class EndpointHandler:
160
  def __init__(self, path=""):
161
  # load the model
162
- tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
163
- model = GPTJForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
164
-
165
- # check for GPU
166
- device = 0 if torch.cuda.is_available() else -1
167
  model.to(device)
168
-
169
  # create inference pipeline
170
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
171
 
172
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
173
  inputs = data.pop("inputs", data)
174
  parameters = data.pop("parameters", None)
175
 
176
- # pass inputs with all kwargs in data
177
- if parameters is not None:
178
- prediction = self.pipeline(inputs, **parameters)
179
- else:
180
- prediction = self.pipeline(inputs)
181
- # postprocess the prediction
182
- return prediction
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
  import transformers
7
  from transformers import pipeline
8
  import torch
 
10
  import torch.nn.functional as F
11
  from torch.cuda.amp import custom_fwd, custom_bwd
12
  from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
13
+ from loguru import logger
14
  from typing import Dict, List, Any
15
 
16
 
 
159
  transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
160
 
161
 
162
+ class Message(BaseModel):
163
+ input: str = None
164
+ output: dict = None
165
+ length: str = None
166
+ temperature: str = None
167
+
168
+
169
+ app = FastAPI()
170
+
171
+ origins = [
172
+ "http://localhost:8000",
173
+ "http://localhost",
174
+ "http://localhost:3000",
175
+ "http://127.0.0.1:3000"
176
+ ]
177
+
178
+ app.add_middleware(
179
+ CORSMiddleware,
180
+ allow_origins=origins,
181
+ allow_credentials=True,
182
+ allow_methods=["POST"],
183
+ allow_headers=["*"],
184
+ )
185
+
186
  # -----------------------------------------> API <---------------------------------------
187
+ tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
188
+ model = GPTJForCausalLM.from_pretrained("Kanpredict/gptj-6b-8bits", low_cpu_mem_usage=True)
189
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
190
 
191
 
192
  class EndpointHandler:
193
  def __init__(self, path=""):
194
  # load the model
 
 
 
 
 
195
  model.to(device)
 
196
  # create inference pipeline
197
+ self.pipeline = pipeline(model=model, tokenizer=tokenizer, device=device)
198
 
199
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
200
  inputs = data.pop("inputs", data)
201
  parameters = data.pop("parameters", None)
202
 
203
+ # run the model and get the output(generated text)
204
+ prompt = inputs
205
+ temperature = float(parameters.temperature)
206
+ length = int(parameters.length)
207
+ logger.info("message input: %s", prompt)
208
+ logger.info("tempereture: %s", parameters.temperature)
209
+ logger.info("length: %s", parameters.length)
210
+ start = time.time()
211
+ prompt = tokenizer(prompt, return_tensors='pt')
212
+ prompt = {key: value.to(device) for key, value in prompt.items()}
213
+ out = model.generate(**prompt, min_length=length, max_length=length, temperature=temperature, do_sample=True)
214
+ generated_text = tokenizer.decode(out[0])
215
+ logger.info("generated text: ", generated_text)
216
+ logger.info("time taken: %s", time.time() - start)
217
+ result = {"output": generated_text}
218
+ result = json.dumps(result)
219
+ return result
220
+