File size: 2,800 Bytes
69fda24
 
f872c60
18cb46c
 
 
3b13f40
c6fe3c5
24d96ab
 
c6fe3c5
 
 
 
3b13f40
 
69fda24
18cb46c
 
 
69fda24
 
18cb46c
69fda24
 
 
18cb46c
69fda24
 
 
18cb46c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69fda24
f872c60
 
 
 
 
 
 
 
69fda24
 
f872c60
c6fe3c5
 
24d96ab
c6fe3c5
24d96ab
c6fe3c5
 
24d96ab
 
 
 
c6fe3c5
 
24d96ab
 
c6fe3c5
 
 
 
18cb46c
69fda24
c6fe3c5
f872c60
 
c6fe3c5
f872c60
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os

import click
from huggingface_hub import HfApi
from loguru import logger

from src import config
from src import data
from src import loss
from src import models
from src import tokenizer as tk
from src import vision_model
from src import utils
from src.lightning_module import LightningModule


def _upload_model_to_hub(
    vision_encoder: models.TinyCLIPVisionEncoder,
    text_encoder: models.TinyCLIPTextEncoder,
    debug: bool = False,
):
    vision_encoder.save_pretrained(
        str(config.VISION_MODEL_PATH),
        safe_serialization=True,
    )
    text_encoder.save_pretrained(
        str(config.TEXT_MODEL_PATH),
        safe_serialization=True,
    )

    api = HfApi()
    if debug:
        repo_components = config.REPO_ID.split("/", maxsplit=1)
        repo_components[1] = f"debug-{repo_components[1]}"
        repo_id = "/".join(repo_components)
    else:
        repo_id = config.REPO_ID
    common_hf_api_params = {
        "repo_id": repo_id,
        "repo_type": "model",
    }
    if not api.repo_exists(**common_hf_api_params):
        logger.info(f"Creating repo {repo_id} on Hugging Face Hub.")
        api.create_repo(**common_hf_api_params)  # type: ignore
    logger.info(f"Uploading models in {str(config.MODEL_PATH)} to {repo_id}.")
    api.upload_folder(
        folder_path=config.MODEL_PATH,
        **common_hf_api_params,  # type: ignore
    )  # type: ignore


@click.group()
def cli():
    pass


@click.command()
@click.option("--trainer-config-json", required=False, default="{}", type=str)
def train(trainer_config_json: str):
    if "HF_TOKEN" not in os.environ:
        raise ValueError("Please set the HF_TOKEN environment variable.")
    trainer_config = config.TrainerConfig.model_validate_json(trainer_config_json)
    transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
    tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
    train_dl, valid_dl = data.get_dataset(
        transform=transform, tokenizer=tokenizer, hyper_parameters=trainer_config  # type: ignore
    )
    vision_encoder = models.TinyCLIPVisionEncoder(config=trainer_config._model_config.vision_config)
    text_encoder = models.TinyCLIPTextEncoder(config=trainer_config._model_config.text_config)

    lightning_module = LightningModule(
        vision_encoder=vision_encoder,
        text_encoder=text_encoder,
        loss_fn=loss.get_loss(trainer_config._model_config.loss_type),
        hyper_parameters=trainer_config,
        len_train_dl=len(train_dl),
    )

    trainer = utils.get_trainer(trainer_config)
    trainer.fit(lightning_module, train_dl, valid_dl)

    _upload_model_to_hub(vision_encoder, text_encoder, trainer_config.debug)


cli.add_command(train)

if __name__ == "__main__":
    cli()