Samhita commited on
Commit
f1aa2f0
1 Parent(s): 3aded45

add banana

Browse files

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

Files changed (6) hide show
  1. .gitignore +1 -0
  2. app.py +141 -39
  3. data/test/greece.jpg +0 -0
  4. examples.json +2 -1
  5. requirements.in +2 -0
  6. requirements.txt +26 -15
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .python-version
app.py CHANGED
@@ -1,17 +1,26 @@
 
1
  import json
2
- import mimetypes
 
3
  import os
 
 
4
  from typing import Dict, Tuple, Union
5
 
 
 
6
  import gradio as gr
7
  import pandas as pd
8
  import plotly
9
  import plotly.express as px
10
- import requests
 
11
  from dotenv import load_dotenv
12
 
13
- from gantry_callback.gantry_util import GantryImageToTextLogger
14
- from gantry_callback.s3_util import make_unique_bucket_name
 
 
15
 
16
  load_dotenv()
17
 
@@ -19,15 +28,68 @@ URL = os.getenv("ENDPOINT")
19
  GANTRY_APP_NAME = os.getenv("GANTRY_APP_NAME")
20
  GANTRY_KEY = os.getenv("GANTRY_API_KEY")
21
  MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN")
 
 
22
 
23
  examples = json.load(open("examples.json"))
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def get_plotly_graph(
27
- latitude: float, longitude: float, location: str
28
  ) -> plotly.graph_objects.Figure:
