import os
import json
import random
import numpy as np
from SmilesPE.pretokenizer import atomwise_tokenizer

PAD = '<pad>'
SOS = '<sos>'
EOS = '<eos>'
UNK = '<unk>'
MASK = '<mask>'
PAD_ID = 0
SOS_ID = 1
EOS_ID = 2
UNK_ID = 3
MASK_ID = 4


class Tokenizer(object):

    def __init__(self, path=None):
        self.stoi = {}
        self.itos = {}
        if path:
            self.load(path)

    def __len__(self):
        return len(self.stoi)

    @property
    def output_constraint(self):
        return False

    def save(self, path):
        with open(path, 'w') as f:
            json.dump(self.stoi, f)

    def load(self, path):
        with open(path) as f:
            self.stoi = json.load(f)
        self.itos = {item[1]: item[0] for item in self.stoi.items()}

    def fit_on_texts(self, texts):
        vocab = set()
        for text in texts:
            vocab.update(text.split(' '))
        vocab = [PAD, SOS, EOS, UNK] + list(vocab)
        for i, s in enumerate(vocab):
            self.stoi[s] = i
        self.itos = {item[1]: item[0] for item in self.stoi.items()}
        assert self.stoi[PAD] == PAD_ID
        assert self.stoi[SOS] == SOS_ID
        assert self.stoi[EOS] == EOS_ID
        assert self.stoi[UNK] == UNK_ID

    def text_to_sequence(self, text, tokenized=True):
        sequence = []
        sequence.append(self.stoi['<sos>'])
        if tokenized:
            tokens = text.split(' ')
        else:
            tokens = atomwise_tokenizer(text)
        for s in tokens:
            if s not in self.stoi:
                s = '<unk>'
            sequence.append(self.stoi[s])
        sequence.append(self.stoi['<eos>'])
        return sequence

    def texts_to_sequences(self, texts):
        sequences = []
        for text in texts:
            sequence = self.text_to_sequence(text)
            sequences.append(sequence)
        return sequences

    def sequence_to_text(self, sequence):
        return ''.join(list(map(lambda i: self.itos[i], sequence)))

    def sequences_to_texts(self, sequences):
        texts = []
        for sequence in sequences:
            text = self.sequence_to_text(sequence)
            texts.append(text)
        return texts

    def predict_caption(self, sequence):
        caption = ''
        for i in sequence:
            if i == self.stoi['<eos>'] or i == self.stoi['<pad>']:
                break
            caption += self.itos[i]
        return caption

    def predict_captions(self, sequences):
        captions = []
        for sequence in sequences:
            caption = self.predict_caption(sequence)
            captions.append(caption)
        return captions

    def sequence_to_smiles(self, sequence):
        return {'smiles': self.predict_caption(sequence)}


