English
detection
open-world
open-set
Inference Endpoints
kelvinou01 commited on
Commit
841a649
1 Parent(s): b435ec9

Update handler

Browse files
Files changed (1) hide show
  1. handler.py +29 -7
handler.py CHANGED
@@ -1,9 +1,12 @@
1
 
 
 
2
  import os
3
  from typing import Dict, List, Any
 
4
  import groundingdino
5
  from groundingdino.util.inference import load_model, load_image, predict, annotate
6
- import subprocess
7
 
8
  # /app
9
  HOME = os.getcwd()
@@ -20,6 +23,9 @@ class EndpointHandler():
20
 
21
  self.model = load_model(CONFIG_PATH, os.path.join(path, "weights", "groundingdino_swint_ogc.pth"))
22
 
 
 
 
23
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
24
  """
25
  data args:
@@ -29,10 +35,26 @@ class EndpointHandler():
29
  A :obj:`list` | `dict`: will be serialized and returned
30
  """
31
  inputs = data.pop("inputs")
32
- image = inputs.pop("image")
33
  prompt = inputs.pop("prompt")
34
-
35
- return [{
36
- "image": image,
37
- "prompt": prompt,
38
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import base64
3
+ from io import BytesIO
4
  import os
5
  from typing import Dict, List, Any
6
+ import cv2
7
  import groundingdino
8
  from groundingdino.util.inference import load_model, load_image, predict, annotate
9
+ import tempfile
10
 
11
  # /app
12
  HOME = os.getcwd()
 
23
 
24
  self.model = load_model(CONFIG_PATH, os.path.join(path, "weights", "groundingdino_swint_ogc.pth"))
25
 
26
+ self.box_threshold = 0.35
27
+ self.text_threshold = 0.25
28
+
29
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
30
  """
31
  data args:
 
35
  A :obj:`list` | `dict`: will be serialized and returned
36
  """
37
  inputs = data.pop("inputs")
38
+ image_base64 = inputs.pop("image")
39
  prompt = inputs.pop("prompt")
40
+
41
+ image_data = base64.b64decode(image_base64)
42
+
43
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=True) as f:
44
+ f.write(image_data)
45
+ image_source, image = load_image(f.name)
46
+ boxes, logits, phrases = predict(
47
+ model=self.model,
48
+ image=image,
49
+ caption=prompt,
50
+ box_threshold=self.box_threshold,
51
+ text_threshold=self.text_threshold
52
+ )
53
+ annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
54
+ _, annotated_image = cv2.imencode(".jpg", annotated_frame)
55
+ annotated_image_b64 = base64.b64encode(annotated_image).decode("utf-8")
56
+
57
+ return [{
58
+ "image": annotated_image_b64,
59
+ "prompt": prompt,
60
+ }]