Spaces:
Running
on
Zero
Running
on
Zero
Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +5 -5
modeling_llava_qwen2.py
CHANGED
@@ -535,13 +535,13 @@ class SigLipVisionTower(nn.Module):
|
|
535 |
if type(images) is list:
|
536 |
image_features = []
|
537 |
for image in images:
|
538 |
-
image_forward_out = self.vision_tower(image.to(device=
|
539 |
output_hidden_states=True)
|
540 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
541 |
assert image_features.shape[-2] == 729
|
542 |
image_features.append(image_feature)
|
543 |
else:
|
544 |
-
image_forward_outs = self.vision_tower(images.to(device=
|
545 |
output_hidden_states=True)
|
546 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
547 |
assert image_features.shape[-2] == 729
|
@@ -550,7 +550,7 @@ class SigLipVisionTower(nn.Module):
|
|
550 |
|
551 |
@property
|
552 |
def dummy_feature(self):
|
553 |
-
return torch.zeros(1, self.hidden_size, device=
|
554 |
|
555 |
@property
|
556 |
def dtype(self):
|
@@ -682,9 +682,9 @@ class LlavaMetaForCausalLM(ABC):
|
|
682 |
image_features = self.encode_images(concat_images)
|
683 |
split_sizes = [image.shape[0] for image in images]
|
684 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
685 |
-
image_features = [x.flatten(0, 1).to(
|
686 |
else:
|
687 |
-
image_features = self.encode_images(images).to(
|
688 |
|
689 |
# Let's just add dummy tensors if they do not exist,
|
690 |
# it is a headache to deal with None all the time.
|
|
|
535 |
if type(images) is list:
|
536 |
image_features = []
|
537 |
for image in images:
|
538 |
+
image_forward_out = self.vision_tower(image.to(device="cuda:0", dtype=self.dtype).unsqueeze(0),
|
539 |
output_hidden_states=True)
|
540 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
541 |
assert image_features.shape[-2] == 729
|
542 |
image_features.append(image_feature)
|
543 |
else:
|
544 |
+
image_forward_outs = self.vision_tower(images.to(device="cuda:0", dtype=self.dtype),
|
545 |
output_hidden_states=True)
|
546 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
547 |
assert image_features.shape[-2] == 729
|
|
|
550 |
|
551 |
@property
|
552 |
def dummy_feature(self):
|
553 |
+
return torch.zeros(1, self.hidden_size, device="cuda:0", dtype=self.dtype)
|
554 |
|
555 |
@property
|
556 |
def dtype(self):
|
|
|
682 |
image_features = self.encode_images(concat_images)
|
683 |
split_sizes = [image.shape[0] for image in images]
|
684 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
685 |
+
image_features = [x.flatten(0, 1).to("cuda:0") for x in image_features]
|
686 |
else:
|
687 |
+
image_features = self.encode_images(images).to("cuda:0")
|
688 |
|
689 |
# Let's just add dummy tensors if they do not exist,
|
690 |
# it is a headache to deal with None all the time.
|