csaybar commited on
Commit
9c55c41
1 Parent(s): 7a13af2

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sr4rs/weights/cesbio_model/sr4rs_sentinel2_bands4328_france2020_savedmodel/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
sr4rs/run.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import opensr_test
3
+ import matplotlib.pyplot as plt
4
+ from utils import load_cesbio_sr, run_sr4rs
5
+
6
+
7
+ # Load the model
8
+ model = load_cesbio_sr()
9
+
10
+
11
+ # Load the dataset
12
+ dataset = opensr_test.load("naip")
13
+ lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
14
+
15
+ # Predict a image
16
+ results = run_sr4rs(
17
+ model=model,
18
+ lr=lr_dataset[2],
19
+ hr=hr_dataset[2],
20
+ )
21
+
22
+ # Display the results
23
+ fig, ax = plt.subplots(1, 3, figsize=(10, 5))
24
+ ax[0].imshow(results["lr"].transpose(1, 2, 0)/3000)
25
+ ax[0].set_title("LR")
26
+ ax[0].axis("off")
27
+ ax[1].imshow(results["sr"].transpose(1, 2, 0)/3000)
28
+ ax[1].set_title("SR")
29
+ ax[1].axis("off")
30
+ ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000)
31
+ ax[2].set_title("HR")
32
+ plt.show()
sr4rs/utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import torch
3
+
4
+ def load_cesbio_sr() -> tf.function:
5
+ """Prepare the CESBIO model
6
+
7
+ Returns:
8
+ tf.function: A tf.function to get the SR image
9
+ """
10
+
11
+ # read the model
12
+ model = tf.saved_model.load("weights/cesbio_model/sr4rs_sentinel2_bands4328_france2020_savedmodel")
13
+
14
+ # get the signature
15
+ signature = list(model.signatures.keys())[0]
16
+
17
+ # get the function
18
+ func = model.signatures[signature]
19
+
20
+ return func
21
+
22
+ def run_sr4rs(
23
+ model: tf.function,
24
+ lr: tf.Tensor,
25
+ hr: tf.Tensor,
26
+ ) -> dict:
27
+ """Run the SR4RS model
28
+
29
+ Args:
30
+ model (tf.function): The model to use
31
+ lr (tf.Tensor): The low resolution image
32
+ hr (tf.Tensor): The high resolution image
33
+ cropsize (int, optional): The cropsize. Defaults to 32.
34
+ overlap (int, optional): The overlap. Defaults to 0.
35
+
36
+ Returns:
37
+ dict: The results
38
+ """
39
+ # Run inference
40
+ Xnp = torch.from_numpy(lr[[3, 2, 1, 7]][None]).permute(0, 2, 3, 1)
41
+ Xtf = tf.convert_to_tensor(Xnp, dtype=tf.float32)
42
+ pred = model(Xtf)
43
+
44
+ # Save the results
45
+ pred_np = pred['output_32:0'].numpy()
46
+ pred_torch = torch.from_numpy(pred_np).permute(0, 3, 1, 2)
47
+ pred_torch_padded = torch.nn.functional.pad(
48
+ pred_torch,
49
+ (32, 32, 32, 32),
50
+ mode='constant',
51
+ value=0,
52
+ ).squeeze().numpy().astype('uint16')
53
+
54
+ results = {
55
+ "lr": lr[[3, 2, 1]],
56
+ "sr": pred_torch_padded[0:3],
57
+ "hr": hr[0:3],
58
+ }
59
+
60
+ return results
sr4rs/weights/cesbio_model/sr4rs_sentinel2_bands4328_france2020_savedmodel/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31562e5cce3bc52576d4cbdb066bb6552b7ffc846f03022d4f9b7a5e6dd6b727
3
+ size 486539011
sr4rs/weights/cesbio_model/sr4rs_sentinel2_bands4328_france2020_savedmodel/variables/variables.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4045311226af9e908b6c741301399a556b2705f20963e9d80213ce1a1fac81a3
3
+ size 297612052
sr4rs/weights/cesbio_model/sr4rs_sentinel2_bands4328_france2020_savedmodel/variables/variables.index ADDED
Binary file (21.7 kB). View file