Jaykintecblic commited on
Commit
9a42ebe
1 Parent(s): 09067ad

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +64 -0
handler.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoProcessor
5
+ from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
6
+ from transformers.image_transforms import resize, to_channel_dimension_format
7
+
8
+ class CustomPipeline:
9
+ def __init__(self, model_path: str, api_token: str):
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ self.processor = AutoProcessor.from_pretrained(
12
+ model_path,
13
+ token=api_token
14
+ )
15
+ self.model = AutoModelForCausalLM.from_pretrained(
16
+ model_path,
17
+ token=api_token,
18
+ trust_remote_code=True,
19
+ torch_dtype=torch.bfloat16,
20
+ ).to(self.device)
21
+ self.image_seq_len = self.model.config.perceiver_config.resampler_n_latents
22
+ self.bos_token = self.processor.tokenizer.bos_token
23
+ self.bad_words_ids = self.processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
24
+
25
+ def convert_to_rgb(self, image: Image.Image) -> Image.Image:
26
+ if image.mode == "RGB":
27
+ return image
28
+ image_rgba = image.convert("RGBA")
29
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
30
+ alpha_composite = Image.alpha_composite(background, image_rgba)
31
+ alpha_composite = alpha_composite.convert("RGB")
32
+ return alpha_composite
33
+
34
+ def custom_transform(self, image: Image.Image) -> torch.Tensor:
35
+ image = self.convert_to_rgb(image)
36
+ image = to_numpy_array(image)
37
+ image = resize(image, (960, 960), resample=PILImageResampling.BILINEAR)
38
+ image = self.processor.image_processor.rescale(image, scale=1 / 255)
39
+ image = self.processor.image_processor.normalize(
40
+ image,
41
+ mean=self.processor.image_processor.image_mean,
42
+ std=self.processor.image_processor.image_std
43
+ )
44
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST)
45
+ return torch.tensor(image)
46
+
47
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
48
+ image = data.get("inputs")
49
+
50
+ if isinstance(image, str):
51
+ image = Image.open(image)
52
+
53
+ inputs = self.processor.tokenizer(
54
+ f"{self.bos_token}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
55
+ return_tensors="pt",
56
+ add_special_tokens=False,
57
+ )
58
+ inputs["pixel_values"] = self.processor.image_processor([image], transform=self.custom_transform)
59
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
60
+
61
+ generated_ids = self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=4096)
62
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
63
+
64
+ return {"text": generated_text}