garbage / train.py
nastasev's picture
Upload 6 files
98b671b verified
raw
history blame
3.29 kB
import torch
import math
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from transformers import ViTFeatureExtractor, ViTForImageClassification
import kagglehub
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
gpu_available = torch.cuda.is_available()
print("GPU available:", gpu_available)
if gpu_available:
print("GPU:", torch.cuda.get_device_name(0))
print("GPU count:", torch.cuda.device_count())
print("#memory avail:", torch.cuda.get_device_properties(0).total_memory / (1024 ** 3), "GB")
else:
print("No GPU available.")
kaggle_path = kagglehub.dataset_download("mostafaabla/garbage-classification")
folder_root_path = kaggle_path+'/garbage_classification'
print("Path to dataset files:", kaggle_path)
ds = ImageFolder(folder_root_path)
indices = torch.randperm(len(ds)).tolist()
n_val = math.floor(len(indices) * .20)
train_ds = torch.utils.data.Subset(ds, indices[:-n_val])
val_ds = torch.utils.data.Subset(ds, indices[-n_val:])
print(ds.classes)
label2id = {}
id2label = {}
for i, class_name in enumerate(ds.classes):
label2id[class_name] = str(i)
id2label[str(i)] = class_name
class ImageClassificationCollator:
def __init__(self, feature_extractor):
self.feature_extractor = feature_extractor
def __call__(self, batch):
encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
return encodings
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=len(label2id),
label2id=label2id,
id2label=id2label
)
collator = ImageClassificationCollator(feature_extractor)
train_loader = DataLoader(train_ds, batch_size=16, collate_fn=collator, num_workers=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=16, collate_fn=collator, num_workers=2)
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 10
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
progress_bar = tqdm(range(num_training_steps))
model.train()
for epoch in range(num_epochs):
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
model.eval()
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
import os
save_directory = "./saved_model"
if not os.path.exists(save_directory):
os.makedirs(save_directory)
model.save_pretrained(save_directory)
feature_extractor.save_pretrained(save_directory)
print(f"Model saved: {save_directory}")