mx262 commited on
Commit
0624b31
1 Parent(s): 7f89b74

Update modeling_minimonkey_chat.py

Browse files
Files changed (1) hide show
  1. modeling_minimonkey_chat.py +2 -2
modeling_minimonkey_chat.py CHANGED
@@ -280,8 +280,8 @@ class MiniMonkeyChatModel(PreTrainedModel):
280
  query = query.replace('<image>', image_tokens, 1)
281
 
282
  model_inputs = tokenizer(query, return_tensors='pt')
283
- input_ids = model_inputs['input_ids'].cuda()
284
- attention_mask = model_inputs['attention_mask'].cuda()
285
  generation_config['eos_token_id'] = eos_token_id
286
  generation_output = self.generate(
287
  pixel_values=pixel_values,
 
280
  query = query.replace('<image>', image_tokens, 1)
281
 
282
  model_inputs = tokenizer(query, return_tensors='pt')
283
+ input_ids = model_inputs['input_ids'].to(self.device)
284
+ attention_mask = model_inputs['attention_mask'].to(self.device)
285
  generation_config['eos_token_id'] = eos_token_id
286
  generation_output = self.generate(
287
  pixel_values=pixel_values,