English
Inference Endpoints
garg-aayush commited on
Commit
e0d40ee
1 Parent(s): 4b28421

Image url as input, evnironment variable for with and without tiling

Browse files
Files changed (2) hide show
  1. handler.py +17 -5
  2. test_handler.ipynb +38 -22
handler.py CHANGED
@@ -13,16 +13,20 @@ import boto3
13
  import uuid, io
14
  import torch
15
  import base64
 
16
 
17
 
18
  class EndpointHandler:
19
  def __init__(self, path=""):
 
 
20
 
21
  # Initialize the Real-ESRGAN model with specified parameters
22
  self.model = RealESRGANer(
23
  scale=4, # Scale factor for the model
24
  # Path to the pre-trained model weights
25
- model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
 
26
  # Initialize the RRDBNet model architecture with specified parameters
27
  model= RRDBNet(num_in_ch=3,
28
  num_out_ch=3,
@@ -31,7 +35,7 @@ class EndpointHandler:
31
  num_grow_ch=32,
32
  scale=4
33
  ),
34
- tile=0,
35
  tile_pad=0,
36
  half=True,
37
  )
@@ -56,12 +60,13 @@ class EndpointHandler:
56
  outscale = float(inputs.pop("outscale", 3))
57
 
58
  # decode base64 image to PIL
59
- image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
60
  in_size, in_mode = image.size, image.mode
61
 
62
  # check image size and mode and return dict
63
  assert in_mode in ["RGB", "RGBA", "L"], f"Unsupported image mode: {in_mode}"
64
- assert in_size[0] * in_size[1] < 1400*1400, f"Image is too large: {in_size}: {in_size[0] * in_size[1]} is greater than {1400*1400}"
 
65
  assert outscale > 1 and outscale <=10, f"Outscale must be between 1 and 10: {outscale}"
66
 
67
  # debug
@@ -141,4 +146,11 @@ class EndpointHandler:
141
  image_url = "https://{0}.s3.amazonaws.com/{1}".format(self.bucket_name, key)
142
 
143
  # return the url and the key
144
- return image_url, key
 
 
 
 
 
 
 
 
13
  import uuid, io
14
  import torch
15
  import base64
16
+ import requests
17
 
18
 
19
  class EndpointHandler:
20
  def __init__(self, path=""):
21
+
22
+ self.tiling_size = int(os.environ["TILING_SIZE"])
23
 
24
  # Initialize the Real-ESRGAN model with specified parameters
25
  self.model = RealESRGANer(
26
  scale=4, # Scale factor for the model
27
  # Path to the pre-trained model weights
28
+ # model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
29
+ model_path=f"/workspace/real-esrgan/weights/Real-ESRGAN-x4plus.pth",
30
  # Initialize the RRDBNet model architecture with specified parameters
31
  model= RRDBNet(num_in_ch=3,
32
  num_out_ch=3,
 
35
  num_grow_ch=32,
36
  scale=4
37
  ),
38
+ tile=self.tiling_size,
39
  tile_pad=0,
40
  half=True,
41
  )
 
60
  outscale = float(inputs.pop("outscale", 3))
61
 
62
  # decode base64 image to PIL
63
+ image = self.download_image_url(inputs['image_url'])
64
  in_size, in_mode = image.size, image.mode
65
 
66
  # check image size and mode and return dict
67
  assert in_mode in ["RGB", "RGBA", "L"], f"Unsupported image mode: {in_mode}"
68
+ if self.tiling_size == 0:
69
+ assert in_size[0] * in_size[1] < 1400*1400, f"Image is too large: {in_size}: {in_size[0] * in_size[1]} is greater than {self.tiling_size*self.tiling_size}"
70
  assert outscale > 1 and outscale <=10, f"Outscale must be between 1 and 10: {outscale}"
71
 
72
  # debug
 
146
  image_url = "https://{0}.s3.amazonaws.com/{1}".format(self.bucket_name, key)
147
 
148
  # return the url and the key
149
+ return image_url, key
150
+
151
+ def download_image_url(self, image_url):
152
+ "Download the image from the url and return the image."
153
+
154
+ response = requests.get(image_url)
155
+ image = Image.open(BytesIO(response.content))
156
+ return image
test_handler.ipynb CHANGED
@@ -15,6 +15,7 @@
15
  }
16
  ],
