Update modeling_codegen.py

#6
Files changed (1) hide show
  1. modeling_codegen.py +1 -0
modeling_codegen.py CHANGED
@@ -713,6 +713,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
713
 
714
  loss = None
715
  if labels is not None:
 
716
  # Shift so that tokens < n predict n
717
  shift_logits = lm_logits[..., :-1, :].contiguous()
718
  shift_labels = labels[..., 1:].contiguous()
 
713
 
714
  loss = None
715
  if labels is not None:
716
+ labels = labels.to(lm_logits.device)
717
  # Shift so that tokens < n predict n
718
  shift_logits = lm_logits[..., :-1, :].contiguous()
719
  shift_labels = labels[..., 1:].contiguous()