Spaces:
Sleeping
Sleeping
Update lxmert/src/ExplanationGenerator.py
Browse files
lxmert/src/ExplanationGenerator.py
CHANGED
@@ -163,7 +163,7 @@ class GeneratorOurs:
|
|
163 |
one_hot[0, index] = 1
|
164 |
one_hot_vector = one_hot
|
165 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
166 |
-
one_hot = torch.sum(one_hot
|
167 |
|
168 |
model.zero_grad()
|
169 |
one_hot.backward(retain_graph=True)
|
@@ -400,7 +400,7 @@ class GeneratorBaselines:
|
|
400 |
one_hot[0, index] = 1
|
401 |
one_hot_vector = one_hot
|
402 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
403 |
-
one_hot = torch.sum(one_hot
|
404 |
|
405 |
model.zero_grad()
|
406 |
one_hot.backward(retain_graph=True)
|
|
|
163 |
one_hot[0, index] = 1
|
164 |
one_hot_vector = one_hot
|
165 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
166 |
+
one_hot = torch.sum(one_hot * output)
|
167 |
|
168 |
model.zero_grad()
|
169 |
one_hot.backward(retain_graph=True)
|
|
|
400 |
one_hot[0, index] = 1
|
401 |
one_hot_vector = one_hot
|
402 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
403 |
+
one_hot = torch.sum(one_hot * output)
|
404 |
|
405 |
model.zero_grad()
|
406 |
one_hot.backward(retain_graph=True)
|