import torch.nn as nn
import torch

import tools.utils as utils
import agent.dreamer_utils as common
from collections import OrderedDict
import numpy as np

from tools.genrl_utils import *

def stop_gradient(x):
  return x.detach()

Module = nn.Module 

def env_reward(agent, seq):
  return agent.wm.heads['reward'](seq['feat']).mean

class DreamerAgent(Module):

  def __init__(self, 
                name, cfg, obs_space, act_spec, **kwargs):
    super().__init__()
    self.name = name
    self.cfg = cfg
    self.cfg.update(**kwargs)
    self.obs_space = obs_space
    self.act_spec = act_spec
    self._use_amp = (cfg.precision == 16)
    self.device = cfg.device
    self.act_dim = act_spec.shape[0]
    self.wm = WorldModel(cfg, obs_space, self.act_dim,)
    self.instantiate_acting_behavior()

    self.to(cfg.device)
    self.requires_grad_(requires_grad=False)

  def instantiate_acting_behavior(self,):
    self._acting_behavior = ActorCritic(self.cfg, self.act_spec, self.wm.inp_size).to(self.device)
    
  def act(self, obs, meta, step, eval_mode, state):
    if self.cfg.only_random_actions:
      return np.random.uniform(-1, 1, self.act_dim,).astype(self.act_spec.dtype), (None, None)
    obs = {k : torch.as_tensor(np.copy(v), device=self.device).unsqueeze(0) for k, v in obs.items()}
    if state is None:
      latent = self.wm.rssm.initial(len(obs['reward']))
      action = torch.zeros((len(obs['reward']),) + self.act_spec.shape, device=self.device)
    else:
      latent, action = state
    embed = self.wm.encoder(self.wm.preprocess(obs))
    should_sample = (not eval_mode) or (not self.cfg.eval_state_mean)
    latent, _ = self.wm.rssm.obs_step(latent, action, embed, obs['is_first'], should_sample)
    feat = self.wm.rssm.get_feat(latent)
    if eval_mode:
      actor = self._acting_behavior.actor(feat)
      try:
        action = actor.mean 
      except:
        action = actor._mean
    else:
      actor = self._acting_behavior.actor(feat)
      action = actor.sample()
    new_state = (latent, action)
    return action.cpu().numpy()[0], new_state

  def update_wm(self, data, step):
    metrics = {}
    state, outputs, mets = self.wm.update(data, state=None)
    outputs['is_terminal'] = data['is_terminal']
    metrics.update(mets)
    return state, outputs, metrics

  def update_acting_behavior(self, state=None, outputs=None, metrics={}, data=None, reward_fn=None):
    if self.cfg.only_random_actions:
      return {}, metrics
    if outputs is not None:
      post = outputs['post']
      is_terminal = outputs['is_terminal']
    else:
      data = self.wm.preprocess(data)
      embed = self.wm.encoder(data)
      post, _ = self.wm.rssm.observe(
          embed, data['action'], data['is_first'])
      is_terminal = data['is_terminal']
    #
    start = {k: stop_gradient(v) for k,v in post.items()}
    if reward_fn is None:
      acting_reward_fn = lambda seq: globals()[self.cfg.acting_reward_fn](self, seq) #.mode()
    else:
      acting_reward_fn = lambda seq: reward_fn(self, seq) #.mode()
    metrics.update(self._acting_behavior.update(self.wm, start, is_terminal, acting_reward_fn))
    return start, metrics

  def update(self, data, step):
    state, outputs, metrics = self.update_wm(data, step)
    start, metrics = self.update_acting_behavior(state, outputs, metrics, data)
    return state, metrics

  def report(self, data):
    report = {}
    data = self.wm.preprocess(data)
    for key in self.wm.heads['decoder'].cnn_keys:
      name = key.replace('/', '_')
      report[f'openl_{name}'] = self.wm.video_pred(data, key)
    for fn in getattr(self.cfg, 'additional_report_fns', []):
      call_fn = globals()[fn]
      additional_report = call_fn(self, data)
      report.update(additional_report)
    return report

  def get_meta_specs(self):
    return tuple()

  def init_meta(self):
    return OrderedDict()

  def update_meta(self, meta, global_step, time_step, finetune=False):
    return meta

