import torch
from torch import nn
from einops import rearrange
import numpy as np
from typing import List
from models.id_embedding.helpers import get_rep_pos, shift_tensor_dim0
from models.id_embedding.meta_net import StyleVectorizer
from models.celeb_embeddings import _get_celeb_embeddings_basis

from functools import partial
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.init as init


DEFAULT_PLACEHOLDER_TOKEN = ["*"]

PROGRESSIVE_SCALE = 2000

def get_clip_token_for_string(tokenizer, string):
    batch_encoding = tokenizer(string, return_length=True, padding=True, truncation=True, return_overflowing_tokens=False, return_tensors="pt")
    tokens = batch_encoding["input_ids"]
    
    return tokens 


def get_embedding_for_clip_token(embedder, token):
    return embedder(token.unsqueeze(0))


class EmbeddingManagerId_adain(nn.Module):
    def __init__(
            self,
            tokenizer,
            text_encoder,
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),  
            experiment_name = "normal_GAN",                      
            num_embeds_per_token: int = 2,  
            loss_type: str = None,
            mlp_depth: int = 2,    
            token_dim: int = 1024,   
            input_dim: int = 1024, 
            **kwargs
    ):
        super().__init__()
        self.device = device
        self.num_es = num_embeds_per_token

        self.get_token_for_string = partial(get_clip_token_for_string, tokenizer)        
        self.get_embedding_for_tkn = partial(get_embedding_for_clip_token, text_encoder.text_model.embeddings)  
        

        self.token_dim = token_dim

        ''' 1. Placeholder mapping dicts '''
        self.placeholder_token = self.get_token_for_string("*")[0][1]    
        
        if experiment_name == "normal_GAN":
            self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names.txt")
        elif experiment_name == "man_GAN":
            self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_man.txt")
        elif experiment_name == "woman_GAN":            
            self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_woman.txt")
        else:
            print("Hello, please notice this ^_^")
            assert 0
        print("now experiment_name:", experiment_name)
        
        self.celeb_embeddings_mean = self.celeb_embeddings_mean.to(device)   
        self.celeb_embeddings_std = self.celeb_embeddings_std.to(device)  

        self.name_projection_layer = StyleVectorizer(input_dim, self.token_dim * self.num_es, depth=mlp_depth, lr_mul=0.1) 
        self.embedding_discriminator = Embedding_discriminator(self.token_dim * self.num_es, dropout_rate = 0.2)

        self.adain_mode = 0
        
    def forward(
            self,
            tokenized_text, 
            embedded_text, 
            name_batch,
            random_embeddings = None,
            timesteps = None,
    ):
        
        if tokenized_text is not None:
            batch_size, n, device = *tokenized_text.shape, tokenized_text.device
        other_return_dict = {}
        
        if random_embeddings is not None:
            mlp_output_embedding = self.name_projection_layer(random_embeddings)   
            total_embedding = mlp_output_embedding.view(mlp_output_embedding.shape[0], 2, 1024)   

            if self.adain_mode == 0:          
                adained_total_embedding = total_embedding * self.celeb_embeddings_std + self.celeb_embeddings_mean
            else:
                adained_total_embedding = total_embedding
                
            other_return_dict["total_embedding"] = total_embedding
            other_return_dict["adained_total_embedding"] = adained_total_embedding

        if name_batch is not None:
            if isinstance(name_batch, list): 
                name_tokens = self.get_token_for_string(name_batch)[:, 1:3]
                name_embeddings = self.get_embedding_for_tkn(name_tokens.to(random_embeddings.device))[0] 
                
                other_return_dict["name_embeddings"] = name_embeddings
            else:
                assert 0

        if tokenized_text is not None:
            placeholder_pos = get_rep_pos(tokenized_text,
                                        [self.placeholder_token])
            placeholder_pos = np.array(placeholder_pos)
            if len(placeholder_pos) != 0:
                batch_size = adained_total_embedding.shape[0]  
                end_index = min(batch_size, placeholder_pos.shape[0]) 
                embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1]] = adained_total_embedding[:end_index,0,:]
                embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1] + 1] = adained_total_embedding[:end_index,1,:]

        return embedded_text, other_return_dict



    def load(self, ckpt_path):
        ckpt = torch.load(ckpt_path, map_location='cuda')
        if ckpt.get("name_projection_layer") is not None:
            self.name_projection_layer = ckpt.get("name_projection_layer").float()

        print('[Embedding Manager] weights loaded.')



    def save(self, ckpt_path):
        save_dict = {}
        save_dict["name_projection_layer"] = self.name_projection_layer
        
        torch.save(save_dict, ckpt_path)


    def trainable_projection_parameters(self):  
        trainable_list = []
        trainable_list.extend(list(self.name_projection_layer.parameters())) 

        return trainable_list
   


class Embedding_discriminator(nn.Module):
    def __init__(self, input_size, dropout_rate):
        super(Embedding_discriminator, self).__init__()
        self.input_size = input_size

        self.fc1 = nn.Linear(input_size, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1)

        self.LayerNorm1 = nn.LayerNorm(512)
        self.LayerNorm2 = nn.LayerNorm(256)

        self.leaky_relu = nn.LeakyReLU(0.2)

        self.dropout_rate = dropout_rate
        if self.dropout_rate > 0:
            self.dropout1 = nn.Dropout(dropout_rate)
            self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, input):
        x = input.view(-1, self.input_size)

        if self.dropout_rate > 0:
            x = self.leaky_relu(self.dropout1(self.fc1(x)))
        else:
            x = self.leaky_relu(self.fc1(x))
        
        if self.dropout_rate > 0:
            x = self.leaky_relu(self.dropout2(self.fc2(x)))
        else:
            x = self.leaky_relu(self.fc2(x))

        x = self.fc3(x)

        return x
    
    
    def save(self, ckpt_path):
        save_dict = {}
 
        save_dict["fc1"] = self.fc1
        save_dict["fc2"] = self.fc2
        save_dict["fc3"] = self.fc3
        save_dict["LayerNorm1"] = self.LayerNorm1
        save_dict["LayerNorm2"] = self.LayerNorm2
        save_dict["leaky_relu"] = self.leaky_relu
        save_dict["dropout1"] = self.dropout1
        save_dict["dropout2"] = self.dropout2
        
        torch.save(save_dict, ckpt_path)
    
    def load(self, ckpt_path):
        ckpt = torch.load(ckpt_path, map_location='cuda')
        
        if ckpt.get("first_name_proj_layer") is not None:
            self.fc1 = ckpt.get("fc1").float()
            self.fc2 = ckpt.get("fc2").float()
            self.fc3 = ckpt.get("fc3").float()
            self.LayerNorm1 = ckpt.get("LayerNorm1").float()
            self.LayerNorm2 = ckpt.get("LayerNorm2").float()
            self.leaky_relu = ckpt.get("leaky_relu").float()
            self.dropout1 = ckpt.get("dropout1").float()
            self.dropout2 = ckpt.get("dropout2").float()
            
        print('[Embedding D] weights loaded.')