29
- lat_long_data = [[latitude, longitude, location]]
30
- map_df = pd.DataFrame(lat_long_data, columns=["latitude", "longitude", "location"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  px.set_mapbox_access_token(MAPBOX_TOKEN)
33
  fig = px.scatter_mapbox(
@@ -35,10 +97,18 @@ def get_plotly_graph(
35
  lat="latitude",
36
  lon="longitude",
37
  hover_name="location",
38
- color_discrete_sequence=["fuchsia"],
39
- zoom=5,
40
- height=300,
 
 
 
 
 
 
 
41
  )
 
42
  fig.update_layout(mapbox_style="dark")
43
  fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0})
44
  return fig
@@ -49,64 +119,96 @@ def gradio_error():
49
 
50
 
51
  def get_outputs(
52
- data: Dict[str, Union[str, float, None]]
53
  ) -> Tuple[str, str, plotly.graph_objects.Figure]:
 
 
 
54
  location, latitude, longitude = (
55
- data["location"],
56
- data["latitude"],
57
- data["longitude"],
58
  )
59
  if location is None:
60
  gradio_error()
61
 
62
  return (
63
- data["location"],
64
  f"{latitude},{longitude}",
65
- get_plotly_graph(latitude=latitude, longitude=longitude, location=location),
66
  )
67
 
68
 
69
  def image_gradio(img_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  data = json.loads(
71
- requests.post(
72
- f"{URL}predict-image",
73
- files={
74
- "image": (
75
- img_file,
76
- open(img_file, "rb"),
77
- mimetypes.guess_type(img_file)[0],
78
- )
79
  },
80
- ).text
81
  )
82
 
83
  return get_outputs(data=data)
84
 
85
 
86
  def video_gradio(video_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  data = json.loads(
88
- requests.post(
89
- f"{URL}predict-video",
90
- files={
91
- "video": (
92
- video_file,
93
- open(video_file, "rb"),
94
- "application/octet-stream",
95
- )
96
  },
97
- ).text
98
  )
99
 
100
  return get_outputs(data=data)
101
 
102
 
103
  def url_gradio(url: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
 
 
 
 
 
 
 
104
  data = json.loads(
105
- requests.post(
106
- f"{URL}predict-url",
107
- headers={"content-type": "text/plain"},
108
- data=url,
109
- ).text
110
  )
111
 
112
  return get_outputs(data=data)
 
1
+ import base64
2
  import json
3
+
4
+ # import mimetypes
5
  import os
6
+ import sys
7
+ from io import BytesIO
8
  from typing import Dict, Tuple, Union
9
 
10
+ import banana_dev as banana
11
+ import geopy.distance
12
  import gradio as gr
13
  import pandas as pd
14
  import plotly
15
  import plotly.express as px
16
+
17
+ # import requests
18
  from dotenv import load_dotenv
19
 
20
+ sys.path.append("..")
21
+
22
+ from gantry_callback.gantry_util import GantryImageToTextLogger # noqa: E402
23
+ from gantry_callback.s3_util import make_unique_bucket_name # noqa: E402
24
 
25
  load_dotenv()
26
 
 
28
  GANTRY_APP_NAME = os.getenv("GANTRY_APP_NAME")
29
  GANTRY_KEY = os.getenv("GANTRY_API_KEY")
30
  MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN")
31
+ BANANA_API_KEY = os.getenv("BANANA_API_KEY")
32
+ BANANA_MODEL_KEY = os.getenv("BANANA_MODEL_KEY")
33
 
34
  examples = json.load(open("examples.json"))
35
 
36
 
37
+ def compute_distance(map_data: Dict[str, Dict[str, Union[str, float, None]]]):
38
+ hierarchy_lat, hierarchy_long = (
39
+ map_data["hierarchy"]["latitude"],
40
+ map_data["hierarchy"]["longitude"],
41
+ )
42
+
43
+ coarse_lat, coarse_long = (
44
+ map_data["coarse"]["latitude"],
45
+ map_data["coarse"]["longitude"],
46
+ )
47
+
48
+ fine_lat, fine_long = (
49
+ map_data["fine"]["latitude"],
50
+ map_data["fine"]["longitude"],
51
+ )
52
+
53
+ hierarchy_to_coarse = geopy.distance.geodesic(
54
+ (hierarchy_lat, hierarchy_long), (coarse_lat, coarse_long)
55
+ ).miles
56
+
57
+ hierarchy_to_fine = geopy.distance.geodesic(
58
+ (hierarchy_lat, hierarchy_long), (fine_lat, fine_long)
59
+ ).miles
60
+
61
+ return hierarchy_to_coarse, hierarchy_to_fine
62
+
63
+
64
  def get_plotly_graph(
65
+ map_data: Dict[str, Dict[str, Union[str, float, None]]]
66
  ) -> plotly.graph_objects.Figure:
67
+
68
+ hierarchy_to_coarse, hierarchy_to_fine = compute_distance(map_data)
69
+ what_to_consider = {"hierarchy"}
70
+ if hierarchy_to_coarse > 30:
71
+ what_to_consider.add("coarse")
72
+ if hierarchy_to_fine > 30:
73
+ what_to_consider.add("fine")
74
+
75
+ size_map = {"hierarchy": 3, "fine": 1, "coarse": 1}
76
+ lat_long_data = []
77
+ for subdivision, location_data in map_data.items():
78
+ if subdivision in what_to_consider:
79
+ lat_long_data.append(
80
+ [
81
+ subdivision,
82
+ float(location_data["latitude"]),
83
+ float(location_data["longitude"]),
84
+ location_data["location"],
85
+ size_map[subdivision],
86
+ ]
87
+ )
88
+
89
+ map_df = pd.DataFrame(
90
+ lat_long_data,
91
+ columns=["subdivision", "latitude", "longitude", "location", "size"],
92
+ )
93
 
94
  px.set_mapbox_access_token(MAPBOX_TOKEN)
95
  fig = px.scatter_mapbox(
 
97
  lat="latitude",
98
  lon="longitude",
99
  hover_name="location",
100
+ hover_data=["latitude", "longitude", "subdivision"],
101
+ color="subdivision",
102
+ color_discrete_map={
103
+ "hierarchy": "fuchsia",
104
+ "coarse": "blue",
105
+ "fine": "blue",
106
+ },
107
+ zoom=3,
108
+ height=500,
109
+ size="size",
110
  )
111
+
112
  fig.update_layout(mapbox_style="dark")
113
  fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0})
114
  return fig
 
119
 
120
 
121
  def get_outputs(
122
+ data: Dict[str, Dict[str, Union[str, float, None]]]
123
  ) -> Tuple[str, str, plotly.graph_objects.Figure]:
124
+ if data is None:
125
+ gradio_error()
126
+
127
  location, latitude, longitude = (
128
+ data["hierarchy"]["location"],
129
+ data["hierarchy"]["latitude"],
130
+ data["hierarchy"]["longitude"],
131
  )
132
  if location is None:
133
  gradio_error()
134
 
135
  return (
136
+ location,
137
  f"{latitude},{longitude}",
138
+ get_plotly_graph(map_data=data),
139
  )
140
 
141
 
142
  def image_gradio(img_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
143
+ # data = json.loads(
144
+ # requests.post(
145
+ # f"{URL}predict-image",
146
+ # files={
147
+ # "image": (
148
+ # img_file,
149
+ # open(img_file, "rb"),
150
+ # mimetypes.guess_type(img_file)[0],
151
+ # )
152
+ # },
153
+ # ).text
154
+ # )
155
+ with open(img_file, "rb") as image_file:
156
+ image_bytes = BytesIO(image_file.read())
157
+
158
  data = json.loads(
159
+ banana.run(
160
+ BANANA_API_KEY,
161
+ BANANA_MODEL_KEY,
162
+ {
163
+ "image": base64.b64encode(image_bytes.getvalue()).decode("utf-8"),
164
+ "filename": os.path.basename(img_file),
 
 
165
  },
166
+ )["modelOutputs"][0]
167
  )
168
 
169
  return get_outputs(data=data)
170
 
171
 
172
  def video_gradio(video_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
173
+ # data = json.loads(
174
+ # requests.post(
175
+ # f"{URL}predict-video",
176
+ # files={
177
+ # "video": (
178
+ # video_file,
179
+ # open(video_file, "rb"),
180
+ # "application/octet-stream",
181
+ # )
182
+ # },
183
+ # ).text
184
+ # )
185
+
186
  data = json.loads(
187
+ banana.run(
188
+ BANANA_API_KEY,
189
+ BANANA_MODEL_KEY,
190
+ {
191
+ "video": video_file,
192
+ "filename": os.path.basename(video_file),
 
 
193
  },
194
+ )["modelOutputs"][0]
195
  )
196
 
197
  return get_outputs(data=data)
198
 
199
 
200
  def url_gradio(url: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
201
+ # data = json.loads(
202
+ # requests.post(
203
+ # f"{URL}predict-url",
204
+ # headers={"content-type": "text/plain"},
205
+ # data=url,
206
+ # ).text
207
+ # )
208
  data = json.loads(
209
+ banana.run(BANANA_API_KEY, BANANA_MODEL_KEY, {"url": url},)[
210
+ "modelOutputs"
211
+ ][0]
 
 
212
  )
213
 
214
  return get_outputs(data=data)
data/test/greece.jpg DELETED
Binary file (103 kB)
 
examples.json CHANGED
@@ -17,6 +17,7 @@
17
  ],
18
  "video_urls": [
19
  "https://www.youtube.com/watch?v=wxeQkJTZrsw",
20
- "https://www.youtube.com/watch?v=_AznzZAlwVA"
 
21
  ]
22
  }
 
17
  ],
18
  "video_urls": [
19
  "https://www.youtube.com/watch?v=wxeQkJTZrsw",
20
+ "https://www.youtube.com/watch?v=_AznzZAlwVA",
21
+ "https://www.youtube.com/watch?v=Lk8r0Ig62zU"
22
  ]
23
  }
requirements.in CHANGED
@@ -7,3 +7,5 @@ boto3
7
  botocore
8
  gantry
9
  smart-open
 
 
 
7
  botocore
8
  gantry
9
  smart-open
10
+ banana-dev
11
+ geopy
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  # This file is autogenerated by pip-compile with python 3.9
3
  # To update, run:
4
  #
5
- # pip-compile --output-file=gradio_requirements.txt gradio_requirements.in
6
  #
7
  aiohttp==3.8.3
8
  # via gradio
@@ -18,24 +18,27 @@ attrs==22.1.0
18
  # via aiohttp
19
  backoff==1.11.1
20
  # via gantry
 
 
21
  bcrypt==4.0.1
22
  # via paramiko
23
  boto3==1.24.89
24
  # via
25
- # -r gradio_requirements.in
26
  # boto3-extensions
27
  boto3-extensions==0.19.0
28
  # via gantry
29
  botocore==1.27.89
30
  # via
31
- # -r gradio_requirements.in
32
  # boto3
33
  # boto3-extensions
34
  # s3transfer
35
  cachetools==4.2.4
36
  # via gantry
37
- certifi==2022.9.24
38
  # via
 
39
  # httpcore
40
  # httpx
41
  # requests
@@ -43,9 +46,10 @@ cffi==1.15.1
43
  # via
44
  # cryptography
45
  # pynacl
46
- charset-normalizer==2.1.1
47
  # via
48
  # aiohttp
 
49
  # requests
50
  click==8.1.3
51
  # via
@@ -79,9 +83,13 @@ frozenlist==1.3.1
79
  fsspec==2022.8.2
80
  # via gradio
81
  gantry==0.4.13
82
- # via -r gradio_requirements.in
 
 
 
 
83
  gradio==3.4.1
84
- # via -r gradio_requirements.in
85
  h11==0.12.0
86
  # via
87
  # gradio
@@ -93,9 +101,10 @@ httpcore==0.15.0
93
  # via httpx
94
  httpx==0.23.0
95
  # via gradio
96
- idna==3.4
97
  # via
98
  # anyio
 
99
  # requests
100
  # rfc3986
101
  # yarl
@@ -152,7 +161,7 @@ packaging==21.3
152
  # matplotlib
153
  pandas==1.5.0
154
  # via
155
- # -r gradio_requirements.in
156
  # gantry
157
  # gradio
158
  paramiko==2.11.0
@@ -162,7 +171,7 @@ pillow==9.2.0
162
  # gradio
163
  # matplotlib
164
  plotly==5.10.0
165
- # via -r gradio_requirements.in
166
  pycparser==2.21
167
  # via cffi
168
  pycryptodome==3.15.0
@@ -187,7 +196,7 @@ python-dateutil==2.8.2
187
  # matplotlib
188
  # pandas
189
  python-dotenv==0.21.0
190
- # via -r gradio_requirements.in
191
  python-multipart==0.0.5
192
  # via gradio
193
  pytz==2022.4
@@ -202,9 +211,10 @@ pyyaml==6.0
202
  # gradio
203
  regex==2022.3.2
204
  # via dateparser
205
- requests==2.28.1
206
  # via
207
- # -r gradio_requirements.in
 
208
  # gantry
209
  # gradio
210
  rfc3986[idna2008]==1.5.0
@@ -219,7 +229,7 @@ six==1.16.0
219
  # python-dateutil
220
  # python-multipart
221
  smart-open==6.2.0
222
- # via -r gradio_requirements.in
223
  sniffio==1.3.0
224
  # via
225
  # anyio
@@ -250,8 +260,9 @@ tzlocal==4.2
250
  # via dateparser
251
  uc-micro-py==1.0.1
252
  # via linkify-it-py
253
- urllib3==1.26.12
254
  # via
 
255
  # botocore
256
  # requests
257
  uvicorn==0.18.3
 
2
  # This file is autogenerated by pip-compile with python 3.9
3
  # To update, run:
4
  #
5
+ # pip-compile requirements.in
6
  #
7
  aiohttp==3.8.3
8
  # via gradio
 
18
  # via aiohttp
19
  backoff==1.11.1
20
  # via gantry
21
+ banana-dev==4.0.0
22
+ # via -r requirements.in
23
  bcrypt==4.0.1
24
  # via paramiko
25
  boto3==1.24.89
26
  # via
27
+ # -r requirements.in
28
  # boto3-extensions
29
  boto3-extensions==0.19.0
30
  # via gantry
31
  botocore==1.27.89
32
  # via
33
+ # -r requirements.in
34
  # boto3
35
  # boto3-extensions
36
  # s3transfer
37
  cachetools==4.2.4
38
  # via gantry
39
+ certifi==2021.10.8
40
  # via
41
+ # banana-dev
42
  # httpcore
43
  # httpx
44
  # requests
 
46
  # via
47
  # cryptography
48
  # pynacl
49
+ charset-normalizer==2.0.7
50
  # via
51
  # aiohttp
52
+ # banana-dev
53
  # requests
54
  click==8.1.3
55
  # via
 
83
  fsspec==2022.8.2
84
  # via gradio
85
  gantry==0.4.13
86
+ # via -r requirements.in
87
+ geographiclib==1.52
88
+ # via geopy
89
+ geopy==2.2.0
90
+ # via -r requirements.in
91
  gradio==3.4.1
92
+ # via -r requirements.in
93
  h11==0.12.0
94
  # via
95
  # gradio
 
101
  # via httpx
102
  httpx==0.23.0
103
  # via gradio
104
+ idna==3.3
105
  # via
106
  # anyio
107
+ # banana-dev
108
  # requests
109
  # rfc3986
110
  # yarl
 
161
  # matplotlib
162
  pandas==1.5.0
163
  # via
164
+ # -r requirements.in
165
  # gantry
166
  # gradio
167
  paramiko==2.11.0
 
171
  # gradio
172
  # matplotlib
173
  plotly==5.10.0
174
+ # via -r requirements.in
175
  pycparser==2.21
176
  # via cffi
177
  pycryptodome==3.15.0
 
196
  # matplotlib
197
  # pandas
198
  python-dotenv==0.21.0
199
+ # via -r requirements.in
200
  python-multipart==0.0.5
201
  # via gradio
202
  pytz==2022.4
 
211
  # gradio
212
  regex==2022.3.2
213
  # via dateparser
214
+ requests==2.26.0
215
  # via
216
+ # -r requirements.in
217
+ # banana-dev
218
  # gantry
219
  # gradio
220
  rfc3986[idna2008]==1.5.0
 
229
  # python-dateutil
230
  # python-multipart
231
  smart-open==6.2.0
232
+ # via -r requirements.in
233
  sniffio==1.3.0
234
  # via
235
  # anyio
 
260
  # via dateparser
261
  uc-micro-py==1.0.1
262
  # via linkify-it-py
263
+ urllib3==1.26.7
264
  # via
265
+ # banana-dev
266
  # botocore
267
  # requests
268
  uvicorn==0.18.3