17
  "source": [
 
18
  "from handler import EndpointHandler\n",
19
  "import base64\n",
20
  "from io import BytesIO\n",
@@ -28,6 +29,18 @@
28
  "execution_count": 2,
29
  "metadata": {},
30
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
31
  "source": [
32
  "# init handler\n",
33
  "my_handler = EndpointHandler(path=\".\")"
@@ -42,45 +55,48 @@
42
  "name": "stdout",
43
  "output_type": "stream",
44
  "text": [
45
- "image.size: (1200, 517), image.mode: RGBA, outscale: 10.0\n"
46
  ]
47
  },
48
  {
49
  "name": "stdout",
50
  "output_type": "stream",
51
  "text": [
52
- "output.shape: (5170, 12000, 4)\n",
53
- "https://upscale-process-results.s3.amazonaws.com/0da01291-5a39-40fd-b322-f5d83e6066f6.png 0da01291-5a39-40fd-b322-f5d83e6066f6.png\n",
54
- "image.size: (1056, 1068), image.mode: RGB, outscale: 3.0\n",
55
- "output.shape: (3204, 3168, 3)\n",
56
- "https://upscale-process-results.s3.amazonaws.com/c1ba714d-50e2-45d0-ac9f-8a4a4f218320.png c1ba714d-50e2-45d0-ac9f-8a4a4f218320.png\n",
57
- "image.size: (1056, 1068), image.mode: L, outscale: 5.49\n",
58
- "output.shape: (5863, 5797, 3)\n",
59
- "https://upscale-process-results.s3.amazonaws.com/e964c021-9169-49d8-9382-104f704a1d92.png e964c021-9169-49d8-9382-104f704a1d92.png\n"
60
  ]
61
  }
62
  ],
63
  "source": [
64
  "img_dir = \"test_data/\"\n",
65
- "img_names = [\"4121783.png\", \"FB_IMG_1725931665635.jpg\", \"FB_IMG_1725931665635_gray.jpg\"]\n",
66
- "out_scales = [10, 3, 5.49]\n",
67
- "for img_name, outscale in zip(img_names, out_scales):\n",
68
- " image_path = img_dir + img_name\n",
 
 
 
69
  " # create payload\n",
70
- " with open(image_path, \"rb\") as i:\n",
71
- " b64 = base64.b64encode(i.read())\n",
72
- " b64 = b64.decode(\"utf-8\")\n",
73
- " payload = {\n",
74
- " \"inputs\": {\"image\": b64, \n",
75
  " \"outscale\": outscale\n",
76
  " }\n",
77
  " }\n",
78
- "\n",
79
- "\n",
80
  " output_payload = my_handler(payload)\n",
81
- " print(output_payload[\"image_url\"], output_payload[\"image_key\"])\n",
82
- " "
83
  ]
 
 
 
 
 
 
 
84
  }
85
  ],
86
  "metadata": {
 
15
  }
16
  ],
17
  "source": [
18
+ "import os\n",
19
  "from handler import EndpointHandler\n",
20
  "import base64\n",
21
  "from io import BytesIO\n",
 
29
  "execution_count": 2,
30
  "metadata": {},
31
  "outputs": [],
32
+ "source": [
33
+ "os.environ[\"AWS_ACCESS_KEY_ID\"] = \"\"\n",
34
+ "os.environ[\"AWS_SECRET_ACCESS_KEY\"] = \"\"\n",
35
+ "os.environ[\"S3_BUCKET_NAME\"] = \"\"\n",
36
+ "os.environ[\"TILING_SIZE\"] = \"1000\""
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 3,
42
+ "metadata": {},
43
+ "outputs": [],
44
  "source": [
45
  "# init handler\n",
46
  "my_handler = EndpointHandler(path=\".\")"
 
55
  "name": "stdout",
56
  "output_type": "stream",
57
  "text": [
58
+ "image.size: (1024, 1024), image.mode: RGB, outscale: 4.0\n"
59
  ]
60
  },
61
  {
62
  "name": "stdout",
63
  "output_type": "stream",
64
  "text": [
65
+ "\tTile 1/4\n",
66
+ "\tTile 2/4\n",
67
+ "\tTile 3/4\n",
68
+ "\tTile 4/4\n",
69
+ "output.shape: (4096, 4096, 3)\n",
70
+ "https://jiffy-staging-upscaled-images.s3.amazonaws.com/d91323cb-0801-45b7-8109-9739212037ed.png d91323cb-0801-45b7-8109-9739212037ed.png\n"
 
 
71
  ]
72
  }
73
  ],
74
  "source": [
75
  "img_dir = \"test_data/\"\n",
76
+ "img_urls = [\"https://jiffy-transfers.imgix.net/2/attachments/r267odvvfmkp6c5lccj1y6f9trb0?ixlib=rb-0.3.5\",\n",
77
+ " # \"https://jiffy-staging-transfers.imgix.net/2/development/attachments/zo31eau0ykhbwoddrjtlbyz6w9mp?ixlib=rb-0.3.5\", # larger than > 1.96M pixels\n",
78
+ " # \"https://jiffy-staging-transfers.imgix.net/2/development/attachments/b8ecchms9rr9wk3g71kfpfprqg1v?ixlib=rb-0.3.5\" # larger than > 1.96M pixels\n",
79
+ " ]\n",
80
+ "\n",
81
+ "out_scales = [4, 3, 2]\n",
82
+ "for img_url, outscale in zip(img_urls, out_scales):\n",
83
  " # create payload\n",
84
+ " payload = {\n",
85
+ " \"inputs\": {\"image_url\": img_url, \n",
 
 
 
86
  " \"outscale\": outscale\n",
87
  " }\n",
88
  " }\n",
89
+ " \n",
 
90
  " output_payload = my_handler(payload)\n",
91
+ " print(output_payload[\"image_url\"], output_payload[\"image_key\"])\n"
 
92
  ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": []
100
  }
101
  ],
102
  "metadata": {