File size: 4,751 Bytes
47af768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import logging
import os
from urllib.parse import urlparse

try:
    import comet_ml
except (ModuleNotFoundError, ImportError):
    comet_ml = None

import yaml

logger = logging.getLogger(__name__)

COMET_PREFIX = 'comet://'
COMET_MODEL_NAME = os.getenv('COMET_MODEL_NAME', 'yolov5')
COMET_DEFAULT_CHECKPOINT_FILENAME = os.getenv('COMET_DEFAULT_CHECKPOINT_FILENAME', 'last.pt')


def download_model_checkpoint(opt, experiment):
    model_dir = f'{opt.project}/{experiment.name}'
    os.makedirs(model_dir, exist_ok=True)

    model_name = COMET_MODEL_NAME
    model_asset_list = experiment.get_model_asset_list(model_name)

    if len(model_asset_list) == 0:
        logger.error(f'COMET ERROR: No checkpoints found for model name : {model_name}')
        return

    model_asset_list = sorted(
        model_asset_list,
        key=lambda x: x['step'],
        reverse=True,
    )
    logged_checkpoint_map = {asset['fileName']: asset['assetId'] for asset in model_asset_list}

    resource_url = urlparse(opt.weights)
    checkpoint_filename = resource_url.query

    if checkpoint_filename:
        asset_id = logged_checkpoint_map.get(checkpoint_filename)
    else:
        asset_id = logged_checkpoint_map.get(COMET_DEFAULT_CHECKPOINT_FILENAME)
        checkpoint_filename = COMET_DEFAULT_CHECKPOINT_FILENAME

    if asset_id is None:
        logger.error(f'COMET ERROR: Checkpoint {checkpoint_filename} not found in the given Experiment')
        return

    try:
        logger.info(f'COMET INFO: Downloading checkpoint {checkpoint_filename}')
        asset_filename = checkpoint_filename

        model_binary = experiment.get_asset(asset_id, return_type='binary', stream=False)
        model_download_path = f'{model_dir}/{asset_filename}'
        with open(model_download_path, 'wb') as f:
            f.write(model_binary)

        opt.weights = model_download_path

    except Exception as e:
        logger.warning('COMET WARNING: Unable to download checkpoint from Comet')
        logger.exception(e)


def set_opt_parameters(opt, experiment):
    """Update the opts Namespace with parameters
    from Comet's ExistingExperiment when resuming a run

    Args:
        opt (argparse.Namespace): Namespace of command line options
        experiment (comet_ml.APIExperiment): Comet API Experiment object
    """
    asset_list = experiment.get_asset_list()
    resume_string = opt.resume

    for asset in asset_list:
        if asset['fileName'] == 'opt.yaml':
            asset_id = asset['assetId']
            asset_binary = experiment.get_asset(asset_id, return_type='binary', stream=False)
            opt_dict = yaml.safe_load(asset_binary)
            for key, value in opt_dict.items():
                setattr(opt, key, value)
            opt.resume = resume_string

    # Save hyperparameters to YAML file
    # Necessary to pass checks in training script
    save_dir = f'{opt.project}/{experiment.name}'
    os.makedirs(save_dir, exist_ok=True)

    hyp_yaml_path = f'{save_dir}/hyp.yaml'
    with open(hyp_yaml_path, 'w') as f:
        yaml.dump(opt.hyp, f)
    opt.hyp = hyp_yaml_path


def check_comet_weights(opt):
    """Downloads model weights from Comet and updates the
    weights path to point to saved weights location

    Args:
        opt (argparse.Namespace): Command Line arguments passed
            to YOLOv5 training script

    Returns:
        None/bool: Return True if weights are successfully downloaded
            else return None
    """
    if comet_ml is None:
        return

    if isinstance(opt.weights, str):
        if opt.weights.startswith(COMET_PREFIX):
            api = comet_ml.API()
            resource = urlparse(opt.weights)
            experiment_path = f'{resource.netloc}{resource.path}'
            experiment = api.get(experiment_path)
            download_model_checkpoint(opt, experiment)
            return True

    return None


def check_comet_resume(opt):
    """Restores run parameters to its original state based on the model checkpoint
    and logged Experiment parameters.

    Args:
        opt (argparse.Namespace): Command Line arguments passed
            to YOLOv5 training script

    Returns:
        None/bool: Return True if the run is restored successfully
            else return None
    """
    if comet_ml is None:
        return

    if isinstance(opt.resume, str):
        if opt.resume.startswith(COMET_PREFIX):
            api = comet_ml.API()
            resource = urlparse(opt.resume)
            experiment_path = f'{resource.netloc}{resource.path}'
            experiment = api.get(experiment_path)
            set_opt_parameters(opt, experiment)
            download_model_checkpoint(opt, experiment)

            return True

    return None