garbage / train.py
nastasev's picture
Upload 6 files
98b671b verified
import torch
import math
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from transformers import ViTFeatureExtractor, ViTForImageClassification
import kagglehub
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
gpu_available = torch.cuda.is_available()
print("GPU available:", gpu_available)
if gpu_available:
print("GPU:", torch.cuda.get_device_name(0))
print("GPU count:", torch.cuda.device_count())
print("#memory avail:", torch.cuda.get_device_properties(0).total_memory / (1024 ** 3), "GB")
else:
print("No GPU available.")
kaggle_path = kagglehub.dataset_download("mostafaabla/garbage-classification")
folder_root_path = kaggle_path+'/garbage_classification'
print("Path to dataset files:", kaggle_path)
ds = ImageFolder(folder_root_path)
indices = torch.randperm(len(ds)).tolist()
n_val = math.floor(len(indices) * .20)
train_ds = torch.utils.data.Subset(ds, indices[:-n_val])
val_ds = torch.utils.data.Subset(ds, indices[-n_val:])
print(ds.classes)
label2id = {}
id2label = {}
for i, class_name in enumerate(ds.classes):
label2id[class_name] = str(i)
id2label[str(i)] = class_name
class ImageClassificationCollator:
def __init__(self, feature_extractor):
self.feature_extractor = feature_extractor
def __call__(self, batch):
encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
return encodings
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=len(label2id),
label2id=label2id,
id2label=id2label
)
collator = ImageClassificationCollator(feature_extractor)
train_loader = DataLoader(train_ds, batch_size=16, collate_fn=collator, num_workers=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=16, collate_fn=collator, num_workers=2)
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 10
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
progress_bar = tqdm(range(num_training_steps))
model.train()
for epoch in range(num_epochs):
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
model.eval()
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
import os
save_directory = "./saved_model"
if not os.path.exists(save_directory):
os.makedirs(save_directory)
model.save_pretrained(save_directory)
feature_extractor.save_pretrained(save_directory)
print(f"Model saved: {save_directory}")