sachin commited on
Commit
3b13f40
1 Parent(s): bd0d978

Refactoring models

Browse files
Files changed (3) hide show
  1. src/models.py +1 -3
  2. src/trainer.py +7 -0
  3. src/vision_model.py +11 -0
src/models.py CHANGED
@@ -70,9 +70,7 @@ class TinyCLIPVisionEncoder(PreTrainedModel):
70
 
71
  def __init__(self, config: TinyCLIPVisionConfig):
72
  super().__init__(config)
73
- self.base = timm.create_model(config.vision_model, num_classes=0)
74
- timm_config = data.resolve_data_config({}, model=self.base)
75
- self.transform = data.transforms_factory.create_transform(**timm_config)
76
  self.projection = projection_layers(
77
  self.base.num_features, config.embed_dims, config.projection_layers
78
  )
 
70
 
71
  def __init__(self, config: TinyCLIPVisionConfig):
72
  super().__init__(config)
73
+
 
 
74
  self.projection = projection_layers(
75
  self.base.num_features, config.embed_dims, config.projection_layers
76
  )
src/trainer.py CHANGED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from src import data
2
+ from src import config
3
+ from src import vision_model
4
+
5
+
6
+ def train(config: config.TrainerConfig):
7
+ train_dl, valid_dl = data.get_dataset()
src/vision_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ from timm import data
3
+
4
+ from src import config
5
+
6
+
7
+ def get_vision_base_and_transform(config: config.TrainerConfig):
8
+ base = timm.create_model(config._model_config.vision_model, num_classes=0)
9
+ timm_config = data.resolve_data_config({}, model=base)
10
+ transform = data.transforms_factory.create_transform(**timm_config)
11
+ return base, transform