Upload model
Browse files- config.json +2 -2
- configuration_mambavision.py +0 -2
- model.safetensors +1 -1
- modeling_mambavision.py +8 -6
config.json
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
{
|
2 |
"architectures": [
|
3 |
-
"
|
4 |
],
|
5 |
"auto_map": {
|
6 |
"AutoConfig": "configuration_mambavision.MambaVisionConfig",
|
7 |
-
"
|
8 |
},
|
9 |
"depths": [
|
10 |
1,
|
|
|
1 |
{
|
2 |
"architectures": [
|
3 |
+
"MambaVisionModel"
|
4 |
],
|
5 |
"auto_map": {
|
6 |
"AutoConfig": "configuration_mambavision.MambaVisionConfig",
|
7 |
+
"AutoModel": "modeling_mambavision.MambaVisionModel"
|
8 |
},
|
9 |
"depths": [
|
10 |
1,
|
configuration_mambavision.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
from transformers import PretrainedConfig
|
2 |
-
from typing import List
|
3 |
-
|
4 |
|
5 |
class MambaVisionConfig(PretrainedConfig):
|
6 |
model_type = "mambavision"
|
|
|
1 |
from transformers import PretrainedConfig
|
|
|
|
|
2 |
|
3 |
class MambaVisionConfig(PretrainedConfig):
|
4 |
model_type = "mambavision"
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 127219000
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f6987be0a2ca2222f386eb750d028a05203b047d3c8dfb664c27e2295d02fc0
|
3 |
size 127219000
|
modeling_mambavision.py
CHANGED
@@ -28,7 +28,7 @@ from einops import rearrange, repeat
|
|
28 |
|
29 |
from transformers import PreTrainedModel
|
30 |
|
31 |
-
from
|
32 |
|
33 |
|
34 |
def _cfg(url='', **kwargs):
|
@@ -602,8 +602,8 @@ class MambaVisionLayer(nn.Module):
|
|
602 |
if pad_r > 0 or pad_b > 0:
|
603 |
x = x[:, :, :H, :W].contiguous()
|
604 |
if self.downsample is None:
|
605 |
-
return x
|
606 |
-
return self.downsample(x)
|
607 |
|
608 |
|
609 |
class MambaVision(nn.Module):
|
@@ -697,15 +697,17 @@ class MambaVision(nn.Module):
|
|
697 |
|
698 |
def forward_features(self, x):
|
699 |
x = self.patch_embed(x)
|
|
|
700 |
for level in self.levels:
|
701 |
-
x = level(x)
|
|
|
702 |
x = self.norm(x)
|
703 |
x = self.avgpool(x)
|
704 |
x = torch.flatten(x, 1)
|
705 |
-
return x
|
706 |
|
707 |
def forward(self, x):
|
708 |
-
x = self.forward_features(x)
|
709 |
x = self.head(x)
|
710 |
return x
|
711 |
|
|
|
28 |
|
29 |
from transformers import PreTrainedModel
|
30 |
|
31 |
+
from configuration_mambavision import MambaVisionConfig
|
32 |
|
33 |
|
34 |
def _cfg(url='', **kwargs):
|
|
|
602 |
if pad_r > 0 or pad_b > 0:
|
603 |
x = x[:, :, :H, :W].contiguous()
|
604 |
if self.downsample is None:
|
605 |
+
return x, x
|
606 |
+
return self.downsample(x), x
|
607 |
|
608 |
|
609 |
class MambaVision(nn.Module):
|
|
|
697 |
|
698 |
def forward_features(self, x):
|
699 |
x = self.patch_embed(x)
|
700 |
+
outs = []
|
701 |
for level in self.levels:
|
702 |
+
x, xo = level(x)
|
703 |
+
outs.append(xo)
|
704 |
x = self.norm(x)
|
705 |
x = self.avgpool(x)
|
706 |
x = torch.flatten(x, 1)
|
707 |
+
return x, outs
|
708 |
|
709 |
def forward(self, x):
|
710 |
+
x, outs = self.forward_features(x)
|
711 |
x = self.head(x)
|
712 |
return x
|
713 |
|