Update InpaintReward.py
Browse files- InpaintReward.py +4 -4
InpaintReward.py
CHANGED
@@ -54,7 +54,7 @@ class ViTBlock(nn.Module):
|
|
54 |
return x
|
55 |
|
56 |
|
57 |
-
class
|
58 |
def __init__(self, config, device='cpu'):
|
59 |
super().__init__()
|
60 |
self.config = config
|
@@ -62,7 +62,7 @@ class ImageReward(nn.Module):
|
|
62 |
|
63 |
self.clip_model, self.preprocess = clip.load("ViT-B/32")
|
64 |
self.clip_model = self.clip_model.float()
|
65 |
-
self.mlp = MLP(self.config['
|
66 |
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
|
67 |
|
68 |
self.toImage = transforms.ToPILImage()
|
@@ -127,7 +127,7 @@ class ImageReward(nn.Module):
|
|
127 |
|
128 |
|
129 |
|
130 |
-
class
|
131 |
def __init__(self, config, device='cpu'):
|
132 |
super().__init__()
|
133 |
self.config = config
|
@@ -136,7 +136,7 @@ class ImageRewardGroup(nn.Module):
|
|
136 |
self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
|
137 |
|
138 |
self.clip_model = self.clip_model.float()
|
139 |
-
self.mlp = MLP(config['
|
140 |
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
|
141 |
|
142 |
if self.config.fix_base:
|
|
|
54 |
return x
|
55 |
|
56 |
|
57 |
+
class InpaintReward(nn.Module):
|
58 |
def __init__(self, config, device='cpu'):
|
59 |
super().__init__()
|
60 |
self.config = config
|
|
|
62 |
|
63 |
self.clip_model, self.preprocess = clip.load("ViT-B/32")
|
64 |
self.clip_model = self.clip_model.float()
|
65 |
+
self.mlp = MLP(self.config['Reward']['mlp_dim'])
|
66 |
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
|
67 |
|
68 |
self.toImage = transforms.ToPILImage()
|
|
|
127 |
|
128 |
|
129 |
|
130 |
+
class InpaintRewardGroup(nn.Module):
|
131 |
def __init__(self, config, device='cpu'):
|
132 |
super().__init__()
|
133 |
self.config = config
|
|
|
136 |
self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
|
137 |
|
138 |
self.clip_model = self.clip_model.float()
|
139 |
+
self.mlp = MLP(config['Reward']['mlp_dim'])
|
140 |
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
|
141 |
|
142 |
if self.config.fix_base:
|