import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from PIL import Image
from einops import rearrange
import math 
class GaussianDiffusion:
  def __init__(self, model, noise_steps, beta_0, beta_T, image_size, channels=3, schedule="linear"):
    """
    suggested betas for:
      * linear schedule: 1e-4, 0.02

    model: the model to be trained (nn.Module)
    noise_steps: the number of steps to apply noise (int)
    beta_0: the initial value of beta (float)
    beta_T: the final value of beta (float)
    image_size: the size of the image (int, int)
    """
    self.device = 'cpu'
    self.channels = channels

    self.model = model
    self.noise_steps = noise_steps
    self.beta_0 = beta_0
    self.beta_T = beta_T
    self.image_size = image_size

    self.betas = self.beta_schedule(schedule=schedule)
    self.alphas = 1.0 - self.betas
    # cumulative product of alphas, so we can optimize forward process calculation
    self.alpha_hat = torch.cumprod(self.alphas, dim=0)

  def beta_schedule(self, schedule="cosine"):
    if schedule == "linear":
      return torch.linspace(self.beta_0, self.beta_T, self.noise_steps).to(self.device)
    elif schedule == "cosine":
      return self.betas_for_cosine(self.noise_steps)
    elif schedule == "sigmoid":
      return self.betas_for_sigmoid(self.noise_steps)
  
  @staticmethod 
  def sigmoid(x):
    return 1 / (1 + np.exp(-x))

  def betas_for_sigmoid(self, num_diffusion_timesteps, start=-3,end=3, tau=1.0, clip_min = 1e-9):
    betas = []
    v_start = self.sigmoid(start/tau)
    v_end = self.sigmoid(end/tau)
    for t in range(num_diffusion_timesteps):
      t_float = float(t/num_diffusion_timesteps)
      output0 = self.sigmoid((t_float* (end-start)+start)/tau)
      output = (v_end-output0) / (v_end-v_start)
      betas.append(np.clip(output*.2, clip_min,.2))
    return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()

  def betas_for_cosine(self,num_steps,start=0,end=1,tau=1,clip_min=1e-9):
    v_start = math.cos(start*math.pi / 2) ** (2 * tau)
    betas = []
    v_end = math.cos(end* math.pi/2) ** 2*tau
    for t in range(num_steps):
      t_float = float(t)/num_steps
      output = math.cos((t_float* (end-start)+start)*math.pi/2)**(2*tau)
      output = (v_end - output) / (v_end-v_start)
      betas.append(np.clip(output*.2,clip_min,.2))
    return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()
  

  def sample_time_steps(self, batch_size=1):
    return torch.randint(0, self.noise_steps, (batch_size,)).to(self.device)
  
  def to(self,device):
    self.device = device
    self.betas = self.betas.to(device)
    self.alphas = self.alphas.to(device)
    self.alpha_hat = self.alpha_hat.to(device)

  
  def q(self, x, t):
    """
    Forward process
    """
    pass
  
  def p(self, x, t):
    """
    Backward process
    """
    pass 

  
  def apply_noise(self, x, t):
    # force x to be (batch_size, image_width, image_height, channels)
    if len(x.shape) == 3:
      x = x.unsqueeze(0)
    if type(t) == int:
      t = torch.tensor([t])
    #print(f'Shape -> {x.shape}, len -> {len(x.shape)}')
    sqrt_alpha_hat = torch.sqrt(torch.tensor([self.alpha_hat[t_] for t_ in t]).to(self.device))
    sqrt_one_minus_alpha_hat = torch.sqrt(torch.tensor([1.0 - self.alpha_hat[t_] for t_ in t]).to(self.device))
    # standard normal distribution
    epsilon = torch.randn_like(x).to(self.device)

    # Eq 2. in DDPM paper
    #noisy_image = sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon

    """print(f'''
              Shape of x {x.shape}
              Shape of sqrt {sqrt_one_minus_alpha_hat.shape}''')"""
    
    try:
      #print(x.shape)
      #noisy_image = torch.einsum("b,bwhc->bwhc", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bwhc->bwhc", sqrt_one_minus_alpha_hat, epsilon)
      noisy_image = torch.einsum("b,bcwh->bcwh", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bcwh->bcwh", sqrt_one_minus_alpha_hat, epsilon)
    except:
      print(f'Failed image: shape {x.shape}')
      
    
    #print(f'Noisy image -> {noisy_image.shape}')
    # returning noisy iamge and the noise which was added to the image
    #return noisy_image, epsilon
    #return torch.clip(noisy_image, -1.0, 1.0), epsilon
    return noisy_image, epsilon
  
  @staticmethod
  def normalize_image(x):
    # normalize image to [-1, 1]
    return x / 255.0 * 2.0 - 1.0
  
  @staticmethod
  def denormalize_image(x):
    # denormalize image to [0, 255]
    return (x + 1.0) / 2.0 * 255.0
  
  def sample_step(self, x, t, cond):
    batch_size = x.shape[0]
    device = x.device
    z = torch.randn_like(x) if t >= 1 else torch.zeros_like(x)
    z = z.to(device)
    alpha = self.alphas[t]
    one_over_sqrt_alpha = 1.0 / torch.sqrt(alpha)
    one_minus_alpha = 1.0 - alpha

    sqrt_one_minus_alpha_hat = torch.sqrt(1.0 - self.alpha_hat[t])
    beta_hat = (1 - self.alpha_hat[t-1]) / (1 - self.alpha_hat[t]) * self.betas[t]
    beta = self.betas[t]
    # should we reshape the params to (batch_size, 1, 1, 1) ?
    

    # we can either use beta_hat or beta_t
    # std = torch.sqrt(beta_hat)
    std = torch.sqrt(beta)
    # mean + variance * z
    if cond is not None:
      predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device), cond)
    else:
      predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device))
    mean = one_over_sqrt_alpha * (x - one_minus_alpha / sqrt_one_minus_alpha_hat * predicted_noise)
    x_t_minus_1 = mean + std * z

    return x_t_minus_1
  
  def sample(self, num_samples, show_progress=True, cond=None, x0=None, cb=None):
    """
    Sample from the model
    """
    #cond = None
    if cond == None:
      # cond is arange()
      assert num_samples <= self.model.num_classes, "num_samples must be less than or equal to the number of classes"
      cond = torch.arange(self.model.num_classes)[:num_samples].to(self.device)
      cond = rearrange(cond, 'i -> i ()')


    # Inpainting
    self.model.eval()
    image_versions = []
    with torch.no_grad():
      x = torch.randn(num_samples, self.channels, *self.image_size).to(self.device)
      

      if x0 is not None:
        x0 = x0.to(self.device)
        mask = x0 != -1
        x_noised = self.apply_noise(x0,self.noise_steps -1)[0].to(self.device)
        new_x = x
        new_x[mask] = x_noised[mask]
        
        x = new_x


      it = reversed(range(1, self.noise_steps))
      if show_progress:
        it = tqdm(it)
      for t in it:
        temp_image = self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0)
        if cb is not None:
          cb(temp_image, 1-t/(self.noise_steps+1))
        image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
        
        if x0 is not None and t > 80:
          x_noised = self.apply_noise(x0,t)[0]
          new_x = x
          new_x[mask] = x_noised[mask]

          x = new_x
          
        x = self.sample_step(x, t, cond)
    self.model.train()
    x = torch.clip(x, -1.0, 1.0)
    return self.denormalize_image(x), image_versions
  
  def validate(self, dataloader):
    """
    Calculate the loss on the validation set
    """
    self.model.eval()
    acc_loss = 0
    with torch.no_grad():
      for (image, cond) in dataloader:
        t = self.sample_time_steps(batch_size=image.shape[0])
        noisy_image, added_noise = self.apply_noise(image, t)
        noisy_image = noisy_image.to(self.device)
        added_noise = added_noise.to(self.device)
        cond = cond.to(self.device)
        predicted_noise = self.model(noisy_image, t, cond)
        loss = nn.MSELoss()(predicted_noise, added_noise)
        acc_loss += loss.item()
    self.model.train()
    return acc_loss / len(dataloader)

class DiffusionImageAPI:
  def __init__(self, diffusion_model):
    self.diffusion_model = diffusion_model
  
  def get_noisy_image(self, image, t):
    x = torch.tensor(np.array(image))
    
    x = self.diffusion_model.normalize_image(x)

    y, _ = self.diffusion_model.apply_noise(x, t)
    
    y = self.diffusion_model.denormalize_image(y)
    #print(f"Shape of Image: {y.shape}")

    return Image.fromarray(y.squeeze(0).numpy().astype(np.uint8))

  
  def get_noisy_images(self, image, time_steps):
    """
    image: the image to be processed PIL.Image
    time_steps: the number of time steps to apply noise (int)
    """

    return [self.get_noisy_image(image, int(t)) for t in time_steps]
  
  def tensor_to_image(self, tensor):
    return Image.fromarray(tensor.cpu().permute(1, 2, 0).numpy().astype(np.uint8))