kd5678 commited on
Commit
b7581b5
·
verified ·
1 Parent(s): 7939348

Update InpaintReward.py

Browse files
Files changed (1) hide show
  1. InpaintReward.py +4 -4
InpaintReward.py CHANGED
@@ -54,7 +54,7 @@ class ViTBlock(nn.Module):
54
  return x
55
 
56
 
57
- class ImageReward(nn.Module):
58
  def __init__(self, config, device='cpu'):
59
  super().__init__()
60
  self.config = config
@@ -62,7 +62,7 @@ class ImageReward(nn.Module):
62
 
63
  self.clip_model, self.preprocess = clip.load("ViT-B/32")
64
  self.clip_model = self.clip_model.float()
65
- self.mlp = MLP(self.config['ImageReward']['mlp_dim'])
66
  self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
67
 
68
  self.toImage = transforms.ToPILImage()
@@ -127,7 +127,7 @@ class ImageReward(nn.Module):
127
 
128
 
129
 
130
- class ImageRewardGroup(nn.Module):
131
  def __init__(self, config, device='cpu'):
132
  super().__init__()
133
  self.config = config
@@ -136,7 +136,7 @@ class ImageRewardGroup(nn.Module):
136
  self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
137
 
138
  self.clip_model = self.clip_model.float()
139
- self.mlp = MLP(config['ImageReward']['mlp_dim'])
140
  self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
141
 
142
  if self.config.fix_base:
 
54
  return x
55
 
56
 
57
+ class InpaintReward(nn.Module):
58
  def __init__(self, config, device='cpu'):
59
  super().__init__()
60
  self.config = config
 
62
 
63
  self.clip_model, self.preprocess = clip.load("ViT-B/32")
64
  self.clip_model = self.clip_model.float()
65
+ self.mlp = MLP(self.config['Reward']['mlp_dim'])
66
  self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
67
 
68
  self.toImage = transforms.ToPILImage()
 
127
 
128
 
129
 
130
+ class InpaintRewardGroup(nn.Module):
131
  def __init__(self, config, device='cpu'):
132
  super().__init__()
133
  self.config = config
 
136
  self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
137
 
138
  self.clip_model = self.clip_model.float()
139
+ self.mlp = MLP(config['Reward']['mlp_dim'])
140
  self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
141
 
142
  if self.config.fix_base: