Spaces:
Sleeping
Sleeping
import os | |
import zipfile | |
import numpy as np | |
import torch | |
from transformers import ViTForImageClassification, AdamW | |
import nibabel as nib | |
from PIL import Image | |
from torch.utils.data import Dataset, DataLoader | |
import streamlit as st | |
# 1. Function to extract zip files | |
def extract_zip(zip_file, extract_to): | |
with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
zip_ref.extractall(extract_to) | |
# 2. Preprocess images | |
def preprocess_image(image_path): | |
ext = os.path.splitext(image_path)[-1].lower() | |
if ext == '.nii' or ext == '.nii.gz': | |
nii_image = nib.load(image_path) | |
image_data = nii_image.get_fdata() | |
image_tensor = torch.tensor(image_data).float() | |
if len(image_tensor.shape) == 3: | |
image_tensor = image_tensor.unsqueeze(0) | |
elif ext in ['.jpg', '.jpeg']: | |
img = Image.open(image_path).convert('RGB').resize((224, 224)) | |
img_np = np.array(img) | |
image_tensor = torch.tensor(img_np).permute(2, 0, 1).float() | |
else: | |
raise ValueError(f"Unsupported format: {ext}") | |
image_tensor /= 255.0 # Normalize to [0, 1] | |
return image_tensor | |
# 3. Label images | |
def prepare_dataset(extracted_folder): | |
image_paths = [] | |
labels = [] | |
for disease_folder in ['alzheimers', 'parkinsons', 'ms']: | |
folder_path = os.path.join(extracted_folder, disease_folder) | |
label = {'alzheimers': 0, 'parkinsons': 1, 'ms': 2}[disease_folder] | |
for img_file in os.listdir(folder_path): | |
if img_file.endswith(('.nii', '.jpg', '.jpeg')): | |
image_paths.append(os.path.join(folder_path, img_file)) | |
labels.append(label) | |
return image_paths, labels | |
# 4. Custom Dataset | |
class CustomImageDataset(Dataset): | |
def __init__(self, image_paths, labels): | |
self.image_paths = image_paths | |
self.labels = labels | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
image = preprocess_image(self.image_paths[idx]) | |
label = self.labels[idx] | |
return image, label | |
# 5. Training function | |
def fine_tune_model(train_loader): | |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3) | |
model.train() | |
optimizer = AdamW(model.parameters(), lr=1e-4) | |
criterion = torch.nn.CrossEntropyLoss() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
for epoch in range(10): | |
running_loss = 0.0 | |
for images, labels in train_loader: | |
images, labels = images.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(pixel_values=images).logits | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
return running_loss / len(train_loader) | |
# Streamlit UI | |
st.title("Fine-tune ViT on MRI Scans") | |
# Input for zip file paths | |
zip_file_1 = st.text_input("https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/archive%20(5).zip") | |
zip_file_2 = st.text_input("https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/MS.zip") | |
if st.button("Start Training"): | |
# Define an extraction directory | |
extraction_dir = 'https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/extracttedfiles' | |
os.makedirs(extraction_dir, exist_ok=True) | |
# Extract both zip files | |
extract_zip(zip_file_1, extraction_dir) | |
extract_zip(zip_file_2, extraction_dir) | |
# Prepare dataset | |
image_paths, labels = prepare_dataset(extraction_dir) | |
dataset = CustomImageDataset(image_paths, labels) | |
train_loader = DataLoader(dataset, batch_size=32, shuffle=True) | |
# Fine-tune the model | |
final_loss = fine_tune_model(train_loader) | |
st.write(f"Training Complete with Final Loss: {final_loss}") | |