jegilj commited on
Commit
56662df
verified
1 Parent(s): ed70c3b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_fastai
2
+
3
+ import gradio as gr
4
+
5
+ from fastai.vision.all import *
6
+
7
+ import torchvision.transforms as transforms
8
+ import torchvision.transforms as transforms
9
+
10
+ from fastai.basics import *
11
+ from fastai.vision import models
12
+ from fastai.vision.all import *
13
+ from fastai.metrics import *
14
+ from fastai.data.all import *
15
+ from fastai.callback import *
16
+ from pathlib import Path
17
+
18
+ import random
19
+ import PIL
20
+
21
+ #Definimos las funciones de transformacion que hemos creado en la practica para poder tratar los datos de entrada y que funcione bien
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ def transform_image(image):
24
+ my_transforms = transforms.Compose([transforms.ToTensor(),
25
+ transforms.Normalize(
26
+ [0.485, 0.456, 0.406],
27
+ [0.229, 0.224, 0.225])])
28
+ image_aux = image
29
+ return my_transforms(image_aux).unsqueeze(0).to(device)
30
+
31
+ class TargetMaskConvertTransform(ItemTransform):
32
+ def __init__(self):
33
+ pass
34
+ def encodes(self, x):
35
+ img,mask = x
36
+
37
+ #Convertimos a array
38
+ mask = np.array(mask)
39
+
40
+ mask[(mask!=255) & (mask!=150) & (mask!=76) & (mask!=74) & (mask!=29) & (mask!=25)]=0
41
+ mask[mask==255]=1
42
+ mask[mask==150]=2
43
+ mask[mask==76]=4
44
+ mask[mask==74]=4
45
+ mask[mask==29]=3
46
+ mask[mask==25]=3
47
+
48
+ # Back to PILMask
49
+ mask = PILMask.create(mask)
50
+ return img, mask
51
+
52
+ from albumentations import (
53
+ Compose,
54
+ OneOf,
55
+ ElasticTransform,
56
+ GridDistortion,
57
+ OpticalDistortion,
58
+ HorizontalFlip,
59
+ Rotate,
60
+ Transpose,
61
+ CLAHE,
62
+ ShiftScaleRotate
63
+ )
64
+
65
+ def get_y_fn (x):
66
+ return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
67
+
68
+ class SegmentationAlbumentationsTransform(ItemTransform):
69
+ split_idx = 0
70
+
71
+ def __init__(self, aug):
72
+ self.aug = aug
73
+
74
+ def encodes(self, x):
75
+ img,mask = x
76
+ aug = self.aug(image=np.array(img), mask=np.array(mask))
77
+ return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
78
+
79
+ #Cargamos el modelo
80
+ repo_id = "luisvarona/Practica3"
81
+ learn = from_pretrained_fastai(repo_id)
82
+ model = learn.model
83
+ model = model.cpu()
84
+
85
+
86
+ # Funcion de predicci贸n
87
+ def predict(img_ruta):
88
+ img = PIL.Image.fromarray(img_ruta)
89
+ image = transforms.Resize((480,640))(img)
90
+ tensor = transform_image(image=image)
91
+ model.to(device)
92
+ with torch.no_grad():
93
+ outputs = model(tensor)
94
+
95
+ outputs = torch.argmax(outputs,1)
96
+ mask = np.array(outputs.cpu())
97
+ mask[mask==1]=255
98
+ mask[mask==2]=150
99
+ mask[mask==3]=29
100
+ mask[mask==4]=74
101
+ mask = np.reshape(mask,(480,640))
102
+ return Image.fromarray(mask.astype('uint8'))
103
+
104
+
105
+ # Creamos la interfaz y la lanzamos.
106
+ gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(480, 640)), outputs=gr.inputs.Image(shape=(480, 640)), examples=['color_184.jpg','color_189.jpg']).launch(share=False)