WwYc commited on
Commit
79059eb
·
verified ·
1 Parent(s): 4c9c9fe

Update lxmert/src/ExplanationGenerator.py

Browse files
Files changed (1) hide show
  1. lxmert/src/ExplanationGenerator.py +2 -2
lxmert/src/ExplanationGenerator.py CHANGED
@@ -163,7 +163,7 @@ class GeneratorOurs:
163
  one_hot[0, index] = 1
164
  one_hot_vector = one_hot
165
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
166
- one_hot = torch.sum(one_hot.cuda() * output)
167
 
168
  model.zero_grad()
169
  one_hot.backward(retain_graph=True)
@@ -400,7 +400,7 @@ class GeneratorBaselines:
400
  one_hot[0, index] = 1
401
  one_hot_vector = one_hot
402
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
403
- one_hot = torch.sum(one_hot.cuda() * output)
404
 
405
  model.zero_grad()
406
  one_hot.backward(retain_graph=True)
 
163
  one_hot[0, index] = 1
164
  one_hot_vector = one_hot
165
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
166
+ one_hot = torch.sum(one_hot * output)
167
 
168
  model.zero_grad()
169
  one_hot.backward(retain_graph=True)
 
400
  one_hot[0, index] = 1
401
  one_hot_vector = one_hot
402
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
403
+ one_hot = torch.sum(one_hot * output)
404
 
405
  model.zero_grad()
406
  one_hot.backward(retain_graph=True)