Jonathan Malott commited on
Commit
ba3afcd
1 Parent(s): 1557730

updated model/_init_.py

Browse files
Files changed (1) hide show
  1. dalle/models/__init__.py +11 -7
dalle/models/__init__.py CHANGED
@@ -43,20 +43,24 @@ class Dalle(nn.Module):
43
  @classmethod
44
  def from_pretrained(cls,
45
  path: str) -> nn.Module:
46
- path = _MODELS[path] if path in _MODELS else path
47
- path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
 
48
 
49
  config_base = get_base_config()
50
- config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
51
  config_update = OmegaConf.merge(config_base, config_new)
52
 
53
  model = cls(config_update)
54
- model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
55
  context_length=model.config_dataset.context_length,
56
  lowercase=True,
57
  dropout=None)
58
- model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
59
- model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
 
 
 
60
  return model
61
 
62
  @torch.no_grad()
@@ -199,4 +203,4 @@ class ImageGPT(pl.LightningModule):
199
  self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
200
 
201
  def on_epoch_start(self):
202
- self.stage1.eval()
 
43
  @classmethod
44
  def from_pretrained(cls,
45
  path: str) -> nn.Module:
46
+ #path = _MODELS[path] if path in _MODELS else path
47
+ #path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
48
+ path = ''
49
 
50
  config_base = get_base_config()
51
+ config_new = OmegaConf.load(os.path.join(path, '.cache/minDALL-E/1.3B/config.yaml'))
52
  config_update = OmegaConf.merge(config_base, config_new)
53
 
54
  model = cls(config_update)
55
+ model.tokenizer = build_tokenizer('.cache/minDALL-E/1.3B/tokenizer',
56
  context_length=model.config_dataset.context_length,
57
  lowercase=True,
58
  dropout=None)
59
+ model.stage1.from_ckpt('.cache/minDALL-E/1.3B/stage1_last.ckpt')
60
+ model.stage2.from_ckpt('.cache/minDALL-E/1.3B/stage2_last.ckpt')
61
+ #model.stage1.from_ckpt('https://utexas.box.com/shared/static/rpt9miyj2kikogyekpqnkd6y115xp51i.ckpt')
62
+ #model.stage2.from_ckpt('https://utexas.box.com/shared/static/54jc9fw0bious5nx6wvayeqaskcrdgv4.ckpt')
63
+
64
  return model
65
 
66
  @torch.no_grad()
 
203
  self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
204
 
205
  def on_epoch_start(self):
206
+ self.stage1.eval()