Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,16 +5,16 @@ from PIL import Image
|
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
|
8 |
-
#
|
9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
|
11 |
-
#
|
12 |
model = AutoModel.from_pretrained('dhhd255/EfficientNet_ParkinsonsPred')
|
13 |
|
14 |
-
#
|
15 |
model = model.to(device)
|
16 |
|
17 |
-
#
|
18 |
st.markdown("""
|
19 |
<style>
|
20 |
@import url('https://fonts.googleapis.com/css2?family=Inter&display=swap');
|
@@ -43,29 +43,29 @@ st.markdown("""
|
|
43 |
</style>
|
44 |
""", unsafe_allow_html=True)
|
45 |
|
46 |
-
st.title("
|
47 |
|
48 |
-
uploaded_file = st.file_uploader("
|
49 |
st.empty()
|
50 |
if uploaded_file is not None:
|
51 |
col1, col2 = st.columns(2)
|
52 |
|
53 |
-
#
|
54 |
image_size = (224, 224)
|
55 |
new_image = Image.open(uploaded_file).convert('RGB').resize(image_size)
|
56 |
col1.image(new_image, width=255)
|
57 |
new_image = np.array(new_image)
|
58 |
new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)
|
59 |
|
60 |
-
#
|
61 |
new_image = new_image.to(device)
|
62 |
|
63 |
-
#
|
64 |
with torch.no_grad():
|
65 |
predictions = model(new_image)
|
66 |
logits = predictions.last_hidden_state
|
67 |
logits = logits.view(logits.shape[0], -1)
|
68 |
-
num_classes=2
|
69 |
feature_reducer = nn.Linear(logits.shape[1], num_classes)
|
70 |
|
71 |
logits = logits.to(device)
|
@@ -74,19 +74,21 @@ if uploaded_file is not None:
|
|
74 |
logits = feature_reducer(logits)
|
75 |
predicted_class = torch.argmax(logits, dim=1).item()
|
76 |
confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
80 |
else:
|
81 |
-
col2.markdown('<span class="result healthy">
|
82 |
-
col2.caption(f'{confidence*100:.0f}%
|
83 |
|
84 |
-
uploaded_file = st.file_uploader("
|
85 |
st.divider()
|
86 |
st.markdown("""
|
87 |
<div class="social-links">
|
88 |
-
<a href="https://twitter.com/
|
89 |
-
<a href="https://facebook.com/
|
90 |
-
<a href="https://instagram.com/
|
91 |
</div>
|
92 |
""", unsafe_allow_html=True)
|
|
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
|
8 |
+
# Definir o dispositivo
|
9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
|
11 |
+
# Carregar o modelo treinado do Hugging Face Hub
|
12 |
model = AutoModel.from_pretrained('dhhd255/EfficientNet_ParkinsonsPred')
|
13 |
|
14 |
+
# Mover o modelo para o dispositivo
|
15 |
model = model.to(device)
|
16 |
|
17 |
+
# Adicionar CSS personalizado para usar a fonte Inter, definir classes personalizadas para resultados de saída e estilos de rodapé
|
18 |
st.markdown("""
|
19 |
<style>
|
20 |
@import url('https://fonts.googleapis.com/css2?family=Inter&display=swap');
|
|
|
43 |
</style>
|
44 |
""", unsafe_allow_html=True)
|
45 |
|
46 |
+
st.title("Predição de Doença de Parkinson")
|
47 |
|
48 |
+
uploaded_file = st.file_uploader("Envie seu desenho em espiral aqui", type=["png", "jpg", "jpeg"])
|
49 |
st.empty()
|
50 |
if uploaded_file is not None:
|
51 |
col1, col2 = st.columns(2)
|
52 |
|
53 |
+
# Carregar e redimensionar a imagem
|
54 |
image_size = (224, 224)
|
55 |
new_image = Image.open(uploaded_file).convert('RGB').resize(image_size)
|
56 |
col1.image(new_image, width=255)
|
57 |
new_image = np.array(new_image)
|
58 |
new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)
|
59 |
|
60 |
+
# Mover os dados para o dispositivo
|
61 |
new_image = new_image.to(device)
|
62 |
|
63 |
+
# Fazer previsões usando o modelo treinado
|
64 |
with torch.no_grad():
|
65 |
predictions = model(new_image)
|
66 |
logits = predictions.last_hidden_state
|
67 |
logits = logits.view(logits.shape[0], -1)
|
68 |
+
num_classes = 2
|
69 |
feature_reducer = nn.Linear(logits.shape[1], num_classes)
|
70 |
|
71 |
logits = logits.to(device)
|
|
|
74 |
logits = feature_reducer(logits)
|
75 |
predicted_class = torch.argmax(logits, dim=1).item()
|
76 |
confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
|
77 |
+
|
78 |
+
# Traduzir as mensagens de saída
|
79 |
+
if predicted_class == 0:
|
80 |
+
col2.markdown('<span class="result parkinsons">Classe prevista: Parkinson</span>', unsafe_allow_html=True)
|
81 |
+
col2.caption(f'{confidence*100:.0f}% de certeza')
|
82 |
else:
|
83 |
+
col2.markdown('<span class="result healthy">Classe prevista: Saudável</span>', unsafe_allow_html=True)
|
84 |
+
col2.caption(f'{confidence*100:.0f}% de certeza')
|
85 |
|
86 |
+
uploaded_file = st.file_uploader("Envie seu desenho de onda aqui", type=["png", "jpg", "jpeg"])
|
87 |
st.divider()
|
88 |
st.markdown("""
|
89 |
<div class="social-links">
|
90 |
+
<a href="https://twitter.com/seu_twitter">Twitter</a>
|
91 |
+
<a href="https://facebook.com/sua_pagina_facebook">Facebook</a>
|
92 |
+
<a href="https://instagram.com/seu_instagram">Instagram</a>
|
93 |
</div>
|
94 |
""", unsafe_allow_html=True)
|