import torch from transformers import AutoModel import torch.nn as nn from PIL import Image import numpy as np import streamlit as st # Definir o dispositivo device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Carregar o modelo treinado do Hugging Face Hub model = AutoModel.from_pretrained('dhhd255/EfficientNet_ParkinsonsPred') # Mover o modelo para o dispositivo model = model.to(device) # Adicionar CSS personalizado para usar a fonte Inter, definir classes personalizadas para resultados de saída e estilos de rodapé st.markdown(""" """, unsafe_allow_html=True) st.title("Predição de Doença de Parkinson") uploaded_file = st.file_uploader("Envie seu desenho em espiral aqui", type=["png", "jpg", "jpeg"]) st.empty() if uploaded_file is not None: col1, col2 = st.columns(2) # Carregar e redimensionar a imagem image_size = (224, 224) new_image = Image.open(uploaded_file).convert('RGB').resize(image_size) col1.image(new_image, width=255) new_image = np.array(new_image) new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0) # Mover os dados para o dispositivo new_image = new_image.to(device) # Fazer previsões usando o modelo treinado with torch.no_grad(): predictions = model(new_image) logits = predictions.last_hidden_state logits = logits.view(logits.shape[0], -1) num_classes = 2 feature_reducer = nn.Linear(logits.shape[1], num_classes) logits = logits.to(device) feature_reducer = feature_reducer.to(device) logits = feature_reducer(logits) predicted_class = torch.argmax(logits, dim=1).item() confidence = torch.softmax(logits, dim=1)[0][predicted_class].item() # Traduzir as mensagens de saída if predicted_class == 0: col2.markdown('Classe prevista: Parkinson', unsafe_allow_html=True) col2.caption(f'{confidence*100:.0f}% de certeza') else: col2.markdown('Classe prevista: Saudável', unsafe_allow_html=True) col2.caption(f'{confidence*100:.0f}% de certeza') uploaded_file = st.file_uploader("Envie seu desenho de onda aqui", type=["png", "jpg", "jpeg"]) st.divider() st.markdown(""" """, unsafe_allow_html=True)