from typing import Union, Dict, Optional
from easydict import EasyDict
import torch
import torch.nn as nn

from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ..common import RegressionHead, ReparameterizationHead, DiscreteHead, MultiHead, \
    FCEncoder, ConvEncoder


@MODEL_REGISTRY.register('discrete_maqac')
class DiscreteMAQAC(nn.Module):
    """
    Overview:
        The neural network and computation graph of algorithms related to discrete action Multi-Agent Q-value \
        Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \
        critic is a MLP network. The actor network is used to predict the action probability distribution, and the \
        critic network is used to predict the Q value of the state-action pair.
    Interfaces:
        ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
    """
    mode = ['compute_actor', 'compute_critic']

    def __init__(
            self,
            agent_obs_shape: Union[int, SequenceType],
            global_obs_shape: Union[int, SequenceType],
            action_shape: Union[int, SequenceType],
            twin_critic: bool = False,
            actor_head_hidden_size: int = 64,
            actor_head_layer_num: int = 1,
            critic_head_hidden_size: int = 64,
            critic_head_layer_num: int = 1,
            activation: Optional[nn.Module] = nn.ReLU(),
            norm_type: Optional[str] = None,
    ) -> None:
        """
        Overview:
            Initialize the DiscreteMAQAC Model according to arguments.
        Arguments:
            - agent_obs_shape (:obj:`Union[int, SequenceType]`): Agent's observation's space.
            - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space.
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
            - action_shape (:obj:`Union[int, SequenceType]`): Action's space.
            - twin_critic (:obj:`bool`): Whether include twin critic.
            - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
            - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
                for actor's nn.
            - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
            - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
                for critic's nn.
            - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \
                ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``
            - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \
                for more details.
        """
        super(DiscreteMAQAC, self).__init__()
        agent_obs_shape: int = squeeze(agent_obs_shape)
        action_shape: int = squeeze(action_shape)
        self.actor = nn.Sequential(
            nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
            DiscreteHead(
                actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type
            )
        )

        self.twin_critic = twin_critic
        if self.twin_critic:
            self.critic = nn.ModuleList()
            for _ in range(2):
                self.critic.append(
                    nn.Sequential(
                        nn.Linear(global_obs_shape, critic_head_hidden_size), activation,
                        DiscreteHead(
                            critic_head_hidden_size,
                            action_shape,
                            critic_head_layer_num,
                            activation=activation,
                            norm_type=norm_type
                        )
                    )
                )
        else:
            self.critic = nn.Sequential(
                nn.Linear(global_obs_shape, critic_head_hidden_size), activation,
                DiscreteHead(
                    critic_head_hidden_size,
                    action_shape,
                    critic_head_layer_num,
                    activation=activation,
                    norm_type=norm_type
                )
            )

    def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
        """
        Overview:
            Use observation tensor to predict output, with ``compute_actor`` or ``compute_critic`` mode.
        Arguments:
            - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                    - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
                        with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
                        N0 corresponds to ``agent_obs_shape``.
                    - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
                        with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
                        N1 corresponds to ``global_obs_shape``.
                    - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
                        with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
                        N2 corresponds to ``action_shape``.

            - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
        Returns:
            - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \
                whose key-values vary in different forward modes.
        Examples:
            >>> B = 32
            >>> agent_obs_shape = 216
            >>> global_obs_shape = 264
            >>> agent_num = 8
            >>> action_shape = 14
            >>> data = {
            >>>     'obs': {
            >>>         'agent_state': torch.randn(B, agent_num, agent_obs_shape),
            >>>         'global_state': torch.randn(B, agent_num, global_obs_shape),
            >>>         'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
            >>>     }
            >>> }
            >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True)
            >>> logit = model(data, mode='compute_actor')['logit']
            >>> value = model(data, mode='compute_critic')['q_value']
        """
        assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
        return getattr(self, mode)(inputs)

    def compute_actor(self, inputs: Dict) -> Dict:
        """
        Overview:
            Use observation tensor to predict action logits.
        Arguments:
            - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                    - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
                        with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
                        N0 corresponds to ``agent_obs_shape``.
                    - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
                        with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
                        N1 corresponds to ``global_obs_shape``.
                    - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
                        with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
                        N2 corresponds to ``action_shape``.
        Returns:
            - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \
                whose key-values vary in different forward modes.
                - logit (:obj:`torch.Tensor`): Action's output logit (real value range), whose shape is \
                    :math:`(B, A, N2)`, where N2 corresponds to ``action_shape``.
                - action_mask (:obj:`torch.Tensor`): Action mask tensor with same size as ``action_shape``.
        Examples:
            >>> B = 32
            >>> agent_obs_shape = 216
            >>> global_obs_shape = 264
            >>> agent_num = 8
            >>> action_shape = 14
            >>> data = {
            >>>     'obs': {
            >>>         'agent_state': torch.randn(B, agent_num, agent_obs_shape),
            >>>         'global_state': torch.randn(B, agent_num, global_obs_shape),
            >>>         'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
            >>>     }
            >>> }
            >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True)
            >>> logit = model.compute_actor(data)['logit']
        """
        action_mask = inputs['obs']['action_mask']
        x = self.actor(inputs['obs']['agent_state'])
        return {'logit': x['logit'], 'action_mask': action_mask}

    def compute_critic(self, inputs: Dict) -> Dict:
        """
        Overview:
            use observation tensor to predict Q value.
        Arguments:
            - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                    - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
                        with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
                        N0 corresponds to ``agent_obs_shape``.
                    - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
                        with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
                        N1 corresponds to ``global_obs_shape``.
                    - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
                        with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
                        N2 corresponds to ``action_shape``.
        Returns:
            - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \
                whose key-values vary in different values of ``twin_critic``.
                - q_value (:obj:`list`): If ``twin_critic=True``, q_value should be 2 elements, each is the shape of \
                    :math:`(B, A, N2)`, where B is batch size and A is agent num. N2 corresponds to ``action_shape``. \
                    Otherwise, q_value should be ``torch.Tensor``.
        Examples:
            >>> B = 32
            >>> agent_obs_shape = 216
            >>> global_obs_shape = 264
            >>> agent_num = 8
            >>> action_shape = 14
            >>> data = {
            >>>     'obs': {
            >>>         'agent_state': torch.randn(B, agent_num, agent_obs_shape),
            >>>         'global_state': torch.randn(B, agent_num, global_obs_shape),
            >>>         'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
            >>>     }
            >>> }
            >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True)
            >>> value = model.compute_critic(data)['q_value']
        """

        if self.twin_critic:
            x = [m(inputs['obs']['global_state'])['logit'] for m in self.critic]
        else:
            x = self.critic(inputs['obs']['global_state'])['logit']
        return {'q_value': x}