class NodeTokenizer(Tokenizer):

    def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False):
        super().__init__(path)
        self.maxx = input_size  # height
        self.maxy = input_size  # width
        self.sep_xy = sep_xy
        self.special_tokens = [PAD, SOS, EOS, UNK, MASK]
        self.continuous_coords = continuous_coords
        self.debug = debug

    def __len__(self):
        if self.sep_xy:
            return self.offset + self.maxx + self.maxy
        else:
            return self.offset + max(self.maxx, self.maxy)

    @property
    def offset(self):
        return len(self.stoi)

    @property
    def output_constraint(self):
        return not self.continuous_coords

    def len_symbols(self):
        return len(self.stoi)

    def fit_atom_symbols(self, atoms):
        vocab = self.special_tokens + list(set(atoms))
        for i, s in enumerate(vocab):
            self.stoi[s] = i
        assert self.stoi[PAD] == PAD_ID
        assert self.stoi[SOS] == SOS_ID
        assert self.stoi[EOS] == EOS_ID
        assert self.stoi[UNK] == UNK_ID
        assert self.stoi[MASK] == MASK_ID
        self.itos = {item[1]: item[0] for item in self.stoi.items()}

    def is_x(self, x):
        return self.offset <= x < self.offset + self.maxx

    def is_y(self, y):
        if self.sep_xy:
            return self.offset + self.maxx <= y
        return self.offset <= y

    def is_symbol(self, s):
        return len(self.special_tokens) <= s < self.offset or s == UNK_ID

    def is_atom(self, id):
        if self.is_symbol(id):
            return self.is_atom_token(self.itos[id])
        return False

    def is_atom_token(self, token):
        return token.isalpha() or token.startswith("[") or token == '*' or token == UNK

    def x_to_id(self, x):
        return self.offset + round(x * (self.maxx - 1))

    def y_to_id(self, y):
        if self.sep_xy:
            return self.offset + self.maxx + round(y * (self.maxy - 1))
        return self.offset + round(y * (self.maxy - 1))

    def id_to_x(self, id):
        return (id - self.offset) / (self.maxx - 1)

    def id_to_y(self, id):
        if self.sep_xy:
            return (id - self.offset - self.maxx) / (self.maxy - 1)
        return (id - self.offset) / (self.maxy - 1)
    
    def get_output_mask(self, id):
        mask = [False] * len(self)
        if self.continuous_coords:
            return mask
        if self.is_atom(id):
            return [True] * self.offset + [False] * self.maxx + [True] * self.maxy
        if self.is_x(id):
            return [True] * (self.offset + self.maxx) + [False] * self.maxy
        if self.is_y(id):
            return [False] * self.offset + [True] * (self.maxx + self.maxy)
        return mask

    def symbol_to_id(self, symbol):
        if symbol not in self.stoi:
            return UNK_ID
        return self.stoi[symbol]

    def symbols_to_labels(self, symbols):
        labels = []
        for symbol in symbols:
            labels.append(self.symbol_to_id(symbol))
        return labels

    def labels_to_symbols(self, labels):
        symbols = []
        for label in labels:
            symbols.append(self.itos[label])
        return symbols

    def nodes_to_grid(self, nodes):
        coords, symbols = nodes['coords'], nodes['symbols']
        grid = np.zeros((self.maxx, self.maxy), dtype=int)
        for [x, y], symbol in zip(coords, symbols):
            x = round(x * (self.maxx - 1))
            y = round(y * (self.maxy - 1))
            grid[x][y] = self.symbol_to_id(symbol)
        return grid

    def grid_to_nodes(self, grid):
        coords, symbols, indices = [], [], []
        for i in range(self.maxx):
            for j in range(self.maxy):
                if grid[i][j] != 0:
                    x = i / (self.maxx - 1)
                    y = j / (self.maxy - 1)
                    coords.append([x, y])
                    symbols.append(self.itos[grid[i][j]])
                    indices.append([i, j])
        return {'coords': coords, 'symbols': symbols, 'indices': indices}

    def nodes_to_sequence(self, nodes):
        coords, symbols = nodes['coords'], nodes['symbols']
        labels = [SOS_ID]
        for (x, y), symbol in zip(coords, symbols):
            assert 0 <= x <= 1
            assert 0 <= y <= 1
            labels.append(self.x_to_id(x))
            labels.append(self.y_to_id(y))
            labels.append(self.symbol_to_id(symbol))
        labels.append(EOS_ID)
        return labels

    def sequence_to_nodes(self, sequence):
        coords, symbols = [], []
        i = 0
        if sequence[0] == SOS_ID:
            i += 1
        while i + 2 < len(sequence):
            if sequence[i] == EOS_ID:
                break
            if self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]):
                x = self.id_to_x(sequence[i])
                y = self.id_to_y(sequence[i+1])
                symbol = self.itos[sequence[i+2]]
                coords.append([x, y])
                symbols.append(symbol)
            i += 3
        return {'coords': coords, 'symbols': symbols}

    def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False):
        tokens = atomwise_tokenizer(smiles)
        labels = [SOS_ID]
        indices = []
        atom_idx = -1
        for token in tokens:
            if atom_only and not self.is_atom_token(token):
                continue
            if token in self.stoi:
                labels.append(self.stoi[token])
            else:
                if self.debug:
                    print(f'{token} not in vocab')
                labels.append(UNK_ID)
            if self.is_atom_token(token):
                atom_idx += 1
                if not self.continuous_coords:
                    if mask_ratio > 0 and random.random() < mask_ratio:
                        labels.append(MASK_ID)
                        labels.append(MASK_ID)
                    elif coords is not None:
                        if atom_idx < len(coords):
                            x, y = coords[atom_idx]
                            assert 0 <= x <= 1
                            assert 0 <= y <= 1
                        else:
                            x = random.random()
                            y = random.random()
                        labels.append(self.x_to_id(x))
                        labels.append(self.y_to_id(y))
                indices.append(len(labels) - 1)
        labels.append(EOS_ID)
        return labels, indices

    def sequence_to_smiles(self, sequence):
        has_coords = not self.continuous_coords
        smiles = ''
        coords, symbols, indices = [], [], []
        for i, label in enumerate(sequence):
            if label == EOS_ID or label == PAD_ID:
                break
            if self.is_x(label) or self.is_y(label):
                continue
            token = self.itos[label]
            smiles += token
            if self.is_atom_token(token):
                if has_coords:
                    if i+3 < len(sequence) and self.is_x(sequence[i+1]) and self.is_y(sequence[i+2]):
                        x = self.id_to_x(sequence[i+1])
                        y = self.id_to_y(sequence[i+2])
                        coords.append([x, y])
                        symbols.append(token)
                        indices.append(i+3)
                else:
                    if i+1 < len(sequence):
                        symbols.append(token)
                        indices.append(i+1)
        results = {'smiles': smiles, 'symbols': symbols, 'indices': indices}
        if has_coords:
            results['coords'] = coords
        return results


