Spaces:
Runtime error
Runtime error
reset fix
Browse files- models/region_diffusion.py +4 -2
- utils/attention_utils.py +1 -0
models/region_diffusion.py
CHANGED
@@ -285,8 +285,10 @@ class RegionDiffusion(nn.Module):
|
|
285 |
We reset attention maps because we append them while getting hooks
|
286 |
to visualize attention maps for every step.
|
287 |
"""
|
288 |
-
for key in self.
|
289 |
-
self.
|
|
|
|
|
290 |
|
291 |
def register_evaluation_hooks(self):
|
292 |
r"""Function for registering hooks during evaluation.
|
|
|
285 |
We reset attention maps because we append them while getting hooks
|
286 |
to visualize attention maps for every step.
|
287 |
"""
|
288 |
+
for key in self.selfattn_maps:
|
289 |
+
self.selfattn_maps[key] = []
|
290 |
+
for key in self.crossattn_maps:
|
291 |
+
self.crossattn_maps[key] = []
|
292 |
|
293 |
def register_evaluation_hooks(self):
|
294 |
r"""Function for registering hooks during evaluation.
|
utils/attention_utils.py
CHANGED
@@ -123,6 +123,7 @@ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=N
|
|
123 |
dtype='uint8').reshape((height, width, 3))
|
124 |
|
125 |
fig.tight_layout()
|
|
|
126 |
return img
|
127 |
|
128 |
|
|
|
123 |
dtype='uint8').reshape((height, width, 3))
|
124 |
|
125 |
fig.tight_layout()
|
126 |
+
plt.close()
|
127 |
return img
|
128 |
|
129 |
|