Samhita commited on
Commit
a1a2673
1 Parent(s): 34a0e77

remove json loads

Browse files

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

Files changed (1) hide show
  1. app.py +60 -27
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import base64
2
  import json
 
3
 
4
  # import mimetypes
5
  import os
@@ -9,18 +10,28 @@ from typing import Dict, Tuple, Union
9
 
10
  import banana_dev as banana
11
  import geopy.distance
12
- import gradio as gr
13
  import pandas as pd
14
  import plotly
15
  import plotly.express as px
16
 
17
  # import requests
18
  from dotenv import load_dotenv
 
 
 
19
 
20
  sys.path.append("..")
21
 
22
  from gantry_callback.gantry_util import GantryImageToTextLogger # noqa: E402
23
- from gantry_callback.s3_util import make_unique_bucket_name # noqa: E402
 
 
 
 
 
 
 
 
24
 
25
  load_dotenv()
26
 
@@ -155,20 +166,37 @@ def image_gradio(img_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
155
  with open(img_file, "rb") as image_file:
156
  image_bytes = BytesIO(image_file.read())
157
 
158
- data = json.loads(
159
- banana.run(
160
- BANANA_API_KEY,
161
- BANANA_MODEL_KEY,
162
- {
163
- "image": base64.b64encode(image_bytes.getvalue()).decode("utf-8"),
164
- "filename": os.path.basename(img_file),
165
- },
166
- )["modelOutputs"][0]
167
- )
168
 
169
  return get_outputs(data=data)
170
 
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def video_gradio(video_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
173
  # data = json.loads(
174
  # requests.post(
@@ -183,16 +211,23 @@ def video_gradio(video_file: str) -> Tuple[str, str, plotly.graph_objects.Figure
183
  # ).text
184
  # )
185
 
186
- data = json.loads(
187
- banana.run(
188
- BANANA_API_KEY,
189
- BANANA_MODEL_KEY,
190
- {
191
- "video": video_file,
192
- "filename": os.path.basename(video_file),
193
- },
194
- )["modelOutputs"][0]
195
- )
 
 
 
 
 
 
 
196
 
197
  return get_outputs(data=data)
198
 
@@ -205,11 +240,9 @@ def url_gradio(url: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
205
  # data=url,
206
  # ).text
207
  # )
208
- data = json.loads(
209
- banana.run(BANANA_API_KEY, BANANA_MODEL_KEY, {"url": url},)[
210
- "modelOutputs"
211
- ][0]
212
- )
213
 
214
  return get_outputs(data=data)
215
 
 
1
  import base64
2
  import json
3
+ import mimetypes
4
 
5
  # import mimetypes
6
  import os
 
10
 
11
  import banana_dev as banana
12
  import geopy.distance
 
13
  import pandas as pd
14
  import plotly
15
  import plotly.express as px
16
 
17
  # import requests
18
  from dotenv import load_dotenv
19
+ from smart_open import open as smartopen
20
+
21
+ import gradio as gr
22
 
23
  sys.path.append("..")
24
 
25
  from gantry_callback.gantry_util import GantryImageToTextLogger # noqa: E402
26
+ from gantry_callback.s3_util import (
27
+ add_access_policy, # noqa: E402
28
+ enable_bucket_versioning,
29
+ get_or_create_bucket,
30
+ get_uri_of,
31
+ make_key,
32
+ make_unique_bucket_name,
33
+ )
34
+ from gantry_callback.string_img_util import read_b64_string # noqa: E402
35
 
36
  load_dotenv()
37
 
 
166
  with open(img_file, "rb") as image_file:
167
  image_bytes = BytesIO(image_file.read())
168
 
169
+ data = banana.run(
170
+ BANANA_API_KEY,
171
+ BANANA_MODEL_KEY,
172
+ {
173
+ "image": base64.b64encode(image_bytes.getvalue()).decode("utf-8"),
174
+ "filename": os.path.basename(img_file),
175
+ },
176
+ )["modelOutputs"][0]
 
 
177
 
178
  return get_outputs(data=data)
179
 
180
 
181
+ def _upload_video_to_s3(video_b64_string):
182
+ bucket = get_or_create_bucket(
183
+ make_unique_bucket_name(prefix="geolocator-app", seed="420")
184
+ )
185
+ enable_bucket_versioning(bucket)
186
+ add_access_policy(bucket)
187
+
188
+ data_type, video_buffer = read_b64_string(video_b64_string, return_data_type=True)
189
+ video_bytes = video_buffer.read()
190
+ key = make_key(video_bytes, filetype=data_type)
191
+
192
+ s3_uri = get_uri_of(bucket, key)
193
+
194
+ with smartopen(s3_uri, "wb") as s3_object:
195
+ s3_object.write(video_bytes)
196
+
197
+ return s3_uri
198
+
199
+
200
  def video_gradio(video_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
201
  # data = json.loads(
202
  # requests.post(
 
211
  # ).text
212
  # )
213
 
214
+ with open(video_file, "rb") as video_file:
215
+ video_b64_string = base64.b64encode(
216
+ BytesIO(video_file.read()).getvalue()
217
+ ).decode("utf8")
218
+
219
+ video_mime = mimetypes.guess_type(video_file)[0]
220
+
221
+ s3_uri = _upload_video_to_s3(f"data:{video_mime};base64," + video_b64_string)
222
+
223
+ data = banana.run(
224
+ BANANA_API_KEY,
225
+ BANANA_MODEL_KEY,
226
+ {
227
+ "video": s3_uri,
228
+ "filename": os.path.basename(video_file),
229
+ },
230
+ )["modelOutputs"][0]
231
 
232
  return get_outputs(data=data)
233
 
 
240
  # data=url,
241
  # ).text
242
  # )
243
+ data = banana.run(BANANA_API_KEY, BANANA_MODEL_KEY, {"url": url},)[
244
+ "modelOutputs"
245
+ ][0]
 
 
246
 
247
  return get_outputs(data=data)
248