Spaces:
No application file
No application file
for _ in range(RUN_COUNT): | |
# 1️⃣ Initialize a new W&B run to track this job | |
run = wandb.init(project=PROJECT, config=set_config()) | |
for epoch in range(5): | |
# 2️⃣ Log metrics to W&B for each epoch of training | |
run.log(get_metrics(epoch)) | |
# 3️⃣ At the end of training, save the model artifact | |
# Name this artifact after the current run | |
model_artifact_name = "demo_model_" + run.id | |
# Create a new artifact | |
model = wandb.Artifact(model_artifact_name, type='model') | |
# Add files to the artifact, in this case a simple text file | |
model.add_file(get_model()) | |
# Log the model to W&B | |
run.log_artifact(model) | |
# Call finish if you're in a notebook, to mark the run as done | |
run.finish() | |
https://colab.research.google.com/github/wandb/examples/blob/master/colabs/wandb-model-registry/W%26B_Model_Registry_Quickstart.ipynb#scrollTo=CFXVyKSaRtUw | |
#@title 1) Run this cell to set up `wandb` and define helper functions | |
# INSTALL W&B LIBRARY | |
!pip install wandb -qqq | |
import wandb | |
import os | |
import math | |
import random | |
# FORM VARIABLES | |
PROJECT = "Model_Registry_Quickstart" #@param {type:"string"} | |
RUN_COUNT = 3 #@param {type:"integer"} | |
# HELPER FUNCTIONS | |
# Create fake data to simulate training a model. | |
# Simulate setting up hyperparameters | |
# Return: A dict of params to log as config to W&B | |
def set_config(): | |
config={ | |
"learning_rate": 0.01 + 0.1 * random.random(), | |
"batch_size": 128, | |
"architecture": "CNN", | |
} | |
return config | |
# Simulate training a model | |
# Return: A model file to log as an artifact to W&B | |
def get_model(): | |
file_name = "demo_model.h5" | |
model_file = open(file_name, 'w') | |
model_file.write('Imagine this is a big model file! ' + str(random.random())) | |
model_file.close() | |
return file_name | |
# Simulate logging metrics from model training | |
# Return: A dictionary of metrics to log to W&B | |
def get_metrics(epoch): | |
metrics = { | |
"acc": .8 + 0.04 * (math.log(1 + epoch + random.random()) + (0.3 * random.random())), | |
"val_acc": .75 + 0.04 * (math.log(1 + epoch + random.random()) - (0.3 * random.random())), | |
"loss": .1 + 0.1 * (4 - math.log(1 + epoch + random.random()) + (0.3 * random.random())), | |
"val_loss": .1 + 0.16 * (5 - math.log(1 + epoch + random.random()) - (0.3 * random.random())), | |
} | |
return metrics | |
run.id | |
saved_model_weights.pt |