class WorldModel(Module):
  def __init__(self, config, obs_space, act_dim,):
    super().__init__()
    shapes = {k: tuple(v.shape) for k, v in obs_space.items()}
    self.shapes = shapes
    self.cfg = config
    self.device = config.device
    self.encoder = common.Encoder(shapes, **config.encoder)
    # Computing embed dim
    with torch.no_grad():
      zeros = {k: torch.zeros( (1,) + v) for k, v in shapes.items()}
      outs = self.encoder(zeros)
      embed_dim = outs.shape[1]
    self.embed_dim = embed_dim
    self.rssm = common.EnsembleRSSM(**config.rssm, action_dim=act_dim, embed_dim=embed_dim, device=self.device,)
    self.heads = {}
    self._use_amp = (config.precision == 16)
    self.inp_size = self.rssm.get_feat_size()
    self.decoder_input_fn = getattr(self.rssm, f'get_{config.decoder_inputs}')
    self.decoder_input_size = getattr(self.rssm, f'get_{config.decoder_inputs}_size')()
    self.heads['decoder'] = common.Decoder(shapes, **config.decoder, embed_dim=self.decoder_input_size, image_dist=config.image_dist)
    self.heads['reward'] = common.MLP(self.inp_size, (1,), **config.reward_head)
    # zero init
    with torch.no_grad():
      for p in self.heads['reward']._out.parameters():
        p.data = p.data * 0
    #
    if config.pred_discount:
      self.heads['discount'] = common.MLP(self.inp_size, (1,), **config.discount_head)
    for name in config.grad_heads:
      assert name in self.heads, name
    self.grad_heads = config.grad_heads
    self.heads = nn.ModuleDict(self.heads)
    self.model_opt = common.Optimizer('model', self.parameters(), **config.model_opt, use_amp=self._use_amp)
    self.e2e_update_fns = {}
    self.detached_update_fns = {}
    self.eval()

  def add_module_to_update(self, name, module, update_fn, detached=False):
    self.add_module(name, module)
    if detached:
      self.detached_update_fns[name] = update_fn
    else:
      self.e2e_update_fns[name] = update_fn
    self.model_opt = common.Optimizer('model', self.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)

  def update(self, data, state=None):
    self.train()
    with common.RequiresGrad(self):
      with torch.cuda.amp.autocast(enabled=self._use_amp):
        if getattr(self.cfg, "freeze_decoder", False):
          self.heads['decoder'].requires_grad_(False)
        if getattr(self.cfg, "freeze_post", False) or getattr(self.cfg, "freeze_model", False):
          self.heads['decoder'].requires_grad_(False)
          self.encoder.requires_grad_(False)
          # Updating only prior
          self.grad_heads = []
          self.rssm.requires_grad_(False)
          if not getattr(self.cfg, "freeze_model", False):
            self.rssm._ensemble_img_out.requires_grad_(True)
            self.rssm._ensemble_img_dist.requires_grad_(True)
        model_loss, state, outputs, metrics = self.loss(data, state)
        model_loss, metrics = self.update_additional_e2e_modules(data, outputs, model_loss, metrics)
      metrics.update(self.model_opt(model_loss, self.parameters())) 
    if len(self.detached_update_fns) > 0:
      detached_loss, metrics = self.update_additional_detached_modules(data, outputs, metrics)
    self.eval()
    return state, outputs, metrics

  def update_additional_detached_modules(self, data, outputs, metrics):
    # additional detached losses
    detached_loss = 0
    for k in self.detached_update_fns:
      detached_module = getattr(self, k)
      with common.RequiresGrad(detached_module):
        with torch.cuda.amp.autocast(enabled=self._use_amp):
          add_loss, add_metrics = self.detached_update_fns[k](self, k, data, outputs, metrics)
          metrics.update(add_metrics)
          opt_metrics = self.model_opt(add_loss, detached_module.parameters())
          metrics.update({ f'{k}_{m}' : opt_metrics[m] for m in opt_metrics})
    return detached_loss, metrics

  def update_additional_e2e_modules(self, data, outputs, model_loss, metrics):
    # additional e2e losses
    for k in self.e2e_update_fns:
      add_loss, add_metrics = self.e2e_update_fns[k](self, k, data, outputs, metrics)
      model_loss += add_loss
      metrics.update(add_metrics)
    return model_loss, metrics

  def observe_data(self, data, state=None):
    data = self.preprocess(data)
    embed = self.encoder(data)
    post, prior = self.rssm.observe(
        embed, data['action'], data['is_first'], state)
    kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl)
    outs = dict(embed=embed, post=post, prior=prior, is_terminal=data['is_terminal'])
    return outs, { 'model_kl' : kl_value.mean() }

  def loss(self, data, state=None):
    data = self.preprocess(data)
    embed = self.encoder(data)
    post, prior = self.rssm.observe(
        embed, data['action'], data['is_first'], state)
    kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl)
    assert len(kl_loss.shape) == 0 or (len(kl_loss.shape) == 1 and kl_loss.shape[0] == 1), kl_loss.shape
    likes = {}
    losses = {'kl': kl_loss}
    feat = self.rssm.get_feat(post)
    for name, head in self.heads.items():
      grad_head = (name in self.grad_heads)
      if name == 'decoder':
        inp = self.decoder_input_fn(post)
      else:
        inp = feat
      inp = inp if grad_head else stop_gradient(inp)
      out = head(inp)
      dists = out if isinstance(out, dict) else {name: out}
      for key, dist in dists.items():
        like = dist.log_prob(data[key]) 
        likes[key] = like
        losses[key] = -like.mean()
    model_loss = sum(
        self.cfg.loss_scales.get(k, 1.0) * v for k, v in losses.items())
    outs = dict(
        embed=embed, feat=feat, post=post,
        prior=prior, likes=likes, kl=kl_value)
    metrics = {f'{name}_loss': value for name, value in losses.items()}
    metrics['model_kl'] = kl_value.mean()
    metrics['prior_ent'] = self.rssm.get_dist(prior).entropy().mean()
    metrics['post_ent'] = self.rssm.get_dist(post).entropy().mean()
    last_state = {k: v[:, -1] for k, v in post.items()}
    return model_loss, last_state, outs, metrics

  def imagine(self, policy, start, is_terminal, horizon, task_cond=None, eval_policy=False):
    flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
    start = {k: flatten(v) for k, v in start.items()}
    start['feat'] = self.rssm.get_feat(start)
    inp = start['feat'] if task_cond is None else torch.cat([start['feat'], task_cond], dim=-1)
    policy_dist = policy(inp)
    start['action'] = torch.zeros_like(policy_dist.sample(), device=self.device) #.mode())
    seq = {k: [v] for k, v in start.items()}
    if task_cond is not None: seq['task'] = [task_cond]
    for _ in range(horizon):
      inp = seq['feat'][-1] if task_cond is None else torch.cat([seq['feat'][-1], task_cond], dim=-1)
      policy_dist = policy(stop_gradient(inp))
      action = policy_dist.sample() if not eval_policy else policy_dist.mean
      state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action)
      feat = self.rssm.get_feat(state)
      for key, value in {**state, 'action': action, 'feat': feat}.items():
        seq[key].append(value)
      if task_cond is not None: seq['task'].append(task_cond)
    # shape will be (T, B, *DIMS)
    seq = {k: torch.stack(v, 0) for k, v in seq.items()}
    if 'discount' in self.heads:
      disc = self.heads['discount'](seq['feat']).mean()
      if is_terminal is not None:
        # Override discount prediction for the first step with the true
        # discount factor from the replay buffer.
        true_first = 1.0 - flatten(is_terminal) 
        disc = torch.cat([true_first[None], disc[1:]], 0)
    else:
      disc = torch.ones(list(seq['feat'].shape[:-1]) + [1], device=self.device)
    seq['discount'] = disc * self.cfg.discount
    # Shift discount factors because they imply whether the following state
    # will be valid, not whether the current state is valid.
    seq['weight'] = torch.cumprod(torch.cat([torch.ones_like(disc[:1], device=self.device), disc[:-1]], 0), 0)
    return seq

  def preprocess(self, obs):
    obs = obs.copy()
    for key, value in obs.items():
      if key.startswith('log_'):
        continue
      if value.dtype in [np.uint8, torch.uint8]:
        value = value / 255.0 - 0.5 
      obs[key] = value
    obs['reward'] = {
        'identity': nn.Identity(),
        'sign': torch.sign,
        'tanh': torch.tanh,
    }[self.cfg.clip_rewards](obs['reward'])
    obs['discount'] = (1.0 - obs['is_terminal'].float())
    if len(obs['discount'].shape) < len(obs['reward'].shape):
      obs['discount'] = obs['discount'].unsqueeze(-1)
    return obs

  def video_pred(self, data, key, nvid=8):
    decoder = self.heads['decoder'] # B, T, C, H, W
    truth = data[key][:nvid] + 0.5
    embed = self.encoder(data)
    states, _ = self.rssm.observe(
        embed[:nvid, :5], data['action'][:nvid, :5], data['is_first'][:nvid, :5])
    recon = decoder(self.decoder_input_fn(states))[key].mean[:nvid] # mode
    init = {k: v[:, -1] for k, v in states.items()}
    prior = self.rssm.imagine(data['action'][:nvid, 5:], init)
    prior_recon = decoder(self.decoder_input_fn(prior))[key].mean # mode
    model = torch.clip(torch.cat([recon[:, :5] + 0.5, prior_recon + 0.5], 1), 0, 1)
    error = (model - truth + 1) / 2
    video = torch.cat([truth, model, error], 3)
    B, T, C, H, W = video.shape
    return video 

