Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModel | |
import torch.nn as nn | |
from PIL import Image | |
import numpy as np | |
import streamlit as st | |
# Definir o dispositivo | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Carregar o modelo treinado do Hugging Face Hub | |
model = AutoModel.from_pretrained('dhhd255/EfficientNet_ParkinsonsPred') | |
# Mover o modelo para o dispositivo | |
model = model.to(device) | |
# Adicionar CSS personalizado para usar a fonte Inter, definir classes personalizadas para resultados de saída e estilos de rodapé | |
st.markdown(""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Inter&display=swap'); | |
body { | |
font-family: 'Inter', sans-serif; | |
} | |
.result { | |
font-size: 24px; | |
font-weight: bold; | |
} | |
.healthy { | |
color: #007E3F; | |
} | |
.parkinsons { | |
color: #C30000; | |
} | |
.social-links { | |
display: flex; | |
text-decoration:none; | |
justify-content: center; | |
} | |
.social-links a { | |
text-decoration:none; | |
padding: 0 10px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title("Predição de Doença de Parkinson") | |
uploaded_file = st.file_uploader("Envie seu desenho em espiral aqui", type=["png", "jpg", "jpeg"]) | |
st.empty() | |
if uploaded_file is not None: | |
col1, col2 = st.columns(2) | |
# Carregar e redimensionar a imagem | |
image_size = (224, 224) | |
new_image = Image.open(uploaded_file).convert('RGB').resize(image_size) | |
col1.image(new_image, width=255) | |
new_image = np.array(new_image) | |
new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0) | |
# Mover os dados para o dispositivo | |
new_image = new_image.to(device) | |
# Fazer previsões usando o modelo treinado | |
with torch.no_grad(): | |
predictions = model(new_image) | |
logits = predictions.last_hidden_state | |
logits = logits.view(logits.shape[0], -1) | |
num_classes = 2 | |
feature_reducer = nn.Linear(logits.shape[1], num_classes) | |
logits = logits.to(device) | |
feature_reducer = feature_reducer.to(device) | |
logits = feature_reducer(logits) | |
predicted_class = torch.argmax(logits, dim=1).item() | |
confidence = torch.softmax(logits, dim=1)[0][predicted_class].item() | |
# Traduzir as mensagens de saída | |
if predicted_class == 0: | |
col2.markdown('<span class="result parkinsons">Classe prevista: Parkinson</span>', unsafe_allow_html=True) | |
col2.caption(f'{confidence*100:.0f}% de certeza') | |
else: | |
col2.markdown('<span class="result healthy">Classe prevista: Saudável</span>', unsafe_allow_html=True) | |
col2.caption(f'{confidence*100:.0f}% de certeza') | |
uploaded_file = st.file_uploader("Envie seu desenho de onda aqui", type=["png", "jpg", "jpeg"]) | |
st.divider() | |
st.markdown(""" | |
<div class="social-links"> | |
<a href="https://twitter.com/seu_twitter">Twitter</a> | |
<a href="https://facebook.com/sua_pagina_facebook">Facebook</a> | |
<a href="https://instagram.com/seu_instagram">Instagram</a> | |
</div> | |
""", unsafe_allow_html=True) | |