File size: 2,800 Bytes
69fda24 f872c60 18cb46c 3b13f40 c6fe3c5 24d96ab c6fe3c5 3b13f40 69fda24 18cb46c 69fda24 18cb46c 69fda24 18cb46c 69fda24 18cb46c 69fda24 f872c60 69fda24 f872c60 c6fe3c5 24d96ab c6fe3c5 24d96ab c6fe3c5 24d96ab c6fe3c5 24d96ab c6fe3c5 18cb46c 69fda24 c6fe3c5 f872c60 c6fe3c5 f872c60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
import click
from huggingface_hub import HfApi
from loguru import logger
from src import config
from src import data
from src import loss
from src import models
from src import tokenizer as tk
from src import vision_model
from src import utils
from src.lightning_module import LightningModule
def _upload_model_to_hub(
vision_encoder: models.TinyCLIPVisionEncoder,
text_encoder: models.TinyCLIPTextEncoder,
debug: bool = False,
):
vision_encoder.save_pretrained(
str(config.VISION_MODEL_PATH),
safe_serialization=True,
)
text_encoder.save_pretrained(
str(config.TEXT_MODEL_PATH),
safe_serialization=True,
)
api = HfApi()
if debug:
repo_components = config.REPO_ID.split("/", maxsplit=1)
repo_components[1] = f"debug-{repo_components[1]}"
repo_id = "/".join(repo_components)
else:
repo_id = config.REPO_ID
common_hf_api_params = {
"repo_id": repo_id,
"repo_type": "model",
}
if not api.repo_exists(**common_hf_api_params):
logger.info(f"Creating repo {repo_id} on Hugging Face Hub.")
api.create_repo(**common_hf_api_params) # type: ignore
logger.info(f"Uploading models in {str(config.MODEL_PATH)} to {repo_id}.")
api.upload_folder(
folder_path=config.MODEL_PATH,
**common_hf_api_params, # type: ignore
) # type: ignore
@click.group()
def cli():
pass
@click.command()
@click.option("--trainer-config-json", required=False, default="{}", type=str)
def train(trainer_config_json: str):
if "HF_TOKEN" not in os.environ:
raise ValueError("Please set the HF_TOKEN environment variable.")
trainer_config = config.TrainerConfig.model_validate_json(trainer_config_json)
transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
train_dl, valid_dl = data.get_dataset(
transform=transform, tokenizer=tokenizer, hyper_parameters=trainer_config # type: ignore
)
vision_encoder = models.TinyCLIPVisionEncoder(config=trainer_config._model_config.vision_config)
text_encoder = models.TinyCLIPTextEncoder(config=trainer_config._model_config.text_config)
lightning_module = LightningModule(
vision_encoder=vision_encoder,
text_encoder=text_encoder,
loss_fn=loss.get_loss(trainer_config._model_config.loss_type),
hyper_parameters=trainer_config,
len_train_dl=len(train_dl),
)
trainer = utils.get_trainer(trainer_config)
trainer.fit(lightning_module, train_dl, valid_dl)
_upload_model_to_hub(vision_encoder, text_encoder, trainer_config.debug)
cli.add_command(train)
if __name__ == "__main__":
cli()
|