Spaces:
Running
Running
TedYeh
commited on
Commit
•
a38dfb6
1
Parent(s):
6d2ffd2
update predictor
Browse files- predictor.py +3 -3
predictor.py
CHANGED
@@ -201,7 +201,7 @@ def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
|
201 |
device = torch.device("cuda")
|
202 |
translator= Translator(to_lang="zh-TW")
|
203 |
|
204 |
-
model = CUPredictor()
|
205 |
model.load_state_dict(torch.load(f'models/model_{epoch}.pt'))
|
206 |
# load image-to-text model
|
207 |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
@@ -218,13 +218,13 @@ def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
|
218 |
image_tensor = trans(inp_img)
|
219 |
image_tensor = image_tensor.unsqueeze(0)
|
220 |
with torch.no_grad():
|
221 |
-
inputs = image_tensor
|
222 |
outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
|
223 |
_, preds = torch.max(outputs_c, 1)
|
224 |
idx = preds.numpy()[0]
|
225 |
|
226 |
# unconditional image captioning
|
227 |
-
inputs = processor(inp_img, return_tensors="pt")
|
228 |
out = model_blip.generate(**inputs)
|
229 |
description = processor.decode(out[0], skip_special_tokens=True)
|
230 |
description_tw = translator.translate(description)
|
|
|
201 |
device = torch.device("cuda")
|
202 |
translator= Translator(to_lang="zh-TW")
|
203 |
|
204 |
+
model = CUPredictor()
|
205 |
model.load_state_dict(torch.load(f'models/model_{epoch}.pt'))
|
206 |
# load image-to-text model
|
207 |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
|
|
218 |
image_tensor = trans(inp_img)
|
219 |
image_tensor = image_tensor.unsqueeze(0)
|
220 |
with torch.no_grad():
|
221 |
+
inputs = image_tensor
|
222 |
outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
|
223 |
_, preds = torch.max(outputs_c, 1)
|
224 |
idx = preds.numpy()[0]
|
225 |
|
226 |
# unconditional image captioning
|
227 |
+
inputs = processor(inp_img, return_tensors="pt")
|
228 |
out = model_blip.generate(**inputs)
|
229 |
description = processor.decode(out[0], skip_special_tokens=True)
|
230 |
description_tw = translator.translate(description)
|