Jaykintecblic commited on
Commit
069cb3b
1 Parent(s): a8a8a76

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -28
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Any
2
  from PIL import Image
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoProcessor
@@ -8,16 +8,8 @@ from transformers.image_transforms import resize, to_channel_dimension_format
8
  class EndpointHandler:
9
  def __init__(self, model_path: str):
10
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- self.processor = AutoProcessor.from_pretrained(
12
- model_path,
13
- # token=api_token
14
- )
15
- self.model = AutoModelForCausalLM.from_pretrained(
16
- model_path,
17
- # token=api_token,
18
- trust_remote_code=True,
19
- torch_dtype=torch.bfloat16,
20
- ).to(self.device)
21
  self.image_seq_len = self.model.config.perceiver_config.resampler_n_latents
22
  self.bos_token = self.processor.tokenizer.bos_token
23
  self.bad_words_ids = self.processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
@@ -44,25 +36,32 @@ class EndpointHandler:
44
  image = to_channel_dimension_format(image, ChannelDimension.FIRST)
45
  return torch.tensor(image)
46
 
47
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
48
  image = data.get("inputs")
49
-
50
  if isinstance(image, str):
51
- image = Image.open(image)
 
 
 
 
52
 
53
- inputs = self.processor.tokenizer(
54
- f"{self.bos_token}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
55
- return_tensors="pt",
56
- add_special_tokens=False,
57
- )
58
- inputs["pixel_values"] = self.processor.image_processor([image], transform=self.custom_transform)
59
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
60
 
61
- generated_ids = self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=2048, early_stopping=True)
62
- generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
63
- # print(generated_text)
64
- # return {"text": generated_text}
65
- # Format the output as an array of dictionaries with 'label' and 'score'
66
- output = [{"label": generated_text, "score": 1.0}]
 
 
67
 
68
- return output
 
 
1
+ from typing import Dict, Any, Generator
2
  from PIL import Image
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoProcessor
 
8
  class EndpointHandler:
9
  def __init__(self, model_path: str):
10
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ self.processor = AutoProcessor.from_pretrained(model_path)
12
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device)
 
 
 
 
 
 
 
 
13
  self.image_seq_len = self.model.config.perceiver_config.resampler_n_latents
14
  self.bos_token = self.processor.tokenizer.bos_token
15
  self.bad_words_ids = self.processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
 
36
  image = to_channel_dimension_format(image, ChannelDimension.FIRST)
37
  return torch.tensor(image)
38
 
39
+ def stream_response(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
40
  image = data.get("inputs")
 
41
  if isinstance(image, str):
42
+ try:
43
+ image = Image.open(image)
44
+ except Exception as e:
45
+ yield {"error": f"Failed to open image: {e}"}
46
+ return
47
 
48
+ try:
49
+ inputs = self.processor.tokenizer(
50
+ f"{self.bos_token}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
51
+ return_tensors="pt",
52
+ add_special_tokens=False,
53
+ )
54
+ inputs["pixel_values"] = self.processor.image_processor([image], transform=self.custom_transform)
55
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
56
 
57
+ for generated_ids in self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=2048, early_stopping=True, return_dict_in_generate=True, output_scores=True):
58
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
59
+ yield {"label": generated_text, "score": 1.0}
60
+
61
+ except torch.cuda.CudaError as e:
62
+ yield {"error": f"CUDA error: {e}"}
63
+ except Exception as e:
64
+ yield {"error": f"Unexpected error: {e}"}
65
 
66
+ def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
67
+ return self.stream_response(data)