Thomasboosinger commited on
Commit
2d58eea
1 Parent(s): d602a6c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -15
handler.py CHANGED
@@ -1,26 +1,32 @@
1
- from typing import Dict, List, Any
2
  from PIL import Image
3
  from io import BytesIO
4
- from transformers import pipeline
5
  import base64
 
6
 
7
  class EndpointHandler():
8
- def __init__(self, path=""):
9
- self.pipeline = pipeline(task="zero-shot-object-detection",model=path, device = 0 ) #device = 0 to use GPU rather than -1 which would be CPU
 
10
 
11
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
  """
13
- data args:
14
- images (:obj:`string`)
15
- candiates (:obj:`list`)
16
- Return:
17
- A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
 
 
18
  """
19
- inputs = data.pop("inputs", data)
20
-
21
- # decode base64 image to PIL
 
22
  image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
23
 
24
- # run prediction one image wit provided candiates
25
- detector = self.pipeline(images=[image], candidate_labels=inputs["candidate_labels"])
26
- return detector
 
 
 
1
+ from transformers import pipeline
2
  from PIL import Image
3
  from io import BytesIO
 
4
  import base64
5
+ from typing import Dict, List, Any
6
 
7
  class EndpointHandler():
8
+ def __init__(self, model_path=""):
9
+ # Initialize the pipeline with the specified model and set the device to GPU
10
+ self.pipeline = pipeline(task="zero-shot-object-detection", model=model_path, device=0)
11
 
12
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
  """
14
+ Process an incoming request for zero-shot object detection.
15
+
16
+ Args:
17
+ data (Dict[str, Any]): The input data containing an encoded image and candidate labels.
18
+
19
+ Returns:
20
+ A list of dictionaries, each containing a label and its corresponding score.
21
  """
22
+ # Correctly accessing the 'inputs' key and fixing the typo in 'candidates'
23
+ inputs = data.get("inputs", {})
24
+
25
+ # Decode the base64 image to a PIL image
26
  image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
27
 
28
+ # Correctly passing the image and candidate labels to the pipeline
29
+ detection_results = self.pipeline(image=image, candidate_labels=inputs["candidates"])
30
+
31
+ # Adjusting the return statement to match the expected output structure
32
+ return detection_results