Update modeling_prismatic.py to account for the case where `input_ids` is `None `

#5
by eliotj - opened

Input Ids and Input Embeds are both marked Optional[torch.LongTensor] = None, however failing to pass in input_ids into the forward() method results in an error in the first block, since the code automatically checks if input_ids.shape[1] == 1 without first checking to see if input_ids is not None.

This pull request updates the logic to allow for this case in Generation with Cache and Multimodal Forward.

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment