File size: 2,075 Bytes
7a13af2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c8cf1a
7a13af2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93e9c59
7a13af2
 
 
 
 
7c8cf1a
7a13af2
 
 
93e9c59
 
 
 
 
 
 
 
 
 
7a13af2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import pickle
import numpy as np
import opensr_test
import onnxruntime as ort
from typing import List, Union

def load_evoland() -> np.ndarray:
    # ONNX inference session options
    so = ort.SessionOptions()
    so.intra_op_num_threads = 10
    so.inter_op_num_threads = 10
    so.use_deterministic_compute = True

    # Execute on cpu only
    ep_list = ["CPUExecutionProvider"]
    ep_list.insert(0, "CUDAExecutionProvider")

    ort_session = ort.InferenceSession(
        "evoland/weights/carn_3x3x64g4sw_bootstrap.onnx",
        sess_options=so,
        providers=ep_list
    )
    ort_session.set_providers(["CPUExecutionProvider"])
    ro = ort.RunOptions()    

    return [ort_session, ro]


def run_evoland(
    model: List,
    lr: np.ndarray,
    hr: np.ndarray 
) -> dict:
    
    ort_session, ro = model

    # Bands to use
    bands = [1, 2, 3, 7, 4, 5, 6, 8, 10, 11]
    lr = lr[bands]

    if lr.shape[1] == 121:
        # add padding
        lr = torch.nn.functional.pad(
            torch.from_numpy(lr[None]).float(),
            pad=(3, 4, 3, 4),
            mode='reflect'
        ).squeeze().cpu().numpy()

        # run the model
        sr = ort_session.run(
            None, 
            {"input": lr[None]},
            run_options=ro
        )[0].squeeze()

        # remove padding
        sr = sr[:, 3*2:-4*2, 3*2:-4*2].astype(np.uint16)
        lr = lr[:, 3:-4, 3:-4].astype(np.uint16)
    else:
        # run the model
        sr = ort_session.run(
            None, 
            {"input": lr[None].astype(np.float32)},
            run_options=ro
        )[0].squeeze()
    
    # Use nn interpolation to go back to x2 without distortion
    # during metrics calculation
    if sr.shape[1] != hr.shape[1]:
        sr = torch.nn.functional.interpolate(
            torch.from_numpy(sr)[None].float(),
            size=hr.shape[1:],
            mode='nearest'  
        ).squeeze().numpy().astype('uint16')


    # Run the model
    return {
        "lr": lr[[2, 1, 0]],
        "sr": sr[[2, 1, 0]],
        "hr": hr[0:3]
    }