Jonathan Malott commited on
Commit
abec4e8
1 Parent(s): 4831980

Revert "Updated model location"

Browse files

This reverts commit 4831980d34025aa262e0bd238cdf4dfcc0cc8e0c.

Files changed (1) hide show
  1. dalle/models/__init__.py +6 -7
dalle/models/__init__.py CHANGED
@@ -43,21 +43,20 @@ class Dalle(nn.Module):
43
  @classmethod
44
  def from_pretrained(cls,
45
  path: str) -> nn.Module:
46
- #path = _MODELS[path] if path in _MODELS else path
47
- #path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
48
- path = ".cache/minDALL-E/1.3B/"
49
 
50
  config_base = get_base_config()
51
- config_new = OmegaConf.load(path+'config.yaml')
52
  config_update = OmegaConf.merge(config_base, config_new)
53
 
54
  model = cls(config_update)
55
- model.tokenizer = build_tokenizer(path+'tokenizer',
56
  context_length=model.config_dataset.context_length,
57
  lowercase=True,
58
  dropout=None)
59
- model.stage1.from_ckpt(path+'stage1_last.ckpt')
60
- model.stage2.from_ckpt(path+'stage2_last.ckpt')
61
  return model
62
 
63
  @torch.no_grad()
 
43
  @classmethod
44
  def from_pretrained(cls,
45
  path: str) -> nn.Module:
46
+ path = _MODELS[path] if path in _MODELS else path
47
+ path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
 
48
 
49
  config_base = get_base_config()
50
+ config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
51
  config_update = OmegaConf.merge(config_base, config_new)
52
 
53
  model = cls(config_update)
54
+ model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
55
  context_length=model.config_dataset.context_length,
56
  lowercase=True,
57
  dropout=None)
58
+ model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
59
+ model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
60
  return model
61
 
62
  @torch.no_grad()