@MODEL_REGISTRY.register('continuous_maqac')
class ContinuousMAQAC(nn.Module):
    """
    Overview:
        The neural network and computation graph of algorithms related to continuous action Multi-Agent Q-value \
        Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \
        critic is a MLP network. The actor network is used to predict the action probability distribution, and the \
        critic network is used to predict the Q value of the state-action pair.
    Interfaces:
        ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
    """
    mode = ['compute_actor', 'compute_critic']

    def __init__(
            self,
            agent_obs_shape: Union[int, SequenceType],
            global_obs_shape: Union[int, SequenceType],
            action_shape: Union[int, SequenceType, EasyDict],
            action_space: str,
            twin_critic: bool = False,
            actor_head_hidden_size: int = 64,
            actor_head_layer_num: int = 1,
            critic_head_hidden_size: int = 64,
            critic_head_layer_num: int = 1,
            activation: Optional[nn.Module] = nn.ReLU(),
            norm_type: Optional[str] = None,
    ) -> None:
        """
        Overview:
            Initialize the QAC Model according to arguments.
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
            - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, )
            - action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization``.
            - twin_critic (:obj:`bool`): Whether include twin critic.
            - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
            - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
                for actor's nn.
            - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
            - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
                for critic's nn.
            - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \
                ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``
            - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \
                for more details.
        """
        super(ContinuousMAQAC, self).__init__()
        obs_shape: int = squeeze(agent_obs_shape)
        global_obs_shape: int = squeeze(global_obs_shape)
        action_shape = squeeze(action_shape)
        self.action_shape = action_shape
        self.action_space = action_space
        assert self.action_space in ['regression', 'reparameterization'], self.action_space
        if self.action_space == 'regression':  # DDPG, TD3
            self.actor = nn.Sequential(
                nn.Linear(obs_shape, actor_head_hidden_size), activation,
                RegressionHead(
                    actor_head_hidden_size,
                    action_shape,
                    actor_head_layer_num,
                    final_tanh=True,
                    activation=activation,
                    norm_type=norm_type
                )
            )
        else:  # SAC
            self.actor = nn.Sequential(
                nn.Linear(obs_shape, actor_head_hidden_size), activation,
                ReparameterizationHead(
                    actor_head_hidden_size,
                    action_shape,
                    actor_head_layer_num,
                    sigma_type='conditioned',
                    activation=activation,
                    norm_type=norm_type
                )
            )
        self.twin_critic = twin_critic
        critic_input_size = global_obs_shape + action_shape
        if self.twin_critic:
            self.critic = nn.ModuleList()
            for _ in range(2):
                self.critic.append(
                    nn.Sequential(
                        nn.Linear(critic_input_size, critic_head_hidden_size), activation,
                        RegressionHead(
                            critic_head_hidden_size,
                            1,
                            critic_head_layer_num,
                            final_tanh=False,
                            activation=activation,
                            norm_type=norm_type
                        )
                    )
                )
        else:
            self.critic = nn.Sequential(
                nn.Linear(critic_input_size, critic_head_hidden_size), activation,
                RegressionHead(
                    critic_head_hidden_size,
                    1,
                    critic_head_layer_num,
                    final_tanh=False,
                    activation=activation,
                    norm_type=norm_type
                )
            )

    def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
        """
        Overview:
            Use observation and action tensor to predict output in ``compute_actor`` or ``compute_critic`` mode.
        Arguments:
            - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                    - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
                        with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
                        N0 corresponds to ``agent_obs_shape``.
                    - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
                        with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
                        N1 corresponds to ``global_obs_shape``.
                    - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
                        with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
                        N2 corresponds to ``action_shape``.

                - ``action`` (:obj:`torch.Tensor`): The action tensor data, \
                    with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \
                    N3 corresponds to ``action_shape``.
            - mode (:obj:`str`): Name of the forward mode.
        Returns:
            - outputs (:obj:`Dict`): Outputs of network forward, whose key-values will be different for different \
                ``mode``, ``twin_critic``, ``action_space``.
        Examples:
            >>> B = 32
            >>> agent_obs_shape = 216
            >>> global_obs_shape = 264
            >>> agent_num = 8
            >>> action_shape = 14
            >>> act_space = 'reparameterization'  # regression
            >>> data = {
            >>>     'obs': {
            >>>         'agent_state': torch.randn(B, agent_num, agent_obs_shape),
            >>>         'global_state': torch.randn(B, agent_num, global_obs_shape),
            >>>         'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
            >>>     },
            >>>     'action': torch.randn(B, agent_num, squeeze(action_shape))
            >>> }
            >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False)
            >>> if action_space == 'regression':
            >>>     action = model(data['obs'], mode='compute_actor')['action']
            >>> elif action_space == 'reparameterization':
            >>>     (mu, sigma) = model(data['obs'], mode='compute_actor')['logit']
            >>> value = model(data, mode='compute_critic')['q_value']
        """
        assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
        return getattr(self, mode)(inputs)

    def compute_actor(self, inputs: Dict) -> Dict:
        """
        Overview:
            Use observation tensor to predict action logits.
        Arguments:
            - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
                    with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
                    N0 corresponds to ``agent_obs_shape``.

        Returns:
            - outputs (:obj:`Dict`): Outputs of network forward.
        ReturnKeys (``action_space == 'regression'``):
            - action (:obj:`torch.Tensor`): Action tensor with same size as ``action_shape``.
        ReturnKeys (``action_space == 'reparameterization'``):
            - logit (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N3)`, where B is batch size and \
                A is agent num. N3 corresponds to ``action_shape``.
        Examples:
            >>> B = 32
            >>> agent_obs_shape = 216
            >>> global_obs_shape = 264
            >>> agent_num = 8
            >>> action_shape = 14
            >>> act_space = 'reparameterization'  # 'regression'
            >>> data = {
            >>>     'agent_state': torch.randn(B, agent_num, agent_obs_shape),
            >>> }
            >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False)
            >>> if action_space == 'regression':
            >>>     action = model.compute_actor(data)['action']
            >>> elif action_space == 'reparameterization':
            >>>     (mu, sigma) = model.compute_actor(data)['logit']
        """
        inputs = inputs['agent_state']
        if self.action_space == 'regression':
            x = self.actor(inputs)
            return {'action': x['pred']}
        else:
            x = self.actor(inputs)
            return {'logit': [x['mu'], x['sigma']]}

    def compute_critic(self, inputs: Dict) -> Dict:
        """
        Overview:
            Use observation tensor and action tensor to predict Q value.
        Arguments:
            - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
                    - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
                        with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
                        N0 corresponds to ``agent_obs_shape``.
                    - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
                        with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
                        N1 corresponds to ``global_obs_shape``.
                    - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
                        with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
                        N2 corresponds to ``action_shape``.

                - ``action`` (:obj:`torch.Tensor`): The action tensor data, \
                    with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \
                    N3 corresponds to ``action_shape``.

        Returns:
            - outputs (:obj:`Dict`): Outputs of network forward.
        ReturnKeys (``twin_critic=True``):
            - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A)`, where B is batch size and \
                A is agent num.
        ReturnKeys (``twin_critic=False``):
            - q_value (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is agent num.
        Examples:
            >>> B = 32
            >>> agent_obs_shape = 216
            >>> global_obs_shape = 264
            >>> agent_num = 8
            >>> action_shape = 14
            >>> act_space = 'reparameterization'  # 'regression'
            >>> data = {
            >>>     'obs': {
            >>>         'agent_state': torch.randn(B, agent_num, agent_obs_shape),
            >>>         'global_state': torch.randn(B, agent_num, global_obs_shape),
            >>>         'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
            >>>     },
            >>>     'action': torch.randn(B, agent_num, squeeze(action_shape))
            >>> }
            >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False)
            >>> value = model.compute_critic(data)['q_value']
        """

        obs, action = inputs['obs']['global_state'], inputs['action']
        if len(action.shape) == 1:  # (B, ) -> (B, 1)
            action = action.unsqueeze(1)
        x = torch.cat([obs, action], dim=-1)
        if self.twin_critic:
            x = [m(x)['pred'] for m in self.critic]
        else:
            x = self.critic(x)['pred']
        return {'q_value': x}