import math, numpy as np, requests, os, time, warnings, json
from PIL import Image
from importlib import import_module
import torch

class AstroSleuth():
    def __init__(self, tile_size=256, tile_pad=16, wrk_dir="models/", model_name="astrosleuthv2", force_cpu=False, on_download=None, off_download=None):
        # Device selection
        self.device = "cpu" if force_cpu else ("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
        
        # Check if model name is known
        model_src:dict = json.load(open("models.json"))["data"]
        assert model_name in model_src, f"Model {model_name} not found! Available models: {list(model_src.keys())}"

        # Load model module
        module_path = model_src[model_name]["src"]["module"]

        self.model_module:torch.nn.Module = getattr(
            import_module(module_path.split("/")[0]),
            module_path.split("/")[1]
        )

        # Download model if not available
        self.model_pth = os.path.join(wrk_dir, f"{model_name}/model.pth")
        self.download(model_src[model_name]["src"]["url"], self.model_pth, model_name, on_download, off_download)
            
        self.wrk_dir = wrk_dir
        self.progress = None

        # Set tile processing parameters
        self.scale = model_src[model_name]["scale"]
        self.tile_size = tile_size
        self.tile_pad = tile_pad
        
    def download(self, src, dst, model_name, on_download=None, off_download=None):
        if not os.path.exists(dst):
            assert not src is None, "That model is not available for downloading - Are you on experimental?"
            os.makedirs(os.path.dirname(dst), exist_ok=True)

            if on_download is not None:
                on_download(model_name)

            with open(dst, 'wb') as f:
                f.write(requests.get(src, allow_redirects=True, headers={"User-Agent":""}).content)

            if off_download is not None:
                off_download()

    def model_inference(self, x: np.ndarray, args:dict={}):
        x = torch.from_numpy(x).to(self.device)
        if not args is None:
            return self.model(x=x, **args).cpu().detach().numpy()
        else:
            return self.model(x=x).cpu().detach().numpy()

    def tile_generator(self, data: np.ndarray, yield_extra_details=False, args={}):
        """
        Process data [height, width, channel] into tiles of size [tile_size, tile_size, channel],
        feed them one by one into the model, then yield the resulting output tiles.
        """

        # [height, width, channel] -> [1, channel, height, width]
        data = np.rollaxis(data, 2, 0)
        data = np.expand_dims(data, axis=0)
        data = np.clip(data, 0, 255)

        batch, channel, height, width = data.shape

        tiles_x = width // self.tile_size
        tiles_y = height // self.tile_size

        for i in range(tiles_y * tiles_x):
            x = i % tiles_y
            y = math.floor(i/tiles_y)

            input_start_x = y * self.tile_size
            input_start_y = x * self.tile_size

            input_end_x = min(input_start_x + self.tile_size, width)
            input_end_y = min(input_start_y + self.tile_size, height)

            input_start_x_pad = max(input_start_x - self.tile_pad, 0)
            input_end_x_pad = min(input_end_x + self.tile_pad, width)
            input_start_y_pad = max(input_start_y - self.tile_pad, 0)
            input_end_y_pad = min(input_end_y + self.tile_pad, height)

            input_tile_width = input_end_x - input_start_x
            input_tile_height = input_end_y - input_start_y

            input_tile = data[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad].astype(np.float32) / 255

            output_tile = self.model_inference(input_tile, args)
            self.progress = (i+1) / (tiles_y * tiles_x)
            
            output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
            output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
            output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
            output_end_y_tile = output_start_y_tile + input_tile_height * self.scale

            output_tile = output_tile[:, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile]

            output_tile = (np.rollaxis(output_tile, 1, 4).squeeze(0).clip(0,1) * 255).astype(np.uint8)
            
            if yield_extra_details:
                yield (output_tile, input_start_x, input_start_y, input_tile_width, input_tile_height, self.progress)
            else:
                yield output_tile
        
        yield None
    
    def load_model(self):
        model:torch.nn.Module = self.model_module().to(self.device)
        model.load_state_dict(torch.load(self.model_pth, map_location=torch.device(self.device)), strict=False)
        model.eval()
        return model

    def enhance_with_progress(self, image:Image, args:dict={}):    
        """
        Take a PIL image and enhance it with the model, yielding stats about the
        final image and then the final image itself.
        """    

        # Load model only now because when using streamlit, multiple users spawn multiple instances of this class, so 
        # we only load the model when needed. The App() class is responsible for queuing requests to this class
        self.model = self.load_model()
        original_width, original_height = image.size

        # Because tiles may not fit perfectly, we resize to the closest multiple of tile_size
        image = image.resize((max(original_width//self.tile_size * self.tile_size, self.tile_size), max(original_height//self.tile_size * self.tile_size, self.tile_size)), resample=Image.Resampling.BICUBIC)
        image = np.array(image)

        # Initiate a pillow image to save the tiles
        result = Image.new("RGB", (image.shape[1]*self.scale, image.shape[0]*self.scale))
        
        for i, tile in enumerate(self.tile_generator(image, yield_extra_details=True, args=args)):
            
            if tile is None:
                break
            
            tile_data, x, y, w, h, p = tile
            result.paste(Image.fromarray(tile_data), (x*self.scale, y*self.scale))
            yield p
        
        # Resize back to the expected size
        yield result.resize((original_width * self.scale, original_height * self.scale), resample=Image.Resampling.BICUBIC)
        
    def enhance(self, image:Image) -> Image:
        """
        Skips the progress reporting and just returns the final image.
        """
        return list(self.enhance_with_progress(image))[-1]
    
if __name__ == '__main__':
    import sys

    # User ran with only "main.py"
    if not len(sys.argv) == 4: 
        print("Use main.py with a source, destination file, and model, eg: 'python3 main.py img.png upscaled.png astrosleuthv2'")
        print("You might also be interested in using the streamlit interface with: 'streamlit run app.py'")
        quit()

    src = sys.argv[1]
    dst = sys.argv[2]
    model_name = sys.argv[3]

    a = AstroSleuth(model_name=model_name)
    img = Image.open(src)
    r = a.enhance(img)
    r.save(dst)