File size: 25,187 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 |
from typing import Union, Dict, Optional, Tuple
import torch
import torch.nn as nn
from ding.utils import squeeze, MODEL_REGISTRY, SequenceType
from ding.torch_utils import MLP
from ding.model.common import RegressionHead
class ATOCAttentionUnit(nn.Module):
"""
Overview:
The attention unit of the ATOC network. We now implement it as two-layer MLP, same as the original paper.
Interface:
``__init__``, ``forward``
.. note::
"ATOC paper: We use two-layer MLP to implement the attention unit but it is also can be realized by RNN."
"""
def __init__(self, thought_size: int, embedding_size: int) -> None:
"""
Overview:
Initialize the attention unit according to the size of input arguments.
Arguments:
- thought_size (:obj:`int`): the size of input thought
- embedding_size (:obj:`int`): the size of hidden layers
"""
super(ATOCAttentionUnit, self).__init__()
self._thought_size = thought_size
self._hidden_size = embedding_size
self._output_size = 1
self._act1 = nn.ReLU()
self._fc1 = nn.Linear(self._thought_size, self._hidden_size, bias=True)
self._fc2 = nn.Linear(self._hidden_size, self._hidden_size, bias=True)
self._fc3 = nn.Linear(self._hidden_size, self._output_size, bias=True)
self._act2 = nn.Sigmoid()
def forward(self, data: Union[Dict, torch.Tensor]) -> torch.Tensor:
"""
Overview:
Take the thought of agents as input and generate the probability of these agent being initiator
Arguments:
- x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor
- ret (:obj:`torch.Tensor`): the output initiator probability
Shapes:
- data['thought']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\
B is batch_size and N is thought size
Examples:
>>> attention_unit = ATOCAttentionUnit(64, 64)
>>> thought = torch.randn(2, 3, 64)
>>> attention_unit(thought)
"""
x = data
if isinstance(data, Dict):
x = data['thought']
x = self._fc1(x)
x = self._act1(x)
x = self._fc2(x)
x = self._act1(x)
x = self._fc3(x)
x = self._act2(x)
return x.squeeze(-1)
class ATOCCommunicationNet(nn.Module):
"""
Overview:
This ATOC commnication net is a bi-direction LSTM, so it can integrate all the thoughts in the group.
Interface:
``__init__``, ``forward``
"""
def __init__(self, thought_size: int) -> None:
"""
Overview:
Initialize the communication network according to the size of input arguments.
Arguments:
- thought_size (:obj:`int`): the size of input thought
.. note::
communication hidden size should be half of the actor_hidden_size because of the bi-direction lstm
"""
super(ATOCCommunicationNet, self).__init__()
assert thought_size % 2 == 0
self._thought_size = thought_size
self._comm_hidden_size = thought_size // 2
self._bi_lstm = nn.LSTM(self._thought_size, self._comm_hidden_size, bidirectional=True)
def forward(self, data: Union[Dict, torch.Tensor]):
"""
Overview:
The forward of ATOCCommunicationNet integrates thoughts in the group.
Arguments:
- x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor
- out (:obj:`torch.Tensor`): the integrated thoughts
Shapes:
- data['thoughts']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\
B is batch_size and N is thought size
Examples:
>>> comm_net = ATOCCommunicationNet(64)
>>> thoughts = torch.randn(2, 3, 64)
>>> comm_net(thoughts)
"""
self._bi_lstm.flatten_parameters()
x = data
if isinstance(data, Dict):
x = data['thoughts']
out, _ = self._bi_lstm(x)
return out
class ATOCActorNet(nn.Module):
"""
Overview:
The actor network of ATOC.
Interface:
``__init__``, ``forward``
.. note::
"ATOC paper: The neural networks use ReLU and batch normalization for some hidden layers."
"""
def __init__(
self,
obs_shape: Union[Tuple, int],
thought_size: int,
action_shape: int,
n_agent: int,
communication: bool = True,
agent_per_group: int = 2,
initiator_threshold: float = 0.5,
attention_embedding_size: int = 64,
actor_1_embedding_size: Union[int, None] = None,
actor_2_embedding_size: Union[int, None] = None,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
):
"""
Overview:
Initialize the actor network of ATOC
Arguments:
- obs_shape(:obj:`Union[Tuple, int]`): the observation size
- thought_size (:obj:`int`): the size of thoughts
- action_shape (:obj:`int`): the action size
- n_agent (:obj:`int`): the num of agents
- agent_per_group (:obj:`int`): the num of agent in each group
- initiator_threshold (:obj:`float`): the threshold of becoming an initiator, default set to 0.5
- attention_embedding_size (obj:`int`): the embedding size of attention unit, default set to 64
- actor_1_embedding_size (:obj:`Union[int, None]`): the size of embedding size of actor network part1, \
if None, then default set to thought size
- actor_2_embedding_size (:obj:`Union[int, None]`): the size of embedding size of actor network part2, \
if None, then default set to thought size
"""
super(ATOCActorNet, self).__init__()
# now only support obs_shape of shape (O_dim, )
self._obs_shape = squeeze(obs_shape)
self._thought_size = thought_size
self._act_shape = action_shape
self._n_agent = n_agent
self._communication = communication
self._agent_per_group = agent_per_group
self._initiator_threshold = initiator_threshold
if not actor_1_embedding_size:
actor_1_embedding_size = self._thought_size
if not actor_2_embedding_size:
actor_2_embedding_size = self._thought_size
# Actor Net(I)
self.actor_1 = MLP(
self._obs_shape,
actor_1_embedding_size,
self._thought_size,
layer_num=2,
activation=activation,
norm_type=norm_type
)
# Actor Net(II)
self.actor_2 = nn.Sequential(
nn.Linear(self._thought_size * 2, actor_2_embedding_size), activation,
RegressionHead(
actor_2_embedding_size, self._act_shape, 2, final_tanh=True, activation=activation, norm_type=norm_type
)
)
# Communication
if self._communication:
self.attention = ATOCAttentionUnit(self._thought_size, attention_embedding_size)
self.comm_net = ATOCCommunicationNet(self._thought_size)
def forward(self, obs: torch.Tensor) -> Dict:
"""
Overview:
Take the input obs, and calculate the corresponding action, group, initiator_prob, thoughts, etc...
Arguments:
- obs (:obj:`Dict`): the input obs containing the observation
Returns:
- ret (:obj:`Dict`): the returned output, including action, group, initiator_prob, is_initiator, \
new_thoughts and old_thoughts
ReturnsKeys:
- necessary: ``action``
- optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts``
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size
- action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size
- group (:obj:`torch.Tensor`): :math:`(B, A, A)`
- initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)`
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
- new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
- old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
Examples:
>>> actor_net = ATOCActorNet(64, 64, 64, 3)
>>> obs = torch.randn(2, 3, 64)
>>> actor_net(obs)
"""
assert len(obs.shape) == 3
self._cur_batch_size = obs.shape[0]
B, A, N = obs.shape
assert A == self._n_agent
assert N == self._obs_shape
current_thoughts = self.actor_1(obs) # B, A, thought size
if self._communication:
old_thoughts = current_thoughts.clone().detach()
init_prob, is_initiator, group = self._get_initiate_group(old_thoughts)
new_thoughts = self._get_new_thoughts(current_thoughts, group, is_initiator)
else:
new_thoughts = current_thoughts
action = self.actor_2(torch.cat([current_thoughts, new_thoughts], dim=-1))['pred']
if self._communication:
return {
'action': action,
'group': group,
'initiator_prob': init_prob,
'is_initiator': is_initiator,
'new_thoughts': new_thoughts,
'old_thoughts': old_thoughts,
}
else:
return {'action': action}
def _get_initiate_group(self, current_thoughts):
"""
Overview:
Calculate the initiator probability, group and is_initiator
Arguments:
- current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts
Returns:
- init_prob (:obj:`torch.Tensor`): tesnor of initiator probability
- is_initiator (:obj:`torch.Tensor`): tensor of is initiator
- group (:obj:`torch.Tensor`): tensor of group
Shapes:
- current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size
- init_prob (:obj:`torch.Tensor`): :math:`(B, A)`
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
- group (:obj:`torch.Tensor`): :math:`(B, A, A)`
Examples:
>>> actor_net = ATOCActorNet(64, 64, 64, 3)
>>> current_thoughts = torch.randn(2, 3, 64)
>>> actor_net._get_initiate_group(current_thoughts)
"""
if not self._communication:
raise NotImplementedError
init_prob = self.attention(current_thoughts) # B, A
is_initiator = (init_prob > self._initiator_threshold)
B, A = init_prob.shape[:2]
thoughts_pair_dot = current_thoughts.bmm(current_thoughts.transpose(1, 2))
thoughts_square = thoughts_pair_dot.diagonal(0, 1, 2)
curr_thought_dists = thoughts_square.unsqueeze(1) - 2 * thoughts_pair_dot + thoughts_square.unsqueeze(2)
group = torch.zeros(B, A, A).to(init_prob.device)
# "considers the agents in its observable field"
# "initiator first chooses collaborators from agents who have not been selected,
# then from agents selected by other initiators,
# finally from other initiators"
# "all based on proximity"
# roughly choose m closest as group
for b in range(B):
for i in range(A):
if is_initiator[b][i]:
index_seq = curr_thought_dists[b][i].argsort()
index_seq = index_seq[:self._agent_per_group]
group[b][i][index_seq] = 1
return init_prob, is_initiator, group
def _get_new_thoughts(self, current_thoughts, group, is_initiator):
"""
Overview:
Calculate the new thoughts according to current thoughts, group and is_initiator
Arguments:
- current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts
- group (:obj:`torch.Tensor`): tensor of group
- is_initiator (:obj:`torch.Tensor`): tensor of is initiator
Returns:
- new_thoughts (:obj:`torch.Tensor`): tensor of new thoughts
Shapes:
- current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size
- group: (:obj:`torch.Tensor`): :math:`(B, A, A)`
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
- new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
Examples:
>>> actor_net = ATOCActorNet(64, 64, 64, 3)
>>> current_thoughts = torch.randn(2, 3, 64)
>>> group = torch.randn(2, 3, 3)
>>> is_initiator = torch.randn(2, 3)
>>> actor_net._get_new_thoughts(current_thoughts, group, is_initiator)
"""
if not self._communication:
raise NotImplementedError
B, A = current_thoughts.shape[:2]
new_thoughts = current_thoughts.detach().clone()
if len(torch.nonzero(is_initiator)) == 0:
return new_thoughts
# TODO(nyz) execute communication serially for shared agent in different group
thoughts_to_commute = []
for b in range(B):
for i in range(A):
if is_initiator[b][i]:
tmp = []
for j in range(A):
if group[b][i][j]:
tmp.append(new_thoughts[b][j])
thoughts_to_commute.append(torch.stack(tmp, dim=0))
thoughts_to_commute = torch.stack(thoughts_to_commute, dim=1) # agent_per_group, B_, N
integrated_thoughts = self.comm_net(thoughts_to_commute)
b_count = 0
for b in range(B):
for i in range(A):
if is_initiator[b][i]:
j_count = 0
for j in range(A):
if group[b][i][j]:
new_thoughts[b][j] = integrated_thoughts[j_count][b_count]
j_count += 1
b_count += 1
return new_thoughts
@MODEL_REGISTRY.register('atoc')
class ATOC(nn.Module):
"""
Overview:
The QAC network of ATOC, a kind of extension of DDPG for MARL.
Learning Attentional Communication for Multi-Agent Cooperation
https://arxiv.org/abs/1805.07733
Interface:
``__init__``, ``forward``, ``compute_critic``, ``compute_actor``, ``optimize_actor_attention``
"""
mode = ['compute_actor', 'compute_critic', 'optimize_actor_attention']
def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
thought_size: int,
n_agent: int,
communication: bool = True,
agent_per_group: int = 2,
actor_1_embedding_size: Union[int, None] = None,
actor_2_embedding_size: Union[int, None] = None,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 2,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
) -> None:
"""
Overview:
Initialize the ATOC QAC network
Arguments:
- obs_shape(:obj:`Union[Tuple, int]`): the observation space shape
- thought_size (:obj:`int`): the size of thoughts
- action_shape (:obj:`int`): the action space shape
- n_agent (:obj:`int`): the num of agents
- agent_per_group (:obj:`int`): the num of agent in each group
"""
super(ATOC, self).__init__()
self._communication = communication
self.actor = ATOCActorNet(
obs_shape,
thought_size,
action_shape,
n_agent,
communication,
agent_per_group,
actor_1_embedding_size=actor_1_embedding_size,
actor_2_embedding_size=actor_2_embedding_size
)
self.critic = nn.Sequential(
nn.Linear(obs_shape + action_shape, critic_head_hidden_size), activation,
RegressionHead(
critic_head_hidden_size,
1,
critic_head_layer_num,
final_tanh=False,
activation=activation,
norm_type=norm_type,
)
)
def _compute_delta_q(self, obs: torch.Tensor, actor_outputs: Dict) -> torch.Tensor:
"""
Overview:
calculate the delta_q according to obs and actor_outputs
Arguments:
- obs (:obj:`torch.Tensor`): the observations
- actor_outputs (:obj:`dict`): the output of actors
- delta_q (:obj:`Dict`): the calculated delta_q
Returns:
- delta_q (:obj:`Dict`): the calculated delta_q
ArgumentsKeys:
- necessary: ``new_thoughts``, ``old_thoughts``, ``group``, ``is_initiator``
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size
- actor_outputs (:obj:`Dict`): the output of actor network, including ``action``, ``new_thoughts``, \
``old_thoughts``, ``group``, ``initiator_prob``, ``is_initiator``
- action (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is action size
- new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size
- old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size
- group (:obj:`torch.Tensor`): :math:`(B, A, A)`
- initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)`
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
- delta_q (:obj:`torch.Tensor`): :math:`(B, A)`
Examples:
>>> net = ATOC(64, 64, 64, 3)
>>> obs = torch.randn(2, 3, 64)
>>> actor_outputs = net.compute_actor(obs)
>>> net._compute_delta_q(obs, actor_outputs)
"""
if not self._communication:
raise NotImplementedError
assert len(obs.shape) == 3
new_thoughts, old_thoughts, group, is_initiator = actor_outputs['new_thoughts'], actor_outputs[
'old_thoughts'], actor_outputs['group'], actor_outputs['is_initiator']
B, A = new_thoughts.shape[:2]
curr_delta_q = torch.zeros(B, A).to(new_thoughts.device)
with torch.no_grad():
for b in range(B):
for i in range(A):
if not is_initiator[b][i]:
continue
q_group = []
actual_q_group = []
for j in range(A):
if not group[b][i][j]:
continue
before_update_action_j = self.actor.actor_2(
torch.cat([old_thoughts[b][j], old_thoughts[b][j]], dim=-1)
)
after_update_action_j = self.actor.actor_2(
torch.cat([old_thoughts[b][j], new_thoughts[b][j]], dim=-1)
)
before_update_input = torch.cat([obs[b][j], before_update_action_j['pred']], dim=-1)
before_update_Q_j = self.critic(before_update_input)['pred']
after_update_input = torch.cat([obs[b][j], after_update_action_j['pred']], dim=-1)
after_update_Q_j = self.critic(after_update_input)['pred']
q_group.append(before_update_Q_j)
actual_q_group.append(after_update_Q_j)
q_group = torch.stack(q_group)
actual_q_group = torch.stack(actual_q_group)
curr_delta_q[b][i] = actual_q_group.mean() - q_group.mean()
return curr_delta_q
def compute_actor(self, obs: torch.Tensor, get_delta_q: bool = False) -> Dict[str, torch.Tensor]:
'''
Overview:
compute the action according to inputs, call the _compute_delta_q function to compute delta_q
Arguments:
- obs (:obj:`torch.Tensor`): observation
- get_delta_q (:obj:`bool`) : whether need to get delta_q
Returns:
- outputs (:obj:`Dict`): the output of actor network and delta_q
ReturnsKeys:
- necessary: ``action``
- optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts``, ``delta_q``
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size
- action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size
- group (:obj:`torch.Tensor`): :math:`(B, A, A)`
- initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)`
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
- new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
- old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
- delta_q (:obj:`torch.Tensor`): :math:`(B, A)`
Examples:
>>> net = ATOC(64, 64, 64, 3)
>>> obs = torch.randn(2, 3, 64)
>>> net.compute_actor(obs)
'''
outputs = self.actor(obs)
if get_delta_q and self._communication:
delta_q = self._compute_delta_q(obs, outputs)
outputs['delta_q'] = delta_q
return outputs
def compute_critic(self, inputs: Dict) -> Dict:
"""
Overview:
compute the q_value according to inputs
Arguments:
- inputs (:obj:`Dict`): the inputs contain the obs and action
Returns:
- outputs (:obj:`Dict`): the output of critic network
ArgumentsKeys:
- necessary: ``obs``, ``action``
ReturnsKeys:
- necessary: ``q_value``
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size
- action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size
- q_value (:obj:`torch.Tensor`): :math:`(B, A)`
Examples:
>>> net = ATOC(64, 64, 64, 3)
>>> obs = torch.randn(2, 3, 64)
>>> action = torch.randn(2, 3, 64)
>>> net.compute_critic({'obs': obs, 'action': action})
"""
obs, action = inputs['obs'], inputs['action']
if len(action.shape) == 2: # (B, A) -> (B, A, 1)
action = action.unsqueeze(2)
x = torch.cat([obs, action], dim=-1)
x = self.critic(x)['pred']
return {'q_value': x}
def optimize_actor_attention(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Overview:
return the actor attention loss
Arguments:
- inputs (:obj:`Dict`): the inputs contain the delta_q, initiator_prob, and is_initiator
Returns
- loss (:obj:`Dict`): the loss of actor attention unit
ArgumentsKeys:
- necessary: ``delta_q``, ``initiator_prob``, ``is_initiator``
ReturnsKeys:
- necessary: ``loss``
Shapes:
- delta_q (:obj:`torch.Tensor`): :math:`(B, A)`
- initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)`
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
- loss (:obj:`torch.Tensor`): :math:`(1)`
Examples:
>>> net = ATOC(64, 64, 64, 3)
>>> delta_q = torch.randn(2, 3)
>>> initiator_prob = torch.randn(2, 3)
>>> is_initiator = torch.randn(2, 3)
>>> net.optimize_actor_attention(
>>> {'delta_q': delta_q,
>>> 'initiator_prob': initiator_prob,
>>> 'is_initiator': is_initiator})
"""
if not self._communication:
raise NotImplementedError
delta_q = inputs['delta_q'].reshape(-1)
init_prob = inputs['initiator_prob'].reshape(-1)
is_init = inputs['is_initiator'].reshape(-1)
delta_q = delta_q[is_init.nonzero()]
init_prob = init_prob[is_init.nonzero()]
init_prob = 0.9 * init_prob + 0.05
# judge to avoid nan
if init_prob.shape == (0, 1):
actor_attention_loss = torch.FloatTensor([-0.0]).to(delta_q.device)
actor_attention_loss.requires_grad = True
else:
actor_attention_loss = -delta_q * \
torch.log(init_prob) - (1 - delta_q) * torch.log(1 - init_prob)
return {'loss': actor_attention_loss.mean()}
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str, **kwargs) -> Dict:
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs, **kwargs)
|