DHEIVER's picture
Update app.py
5607695
import torch
from transformers import AutoModel
import torch.nn as nn
from PIL import Image
import numpy as np
import streamlit as st
# Definir o dispositivo
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Carregar o modelo treinado do Hugging Face Hub
model = AutoModel.from_pretrained('dhhd255/EfficientNet_ParkinsonsPred')
# Mover o modelo para o dispositivo
model = model.to(device)
# Adicionar CSS personalizado para usar a fonte Inter, definir classes personalizadas para resultados de saída e estilos de rodapé
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter&display=swap');
body {
font-family: 'Inter', sans-serif;
}
.result {
font-size: 24px;
font-weight: bold;
}
.healthy {
color: #007E3F;
}
.parkinsons {
color: #C30000;
}
.social-links {
display: flex;
text-decoration:none;
justify-content: center;
}
.social-links a {
text-decoration:none;
padding: 0 10px;
}
</style>
""", unsafe_allow_html=True)
st.title("Predição de Doença de Parkinson")
uploaded_file = st.file_uploader("Envie seu desenho em espiral aqui", type=["png", "jpg", "jpeg"])
st.empty()
if uploaded_file is not None:
col1, col2 = st.columns(2)
# Carregar e redimensionar a imagem
image_size = (224, 224)
new_image = Image.open(uploaded_file).convert('RGB').resize(image_size)
col1.image(new_image, width=255)
new_image = np.array(new_image)
new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)
# Mover os dados para o dispositivo
new_image = new_image.to(device)
# Fazer previsões usando o modelo treinado
with torch.no_grad():
predictions = model(new_image)
logits = predictions.last_hidden_state
logits = logits.view(logits.shape[0], -1)
num_classes = 2
feature_reducer = nn.Linear(logits.shape[1], num_classes)
logits = logits.to(device)
feature_reducer = feature_reducer.to(device)
logits = feature_reducer(logits)
predicted_class = torch.argmax(logits, dim=1).item()
confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
# Traduzir as mensagens de saída
if predicted_class == 0:
col2.markdown('<span class="result parkinsons">Classe prevista: Parkinson</span>', unsafe_allow_html=True)
col2.caption(f'{confidence*100:.0f}% de certeza')
else:
col2.markdown('<span class="result healthy">Classe prevista: Saudável</span>', unsafe_allow_html=True)
col2.caption(f'{confidence*100:.0f}% de certeza')
uploaded_file = st.file_uploader("Envie seu desenho de onda aqui", type=["png", "jpg", "jpeg"])
st.divider()
st.markdown("""
<div class="social-links">
<a href="https://twitter.com/seu_twitter">Twitter</a>
<a href="https://facebook.com/sua_pagina_facebook">Facebook</a>
<a href="https://instagram.com/seu_instagram">Instagram</a>
</div>
""", unsafe_allow_html=True)