Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -379,11 +379,8 @@ class StreamMultiDiffusion(nn.Module):
|
|
379 |
A single string of text prompt.
|
380 |
"""
|
381 |
question = 'Question: What are in the image? Answer:'
|
382 |
-
print(self.i2t_model.device)
|
383 |
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
384 |
-
|
385 |
-
out = self.i2t_model.generate(**inputs, max_new_tokens=77)
|
386 |
-
print(out[0].device)
|
387 |
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
388 |
return prompt
|
389 |
|
|
|
379 |
A single string of text prompt.
|
380 |
"""
|
381 |
question = 'Question: What are in the image? Answer:'
|
|
|
382 |
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
383 |
+
out = self.i2t_model.generate(**{k: v.to(self.i2t_model.device) for k, v in inputs.items()}, max_new_tokens=77)
|
|
|
|
|
384 |
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
385 |
return prompt
|
386 |
|