Spaces:
Runtime error
Runtime error
jeremyLE-Ekimetrics
commited on
Commit
•
4debc65
1
Parent(s):
86735e0
fix water
Browse files- biomap/inference.py +8 -1
- biomap/streamlit_app.py +1 -4
biomap/inference.py
CHANGED
@@ -13,6 +13,7 @@ preprocess = T.Compose(
|
|
13 |
]
|
14 |
)
|
15 |
|
|
|
16 |
def inference(images, model):
|
17 |
logging.info("Inference on Images")
|
18 |
x = torch.stack([preprocess(image) for image in images]).cpu()
|
@@ -25,6 +26,10 @@ def inference(images, model):
|
|
25 |
"img": x[i].detach().cpu(),
|
26 |
"linear_preds": linear_pred[i].detach().cpu(),
|
27 |
} for i in range(x.shape[0])]
|
|
|
|
|
|
|
|
|
28 |
return outputs
|
29 |
|
30 |
|
@@ -32,6 +37,7 @@ if __name__ == "__main__":
|
|
32 |
import hydra
|
33 |
from model import LitUnsupervisedSegmenter
|
34 |
from utils_gee import extract_img, transform_ee_img
|
|
|
35 |
latitude = 2.98
|
36 |
longitude = 48.81
|
37 |
start_date = '2020-03-20'
|
@@ -49,7 +55,8 @@ if __name__ == "__main__":
|
|
49 |
cfg = hydra.compose(config_name="my_train_config.yml")
|
50 |
|
51 |
# Load the model
|
52 |
-
|
|
|
53 |
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
54 |
|
55 |
nbclasses = cfg.dir_dataset_n_classes
|
|
|
13 |
]
|
14 |
)
|
15 |
|
16 |
+
import numpy as np
|
17 |
def inference(images, model):
|
18 |
logging.info("Inference on Images")
|
19 |
x = torch.stack([preprocess(image) for image in images]).cpu()
|
|
|
26 |
"img": x[i].detach().cpu(),
|
27 |
"linear_preds": linear_pred[i].detach().cpu(),
|
28 |
} for i in range(x.shape[0])]
|
29 |
+
|
30 |
+
# water to natural green
|
31 |
+
for output in outputs:
|
32 |
+
output["linear_preds"] = torch.where(output["linear_preds"] == 5, 3, output["linear_preds"])
|
33 |
return outputs
|
34 |
|
35 |
|
|
|
37 |
import hydra
|
38 |
from model import LitUnsupervisedSegmenter
|
39 |
from utils_gee import extract_img, transform_ee_img
|
40 |
+
import os
|
41 |
latitude = 2.98
|
42 |
longitude = 48.81
|
43 |
start_date = '2020-03-20'
|
|
|
55 |
cfg = hydra.compose(config_name="my_train_config.yml")
|
56 |
|
57 |
# Load the model
|
58 |
+
|
59 |
+
model_path = os.path.join(os.path.dirname(__file__), "checkpoint/model/model.pt")
|
60 |
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
61 |
|
62 |
nbclasses = cfg.dir_dataset_n_classes
|
biomap/streamlit_app.py
CHANGED
@@ -64,6 +64,7 @@ def app(model):
|
|
64 |
st.markdown("<p style='text-align: center;'>The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
|
65 |
st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
|
66 |
st.markdown("<p style='text-align: center;'>The score is then averaged on the full image.</p>", unsafe_allow_html=True)
|
|
|
67 |
if st.session_state["submit"]:
|
68 |
fig = inference_on_location(model, st.session_state["lat"], st.session_state["long"], st.session_state["start_date"], st.session_state["end_date"], st.session_state["segment_interval"])
|
69 |
st.session_state["infered"] = True
|
@@ -76,10 +77,6 @@ def app(model):
|
|
76 |
|
77 |
if st.session_state["infered"]:
|
78 |
st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
col_1, col_2 = st.columns([0.5, 0.5])
|
85 |
with col_1:
|
|
|
64 |
st.markdown("<p style='text-align: center;'>The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
|
65 |
st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
|
66 |
st.markdown("<p style='text-align: center;'>The score is then averaged on the full image.</p>", unsafe_allow_html=True)
|
67 |
+
|
68 |
if st.session_state["submit"]:
|
69 |
fig = inference_on_location(model, st.session_state["lat"], st.session_state["long"], st.session_state["start_date"], st.session_state["end_date"], st.session_state["segment_interval"])
|
70 |
st.session_state["infered"] = True
|
|
|
77 |
|
78 |
if st.session_state["infered"]:
|
79 |
st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)
|
|
|
|
|
|
|
|
|
80 |
|
81 |
col_1, col_2 = st.columns([0.5, 0.5])
|
82 |
with col_1:
|