DHEIVER commited on
Commit
5607695
1 Parent(s): cddfb0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -5,16 +5,16 @@ from PIL import Image
5
  import numpy as np
6
  import streamlit as st
7
 
8
- # Set the device
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
- # Load the trained model from the Hugging Face Hub
12
  model = AutoModel.from_pretrained('dhhd255/EfficientNet_ParkinsonsPred')
13
 
14
- # Move the model to the device
15
  model = model.to(device)
16
 
17
- # Add custom CSS to use the Inter font, define custom classes for healthy and parkinsons results, increase the font size, make the text bold, and define the footer styles
18
  st.markdown("""
19
  <style>
20
  @import url('https://fonts.googleapis.com/css2?family=Inter&display=swap');
@@ -43,29 +43,29 @@ st.markdown("""
43
  </style>
44
  """, unsafe_allow_html=True)
45
 
46
- st.title("Parkinson's Disease Prediction")
47
 
48
- uploaded_file = st.file_uploader("Upload your :blue[Spiral] drawing here", type=["png", "jpg", "jpeg"])
49
  st.empty()
50
  if uploaded_file is not None:
51
  col1, col2 = st.columns(2)
52
 
53
- # Load and resize the image
54
  image_size = (224, 224)
55
  new_image = Image.open(uploaded_file).convert('RGB').resize(image_size)
56
  col1.image(new_image, width=255)
57
  new_image = np.array(new_image)
58
  new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)
59
 
60
- # Move the data to the device
61
  new_image = new_image.to(device)
62
 
63
- # Make predictions using the trained model
64
  with torch.no_grad():
65
  predictions = model(new_image)
66
  logits = predictions.last_hidden_state
67
  logits = logits.view(logits.shape[0], -1)
68
- num_classes=2
69
  feature_reducer = nn.Linear(logits.shape[1], num_classes)
70
 
71
  logits = logits.to(device)
@@ -74,19 +74,21 @@ if uploaded_file is not None:
74
  logits = feature_reducer(logits)
75
  predicted_class = torch.argmax(logits, dim=1).item()
76
  confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
77
- if(predicted_class == 0):
78
- col2.markdown('<span class="result parkinsons">Predicted class: Parkinson\'s</span>', unsafe_allow_html=True)
79
- col2.caption(f'{confidence*100:.0f}% sure')
 
 
80
  else:
81
- col2.markdown('<span class="result healthy">Predicted class: Healthy</span>', unsafe_allow_html=True)
82
- col2.caption(f'{confidence*100:.0f}% sure')
83
 
84
- uploaded_file = st.file_uploader("Upload your :blue[Wave] drawing here", type=["png", "jpg", "jpeg"])
85
  st.divider()
86
  st.markdown("""
87
  <div class="social-links">
88
- <a href="https://twitter.com/your_twitter_handle">Twitter</a>
89
- <a href="https://facebook.com/your_facebook_page">Facebook</a>
90
- <a href="https://instagram.com/your_instagram_handle">Instagram</a>
91
  </div>
92
  """, unsafe_allow_html=True)
 
5
  import numpy as np
6
  import streamlit as st
7
 
8
+ # Definir o dispositivo
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
+ # Carregar o modelo treinado do Hugging Face Hub
12
  model = AutoModel.from_pretrained('dhhd255/EfficientNet_ParkinsonsPred')
13
 
14
+ # Mover o modelo para o dispositivo
15
  model = model.to(device)
16
 
17
+ # Adicionar CSS personalizado para usar a fonte Inter, definir classes personalizadas para resultados de saída e estilos de rodapé
18
  st.markdown("""
19
  <style>
20
  @import url('https://fonts.googleapis.com/css2?family=Inter&display=swap');
 
43
  </style>
44
  """, unsafe_allow_html=True)
45
 
46
+ st.title("Predição de Doença de Parkinson")
47
 
48
+ uploaded_file = st.file_uploader("Envie seu desenho em espiral aqui", type=["png", "jpg", "jpeg"])
49
  st.empty()
50
  if uploaded_file is not None:
51
  col1, col2 = st.columns(2)
52
 
53
+ # Carregar e redimensionar a imagem
54
  image_size = (224, 224)
55
  new_image = Image.open(uploaded_file).convert('RGB').resize(image_size)
56
  col1.image(new_image, width=255)
57
  new_image = np.array(new_image)
58
  new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)
59
 
60
+ # Mover os dados para o dispositivo
61
  new_image = new_image.to(device)
62
 
63
+ # Fazer previsões usando o modelo treinado
64
  with torch.no_grad():
65
  predictions = model(new_image)
66
  logits = predictions.last_hidden_state
67
  logits = logits.view(logits.shape[0], -1)
68
+ num_classes = 2
69
  feature_reducer = nn.Linear(logits.shape[1], num_classes)
70
 
71
  logits = logits.to(device)
 
74
  logits = feature_reducer(logits)
75
  predicted_class = torch.argmax(logits, dim=1).item()
76
  confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
77
+
78
+ # Traduzir as mensagens de saída
79
+ if predicted_class == 0:
80
+ col2.markdown('<span class="result parkinsons">Classe prevista: Parkinson</span>', unsafe_allow_html=True)
81
+ col2.caption(f'{confidence*100:.0f}% de certeza')
82
  else:
83
+ col2.markdown('<span class="result healthy">Classe prevista: Saudável</span>', unsafe_allow_html=True)
84
+ col2.caption(f'{confidence*100:.0f}% de certeza')
85
 
86
+ uploaded_file = st.file_uploader("Envie seu desenho de onda aqui", type=["png", "jpg", "jpeg"])
87
  st.divider()
88
  st.markdown("""
89
  <div class="social-links">
90
+ <a href="https://twitter.com/seu_twitter">Twitter</a>
91
+ <a href="https://facebook.com/sua_pagina_facebook">Facebook</a>
92
+ <a href="https://instagram.com/seu_instagram">Instagram</a>
93
  </div>
94
  """, unsafe_allow_html=True)