Spaces:
Sleeping
Sleeping
Ahsen Khaliq
commited on
Commit
•
178d84c
1
Parent(s):
d0b8090
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
os.system("gdown https://drive.google.com/uc?id=14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT")
|
|
|
3 |
import clip
|
4 |
import os
|
5 |
from torch import nn
|
@@ -32,7 +33,6 @@ TA = Union[T, ARRAY]
|
|
32 |
D = torch.device
|
33 |
CPU = torch.device('cpu')
|
34 |
|
35 |
-
model_path = 'conceptual_weights.pt'
|
36 |
|
37 |
def get_device(device_id: int) -> D:
|
38 |
if not torch.cuda.is_available():
|
@@ -234,14 +234,16 @@ prefix_length = 10
|
|
234 |
|
235 |
model = ClipCaptionModel(prefix_length)
|
236 |
|
237 |
-
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
238 |
-
|
239 |
-
model = model.eval()
|
240 |
-
device = CUDA(0) if is_gpu else "cpu"
|
241 |
model = model.to(device)
|
242 |
|
243 |
|
244 |
-
def inference(img):
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
use_beam_search = False
|
246 |
image = io.imread(img.name)
|
247 |
pil_image = PIL.Image.fromarray(image)
|
@@ -262,7 +264,7 @@ article = "<p style='text-align: center'><a href='https://github.com/rmokady/CLI
|
|
262 |
examples=[['water.jpeg']]
|
263 |
gr.Interface(
|
264 |
inference,
|
265 |
-
gr.inputs.Image(type="file", label="Input"),
|
266 |
gr.outputs.Textbox(label="Output"),
|
267 |
title=title,
|
268 |
description=description,
|
|
|
1 |
import os
|
2 |
os.system("gdown https://drive.google.com/uc?id=14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT")
|
3 |
+
os.system("gdown https://drive.google.com/uc?id=1IdaBtMSvtyzF0ByVaBHtvM0JYSXRExRX")
|
4 |
import clip
|
5 |
import os
|
6 |
from torch import nn
|
|
|
33 |
D = torch.device
|
34 |
CPU = torch.device('cpu')
|
35 |
|
|
|
36 |
|
37 |
def get_device(device_id: int) -> D:
|
38 |
if not torch.cuda.is_available():
|
|
|
234 |
|
235 |
model = ClipCaptionModel(prefix_length)
|
236 |
|
|
|
|
|
|
|
|
|
237 |
model = model.to(device)
|
238 |
|
239 |
|
240 |
+
def inference(img,model):
|
241 |
+
if model == "COCO":
|
242 |
+
model_path = 'coco_weights.pt'
|
243 |
+
else:
|
244 |
+
model_path = 'conceptual_weights.pt'
|
245 |
+
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
246 |
+
|
247 |
use_beam_search = False
|
248 |
image = io.imread(img.name)
|
249 |
pil_image = PIL.Image.fromarray(image)
|
|
|
264 |
examples=[['water.jpeg']]
|
265 |
gr.Interface(
|
266 |
inference,
|
267 |
+
[gr.inputs.Image(type="file", label="Input"),gr.inputs.Radio(choices["COCO","Conceptual captions"], type="value", default="COCO", label="Model")],
|
268 |
gr.outputs.Textbox(label="Output"),
|
269 |
title=title,
|
270 |
description=description,
|