sooh-j commited on
Commit
3686b46
·
verified ·
1 Parent(s): 254bcc9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +57 -7
handler.py CHANGED
@@ -3,28 +3,78 @@ from PIL import Image
3
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
4
  from typing import Dict, List, Any
5
  import torch
 
 
 
 
 
6
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
- self.base_model_name = "Salesforce/blip2-opt-2.7b"
10
  self.model_name = "sooh-j/blip2-vizwizqa"
11
- self.base_model = Blip2ForConditionalGeneration.from_pretrained(self.base_model_name,
12
- load_in_8bit=True)
13
  self.processor = Blip2Processor.from_pretrained(self.base_model_name)
14
  self.model = PeftModel.from_pretrained(self.model_name, self.base_model_name)
15
 
16
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
  self.model.to(self.device)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
20
- data = data.pop("inputs", data)
 
 
 
 
 
21
 
22
- image = data.image
23
- question = data.question
24
 
25
  prompt = f"Question: {question}, Answer:"
26
- processed = self.processor(images=image, prompt, return_tensors="pt").to(self.device)
27
 
 
 
 
28
  out = self.model.generate(**processed)
29
 
30
  return self.processor.decode(out[0], skip_special_tokens=True)
 
3
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
4
  from typing import Dict, List, Any
5
  import torch
6
+ import sys
7
+ import base64
8
+ import logging
9
+ import copy
10
+ import numpy as np
11
 
12
  class EndpointHandler():
13
  def __init__(self, path=""):
14
+ self.model_base = "Salesforce/blip2-opt-2.7b"
15
  self.model_name = "sooh-j/blip2-vizwizqa"
16
+ self.base_model = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True)
17
+ self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
18
  self.processor = Blip2Processor.from_pretrained(self.base_model_name)
19
  self.model = PeftModel.from_pretrained(self.model_name, self.base_model_name)
20
 
21
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
22
  self.model.to(self.device)
23
 
24
+ # def _generate_answer(
25
+ # self,
26
+ # model_path,
27
+ # prompt,
28
+ # # num_inference_steps=25,
29
+ # # guidance_scale=7.5,
30
+ # # num_images_per_prompt=1
31
+ # ):
32
+
33
+ # self.pipe.to(self.device)
34
+
35
+ # # pil_images = self.pipe(
36
+ # # prompt=prompt,
37
+ # # num_inference_steps=num_inference_steps,
38
+ # # guidance_scale=guidance_scale,
39
+ # # num_images_per_prompt=num_images_per_prompt).images
40
+
41
+ # # np_images = []
42
+ # # for i in range(len(pil_images)):
43
+ # # np_images.append(np.asarray(pil_images[i]))
44
+
45
+ # return np.stack(np_images, axis=0)
46
+
47
+ # inputs = data.get("inputs")
48
+ # imageBase64 = inputs.get("image")
49
+ # # imageURL = inputs.get("image")
50
+ # text = inputs.get("text")
51
+ # # print(imageURL)
52
+ # # print(text)
53
+ # # image = Image.open(requests.get(imageBase64, stream=True).raw)
54
+
55
+ # image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
56
+ # inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True)
57
+ # outputs = self.model(**inputs)
58
+ # embeddings = outputs.image_embeds.detach().numpy().flatten().tolist()
59
+ # return { "embeddings": embeddings }
60
+
61
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
62
+ inputs = data.get("inputs")
63
+ imageBase64 = inputs.get("image")
64
+ question = inputs.get("text")
65
+
66
+ # data = data.pop("inputs", data)
67
+ # data = data.pop("image", image)
68
 
69
+ # image = Image.open(requests.get(imageBase64, stream=True).raw)
70
+ image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
71
 
72
  prompt = f"Question: {question}, Answer:"
73
+ processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
74
 
75
+ # answer = self._generate_answer(
76
+ # model_path, prompt, image,
77
+ # )
78
  out = self.model.generate(**processed)
79
 
80
  return self.processor.decode(out[0], skip_special_tokens=True)