|
from typing import Union, Dict, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ding.utils import squeeze, MODEL_REGISTRY, SequenceType |
|
from ding.torch_utils import MLP |
|
from ding.model.common import RegressionHead |
|
|
|
|
|
class ATOCAttentionUnit(nn.Module): |
|
""" |
|
Overview: |
|
The attention unit of the ATOC network. We now implement it as two-layer MLP, same as the original paper. |
|
Interface: |
|
``__init__``, ``forward`` |
|
|
|
.. note:: |
|
"ATOC paper: We use two-layer MLP to implement the attention unit but it is also can be realized by RNN." |
|
""" |
|
|
|
def __init__(self, thought_size: int, embedding_size: int) -> None: |
|
""" |
|
Overview: |
|
Initialize the attention unit according to the size of input arguments. |
|
Arguments: |
|
- thought_size (:obj:`int`): the size of input thought |
|
- embedding_size (:obj:`int`): the size of hidden layers |
|
""" |
|
super(ATOCAttentionUnit, self).__init__() |
|
self._thought_size = thought_size |
|
self._hidden_size = embedding_size |
|
self._output_size = 1 |
|
self._act1 = nn.ReLU() |
|
self._fc1 = nn.Linear(self._thought_size, self._hidden_size, bias=True) |
|
self._fc2 = nn.Linear(self._hidden_size, self._hidden_size, bias=True) |
|
self._fc3 = nn.Linear(self._hidden_size, self._output_size, bias=True) |
|
self._act2 = nn.Sigmoid() |
|
|
|
def forward(self, data: Union[Dict, torch.Tensor]) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Take the thought of agents as input and generate the probability of these agent being initiator |
|
Arguments: |
|
- x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor |
|
- ret (:obj:`torch.Tensor`): the output initiator probability |
|
Shapes: |
|
- data['thought']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\ |
|
B is batch_size and N is thought size |
|
Examples: |
|
>>> attention_unit = ATOCAttentionUnit(64, 64) |
|
>>> thought = torch.randn(2, 3, 64) |
|
>>> attention_unit(thought) |
|
""" |
|
x = data |
|
if isinstance(data, Dict): |
|
x = data['thought'] |
|
x = self._fc1(x) |
|
x = self._act1(x) |
|
x = self._fc2(x) |
|
x = self._act1(x) |
|
x = self._fc3(x) |
|
x = self._act2(x) |
|
return x.squeeze(-1) |
|
|
|
|
|
class ATOCCommunicationNet(nn.Module): |
|
""" |
|
Overview: |
|
This ATOC commnication net is a bi-direction LSTM, so it can integrate all the thoughts in the group. |
|
Interface: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, thought_size: int) -> None: |
|
""" |
|
Overview: |
|
Initialize the communication network according to the size of input arguments. |
|
Arguments: |
|
- thought_size (:obj:`int`): the size of input thought |
|
|
|
.. note:: |
|
|
|
communication hidden size should be half of the actor_hidden_size because of the bi-direction lstm |
|
""" |
|
super(ATOCCommunicationNet, self).__init__() |
|
assert thought_size % 2 == 0 |
|
self._thought_size = thought_size |
|
self._comm_hidden_size = thought_size // 2 |
|
self._bi_lstm = nn.LSTM(self._thought_size, self._comm_hidden_size, bidirectional=True) |
|
|
|
def forward(self, data: Union[Dict, torch.Tensor]): |
|
""" |
|
Overview: |
|
The forward of ATOCCommunicationNet integrates thoughts in the group. |
|
Arguments: |
|
- x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor |
|
- out (:obj:`torch.Tensor`): the integrated thoughts |
|
Shapes: |
|
- data['thoughts']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\ |
|
B is batch_size and N is thought size |
|
Examples: |
|
>>> comm_net = ATOCCommunicationNet(64) |
|
>>> thoughts = torch.randn(2, 3, 64) |
|
>>> comm_net(thoughts) |
|
""" |
|
self._bi_lstm.flatten_parameters() |
|
x = data |
|
if isinstance(data, Dict): |
|
x = data['thoughts'] |
|
out, _ = self._bi_lstm(x) |
|
return out |
|
|
|
|
|
class ATOCActorNet(nn.Module): |
|
""" |
|
Overview: |
|
The actor network of ATOC. |
|
Interface: |
|
``__init__``, ``forward`` |
|
|
|
.. note:: |
|
"ATOC paper: The neural networks use ReLU and batch normalization for some hidden layers." |
|
""" |
|
|
|
def __init__( |
|
self, |
|
obs_shape: Union[Tuple, int], |
|
thought_size: int, |
|
action_shape: int, |
|
n_agent: int, |
|
communication: bool = True, |
|
agent_per_group: int = 2, |
|
initiator_threshold: float = 0.5, |
|
attention_embedding_size: int = 64, |
|
actor_1_embedding_size: Union[int, None] = None, |
|
actor_2_embedding_size: Union[int, None] = None, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
): |
|
""" |
|
Overview: |
|
Initialize the actor network of ATOC |
|
Arguments: |
|
- obs_shape(:obj:`Union[Tuple, int]`): the observation size |
|
- thought_size (:obj:`int`): the size of thoughts |
|
- action_shape (:obj:`int`): the action size |
|
- n_agent (:obj:`int`): the num of agents |
|
- agent_per_group (:obj:`int`): the num of agent in each group |
|
- initiator_threshold (:obj:`float`): the threshold of becoming an initiator, default set to 0.5 |
|
- attention_embedding_size (obj:`int`): the embedding size of attention unit, default set to 64 |
|
- actor_1_embedding_size (:obj:`Union[int, None]`): the size of embedding size of actor network part1, \ |
|
if None, then default set to thought size |
|
- actor_2_embedding_size (:obj:`Union[int, None]`): the size of embedding size of actor network part2, \ |
|
if None, then default set to thought size |
|
""" |
|
super(ATOCActorNet, self).__init__() |
|
|
|
self._obs_shape = squeeze(obs_shape) |
|
self._thought_size = thought_size |
|
self._act_shape = action_shape |
|
self._n_agent = n_agent |
|
self._communication = communication |
|
self._agent_per_group = agent_per_group |
|
self._initiator_threshold = initiator_threshold |
|
if not actor_1_embedding_size: |
|
actor_1_embedding_size = self._thought_size |
|
if not actor_2_embedding_size: |
|
actor_2_embedding_size = self._thought_size |
|
|
|
|
|
self.actor_1 = MLP( |
|
self._obs_shape, |
|
actor_1_embedding_size, |
|
self._thought_size, |
|
layer_num=2, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
|
|
|
|
self.actor_2 = nn.Sequential( |
|
nn.Linear(self._thought_size * 2, actor_2_embedding_size), activation, |
|
RegressionHead( |
|
actor_2_embedding_size, self._act_shape, 2, final_tanh=True, activation=activation, norm_type=norm_type |
|
) |
|
) |
|
|
|
|
|
if self._communication: |
|
self.attention = ATOCAttentionUnit(self._thought_size, attention_embedding_size) |
|
self.comm_net = ATOCCommunicationNet(self._thought_size) |
|
|
|
def forward(self, obs: torch.Tensor) -> Dict: |
|
""" |
|
Overview: |
|
Take the input obs, and calculate the corresponding action, group, initiator_prob, thoughts, etc... |
|
Arguments: |
|
- obs (:obj:`Dict`): the input obs containing the observation |
|
Returns: |
|
- ret (:obj:`Dict`): the returned output, including action, group, initiator_prob, is_initiator, \ |
|
new_thoughts and old_thoughts |
|
ReturnsKeys: |
|
- necessary: ``action`` |
|
- optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts`` |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size |
|
- action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size |
|
- group (:obj:`torch.Tensor`): :math:`(B, A, A)` |
|
- initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` |
|
- old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` |
|
Examples: |
|
>>> actor_net = ATOCActorNet(64, 64, 64, 3) |
|
>>> obs = torch.randn(2, 3, 64) |
|
>>> actor_net(obs) |
|
""" |
|
assert len(obs.shape) == 3 |
|
self._cur_batch_size = obs.shape[0] |
|
B, A, N = obs.shape |
|
assert A == self._n_agent |
|
assert N == self._obs_shape |
|
|
|
current_thoughts = self.actor_1(obs) |
|
|
|
if self._communication: |
|
old_thoughts = current_thoughts.clone().detach() |
|
init_prob, is_initiator, group = self._get_initiate_group(old_thoughts) |
|
|
|
new_thoughts = self._get_new_thoughts(current_thoughts, group, is_initiator) |
|
else: |
|
new_thoughts = current_thoughts |
|
action = self.actor_2(torch.cat([current_thoughts, new_thoughts], dim=-1))['pred'] |
|
|
|
if self._communication: |
|
return { |
|
'action': action, |
|
'group': group, |
|
'initiator_prob': init_prob, |
|
'is_initiator': is_initiator, |
|
'new_thoughts': new_thoughts, |
|
'old_thoughts': old_thoughts, |
|
} |
|
else: |
|
return {'action': action} |
|
|
|
def _get_initiate_group(self, current_thoughts): |
|
""" |
|
Overview: |
|
Calculate the initiator probability, group and is_initiator |
|
Arguments: |
|
- current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts |
|
Returns: |
|
- init_prob (:obj:`torch.Tensor`): tesnor of initiator probability |
|
- is_initiator (:obj:`torch.Tensor`): tensor of is initiator |
|
- group (:obj:`torch.Tensor`): tensor of group |
|
Shapes: |
|
- current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size |
|
- init_prob (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- group (:obj:`torch.Tensor`): :math:`(B, A, A)` |
|
Examples: |
|
>>> actor_net = ATOCActorNet(64, 64, 64, 3) |
|
>>> current_thoughts = torch.randn(2, 3, 64) |
|
>>> actor_net._get_initiate_group(current_thoughts) |
|
""" |
|
if not self._communication: |
|
raise NotImplementedError |
|
init_prob = self.attention(current_thoughts) |
|
is_initiator = (init_prob > self._initiator_threshold) |
|
B, A = init_prob.shape[:2] |
|
|
|
thoughts_pair_dot = current_thoughts.bmm(current_thoughts.transpose(1, 2)) |
|
thoughts_square = thoughts_pair_dot.diagonal(0, 1, 2) |
|
curr_thought_dists = thoughts_square.unsqueeze(1) - 2 * thoughts_pair_dot + thoughts_square.unsqueeze(2) |
|
|
|
group = torch.zeros(B, A, A).to(init_prob.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for b in range(B): |
|
for i in range(A): |
|
if is_initiator[b][i]: |
|
index_seq = curr_thought_dists[b][i].argsort() |
|
index_seq = index_seq[:self._agent_per_group] |
|
group[b][i][index_seq] = 1 |
|
return init_prob, is_initiator, group |
|
|
|
def _get_new_thoughts(self, current_thoughts, group, is_initiator): |
|
""" |
|
Overview: |
|
Calculate the new thoughts according to current thoughts, group and is_initiator |
|
Arguments: |
|
- current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts |
|
- group (:obj:`torch.Tensor`): tensor of group |
|
- is_initiator (:obj:`torch.Tensor`): tensor of is initiator |
|
Returns: |
|
- new_thoughts (:obj:`torch.Tensor`): tensor of new thoughts |
|
Shapes: |
|
- current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size |
|
- group: (:obj:`torch.Tensor`): :math:`(B, A, A)` |
|
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` |
|
Examples: |
|
>>> actor_net = ATOCActorNet(64, 64, 64, 3) |
|
>>> current_thoughts = torch.randn(2, 3, 64) |
|
>>> group = torch.randn(2, 3, 3) |
|
>>> is_initiator = torch.randn(2, 3) |
|
>>> actor_net._get_new_thoughts(current_thoughts, group, is_initiator) |
|
""" |
|
if not self._communication: |
|
raise NotImplementedError |
|
B, A = current_thoughts.shape[:2] |
|
new_thoughts = current_thoughts.detach().clone() |
|
if len(torch.nonzero(is_initiator)) == 0: |
|
return new_thoughts |
|
|
|
|
|
thoughts_to_commute = [] |
|
for b in range(B): |
|
for i in range(A): |
|
if is_initiator[b][i]: |
|
tmp = [] |
|
for j in range(A): |
|
if group[b][i][j]: |
|
tmp.append(new_thoughts[b][j]) |
|
thoughts_to_commute.append(torch.stack(tmp, dim=0)) |
|
thoughts_to_commute = torch.stack(thoughts_to_commute, dim=1) |
|
integrated_thoughts = self.comm_net(thoughts_to_commute) |
|
b_count = 0 |
|
for b in range(B): |
|
for i in range(A): |
|
if is_initiator[b][i]: |
|
j_count = 0 |
|
for j in range(A): |
|
if group[b][i][j]: |
|
new_thoughts[b][j] = integrated_thoughts[j_count][b_count] |
|
j_count += 1 |
|
b_count += 1 |
|
return new_thoughts |
|
|
|
|
|
@MODEL_REGISTRY.register('atoc') |
|
class ATOC(nn.Module): |
|
""" |
|
Overview: |
|
The QAC network of ATOC, a kind of extension of DDPG for MARL. |
|
Learning Attentional Communication for Multi-Agent Cooperation |
|
https://arxiv.org/abs/1805.07733 |
|
Interface: |
|
``__init__``, ``forward``, ``compute_critic``, ``compute_actor``, ``optimize_actor_attention`` |
|
""" |
|
mode = ['compute_actor', 'compute_critic', 'optimize_actor_attention'] |
|
|
|
def __init__( |
|
self, |
|
obs_shape: Union[int, SequenceType], |
|
action_shape: Union[int, SequenceType], |
|
thought_size: int, |
|
n_agent: int, |
|
communication: bool = True, |
|
agent_per_group: int = 2, |
|
actor_1_embedding_size: Union[int, None] = None, |
|
actor_2_embedding_size: Union[int, None] = None, |
|
critic_head_hidden_size: int = 64, |
|
critic_head_layer_num: int = 2, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
) -> None: |
|
""" |
|
Overview: |
|
Initialize the ATOC QAC network |
|
Arguments: |
|
- obs_shape(:obj:`Union[Tuple, int]`): the observation space shape |
|
- thought_size (:obj:`int`): the size of thoughts |
|
- action_shape (:obj:`int`): the action space shape |
|
- n_agent (:obj:`int`): the num of agents |
|
- agent_per_group (:obj:`int`): the num of agent in each group |
|
""" |
|
super(ATOC, self).__init__() |
|
self._communication = communication |
|
|
|
self.actor = ATOCActorNet( |
|
obs_shape, |
|
thought_size, |
|
action_shape, |
|
n_agent, |
|
communication, |
|
agent_per_group, |
|
actor_1_embedding_size=actor_1_embedding_size, |
|
actor_2_embedding_size=actor_2_embedding_size |
|
) |
|
self.critic = nn.Sequential( |
|
nn.Linear(obs_shape + action_shape, critic_head_hidden_size), activation, |
|
RegressionHead( |
|
critic_head_hidden_size, |
|
1, |
|
critic_head_layer_num, |
|
final_tanh=False, |
|
activation=activation, |
|
norm_type=norm_type, |
|
) |
|
) |
|
|
|
def _compute_delta_q(self, obs: torch.Tensor, actor_outputs: Dict) -> torch.Tensor: |
|
""" |
|
Overview: |
|
calculate the delta_q according to obs and actor_outputs |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): the observations |
|
- actor_outputs (:obj:`dict`): the output of actors |
|
- delta_q (:obj:`Dict`): the calculated delta_q |
|
Returns: |
|
- delta_q (:obj:`Dict`): the calculated delta_q |
|
ArgumentsKeys: |
|
- necessary: ``new_thoughts``, ``old_thoughts``, ``group``, ``is_initiator`` |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size |
|
- actor_outputs (:obj:`Dict`): the output of actor network, including ``action``, ``new_thoughts``, \ |
|
``old_thoughts``, ``group``, ``initiator_prob``, ``is_initiator`` |
|
- action (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is action size |
|
- new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size |
|
- old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size |
|
- group (:obj:`torch.Tensor`): :math:`(B, A, A)` |
|
- initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- delta_q (:obj:`torch.Tensor`): :math:`(B, A)` |
|
Examples: |
|
>>> net = ATOC(64, 64, 64, 3) |
|
>>> obs = torch.randn(2, 3, 64) |
|
>>> actor_outputs = net.compute_actor(obs) |
|
>>> net._compute_delta_q(obs, actor_outputs) |
|
""" |
|
if not self._communication: |
|
raise NotImplementedError |
|
assert len(obs.shape) == 3 |
|
new_thoughts, old_thoughts, group, is_initiator = actor_outputs['new_thoughts'], actor_outputs[ |
|
'old_thoughts'], actor_outputs['group'], actor_outputs['is_initiator'] |
|
B, A = new_thoughts.shape[:2] |
|
curr_delta_q = torch.zeros(B, A).to(new_thoughts.device) |
|
with torch.no_grad(): |
|
for b in range(B): |
|
for i in range(A): |
|
if not is_initiator[b][i]: |
|
continue |
|
q_group = [] |
|
actual_q_group = [] |
|
for j in range(A): |
|
if not group[b][i][j]: |
|
continue |
|
before_update_action_j = self.actor.actor_2( |
|
torch.cat([old_thoughts[b][j], old_thoughts[b][j]], dim=-1) |
|
) |
|
after_update_action_j = self.actor.actor_2( |
|
torch.cat([old_thoughts[b][j], new_thoughts[b][j]], dim=-1) |
|
) |
|
before_update_input = torch.cat([obs[b][j], before_update_action_j['pred']], dim=-1) |
|
before_update_Q_j = self.critic(before_update_input)['pred'] |
|
after_update_input = torch.cat([obs[b][j], after_update_action_j['pred']], dim=-1) |
|
after_update_Q_j = self.critic(after_update_input)['pred'] |
|
q_group.append(before_update_Q_j) |
|
actual_q_group.append(after_update_Q_j) |
|
q_group = torch.stack(q_group) |
|
actual_q_group = torch.stack(actual_q_group) |
|
curr_delta_q[b][i] = actual_q_group.mean() - q_group.mean() |
|
return curr_delta_q |
|
|
|
def compute_actor(self, obs: torch.Tensor, get_delta_q: bool = False) -> Dict[str, torch.Tensor]: |
|
''' |
|
Overview: |
|
compute the action according to inputs, call the _compute_delta_q function to compute delta_q |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): observation |
|
- get_delta_q (:obj:`bool`) : whether need to get delta_q |
|
Returns: |
|
- outputs (:obj:`Dict`): the output of actor network and delta_q |
|
ReturnsKeys: |
|
- necessary: ``action`` |
|
- optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts``, ``delta_q`` |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size |
|
- action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size |
|
- group (:obj:`torch.Tensor`): :math:`(B, A, A)` |
|
- initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` |
|
- old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` |
|
- delta_q (:obj:`torch.Tensor`): :math:`(B, A)` |
|
Examples: |
|
>>> net = ATOC(64, 64, 64, 3) |
|
>>> obs = torch.randn(2, 3, 64) |
|
>>> net.compute_actor(obs) |
|
''' |
|
outputs = self.actor(obs) |
|
if get_delta_q and self._communication: |
|
delta_q = self._compute_delta_q(obs, outputs) |
|
outputs['delta_q'] = delta_q |
|
return outputs |
|
|
|
def compute_critic(self, inputs: Dict) -> Dict: |
|
""" |
|
Overview: |
|
compute the q_value according to inputs |
|
Arguments: |
|
- inputs (:obj:`Dict`): the inputs contain the obs and action |
|
Returns: |
|
- outputs (:obj:`Dict`): the output of critic network |
|
ArgumentsKeys: |
|
- necessary: ``obs``, ``action`` |
|
ReturnsKeys: |
|
- necessary: ``q_value`` |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size |
|
- action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size |
|
- q_value (:obj:`torch.Tensor`): :math:`(B, A)` |
|
Examples: |
|
>>> net = ATOC(64, 64, 64, 3) |
|
>>> obs = torch.randn(2, 3, 64) |
|
>>> action = torch.randn(2, 3, 64) |
|
>>> net.compute_critic({'obs': obs, 'action': action}) |
|
""" |
|
obs, action = inputs['obs'], inputs['action'] |
|
if len(action.shape) == 2: |
|
action = action.unsqueeze(2) |
|
x = torch.cat([obs, action], dim=-1) |
|
x = self.critic(x)['pred'] |
|
return {'q_value': x} |
|
|
|
def optimize_actor_attention(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
""" |
|
Overview: |
|
return the actor attention loss |
|
Arguments: |
|
- inputs (:obj:`Dict`): the inputs contain the delta_q, initiator_prob, and is_initiator |
|
Returns |
|
- loss (:obj:`Dict`): the loss of actor attention unit |
|
ArgumentsKeys: |
|
- necessary: ``delta_q``, ``initiator_prob``, ``is_initiator`` |
|
ReturnsKeys: |
|
- necessary: ``loss`` |
|
Shapes: |
|
- delta_q (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` |
|
- loss (:obj:`torch.Tensor`): :math:`(1)` |
|
Examples: |
|
>>> net = ATOC(64, 64, 64, 3) |
|
>>> delta_q = torch.randn(2, 3) |
|
>>> initiator_prob = torch.randn(2, 3) |
|
>>> is_initiator = torch.randn(2, 3) |
|
>>> net.optimize_actor_attention( |
|
>>> {'delta_q': delta_q, |
|
>>> 'initiator_prob': initiator_prob, |
|
>>> 'is_initiator': is_initiator}) |
|
""" |
|
if not self._communication: |
|
raise NotImplementedError |
|
delta_q = inputs['delta_q'].reshape(-1) |
|
init_prob = inputs['initiator_prob'].reshape(-1) |
|
is_init = inputs['is_initiator'].reshape(-1) |
|
delta_q = delta_q[is_init.nonzero()] |
|
init_prob = init_prob[is_init.nonzero()] |
|
init_prob = 0.9 * init_prob + 0.05 |
|
|
|
|
|
if init_prob.shape == (0, 1): |
|
actor_attention_loss = torch.FloatTensor([-0.0]).to(delta_q.device) |
|
actor_attention_loss.requires_grad = True |
|
else: |
|
actor_attention_loss = -delta_q * \ |
|
torch.log(init_prob) - (1 - delta_q) * torch.log(1 - init_prob) |
|
return {'loss': actor_attention_loss.mean()} |
|
|
|
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str, **kwargs) -> Dict: |
|
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
|
return getattr(self, mode)(inputs, **kwargs) |
|
|