class CharTokenizer(NodeTokenizer):

    def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False):
        super().__init__(input_size, path, sep_xy, continuous_coords, debug)

    def fit_on_texts(self, texts):
        vocab = set()
        for text in texts:
            vocab.update(list(text))
        if ' ' in vocab:
            vocab.remove(' ')
        vocab = [PAD, SOS, EOS, UNK] + list(vocab)
        for i, s in enumerate(vocab):
            self.stoi[s] = i
        self.itos = {item[1]: item[0] for item in self.stoi.items()}
        assert self.stoi[PAD] == PAD_ID
        assert self.stoi[SOS] == SOS_ID
        assert self.stoi[EOS] == EOS_ID
        assert self.stoi[UNK] == UNK_ID

    def text_to_sequence(self, text, tokenized=True):
        sequence = []
        sequence.append(self.stoi['<sos>'])
        if tokenized:
            tokens = text.split(' ')
            assert all(len(s) == 1 for s in tokens)
        else:
            tokens = list(text)
        for s in tokens:
            if s not in self.stoi:
                s = '<unk>'
            sequence.append(self.stoi[s])
        sequence.append(self.stoi['<eos>'])
        return sequence

    def fit_atom_symbols(self, atoms):
        atoms = list(set(atoms))
        chars = []
        for atom in atoms:
            chars.extend(list(atom))
        vocab = self.special_tokens + chars
        for i, s in enumerate(vocab):
            self.stoi[s] = i
        assert self.stoi[PAD] == PAD_ID
        assert self.stoi[SOS] == SOS_ID
        assert self.stoi[EOS] == EOS_ID
        assert self.stoi[UNK] == UNK_ID
        assert self.stoi[MASK] == MASK_ID
        self.itos = {item[1]: item[0] for item in self.stoi.items()}

    def get_output_mask(self, id):
        ''' TO FIX '''
        mask = [False] * len(self)
        if self.continuous_coords:
            return mask
        if self.is_x(id):
            return [True] * (self.offset + self.maxx) + [False] * self.maxy
        if self.is_y(id):
            return [False] * self.offset + [True] * (self.maxx + self.maxy)
        return mask

    def nodes_to_sequence(self, nodes):
        coords, symbols = nodes['coords'], nodes['symbols']
        labels = [SOS_ID]
        for (x, y), symbol in zip(coords, symbols):
            assert 0 <= x <= 1
            assert 0 <= y <= 1
            labels.append(self.x_to_id(x))
            labels.append(self.y_to_id(y))
            for char in symbol:
                labels.append(self.symbol_to_id(char))
        labels.append(EOS_ID)
        return labels

    def sequence_to_nodes(self, sequence):
        coords, symbols = [], []
        i = 0
        if sequence[0] == SOS_ID:
            i += 1
        while i < len(sequence):
            if sequence[i] == EOS_ID:
                break
            if i+2 < len(sequence) and self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]):
                x = self.id_to_x(sequence[i])
                y = self.id_to_y(sequence[i+1])
                for j in range(i+2, len(sequence)):
                    if not self.is_symbol(sequence[j]):
                        break
                symbol = ''.join(self.itos(sequence[k]) for k in range(i+2, j))
                coords.append([x, y])
                symbols.append(symbol)
                i = j
            else:
                i += 1
        return {'coords': coords, 'symbols': symbols}

    def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False):
        tokens = atomwise_tokenizer(smiles)
        labels = [SOS_ID]
        indices = []
        atom_idx = -1
        for token in tokens:
            if atom_only and not self.is_atom_token(token):
                continue
            for c in token:
                if c in self.stoi:
                    labels.append(self.stoi[c])
                else:
                    if self.debug:
                        print(f'{c} not in vocab')
                    labels.append(UNK_ID)
            if self.is_atom_token(token):
                atom_idx += 1
                if not self.continuous_coords:
                    if mask_ratio > 0 and random.random() < mask_ratio:
                        labels.append(MASK_ID)
                        labels.append(MASK_ID)
                    elif coords is not None:
                        if atom_idx < len(coords):
                            x, y = coords[atom_idx]
                            assert 0 <= x <= 1
                            assert 0 <= y <= 1
                        else:
                            x = random.random()
                            y = random.random()
                        labels.append(self.x_to_id(x))
                        labels.append(self.y_to_id(y))
                indices.append(len(labels) - 1)
        labels.append(EOS_ID)
        return labels, indices

    def sequence_to_smiles(self, sequence):
        has_coords = not self.continuous_coords
        smiles = ''
        coords, symbols, indices = [], [], []
        i = 0
        while i < len(sequence):
            label = sequence[i]
            if label == EOS_ID or label == PAD_ID:
                break
            if self.is_x(label) or self.is_y(label):
                i += 1
                continue
            if not self.is_atom(label):
                smiles += self.itos[label]
                i += 1
                continue
            if self.itos[label] == '[':
                j = i + 1
                while j < len(sequence):
                    if not self.is_symbol(sequence[j]):
                        break
                    if self.itos[sequence[j]] == ']':
                        j += 1
                        break
                    j += 1
            else:
                if i+1 < len(sequence) and (self.itos[label] == 'C' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'l' \
                        or self.itos[label] == 'B' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'r'):
                    j = i+2
                else:
                    j = i+1
            token = ''.join(self.itos[sequence[k]] for k in range(i, j))
            smiles += token
            if has_coords:
                if j+2 < len(sequence) and self.is_x(sequence[j]) and self.is_y(sequence[j+1]):
                    x = self.id_to_x(sequence[j])
                    y = self.id_to_y(sequence[j+1])
                    coords.append([x, y])
                    symbols.append(token)
                    indices.append(j+2)
                    i = j+2
                else:
                    i = j
            else:
                if j < len(sequence):
                    symbols.append(token)
                    indices.append(j)
                i = j
        results = {'smiles': smiles, 'symbols': symbols, 'indices': indices}
        if has_coords:
            results['coords'] = coords
        return results


def get_tokenizer(args):
    tokenizer = {}
    for format_ in args.formats:
        if format_ == 'atomtok':
            if args.vocab_file is None:
                args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json')
            tokenizer['atomtok'] = Tokenizer(args.vocab_file)
        elif format_ == "atomtok_coords":
            if args.vocab_file is None:
                args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json')
            tokenizer["atomtok_coords"] = NodeTokenizer(args.coord_bins, args.vocab_file, args.sep_xy,
                                                        continuous_coords=args.continuous_coords)
        elif format_ == "chartok_coords":
            if args.vocab_file is None:
                args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_chars.json')
            tokenizer["chartok_coords"] = CharTokenizer(args.coord_bins, args.vocab_file, args.sep_xy,
                                                        continuous_coords=args.continuous_coords)
    return tokenizer