import torch
from torch.utils.data import Sampler, ConcatDataset


class RandomConcatSampler(Sampler):
    """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
    in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
    However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
    
    For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
    Args:
        shuffle (bool): shuffle the random sampled indices across all sub-datsets.
        repeat (int): repeatedly use the sampled indices multiple times for training.
            [arXiv:1902.05509, arXiv:1901.09335]
    NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
    NOTE: This sampler behaves differently with DistributedSampler.
          It assume the dataset is splitted across ranks instead of replicated.
    TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
          ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
    """
    def __init__(self,
                 data_source: ConcatDataset,
                 n_samples_per_subset: int,
                 subset_replacement: bool=True,
                 shuffle: bool=True,
                 repeat: int=1,
                 seed: int=None):
        if not isinstance(data_source, ConcatDataset):
            raise TypeError("data_source should be torch.utils.data.ConcatDataset")
        
        self.data_source = data_source
        self.n_subset = len(self.data_source.datasets)
        self.n_samples_per_subset = n_samples_per_subset
        self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
        self.subset_replacement = subset_replacement
        self.repeat = repeat
        self.shuffle = shuffle
        self.generator = torch.manual_seed(seed)
        assert self.repeat >= 1
        
    def __len__(self):
        return self.n_samples
    
    def __iter__(self):
        indices = []
        # sample from each sub-dataset
        for d_idx in range(self.n_subset):
            low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
            high = self.data_source.cumulative_sizes[d_idx]
            if self.subset_replacement:
                rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
                                            generator=self.generator, dtype=torch.int64)
            else:  # sample without replacement
                len_subset = len(self.data_source.datasets[d_idx])
                rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
                if len_subset >= self.n_samples_per_subset:
                    rand_tensor = rand_tensor[:self.n_samples_per_subset]
                else: # padding with replacement
                    rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
                                                            generator=self.generator, dtype=torch.int64)
                    rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
            indices.append(rand_tensor)
        indices = torch.cat(indices)
        if self.shuffle:  # shuffle the sampled dataset (from multiple subsets)
            rand_tensor = torch.randperm(len(indices), generator=self.generator)
            indices = indices[rand_tensor]

        # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
        if self.repeat > 1:
            repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
            if self.shuffle:
                _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
                repeat_indices = map(_choice, repeat_indices)
            indices = torch.cat([indices, *repeat_indices], 0)
        
        assert indices.shape[0] == self.n_samples
        return iter(indices.tolist())