class ActorCritic(Module):
  def __init__(self, config, act_spec, feat_size, name=''):
    super().__init__()
    self.name = name
    self.cfg = config
    self.act_spec = act_spec
    self._use_amp = (config.precision == 16)
    self.device = config.device
    
    if getattr(self.cfg, 'discrete_actions', False):
      self.cfg.actor.dist = 'onehot'

    self.actor_grad = getattr(self.cfg, f'{self.name}_actor_grad'.strip('_'))
    
    inp_size = feat_size
    self.actor = common.MLP(inp_size, act_spec.shape[0], **self.cfg.actor)
    self.critic = common.MLP(inp_size, (1,), **self.cfg.critic)
    if self.cfg.slow_target:
      self._target_critic = common.MLP(inp_size, (1,), **self.cfg.critic)
      self._updates = 0 # tf.Variable(0, tf.int64)
    else:
      self._target_critic = self.critic
    self.actor_opt = common.Optimizer('actor', self.actor.parameters(), **self.cfg.actor_opt, use_amp=self._use_amp)
    self.critic_opt = common.Optimizer('critic', self.critic.parameters(), **self.cfg.critic_opt, use_amp=self._use_amp)
    
    if self.cfg.reward_ema:
        # register ema_vals to nn.Module for enabling torch.save and torch.load
        self.register_buffer("ema_vals", torch.zeros((2,)).to(self.device))
        self.reward_ema = common.RewardEMA(device=self.device)
        self.rewnorm = common.StreamNorm(momentum=1, scale=1.0, device=self.device)
    else:
        self.rewnorm = common.StreamNorm(**self.cfg.reward_norm, device=self.device)

    # zero init
    with torch.no_grad():
      for p in self.critic._out.parameters():
        p.data = p.data * 0
    # hard copy critic initial params
    for s, d in zip(self.critic.parameters(), self._target_critic.parameters()):
      d.data = s.data
    #


  def update(self, world_model, start, is_terminal, reward_fn):
    metrics = {}
    hor = self.cfg.imag_horizon
    # The weights are is_terminal flags for the imagination start states.
    # Technically, they should multiply the losses from the second trajectory
    # step onwards, which is the first imagined step. However, we are not
    # training the action that led into the first step anyway, so we can use
    # them to scale the whole sequence.
    with common.RequiresGrad(self.actor):
      with torch.cuda.amp.autocast(enabled=self._use_amp):
        seq = world_model.imagine(self.actor, start, is_terminal, hor)
        reward = reward_fn(seq)
        seq['reward'], mets1 = self.rewnorm(reward)
        mets1 = {f'reward_{k}': v for k, v in mets1.items()}
        target, mets2, baseline = self.target(seq)
        actor_loss, mets3 = self.actor_loss(seq, target, baseline)
      metrics.update(self.actor_opt(actor_loss, self.actor.parameters()))
    with common.RequiresGrad(self.critic):
      with torch.cuda.amp.autocast(enabled=self._use_amp):
        seq = {k: stop_gradient(v) for k,v in seq.items()}
        critic_loss, mets4 = self.critic_loss(seq, target)
      metrics.update(self.critic_opt(critic_loss, self.critic.parameters()))
    metrics.update(**mets1, **mets2, **mets3, **mets4)
    self.update_slow_target()  # Variables exist after first forward pass.
    return { f'{self.name}_{k}'.strip('_') : v for k,v in metrics.items() }

  def actor_loss(self, seq, target, baseline): #, step):
    # Two state-actions are lost at the end of the trajectory, one for the boostrap
    # value prediction and one because the corresponding action does not lead
    # anywhere anymore. One target is lost at the start of the trajectory
    # because the initial state comes from the replay buffer.
    policy = self.actor(stop_gradient(seq['feat'][:-2])) # actions are the ones in [1:-1]

    metrics = {}
    if self.cfg.reward_ema:
      offset, scale = self.reward_ema(target, self.ema_vals)
      normed_target = (target - offset) / scale
      normed_baseline = (baseline - offset) / scale
      # adv = normed_target - normed_baseline
      metrics['normed_target_mean'] = normed_target.mean()
      metrics['normed_target_std'] = normed_target.std()
      metrics["reward_ema_005"] = self.ema_vals[0]
      metrics["reward_ema_095"] = self.ema_vals[1]
    else:
      normed_target = target
      normed_baseline = baseline
    
    if self.actor_grad == 'dynamics':
      objective = normed_target[1:]
    elif self.actor_grad == 'reinforce':
      advantage = normed_target[1:] - normed_baseline[1:]
      objective = policy.log_prob(stop_gradient(seq['action'][1:-1]))[:,:,None] * advantage
    else:
      raise NotImplementedError(self.actor_grad)
    
    ent = policy.entropy()[:,:,None]
    ent_scale = self.cfg.actor_ent
    objective += ent_scale * ent
    metrics['actor_ent'] = ent.mean()
    metrics['actor_ent_scale'] = ent_scale
    
    weight = stop_gradient(seq['weight'])
    actor_loss = -(weight[:-2] * objective).mean() 
    return actor_loss, metrics

  def critic_loss(self, seq, target):
    feat = seq['feat'][:-1]
    target = stop_gradient(target)
    weight = stop_gradient(seq['weight'])
    dist = self.critic(feat)
    critic_loss = -(dist.log_prob(target)[:,:,None] * weight[:-1]).mean()
    metrics = {'critic': dist.mean.mean() } 
    return critic_loss, metrics

  def target(self, seq):
    reward = seq['reward'] 
    disc = seq['discount'] 
    value = self._target_critic(seq['feat']).mean 
    # Skipping last time step because it is used for bootstrapping.
    target = common.lambda_return(
        reward[:-1], value[:-1], disc[:-1],
        bootstrap=value[-1],
        lambda_=self.cfg.discount_lambda,
        axis=0)
    metrics = {}
    metrics['critic_slow'] = value.mean()
    metrics['critic_target'] = target.mean()
    return target, metrics, value[:-1]

  def update_slow_target(self):
    if self.cfg.slow_target:
      if self._updates % self.cfg.slow_target_update == 0:
        mix = 1.0 if self._updates == 0 else float(
            self.cfg.slow_target_fraction)
        for s, d in zip(self.critic.parameters(), self._target_critic.parameters()):
          d.data = mix * s.data + (1 - mix) * d.data
      self._updates += 1