davanstrien HF staff commited on
Commit
5fd9231
1 Parent(s): 2dce2dc

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +47 -0
handler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
3
+ from PIL import Image
4
+ import requests
5
+ import torch
6
+
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path=""):
10
+ self.processor = AutoProcessor.from_pretrained(
11
+ path, trust_remote_code=True, torch_dtype="auto", device_map="auto"
12
+ )
13
+ self.model = AutoModelForCausalLM.from_pretrained(
14
+ path, trust_remote_code=True, torch_dtype="auto", device_map="auto"
15
+ )
16
+
17
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
+ # Extract inputs from the request data
19
+ image_url = data.get("image_url")
20
+ text_prompt = data.get("text_prompt", "Describe this image.")
21
+
22
+ # Download and process the image
23
+ image = Image.open(requests.get(image_url, stream=True).raw)
24
+ if image.mode != "RGB":
25
+ image = image.convert("RGB")
26
+
27
+ # Process the image and text
28
+ inputs = self.processor.process(images=[image], text=text_prompt)
29
+
30
+ # Move inputs to the correct device and make a batch of size 1
31
+ inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
32
+
33
+ # Generate output
34
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
35
+ output = self.model.generate_from_batch(
36
+ inputs,
37
+ GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
38
+ tokenizer=self.processor.tokenizer,
39
+ )
40
+
41
+ # Decode the generated tokens
42
+ generated_tokens = output[0, inputs["input_ids"].size(1) :]
43
+ generated_text = self.processor.tokenizer.decode(
44
+ generated_tokens, skip_special_tokens=True
45
+ )
46
+
47
+ return [{"generated_text": generated_text}]