succesful local run
Browse files- .vscode/settings.json +9 -9
- src/config.py +10 -0
- src/data.py +1 -1
- src/lightning_module.py +5 -8
- src/models.py +2 -4
- src/trainer.py +21 -12
- src/utils.py +31 -0
.vscode/settings.json
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"files.insertFinalNewline": true,
|
3 |
"jupyter.debugJustMyCode": false,
|
4 |
"editor.formatOnSave": true,
|
5 |
-
"editor.formatOnPaste": true,
|
6 |
"files.autoSave": "onFocusChange",
|
7 |
"editor.defaultFormatter": "ms-python.black-formatter",
|
8 |
"black-formatter.path": ["/opt/homebrew/bin/black"],
|
@@ -12,12 +12,12 @@
|
|
12 |
"isort.check": true,
|
13 |
"python.analysis.typeCheckingMode": "basic",
|
14 |
"python.defaultInterpreterPath": "/opt/homebrew/bin/python3",
|
15 |
-
"[python]": {
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
},
|
22 |
-
"isort.args":["--profile", "black"],
|
23 |
}
|
|
|
2 |
"files.insertFinalNewline": true,
|
3 |
"jupyter.debugJustMyCode": false,
|
4 |
"editor.formatOnSave": true,
|
5 |
+
// "editor.formatOnPaste": true,
|
6 |
"files.autoSave": "onFocusChange",
|
7 |
"editor.defaultFormatter": "ms-python.black-formatter",
|
8 |
"black-formatter.path": ["/opt/homebrew/bin/black"],
|
|
|
12 |
"isort.check": true,
|
13 |
"python.analysis.typeCheckingMode": "basic",
|
14 |
"python.defaultInterpreterPath": "/opt/homebrew/bin/python3",
|
15 |
+
// "[python]": {
|
16 |
+
// "editor.defaultFormatter": "ms-python.black-formatter",
|
17 |
+
// "editor.formatOnSave": true,
|
18 |
+
// "editor.codeActionsOnSave": {
|
19 |
+
// "source.organizeImports": "explicit"
|
20 |
+
// },
|
21 |
+
// },
|
22 |
+
// "isort.args":["--profile", "black"],
|
23 |
}
|
src/config.py
CHANGED
@@ -6,6 +6,14 @@ from transformers import PretrainedConfig
|
|
6 |
MAX_DOWNLOAD_TIME = 0.2
|
7 |
|
8 |
IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class DataConfig(pydantic.BaseModel):
|
@@ -97,6 +105,8 @@ class TrainerConfig(pydantic.BaseModel):
|
|
97 |
lambda_2: float = 1.0
|
98 |
|
99 |
val_check_interval: int = 1000
|
|
|
|
|
100 |
|
101 |
run_openai_clip: bool = False
|
102 |
|
|
|
6 |
MAX_DOWNLOAD_TIME = 0.2
|
7 |
|
8 |
IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
|
9 |
+
WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
|
10 |
+
|
11 |
+
IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
|
12 |
+
WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
|
13 |
+
|
14 |
+
MODEL_NAME = "tiny_clip"
|
15 |
+
|
16 |
+
WANDB_ENTITY = "sachinruk"
|
17 |
|
18 |
|
19 |
class DataConfig(pydantic.BaseModel):
|
|
|
105 |
lambda_2: float = 1.0
|
106 |
|
107 |
val_check_interval: int = 1000
|
108 |
+
log_every_n_steps: int = 100
|
109 |
+
debug: bool = False
|
110 |
|
111 |
run_openai_clip: bool = False
|
112 |
|
src/data.py
CHANGED
@@ -37,7 +37,7 @@ class CollateFn:
|
|
37 |
tokenized_text = self.tokenizer([item["caption"] for item in batch])
|
38 |
|
39 |
return {
|
40 |
-
"
|
41 |
**tokenized_text,
|
42 |
}
|
43 |
|
|
|
37 |
tokenized_text = self.tokenizer([item["caption"] for item in batch])
|
38 |
|
39 |
return {
|
40 |
+
"images": stacked_images,
|
41 |
**tokenized_text,
|
42 |
}
|
43 |
|
src/lightning_module.py
CHANGED
@@ -24,10 +24,11 @@ class LightningModule(pl.LightningModule):
|
|
24 |
self.hyper_parameters = hyper_parameters
|
25 |
self.len_train_dl = len_train_dl
|
26 |
|
27 |
-
def common_step(self, batch:
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
31 |
similarity_matrix = loss_utils.get_similarity_matrix(image_features, text_features)
|
32 |
|
33 |
loss = self.loss_fn(similarity_matrix, image_features, text_features)
|
@@ -52,10 +53,6 @@ class LightningModule(pl.LightningModule):
|
|
52 |
"params": self.vision_encoder.projection.parameters(),
|
53 |
"lr": self.hyper_parameters.learning_rate,
|
54 |
},
|
55 |
-
{
|
56 |
-
"params": self.vision_encoder.base.parameters(),
|
57 |
-
"lr": self.hyper_parameters.learning_rate / 2,
|
58 |
-
},
|
59 |
]
|
60 |
caption_params = [
|
61 |
{
|
|
|
24 |
self.hyper_parameters = hyper_parameters
|
25 |
self.len_train_dl = len_train_dl
|
26 |
|
27 |
+
def common_step(self, batch: dict[str, torch.Tensor], step_kind: str) -> torch.Tensor:
|
28 |
+
image_features = self.vision_encoder(batch["images"])
|
29 |
+
text_features = self.text_encoder(
|
30 |
+
{key: value for key, value in batch.items() if key != "images"}
|
31 |
+
)
|
32 |
similarity_matrix = loss_utils.get_similarity_matrix(image_features, text_features)
|
33 |
|
34 |
loss = self.loss_fn(similarity_matrix, image_features, text_features)
|
|
|
53 |
"params": self.vision_encoder.projection.parameters(),
|
54 |
"lr": self.hyper_parameters.learning_rate,
|
55 |
},
|
|
|
|
|
|
|
|
|
56 |
]
|
57 |
caption_params = [
|
58 |
{
|
src/models.py
CHANGED
@@ -77,10 +77,8 @@ class TinyCLIPVisionEncoder(PreTrainedModel):
|
|
77 |
num_features, config.embed_dims, config.projection_layers
|
78 |
)
|
79 |
|
80 |
-
def forward(self, images:
|
81 |
-
|
82 |
-
|
83 |
-
projected_vec = self.projection(self.base(x))
|
84 |
return F.normalize(projected_vec, dim=-1)
|
85 |
|
86 |
|
|
|
77 |
num_features, config.embed_dims, config.projection_layers
|
78 |
)
|
79 |
|
80 |
+
def forward(self, images: torch.Tensor):
|
81 |
+
projected_vec = self.projection(self.base(images))
|
|
|
|
|
82 |
return F.normalize(projected_vec, dim=-1)
|
83 |
|
84 |
|
src/trainer.py
CHANGED
@@ -1,25 +1,34 @@
|
|
1 |
-
from src import data
|
2 |
from src import config
|
3 |
-
from src import
|
4 |
-
from src import tokenizer as tk
|
5 |
-
from src.lightning_module import LightningModule
|
6 |
from src import loss
|
7 |
from src import models
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
-
def train(
|
11 |
-
transform = vision_model.get_vision_transform(
|
12 |
-
tokenizer = tk.Tokenizer(
|
13 |
train_dl, valid_dl = data.get_dataset(
|
14 |
-
transform=transform, tokenizer=tokenizer, hyper_parameters=
|
15 |
)
|
16 |
-
vision_encoder = models.TinyCLIPVisionEncoder(config=
|
17 |
-
text_encoder = models.TinyCLIPTextEncoder(config=
|
18 |
|
19 |
lightning_module = LightningModule(
|
20 |
vision_encoder=vision_encoder,
|
21 |
text_encoder=text_encoder,
|
22 |
-
loss_fn=loss.get_loss(
|
23 |
-
hyper_parameters=
|
24 |
len_train_dl=len(train_dl),
|
25 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from src import config
|
2 |
+
from src import data
|
|
|
|
|
3 |
from src import loss
|
4 |
from src import models
|
5 |
+
from src import tokenizer as tk
|
6 |
+
from src import vision_model
|
7 |
+
from src import utils
|
8 |
+
from src.lightning_module import LightningModule
|
9 |
|
10 |
|
11 |
+
def train(trainer_config: config.TrainerConfig):
|
12 |
+
transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
|
13 |
+
tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
|
14 |
train_dl, valid_dl = data.get_dataset(
|
15 |
+
transform=transform, tokenizer=tokenizer, hyper_parameters=trainer_config # type: ignore
|
16 |
)
|
17 |
+
vision_encoder = models.TinyCLIPVisionEncoder(config=trainer_config._model_config.vision_config)
|
18 |
+
text_encoder = models.TinyCLIPTextEncoder(config=trainer_config._model_config.text_config)
|
19 |
|
20 |
lightning_module = LightningModule(
|
21 |
vision_encoder=vision_encoder,
|
22 |
text_encoder=text_encoder,
|
23 |
+
loss_fn=loss.get_loss(trainer_config._model_config.loss_type),
|
24 |
+
hyper_parameters=trainer_config,
|
25 |
len_train_dl=len(train_dl),
|
26 |
)
|
27 |
+
|
28 |
+
trainer = utils.get_trainer(trainer_config)
|
29 |
+
trainer.fit(lightning_module, train_dl, valid_dl)
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
trainer_config = config.TrainerConfig(debug=True)
|
34 |
+
train(trainer_config)
|
src/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from pytorch_lightning import loggers
|
5 |
+
|
6 |
+
from src import config
|
7 |
+
|
8 |
+
|
9 |
+
def _get_wandb_logger(trainer_config: config.TrainerConfig):
|
10 |
+
name = f"{config.MODEL_NAME}-{datetime.datetime.now()}"
|
11 |
+
if trainer_config.debug:
|
12 |
+
name = "debug-" + name
|
13 |
+
return loggers.WandbLogger(
|
14 |
+
entity=config.WANDB_ENTITY,
|
15 |
+
save_dir=config.WANDB_LOG_PATH,
|
16 |
+
project=config.MODEL_NAME,
|
17 |
+
name=name,
|
18 |
+
config=trainer_config._model_config.to_dict(),
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def get_trainer(trainer_config: config.TrainerConfig):
|
23 |
+
return pl.Trainer(
|
24 |
+
max_epochs=trainer_config.epochs if not trainer_config.debug else 1,
|
25 |
+
logger=_get_wandb_logger(trainer_config),
|
26 |
+
log_every_n_steps=trainer_config.log_every_n_steps,
|
27 |
+
gradient_clip_val=1.0,
|
28 |
+
limit_train_batches=5 if trainer_config.debug else 1.0,
|
29 |
+
limit_val_batches=5 if trainer_config.debug else 1.0,
|
30 |
+
accelerator="auto",
|
31 |
+
)
|