for _ in range(RUN_COUNT): # 1️⃣ Initialize a new W&B run to track this job run = wandb.init(project=PROJECT, config=set_config()) for epoch in range(5): # 2️⃣ Log metrics to W&B for each epoch of training run.log(get_metrics(epoch)) # 3️⃣ At the end of training, save the model artifact # Name this artifact after the current run model_artifact_name = "demo_model_" + run.id # Create a new artifact model = wandb.Artifact(model_artifact_name, type='model') # Add files to the artifact, in this case a simple text file model.add_file(get_model()) # Log the model to W&B run.log_artifact(model) # Call finish if you're in a notebook, to mark the run as done run.finish() https://colab.research.google.com/github/wandb/examples/blob/master/colabs/wandb-model-registry/W%26B_Model_Registry_Quickstart.ipynb#scrollTo=CFXVyKSaRtUw #@title 1) Run this cell to set up `wandb` and define helper functions # INSTALL W&B LIBRARY !pip install wandb -qqq import wandb import os import math import random # FORM VARIABLES PROJECT = "Model_Registry_Quickstart" #@param {type:"string"} RUN_COUNT = 3 #@param {type:"integer"} # HELPER FUNCTIONS # Create fake data to simulate training a model. # Simulate setting up hyperparameters # Return: A dict of params to log as config to W&B def set_config(): config={ "learning_rate": 0.01 + 0.1 * random.random(), "batch_size": 128, "architecture": "CNN", } return config # Simulate training a model # Return: A model file to log as an artifact to W&B def get_model(): file_name = "demo_model.h5" model_file = open(file_name, 'w') model_file.write('Imagine this is a big model file! ' + str(random.random())) model_file.close() return file_name # Simulate logging metrics from model training # Return: A dictionary of metrics to log to W&B def get_metrics(epoch): metrics = { "acc": .8 + 0.04 * (math.log(1 + epoch + random.random()) + (0.3 * random.random())), "val_acc": .75 + 0.04 * (math.log(1 + epoch + random.random()) - (0.3 * random.random())), "loss": .1 + 0.1 * (4 - math.log(1 + epoch + random.random()) + (0.3 * random.random())), "val_loss": .1 + 0.16 * (5 - math.log(1 + epoch + random.random()) - (0.3 * random.random())), } return metrics run.id saved_model_weights.pt