TedYeh commited on
Commit
a38dfb6
1 Parent(s): 6d2ffd2

update predictor

Browse files
Files changed (1) hide show
  1. predictor.py +3 -3
predictor.py CHANGED
@@ -201,7 +201,7 @@ def inference(inp_img, classes = ['big', 'small'], epoch = 6):
201
  device = torch.device("cuda")
202
  translator= Translator(to_lang="zh-TW")
203
 
204
- model = CUPredictor().to(device)
205
  model.load_state_dict(torch.load(f'models/model_{epoch}.pt'))
206
  # load image-to-text model
207
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
@@ -218,13 +218,13 @@ def inference(inp_img, classes = ['big', 'small'], epoch = 6):
218
  image_tensor = trans(inp_img)
219
  image_tensor = image_tensor.unsqueeze(0)
220
  with torch.no_grad():
221
- inputs = image_tensor.to(device)
222
  outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
223
  _, preds = torch.max(outputs_c, 1)
224
  idx = preds.numpy()[0]
225
 
226
  # unconditional image captioning
227
- inputs = processor(inp_img, return_tensors="pt").to(device)
228
  out = model_blip.generate(**inputs)
229
  description = processor.decode(out[0], skip_special_tokens=True)
230
  description_tw = translator.translate(description)
 
201
  device = torch.device("cuda")
202
  translator= Translator(to_lang="zh-TW")
203
 
204
+ model = CUPredictor()
205
  model.load_state_dict(torch.load(f'models/model_{epoch}.pt'))
206
  # load image-to-text model
207
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 
218
  image_tensor = trans(inp_img)
219
  image_tensor = image_tensor.unsqueeze(0)
220
  with torch.no_grad():
221
+ inputs = image_tensor
222
  outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
223
  _, preds = torch.max(outputs_c, 1)
224
  idx = preds.numpy()[0]
225
 
226
  # unconditional image captioning
227
+ inputs = processor(inp_img, return_tensors="pt")
228
  out = model_blip.generate(**inputs)
229
  description = processor.decode(out[0], skip_special_tokens=True)
230
  description_tw = translator.translate(description)