Spaces:
Runtime error
Runtime error
File size: 1,028 Bytes
6742988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import os
import tempfile
from pathlib import Path
import wandb
class PretrainedFromWandbMixin:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Initializes from a wandb artifact or delegates loading to the superclass.
"""
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
if ":" in pretrained_model_name_or_path and not os.path.isdir(
pretrained_model_name_or_path
):
# wandb artifact
if wandb.run is not None:
artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
else:
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
pretrained_model_name_or_path = artifact.download(tmp_dir)
return super(PretrainedFromWandbMixin, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
|