Spaces:
Runtime error
Runtime error
Jonathan Malott
commited on
Commit
•
abec4e8
1
Parent(s):
4831980
Revert "Updated model location"
Browse filesThis reverts commit 4831980d34025aa262e0bd238cdf4dfcc0cc8e0c.
- dalle/models/__init__.py +6 -7
dalle/models/__init__.py
CHANGED
@@ -43,21 +43,20 @@ class Dalle(nn.Module):
|
|
43 |
@classmethod
|
44 |
def from_pretrained(cls,
|
45 |
path: str) -> nn.Module:
|
46 |
-
|
47 |
-
|
48 |
-
path = ".cache/minDALL-E/1.3B/"
|
49 |
|
50 |
config_base = get_base_config()
|
51 |
-
config_new = OmegaConf.load(path
|
52 |
config_update = OmegaConf.merge(config_base, config_new)
|
53 |
|
54 |
model = cls(config_update)
|
55 |
-
model.tokenizer = build_tokenizer(path
|
56 |
context_length=model.config_dataset.context_length,
|
57 |
lowercase=True,
|
58 |
dropout=None)
|
59 |
-
model.stage1.from_ckpt(path
|
60 |
-
model.stage2.from_ckpt(path
|
61 |
return model
|
62 |
|
63 |
@torch.no_grad()
|
|
|
43 |
@classmethod
|
44 |
def from_pretrained(cls,
|
45 |
path: str) -> nn.Module:
|
46 |
+
path = _MODELS[path] if path in _MODELS else path
|
47 |
+
path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
|
|
|
48 |
|
49 |
config_base = get_base_config()
|
50 |
+
config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
|
51 |
config_update = OmegaConf.merge(config_base, config_new)
|
52 |
|
53 |
model = cls(config_update)
|
54 |
+
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
|
55 |
context_length=model.config_dataset.context_length,
|
56 |
lowercase=True,
|
57 |
dropout=None)
|
58 |
+
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
|
59 |
+
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
|
60 |
return model
|
61 |
|
62 |
@torch.no_grad()
|