|
import torch |
|
import pickle |
|
import numpy as np |
|
import opensr_test |
|
import onnxruntime as ort |
|
from typing import List, Union |
|
|
|
def load_evoland() -> np.ndarray: |
|
|
|
so = ort.SessionOptions() |
|
so.intra_op_num_threads = 10 |
|
so.inter_op_num_threads = 10 |
|
so.use_deterministic_compute = True |
|
|
|
|
|
ep_list = ["CPUExecutionProvider"] |
|
ep_list.insert(0, "CUDAExecutionProvider") |
|
|
|
ort_session = ort.InferenceSession( |
|
"evoland/weights/carn_3x3x64g4sw_bootstrap.onnx", |
|
sess_options=so, |
|
providers=ep_list |
|
) |
|
ort_session.set_providers(["CPUExecutionProvider"]) |
|
ro = ort.RunOptions() |
|
|
|
return [ort_session, ro] |
|
|
|
|
|
def run_evoland( |
|
model: List, |
|
lr: np.ndarray, |
|
hr: np.ndarray |
|
) -> dict: |
|
|
|
ort_session, ro = model |
|
|
|
|
|
bands = [1, 2, 3, 7, 4, 5, 6, 8, 10, 11] |
|
lr = lr[bands] |
|
|
|
if lr.shape[1] == 121: |
|
|
|
lr = torch.nn.functional.pad( |
|
torch.from_numpy(lr[None]).float(), |
|
pad=(3, 4, 3, 4), |
|
mode='reflect' |
|
).squeeze().cpu().numpy() |
|
|
|
|
|
sr = ort_session.run( |
|
None, |
|
{"input": lr[None]}, |
|
run_options=ro |
|
)[0].squeeze() |
|
|
|
|
|
sr = sr[:, 3*2:-4*2, 3*2:-4*2].astype(np.uint16) |
|
lr = lr[:, 3:-4, 3:-4].astype(np.uint16) |
|
else: |
|
|
|
sr = ort_session.run( |
|
None, |
|
{"input": lr[None].astype(np.float32)}, |
|
run_options=ro |
|
)[0].squeeze() |
|
|
|
|
|
|
|
if sr.shape[1] != hr.shape[1]: |
|
sr = torch.nn.functional.interpolate( |
|
torch.from_numpy(sr)[None].float(), |
|
size=hr.shape[1:], |
|
mode='nearest' |
|
).squeeze().numpy().astype('uint16') |
|
|
|
|
|
|
|
return { |
|
"lr": lr[[2, 1, 0]], |
|
"sr": sr[[2, 1, 0]], |
|
"hr": hr[0:3] |
|
} |