notRaphael commited on
Commit
17bc682
·
verified ·
1 Parent(s): b115203

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +1185 -0
train.py ADDED
@@ -0,0 +1,1185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CSIRO Image2Biomass Prediction - Training Pipeline
4
+ ====================================================
5
+ Multi-output regression: predicting 5 biomass targets from pasture images.
6
+
7
+ Targets: Dry_Green_g, Dry_Dead_g, Dry_Clover_g, GDM_g, Dry_Total_g
8
+ Metric: Weighted R² (weights: 0.1, 0.1, 0.1, 0.2, 0.5)
9
+
10
+ Architecture:
11
+ - Backbone: DINOv2 / ConvNeXt / EfficientNet (via timm)
12
+ - Head: MLP with LayerNorm, GELU, Dropout
13
+ - Loss: SmoothL1 + optional weighted R² + consistency regularizer
14
+ - Training: Mixed precision, gradient checkpointing, differential LR
15
+
16
+ Usage:
17
+ python train.py --data_dir /path/to/competition/data --backbone dinov2_base --epochs 50
18
+ """
19
+
20
+ import os
21
+ import sys
22
+ import json
23
+ import time
24
+ import random
25
+ import argparse
26
+ import logging
27
+ from pathlib import Path
28
+ from typing import Dict, List, Optional, Tuple
29
+
30
+ import numpy as np
31
+ import pandas as pd
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
36
+ from torch.cuda.amp import GradScaler, autocast
37
+ import timm
38
+
39
+ try:
40
+ import albumentations as A
41
+ from albumentations.pytorch import ToTensorV2
42
+ HAS_ALBUMENTATIONS = True
43
+ except ImportError:
44
+ HAS_ALBUMENTATIONS = False
45
+ from torchvision import transforms
46
+
47
+ from PIL import Image
48
+ from sklearn.model_selection import KFold, StratifiedKFold
49
+
50
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
51
+ logger = logging.getLogger(__name__)
52
+
53
+ # ============================================================
54
+ # Constants
55
+ # ============================================================
56
+ TARGET_COLS = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
57
+ TARGET_WEIGHTS = [0.1, 0.1, 0.1, 0.2, 0.5]
58
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
59
+ IMAGENET_STD = (0.229, 0.224, 0.225)
60
+
61
+ # Backbone configurations
62
+ BACKBONE_CONFIGS = {
63
+ 'dinov2_small': {
64
+ 'name': 'vit_small_patch14_dinov2.lvd142m',
65
+ 'feat_dim': 384, 'native_size': 518, 'default_size': 224,
66
+ },
67
+ 'dinov2_base': {
68
+ 'name': 'vit_base_patch14_dinov2.lvd142m',
69
+ 'feat_dim': 768, 'native_size': 518, 'default_size': 224,
70
+ },
71
+ 'dinov2_large': {
72
+ 'name': 'vit_large_patch14_dinov2.lvd142m',
73
+ 'feat_dim': 1024, 'native_size': 518, 'default_size': 224,
74
+ },
75
+ 'dinov2_base_reg': {
76
+ 'name': 'vit_base_patch14_reg4_dinov2.lvd142m',
77
+ 'feat_dim': 768, 'native_size': 518, 'default_size': 224,
78
+ },
79
+ 'convnext_large': {
80
+ 'name': 'convnext_large.fb_in22k_ft_in1k',
81
+ 'feat_dim': 1536, 'native_size': 224, 'default_size': 224,
82
+ },
83
+ 'convnextv2_large': {
84
+ 'name': 'convnextv2_large.fcmae_ft_in22k_in1k',
85
+ 'feat_dim': 1536, 'native_size': 224, 'default_size': 224,
86
+ },
87
+ 'efficientnet_b4': {
88
+ 'name': 'efficientnet_b4.ra2_in1k',
89
+ 'feat_dim': 1792, 'native_size': 380, 'default_size': 320,
90
+ },
91
+ 'swin_large': {
92
+ 'name': 'swin_large_patch4_window7_224.ms_in22k_ft_in1k',
93
+ 'feat_dim': 1536, 'native_size': 224, 'default_size': 224,
94
+ },
95
+ 'eva02_large': {
96
+ 'name': 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k',
97
+ 'feat_dim': 1024, 'native_size': 448, 'default_size': 448,
98
+ },
99
+ }
100
+
101
+
102
+ # ============================================================
103
+ # Dataset
104
+ # ============================================================
105
+ class BiomassDataset(Dataset):
106
+ """Dataset for pasture biomass regression from images."""
107
+
108
+ def __init__(
109
+ self,
110
+ image_dir: str,
111
+ df: pd.DataFrame,
112
+ targets: Optional[pd.DataFrame] = None,
113
+ transform=None,
114
+ img_size: int = 224,
115
+ use_ndvi: bool = False,
116
+ log_transform: bool = True,
117
+ is_test: bool = False,
118
+ ):
119
+ self.image_dir = Path(image_dir)
120
+ self.df = df.reset_index(drop=True)
121
+ self.targets = targets
122
+ self.transform = transform
123
+ self.img_size = img_size
124
+ self.use_ndvi = use_ndvi
125
+ self.log_transform = log_transform
126
+ self.is_test = is_test
127
+
128
+ # Pre-compute image paths
129
+ self.image_ids = self.df['image_id'].values if 'image_id' in self.df.columns else self.df.index.values
130
+
131
+ def __len__(self):
132
+ return len(self.df)
133
+
134
+ def __getitem__(self, idx):
135
+ row = self.df.iloc[idx]
136
+ img_id = row['image_id'] if 'image_id' in row.index else row.name
137
+
138
+ # Load image
139
+ img_path = self.image_dir / f"{img_id}.jpg"
140
+ if not img_path.exists():
141
+ img_path = self.image_dir / f"{img_id}.png"
142
+ if not img_path.exists():
143
+ # Try without extension - search
144
+ candidates = list(self.image_dir.glob(f"{img_id}.*"))
145
+ if candidates:
146
+ img_path = candidates[0]
147
+ else:
148
+ raise FileNotFoundError(f"Image not found: {img_id}")
149
+
150
+ img = Image.open(img_path).convert('RGB')
151
+ img = np.array(img)
152
+
153
+ # Apply transforms
154
+ if self.transform is not None:
155
+ if HAS_ALBUMENTATIONS:
156
+ augmented = self.transform(image=img)
157
+ img_tensor = augmented['image']
158
+ else:
159
+ img = Image.fromarray(img)
160
+ img_tensor = self.transform(img)
161
+ else:
162
+ img = Image.fromarray(img)
163
+ img_tensor = transforms.ToTensor()(img)
164
+
165
+ result = {'image': img_tensor, 'image_id': str(img_id)}
166
+
167
+ # Add NDVI if available
168
+ if self.use_ndvi and 'NDVI' in self.df.columns:
169
+ ndvi = torch.tensor(row['NDVI'], dtype=torch.float32)
170
+ result['ndvi'] = ndvi
171
+
172
+ # Add targets if training
173
+ if self.targets is not None:
174
+ target_values = self.targets.iloc[idx][TARGET_COLS].values.astype(np.float32)
175
+ if self.log_transform:
176
+ target_values = np.log1p(target_values)
177
+ result['targets'] = torch.tensor(target_values, dtype=torch.float32)
178
+
179
+ return result
180
+
181
+
182
+ # ============================================================
183
+ # Augmentations
184
+ # ============================================================
185
+ def get_train_transforms(img_size: int = 224, aug_strength: str = 'medium'):
186
+ """Get training augmentations."""
187
+ if HAS_ALBUMENTATIONS:
188
+ if aug_strength == 'light':
189
+ return A.Compose([
190
+ A.RandomResizedCrop(size=(img_size, img_size), scale=(0.7, 1.0)),
191
+ A.HorizontalFlip(p=0.5),
192
+ A.VerticalFlip(p=0.5),
193
+ A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
194
+ ToTensorV2(),
195
+ ])
196
+ elif aug_strength == 'medium':
197
+ return A.Compose([
198
+ A.RandomResizedCrop(size=(img_size, img_size), scale=(0.5, 1.0)),
199
+ A.HorizontalFlip(p=0.5),
200
+ A.VerticalFlip(p=0.5),
201
+ A.RandomRotate90(p=0.5),
202
+ A.Transpose(p=0.5),
203
+ A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
204
+ A.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=25, val_shift_limit=15, p=0.4),
205
+ A.OneOf([
206
+ A.GaussianBlur(blur_limit=(3, 5)),
207
+ A.MotionBlur(blur_limit=5),
208
+ ], p=0.15),
209
+ A.CoarseDropout(
210
+ num_holes_range=(1, 4),
211
+ hole_height_range=(int(img_size*0.05), int(img_size*0.15)),
212
+ hole_width_range=(int(img_size*0.05), int(img_size*0.15)),
213
+ p=0.2,
214
+ ),
215
+ A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
216
+ ToTensorV2(),
217
+ ])
218
+ else: # heavy
219
+ return A.Compose([
220
+ A.RandomResizedCrop(size=(img_size, img_size), scale=(0.4, 1.0)),
221
+ A.HorizontalFlip(p=0.5),
222
+ A.VerticalFlip(p=0.5),
223
+ A.RandomRotate90(p=0.5),
224
+ A.Transpose(p=0.5),
225
+ A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
226
+ A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
227
+ A.RandomGamma(gamma_limit=(80, 120), p=0.3),
228
+ A.OneOf([
229
+ A.GaussianBlur(blur_limit=(3, 7)),
230
+ A.MotionBlur(blur_limit=7),
231
+ ], p=0.2),
232
+ A.OneOf([
233
+ A.GaussNoise(p=1.0),
234
+ A.ISONoise(p=1.0),
235
+ ], p=0.2),
236
+ A.CoarseDropout(
237
+ num_holes_range=(1, 8),
238
+ hole_height_range=(int(img_size*0.05), int(img_size*0.2)),
239
+ hole_width_range=(int(img_size*0.05), int(img_size*0.2)),
240
+ p=0.3,
241
+ ),
242
+ A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
243
+ ToTensorV2(),
244
+ ])
245
+ else:
246
+ return transforms.Compose([
247
+ transforms.RandomResizedCrop(img_size, scale=(0.5, 1.0)),
248
+ transforms.RandomHorizontalFlip(0.5),
249
+ transforms.RandomVerticalFlip(0.5),
250
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
251
+ transforms.ToTensor(),
252
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
253
+ ])
254
+
255
+
256
+ def get_val_transforms(img_size: int = 224):
257
+ """Get validation transforms."""
258
+ if HAS_ALBUMENTATIONS:
259
+ return A.Compose([
260
+ A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
261
+ A.CenterCrop(height=img_size, width=img_size),
262
+ A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
263
+ ToTensorV2(),
264
+ ])
265
+ else:
266
+ return transforms.Compose([
267
+ transforms.Resize(int(img_size * 1.14)),
268
+ transforms.CenterCrop(img_size),
269
+ transforms.ToTensor(),
270
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
271
+ ])
272
+
273
+
274
+ def get_tta_transforms(img_size: int = 224, n_augments: int = 8):
275
+ """Get TTA (test-time augmentation) transforms. Returns list of transforms."""
276
+ tta_list = [get_val_transforms(img_size)] # Original
277
+ if HAS_ALBUMENTATIONS:
278
+ # Add flipped/rotated versions
279
+ tta_list.append(A.Compose([
280
+ A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
281
+ A.CenterCrop(height=img_size, width=img_size),
282
+ A.HorizontalFlip(p=1.0),
283
+ A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
284
+ ToTensorV2(),
285
+ ]))
286
+ tta_list.append(A.Compose([
287
+ A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
288
+ A.CenterCrop(height=img_size, width=img_size),
289
+ A.VerticalFlip(p=1.0),
290
+ A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
291
+ ToTensorV2(),
292
+ ]))
293
+ tta_list.append(A.Compose([
294
+ A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
295
+ A.CenterCrop(height=img_size, width=img_size),
296
+ A.HorizontalFlip(p=1.0),
297
+ A.VerticalFlip(p=1.0),
298
+ A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
299
+ ToTensorV2(),
300
+ ]))
301
+ # Slightly different crops
302
+ for scale in [0.9, 1.0, 1.2]:
303
+ tta_list.append(A.Compose([
304
+ A.Resize(height=int(img_size * scale * 1.14), width=int(img_size * scale * 1.14)),
305
+ A.CenterCrop(height=img_size, width=img_size),
306
+ A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
307
+ ToTensorV2(),
308
+ ]))
309
+ return tta_list[:n_augments]
310
+
311
+
312
+ # ============================================================
313
+ # Model
314
+ # ============================================================
315
+ class BiomassModel(nn.Module):
316
+ """
317
+ Multi-output regression model for biomass prediction.
318
+
319
+ Architecture:
320
+ - timm backbone (DINOv2, ConvNeXt, etc.)
321
+ - Optional auxiliary features (NDVI)
322
+ - MLP regression head with LayerNorm + GELU + Dropout
323
+ - Optional: separate heads per target for better specialization
324
+ """
325
+
326
+ def __init__(
327
+ self,
328
+ backbone_name: str = 'vit_base_patch14_dinov2.lvd142m',
329
+ num_targets: int = 5,
330
+ hidden_dim: int = 512,
331
+ dropout: float = 0.3,
332
+ pretrained: bool = True,
333
+ img_size: int = 224,
334
+ use_ndvi: bool = False,
335
+ separate_heads: bool = False,
336
+ grad_checkpointing: bool = False,
337
+ ):
338
+ super().__init__()
339
+ self.use_ndvi = use_ndvi
340
+ self.separate_heads = separate_heads
341
+ self.num_targets = num_targets
342
+
343
+ # Create backbone
344
+ kwargs = {'pretrained': pretrained, 'num_classes': 0}
345
+ if 'vit' in backbone_name or 'dinov2' in backbone_name:
346
+ kwargs['img_size'] = img_size
347
+
348
+ self.backbone = timm.create_model(backbone_name, **kwargs)
349
+ feat_dim = self.backbone.num_features
350
+
351
+ # Enable gradient checkpointing for memory efficiency
352
+ if grad_checkpointing:
353
+ if hasattr(self.backbone, 'set_grad_checkpointing'):
354
+ self.backbone.set_grad_checkpointing(True)
355
+ logger.info("Gradient checkpointing enabled")
356
+
357
+ # NDVI embedding
358
+ if use_ndvi:
359
+ self.ndvi_embed = nn.Sequential(
360
+ nn.Linear(1, 32),
361
+ nn.GELU(),
362
+ nn.Linear(32, 64),
363
+ )
364
+ feat_dim += 64
365
+
366
+ # Regression head(s)
367
+ if separate_heads:
368
+ # Separate MLP head per target - better specialization
369
+ self.heads = nn.ModuleList([
370
+ nn.Sequential(
371
+ nn.LayerNorm(feat_dim),
372
+ nn.Dropout(dropout),
373
+ nn.Linear(feat_dim, hidden_dim),
374
+ nn.GELU(),
375
+ nn.Dropout(dropout * 0.5),
376
+ nn.Linear(hidden_dim, 1),
377
+ )
378
+ for _ in range(num_targets)
379
+ ])
380
+ else:
381
+ # Shared head - better when data is limited
382
+ self.head = nn.Sequential(
383
+ nn.LayerNorm(feat_dim),
384
+ nn.Dropout(dropout),
385
+ nn.Linear(feat_dim, hidden_dim),
386
+ nn.GELU(),
387
+ nn.Dropout(dropout * 0.5),
388
+ nn.Linear(hidden_dim, hidden_dim // 2),
389
+ nn.GELU(),
390
+ nn.Dropout(dropout * 0.3),
391
+ nn.Linear(hidden_dim // 2, num_targets),
392
+ )
393
+
394
+ def forward(self, x, ndvi=None):
395
+ features = self.backbone(x)
396
+
397
+ if self.use_ndvi and ndvi is not None:
398
+ ndvi_feats = self.ndvi_embed(ndvi.unsqueeze(-1))
399
+ features = torch.cat([features, ndvi_feats], dim=-1)
400
+
401
+ if self.separate_heads:
402
+ outputs = [head(features) for head in self.heads]
403
+ return torch.cat(outputs, dim=-1)
404
+ else:
405
+ return self.head(features)
406
+
407
+ def get_param_groups(self, backbone_lr: float = 5e-5, head_lr: float = 1e-3):
408
+ """Get parameter groups with differential learning rates."""
409
+ backbone_params = list(self.backbone.parameters())
410
+ head_params = [p for n, p in self.named_parameters() if 'backbone' not in n]
411
+
412
+ return [
413
+ {'params': backbone_params, 'lr': backbone_lr},
414
+ {'params': head_params, 'lr': head_lr},
415
+ ]
416
+
417
+
418
+ # ============================================================
419
+ # Loss Functions
420
+ # ============================================================
421
+ class WeightedSmoothL1Loss(nn.Module):
422
+ """SmoothL1 loss weighted by target importance."""
423
+
424
+ def __init__(self, target_weights=None, beta=1.0):
425
+ super().__init__()
426
+ self.beta = beta
427
+ if target_weights is None:
428
+ target_weights = TARGET_WEIGHTS
429
+ self.register_buffer('weights', torch.tensor(target_weights, dtype=torch.float32))
430
+
431
+ def forward(self, pred, target):
432
+ loss = F.smooth_l1_loss(pred, target, beta=self.beta, reduction='none') # [B, 5]
433
+ weighted = loss * self.weights.unsqueeze(0)
434
+ return weighted.mean()
435
+
436
+
437
+ class WeightedMSELoss(nn.Module):
438
+ """MSE loss weighted by target importance."""
439
+
440
+ def __init__(self, target_weights=None):
441
+ super().__init__()
442
+ if target_weights is None:
443
+ target_weights = TARGET_WEIGHTS
444
+ self.register_buffer('weights', torch.tensor(target_weights, dtype=torch.float32))
445
+
446
+ def forward(self, pred, target):
447
+ loss = (pred - target) ** 2 # [B, 5]
448
+ weighted = loss * self.weights.unsqueeze(0)
449
+ return weighted.mean()
450
+
451
+
452
+ class ConsistencyLoss(nn.Module):
453
+ """
454
+ Enforce structural constraint: Dry_Total_g ≈ Dry_Green_g + Dry_Dead_g + Dry_Clover_g
455
+ Only approximate because GDM includes all dry matter components.
456
+ """
457
+
458
+ def __init__(self, weight=0.1):
459
+ super().__init__()
460
+ self.weight = weight
461
+
462
+ def forward(self, pred):
463
+ # pred columns: [Green, Dead, Clover, GDM, Total]
464
+ component_sum = pred[:, 0] + pred[:, 1] + pred[:, 2]
465
+ total = pred[:, 4]
466
+ return self.weight * F.mse_loss(component_sum, total)
467
+
468
+
469
+ class CombinedLoss(nn.Module):
470
+ """Combined loss with SmoothL1 + consistency regularization."""
471
+
472
+ def __init__(self, smoothl1_weight=1.0, mse_weight=0.0, consistency_weight=0.1,
473
+ target_weights=None):
474
+ super().__init__()
475
+ self.smoothl1 = WeightedSmoothL1Loss(target_weights)
476
+ self.mse = WeightedMSELoss(target_weights) if mse_weight > 0 else None
477
+ self.consistency = ConsistencyLoss(consistency_weight) if consistency_weight > 0 else None
478
+ self.smoothl1_weight = smoothl1_weight
479
+ self.mse_weight = mse_weight
480
+
481
+ def forward(self, pred, target):
482
+ loss = self.smoothl1_weight * self.smoothl1(pred, target)
483
+ if self.mse is not None:
484
+ loss += self.mse_weight * self.mse(pred, target)
485
+ if self.consistency is not None:
486
+ loss += self.consistency(pred)
487
+ return loss
488
+
489
+
490
+ # ============================================================
491
+ # Label Distribution Smoothing (LDS)
492
+ # ============================================================
493
+ def get_lds_weights(labels: np.ndarray, bins: int = 100, kernel_size: int = 5, sigma: float = 2.0):
494
+ """
495
+ Compute Label Distribution Smoothing (LDS) weights.
496
+ From "Delving into Deep Imbalanced Regression" (ICML 2021).
497
+ """
498
+ from scipy.ndimage import convolve1d
499
+
500
+ # Use the most important target (Dry_Total_g) for weighting
501
+ if labels.ndim > 1:
502
+ labels = labels[:, -1] # Last column = Dry_Total_g
503
+
504
+ hist, bin_edges = np.histogram(labels, bins=bins)
505
+ kernel = np.exp(-np.linspace(-3, 3, kernel_size)**2 / (2 * sigma**2))
506
+ kernel /= kernel.sum()
507
+ smoothed = convolve1d(hist.astype(float), kernel, mode='reflect')
508
+
509
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
510
+ weights = 1.0 / (np.interp(labels, bin_centers, smoothed) + 1e-8)
511
+ weights = weights / weights.mean() # Normalize to mean=1
512
+
513
+ return weights
514
+
515
+
516
+ # ============================================================
517
+ # Metrics
518
+ # ============================================================
519
+ def compute_weighted_r2(preds: np.ndarray, targets: np.ndarray,
520
+ target_weights: List[float] = None) -> float:
521
+ """
522
+ Compute the globally weighted R² (competition metric).
523
+
524
+ Args:
525
+ preds: [N, 5] predictions
526
+ targets: [N, 5] ground truth
527
+ target_weights: per-target weights (default: competition weights)
528
+
529
+ Returns:
530
+ Weighted R² score
531
+ """
532
+ if target_weights is None:
533
+ target_weights = TARGET_WEIGHTS
534
+
535
+ n_samples = preds.shape[0]
536
+ n_targets = preds.shape[1]
537
+
538
+ # Expand to long format with per-row weights
539
+ all_preds = []
540
+ all_targets = []
541
+ all_weights = []
542
+
543
+ for j in range(n_targets):
544
+ all_preds.extend(preds[:, j])
545
+ all_targets.extend(targets[:, j])
546
+ all_weights.extend([target_weights[j]] * n_samples)
547
+
548
+ all_preds = np.array(all_preds)
549
+ all_targets = np.array(all_targets)
550
+ all_weights = np.array(all_weights)
551
+
552
+ # Weighted mean
553
+ weighted_mean = np.sum(all_weights * all_targets) / np.sum(all_weights)
554
+
555
+ # SS_res and SS_tot
556
+ ss_res = np.sum(all_weights * (all_targets - all_preds) ** 2)
557
+ ss_tot = np.sum(all_weights * (all_targets - weighted_mean) ** 2)
558
+
559
+ r2 = 1.0 - ss_res / (ss_tot + 1e-8)
560
+ return r2
561
+
562
+
563
+ def compute_per_target_r2(preds: np.ndarray, targets: np.ndarray) -> Dict[str, float]:
564
+ """Compute R² per target column."""
565
+ results = {}
566
+ for i, name in enumerate(TARGET_COLS):
567
+ ss_res = np.sum((targets[:, i] - preds[:, i]) ** 2)
568
+ ss_tot = np.sum((targets[:, i] - targets[:, i].mean()) ** 2)
569
+ r2 = 1.0 - ss_res / (ss_tot + 1e-8)
570
+ results[name] = r2
571
+ return results
572
+
573
+
574
+ # ============================================================
575
+ # Training Engine
576
+ # ============================================================
577
+ class Trainer:
578
+ """Training engine with mixed precision, gradient accumulation, and k-fold."""
579
+
580
+ def __init__(self, args):
581
+ self.args = args
582
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
583
+ self.scaler = GradScaler() if self.device.type == 'cuda' else None
584
+
585
+ logger.info(f"Device: {self.device}")
586
+ if self.device.type == 'cuda':
587
+ logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
588
+ logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
589
+
590
+ def build_model(self):
591
+ """Build model from args."""
592
+ backbone_cfg = BACKBONE_CONFIGS[self.args.backbone]
593
+ img_size = self.args.img_size or backbone_cfg['default_size']
594
+
595
+ model = BiomassModel(
596
+ backbone_name=backbone_cfg['name'],
597
+ num_targets=5,
598
+ hidden_dim=self.args.hidden_dim,
599
+ dropout=self.args.dropout,
600
+ pretrained=True,
601
+ img_size=img_size,
602
+ use_ndvi=self.args.use_ndvi,
603
+ separate_heads=self.args.separate_heads,
604
+ grad_checkpointing=self.args.grad_checkpointing,
605
+ )
606
+ return model.to(self.device)
607
+
608
+ def build_optimizer(self, model):
609
+ """Build optimizer with differential learning rates."""
610
+ param_groups = model.get_param_groups(
611
+ backbone_lr=self.args.backbone_lr,
612
+ head_lr=self.args.head_lr,
613
+ )
614
+
615
+ if self.args.optimizer == 'adamw':
616
+ optimizer = torch.optim.AdamW(param_groups, weight_decay=self.args.weight_decay)
617
+ elif self.args.optimizer == 'sgd':
618
+ optimizer = torch.optim.SGD(param_groups, momentum=0.9, weight_decay=self.args.weight_decay)
619
+ else:
620
+ raise ValueError(f"Unknown optimizer: {self.args.optimizer}")
621
+
622
+ return optimizer
623
+
624
+ def build_scheduler(self, optimizer, num_training_steps):
625
+ """Build learning rate scheduler."""
626
+ warmup_steps = int(num_training_steps * self.args.warmup_ratio)
627
+
628
+ if self.args.scheduler == 'cosine':
629
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
630
+ warmup = LinearLR(optimizer, start_factor=0.01, total_iters=warmup_steps)
631
+ cosine = CosineAnnealingLR(optimizer, T_max=num_training_steps - warmup_steps,
632
+ eta_min=self.args.min_lr)
633
+ scheduler = SequentialLR(optimizer, [warmup, cosine], milestones=[warmup_steps])
634
+ elif self.args.scheduler == 'plateau':
635
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
636
+ optimizer, mode='max', factor=0.5, patience=5, verbose=True)
637
+ else:
638
+ scheduler = None
639
+
640
+ return scheduler
641
+
642
+ def train_one_epoch(self, model, loader, optimizer, scheduler, loss_fn, epoch):
643
+ """Train for one epoch."""
644
+ model.train()
645
+ running_loss = 0.0
646
+ num_samples = 0
647
+
648
+ for batch_idx, batch in enumerate(loader):
649
+ images = batch['image'].to(self.device)
650
+ targets = batch['targets'].to(self.device)
651
+ ndvi = batch.get('ndvi', None)
652
+ if ndvi is not None:
653
+ ndvi = ndvi.to(self.device)
654
+
655
+ # Forward pass with mixed precision
656
+ if self.scaler is not None:
657
+ with autocast(dtype=torch.float16):
658
+ preds = model(images, ndvi)
659
+ loss = loss_fn(preds, targets)
660
+
661
+ # Gradient accumulation
662
+ loss = loss / self.args.grad_accum_steps
663
+ self.scaler.scale(loss).backward()
664
+
665
+ if (batch_idx + 1) % self.args.grad_accum_steps == 0:
666
+ self.scaler.unscale_(optimizer)
667
+ torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
668
+ self.scaler.step(optimizer)
669
+ self.scaler.update()
670
+ optimizer.zero_grad()
671
+
672
+ if scheduler is not None and not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
673
+ scheduler.step()
674
+ else:
675
+ preds = model(images, ndvi)
676
+ loss = loss_fn(preds, targets)
677
+ loss = loss / self.args.grad_accum_steps
678
+ loss.backward()
679
+
680
+ if (batch_idx + 1) % self.args.grad_accum_steps == 0:
681
+ torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
682
+ optimizer.step()
683
+ optimizer.zero_grad()
684
+
685
+ if scheduler is not None and not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
686
+ scheduler.step()
687
+
688
+ running_loss += loss.item() * self.args.grad_accum_steps * images.size(0)
689
+ num_samples += images.size(0)
690
+
691
+ if (batch_idx + 1) % self.args.log_interval == 0:
692
+ avg_loss = running_loss / num_samples
693
+ lr = optimizer.param_groups[0]['lr']
694
+ logger.info(f"Epoch {epoch} [{batch_idx+1}/{len(loader)}] loss={avg_loss:.4f} lr={lr:.2e}")
695
+
696
+ return running_loss / num_samples
697
+
698
+ @torch.no_grad()
699
+ def validate(self, model, loader, loss_fn, log_transform=True):
700
+ """Validate and compute metrics."""
701
+ model.eval()
702
+ all_preds = []
703
+ all_targets = []
704
+ running_loss = 0.0
705
+ num_samples = 0
706
+
707
+ for batch in loader:
708
+ images = batch['image'].to(self.device)
709
+ targets = batch['targets'].to(self.device)
710
+ ndvi = batch.get('ndvi', None)
711
+ if ndvi is not None:
712
+ ndvi = ndvi.to(self.device)
713
+
714
+ if self.scaler is not None:
715
+ with autocast(dtype=torch.float16):
716
+ preds = model(images, ndvi)
717
+ loss = loss_fn(preds, targets)
718
+ else:
719
+ preds = model(images, ndvi)
720
+ loss = loss_fn(preds, targets)
721
+
722
+ running_loss += loss.item() * images.size(0)
723
+ num_samples += images.size(0)
724
+
725
+ all_preds.append(preds.cpu().numpy())
726
+ all_targets.append(targets.cpu().numpy())
727
+
728
+ all_preds = np.concatenate(all_preds, axis=0)
729
+ all_targets = np.concatenate(all_targets, axis=0)
730
+
731
+ # Inverse log transform for metric computation
732
+ if log_transform:
733
+ all_preds_orig = np.expm1(all_preds)
734
+ all_targets_orig = np.expm1(all_targets)
735
+ else:
736
+ all_preds_orig = all_preds
737
+ all_targets_orig = all_targets
738
+
739
+ # Clip negative predictions
740
+ all_preds_orig = np.clip(all_preds_orig, 0, None)
741
+
742
+ # Compute metrics
743
+ weighted_r2 = compute_weighted_r2(all_preds_orig, all_targets_orig)
744
+ per_target_r2 = compute_per_target_r2(all_preds_orig, all_targets_orig)
745
+
746
+ avg_loss = running_loss / num_samples
747
+
748
+ return {
749
+ 'loss': avg_loss,
750
+ 'weighted_r2': weighted_r2,
751
+ 'per_target_r2': per_target_r2,
752
+ 'preds': all_preds_orig,
753
+ 'targets': all_targets_orig,
754
+ }
755
+
756
+ @torch.no_grad()
757
+ def predict(self, model, loader, log_transform=True, tta_transforms=None):
758
+ """Generate predictions (inference)."""
759
+ model.eval()
760
+ all_preds = []
761
+ all_ids = []
762
+
763
+ for batch in loader:
764
+ images = batch['image'].to(self.device)
765
+ ndvi = batch.get('ndvi', None)
766
+ if ndvi is not None:
767
+ ndvi = ndvi.to(self.device)
768
+
769
+ if self.scaler is not None:
770
+ with autocast(dtype=torch.float16):
771
+ preds = model(images, ndvi)
772
+ else:
773
+ preds = model(images, ndvi)
774
+
775
+ all_preds.append(preds.cpu().numpy())
776
+ all_ids.extend(batch['image_id'])
777
+
778
+ all_preds = np.concatenate(all_preds, axis=0)
779
+
780
+ if log_transform:
781
+ all_preds = np.expm1(all_preds)
782
+
783
+ all_preds = np.clip(all_preds, 0, None)
784
+
785
+ return all_preds, all_ids
786
+
787
+ def train_fold(self, fold: int, train_df: pd.DataFrame, val_df: pd.DataFrame,
788
+ train_targets: pd.DataFrame, val_targets: pd.DataFrame,
789
+ image_dir: str):
790
+ """Train a single fold."""
791
+ backbone_cfg = BACKBONE_CONFIGS[self.args.backbone]
792
+ img_size = self.args.img_size or backbone_cfg['default_size']
793
+
794
+ # Datasets
795
+ train_dataset = BiomassDataset(
796
+ image_dir=image_dir,
797
+ df=train_df,
798
+ targets=train_targets,
799
+ transform=get_train_transforms(img_size, self.args.aug_strength),
800
+ img_size=img_size,
801
+ use_ndvi=self.args.use_ndvi,
802
+ log_transform=self.args.log_transform,
803
+ )
804
+ val_dataset = BiomassDataset(
805
+ image_dir=image_dir,
806
+ df=val_df,
807
+ targets=val_targets,
808
+ transform=get_val_transforms(img_size),
809
+ img_size=img_size,
810
+ use_ndvi=self.args.use_ndvi,
811
+ log_transform=self.args.log_transform,
812
+ )
813
+
814
+ # Optional: LDS sample weights
815
+ if self.args.use_lds:
816
+ sample_weights = get_lds_weights(
817
+ train_targets[TARGET_COLS].values,
818
+ bins=self.args.lds_bins,
819
+ kernel_size=self.args.lds_kernel_size,
820
+ sigma=self.args.lds_sigma,
821
+ )
822
+ sampler = WeightedRandomSampler(
823
+ weights=sample_weights,
824
+ num_samples=len(train_dataset),
825
+ replacement=True,
826
+ )
827
+ train_loader = DataLoader(
828
+ train_dataset, batch_size=self.args.batch_size,
829
+ sampler=sampler, num_workers=self.args.num_workers,
830
+ pin_memory=True, drop_last=True,
831
+ )
832
+ else:
833
+ train_loader = DataLoader(
834
+ train_dataset, batch_size=self.args.batch_size,
835
+ shuffle=True, num_workers=self.args.num_workers,
836
+ pin_memory=True, drop_last=True,
837
+ )
838
+
839
+ val_loader = DataLoader(
840
+ val_dataset, batch_size=self.args.batch_size * 2,
841
+ shuffle=False, num_workers=self.args.num_workers,
842
+ pin_memory=True,
843
+ )
844
+
845
+ # Model, optimizer, scheduler
846
+ model = self.build_model()
847
+ optimizer = self.build_optimizer(model)
848
+
849
+ num_training_steps = len(train_loader) * self.args.epochs // self.args.grad_accum_steps
850
+ scheduler = self.build_scheduler(optimizer, num_training_steps)
851
+
852
+ # Loss
853
+ loss_fn = CombinedLoss(
854
+ smoothl1_weight=1.0,
855
+ mse_weight=self.args.mse_weight,
856
+ consistency_weight=self.args.consistency_weight,
857
+ target_weights=TARGET_WEIGHTS,
858
+ ).to(self.device)
859
+
860
+ # Training loop
861
+ best_r2 = -float('inf')
862
+ best_epoch = 0
863
+ patience_counter = 0
864
+ save_dir = Path(self.args.output_dir) / f"fold_{fold}"
865
+ save_dir.mkdir(parents=True, exist_ok=True)
866
+
867
+ logger.info(f"\n{'='*60}")
868
+ logger.info(f"FOLD {fold}")
869
+ logger.info(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")
870
+ logger.info(f"Backbone: {backbone_cfg['name']}, img_size: {img_size}")
871
+ logger.info(f"{'='*60}")
872
+
873
+ for epoch in range(1, self.args.epochs + 1):
874
+ t0 = time.time()
875
+
876
+ # Train
877
+ train_loss = self.train_one_epoch(model, train_loader, optimizer, scheduler, loss_fn, epoch)
878
+
879
+ # Validate
880
+ val_metrics = self.validate(model, val_loader, loss_fn, self.args.log_transform)
881
+
882
+ # LR scheduler step (for ReduceLROnPlateau)
883
+ if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
884
+ scheduler.step(val_metrics['weighted_r2'])
885
+
886
+ elapsed = time.time() - t0
887
+
888
+ # Logging
889
+ logger.info(
890
+ f"Epoch {epoch}/{self.args.epochs} | "
891
+ f"train_loss={train_loss:.4f} | "
892
+ f"val_loss={val_metrics['loss']:.4f} | "
893
+ f"val_R²={val_metrics['weighted_r2']:.4f} | "
894
+ f"time={elapsed:.1f}s"
895
+ )
896
+ for name, r2 in val_metrics['per_target_r2'].items():
897
+ logger.info(f" {name}: R²={r2:.4f}")
898
+
899
+ # Save best model
900
+ if val_metrics['weighted_r2'] > best_r2:
901
+ best_r2 = val_metrics['weighted_r2']
902
+ best_epoch = epoch
903
+ patience_counter = 0
904
+
905
+ torch.save({
906
+ 'epoch': epoch,
907
+ 'model_state_dict': model.state_dict(),
908
+ 'optimizer_state_dict': optimizer.state_dict(),
909
+ 'weighted_r2': best_r2,
910
+ 'per_target_r2': val_metrics['per_target_r2'],
911
+ 'args': vars(self.args),
912
+ }, save_dir / 'best_model.pth')
913
+
914
+ logger.info(f" *** New best R²={best_r2:.4f} (epoch {epoch}) ***")
915
+ else:
916
+ patience_counter += 1
917
+
918
+ # Early stopping
919
+ if patience_counter >= self.args.patience:
920
+ logger.info(f"Early stopping at epoch {epoch}. Best R²={best_r2:.4f} (epoch {best_epoch})")
921
+ break
922
+
923
+ # Load best model for final predictions
924
+ checkpoint = torch.load(save_dir / 'best_model.pth', map_location=self.device, weights_only=False)
925
+ model.load_state_dict(checkpoint['model_state_dict'])
926
+
927
+ # OOF predictions
928
+ val_metrics = self.validate(model, val_loader, loss_fn, self.args.log_transform)
929
+
930
+ logger.info(f"\nFold {fold} Final: R²={val_metrics['weighted_r2']:.4f}")
931
+
932
+ return model, val_metrics
933
+
934
+ def train_kfold(self, df: pd.DataFrame, targets: pd.DataFrame, image_dir: str):
935
+ """Train with K-Fold cross-validation."""
936
+ n_folds = self.args.n_folds
937
+
938
+ # Stratified bins based on Dry_Total_g
939
+ bins = pd.qcut(targets['Dry_Total_g'], q=min(10, n_folds), labels=False, duplicates='drop')
940
+
941
+ kf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=self.args.seed)
942
+
943
+ oof_preds = np.zeros((len(df), 5))
944
+ fold_scores = []
945
+
946
+ for fold, (train_idx, val_idx) in enumerate(kf.split(df, bins)):
947
+ train_df = df.iloc[train_idx]
948
+ val_df = df.iloc[val_idx]
949
+ train_targets = targets.iloc[train_idx]
950
+ val_targets = targets.iloc[val_idx]
951
+
952
+ model, val_metrics = self.train_fold(
953
+ fold, train_df, val_df, train_targets, val_targets, image_dir
954
+ )
955
+
956
+ oof_preds[val_idx] = val_metrics['preds']
957
+ fold_scores.append(val_metrics['weighted_r2'])
958
+
959
+ logger.info(f"Fold {fold} R²: {val_metrics['weighted_r2']:.4f}")
960
+
961
+ # Overall OOF score
962
+ targets_arr = targets[TARGET_COLS].values
963
+ overall_r2 = compute_weighted_r2(oof_preds, targets_arr)
964
+
965
+ logger.info(f"\n{'='*60}")
966
+ logger.info(f"Overall OOF R²: {overall_r2:.4f}")
967
+ logger.info(f"Per-fold R²: {[f'{s:.4f}' for s in fold_scores]}")
968
+ logger.info(f"Mean fold R²: {np.mean(fold_scores):.4f} ± {np.std(fold_scores):.4f}")
969
+ logger.info(f"{'='*60}")
970
+
971
+ # Save OOF predictions
972
+ oof_df = df[['image_id']].copy()
973
+ for i, col in enumerate(TARGET_COLS):
974
+ oof_df[col] = oof_preds[:, i]
975
+ oof_df.to_csv(Path(self.args.output_dir) / 'oof_predictions.csv', index=False)
976
+
977
+ return overall_r2, fold_scores
978
+
979
+
980
+ # ============================================================
981
+ # Data Loading Utilities
982
+ # ============================================================
983
+ def load_competition_data(data_dir: str):
984
+ """
985
+ Load competition data. Expected structure:
986
+ data_dir/
987
+ train.csv
988
+ test.csv
989
+ train_images/
990
+ test_images/
991
+ sample_submission.csv
992
+ """
993
+ data_dir = Path(data_dir)
994
+
995
+ # Load CSVs
996
+ train_df = pd.read_csv(data_dir / 'train.csv')
997
+ test_df = pd.read_csv(data_dir / 'test.csv')
998
+
999
+ if (data_dir / 'sample_submission.csv').exists():
1000
+ sample_sub = pd.read_csv(data_dir / 'sample_submission.csv')
1001
+ else:
1002
+ sample_sub = None
1003
+
1004
+ # Determine image directories
1005
+ train_img_dir = data_dir / 'train_images'
1006
+ test_img_dir = data_dir / 'test_images'
1007
+
1008
+ if not train_img_dir.exists():
1009
+ train_img_dir = data_dir / 'train'
1010
+ if not test_img_dir.exists():
1011
+ test_img_dir = data_dir / 'test'
1012
+
1013
+ logger.info(f"Train samples: {len(train_df)}")
1014
+ logger.info(f"Test samples: {len(test_df)}")
1015
+ logger.info(f"Train columns: {list(train_df.columns)}")
1016
+ logger.info(f"Test columns: {list(test_df.columns)}")
1017
+
1018
+ # Check for target columns
1019
+ available_targets = [c for c in TARGET_COLS if c in train_df.columns]
1020
+ logger.info(f"Available targets: {available_targets}")
1021
+
1022
+ # Print target statistics
1023
+ if available_targets:
1024
+ logger.info("\nTarget statistics:")
1025
+ for col in available_targets:
1026
+ logger.info(f" {col}: mean={train_df[col].mean():.2f}, "
1027
+ f"median={train_df[col].median():.2f}, "
1028
+ f"std={train_df[col].std():.2f}, "
1029
+ f"min={train_df[col].min():.2f}, "
1030
+ f"max={train_df[col].max():.2f}")
1031
+
1032
+ return train_df, test_df, sample_sub, str(train_img_dir), str(test_img_dir)
1033
+
1034
+
1035
+ def create_submission(preds: np.ndarray, image_ids: List[str], output_path: str):
1036
+ """
1037
+ Create submission CSV in required format.
1038
+
1039
+ Args:
1040
+ preds: [N, 5] predictions
1041
+ image_ids: list of image IDs
1042
+ output_path: path to save CSV
1043
+ """
1044
+ rows = []
1045
+ for i, img_id in enumerate(image_ids):
1046
+ for j, target_name in enumerate(TARGET_COLS):
1047
+ rows.append({
1048
+ 'sample_id': f"{img_id}__{target_name}",
1049
+ 'target': max(0, preds[i, j]), # Ensure non-negative
1050
+ })
1051
+
1052
+ sub_df = pd.DataFrame(rows)
1053
+ sub_df.to_csv(output_path, index=False)
1054
+ logger.info(f"Submission saved to {output_path} ({len(sub_df)} rows)")
1055
+ return sub_df
1056
+
1057
+
1058
+ # ============================================================
1059
+ # Seed and Reproducibility
1060
+ # ============================================================
1061
+ def set_seed(seed: int):
1062
+ """Set random seed for reproducibility."""
1063
+ random.seed(seed)
1064
+ np.random.seed(seed)
1065
+ torch.manual_seed(seed)
1066
+ if torch.cuda.is_available():
1067
+ torch.cuda.manual_seed_all(seed)
1068
+ torch.backends.cudnn.deterministic = True
1069
+ torch.backends.cudnn.benchmark = False
1070
+
1071
+
1072
+ # ============================================================
1073
+ # Main
1074
+ # ============================================================
1075
+ def get_args():
1076
+ parser = argparse.ArgumentParser(description='CSIRO Image2Biomass Training')
1077
+
1078
+ # Data
1079
+ parser.add_argument('--data_dir', type=str, required=True, help='Competition data directory')
1080
+ parser.add_argument('--output_dir', type=str, default='./output', help='Output directory')
1081
+
1082
+ # Model
1083
+ parser.add_argument('--backbone', type=str, default='dinov2_base',
1084
+ choices=list(BACKBONE_CONFIGS.keys()), help='Backbone architecture')
1085
+ parser.add_argument('--img_size', type=int, default=None, help='Image size (default: backbone native)')
1086
+ parser.add_argument('--hidden_dim', type=int, default=512, help='Hidden dim in MLP head')
1087
+ parser.add_argument('--dropout', type=float, default=0.3, help='Dropout rate')
1088
+ parser.add_argument('--separate_heads', action='store_true', help='Use separate heads per target')
1089
+ parser.add_argument('--grad_checkpointing', action='store_true', help='Enable gradient checkpointing')
1090
+ parser.add_argument('--use_ndvi', action='store_true', help='Use NDVI features')
1091
+
1092
+ # Training
1093
+ parser.add_argument('--epochs', type=int, default=50, help='Max epochs')
1094
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
1095
+ parser.add_argument('--backbone_lr', type=float, default=5e-5, help='Backbone learning rate')
1096
+ parser.add_argument('--head_lr', type=float, default=1e-3, help='Head learning rate')
1097
+ parser.add_argument('--min_lr', type=float, default=1e-7, help='Min learning rate')
1098
+ parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
1099
+ parser.add_argument('--optimizer', type=str, default='adamw', choices=['adamw', 'sgd'])
1100
+ parser.add_argument('--scheduler', type=str, default='cosine', choices=['cosine', 'plateau', 'none'])
1101
+ parser.add_argument('--warmup_ratio', type=float, default=0.05, help='Warmup ratio')
1102
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm')
1103
+ parser.add_argument('--grad_accum_steps', type=int, default=1, help='Gradient accumulation steps')
1104
+ parser.add_argument('--patience', type=int, default=10, help='Early stopping patience')
1105
+ parser.add_argument('--log_interval', type=int, default=10, help='Log every N batches')
1106
+
1107
+ # Augmentation
1108
+ parser.add_argument('--aug_strength', type=str, default='medium', choices=['light', 'medium', 'heavy'])
1109
+ parser.add_argument('--log_transform', action='store_true', default=True, help='Log-transform targets')
1110
+ parser.add_argument('--no_log_transform', action='store_true', help='Disable log-transform')
1111
+
1112
+ # LDS
1113
+ parser.add_argument('--use_lds', action='store_true', help='Use Label Distribution Smoothing')
1114
+ parser.add_argument('--lds_bins', type=int, default=100)
1115
+ parser.add_argument('--lds_kernel_size', type=int, default=5)
1116
+ parser.add_argument('--lds_sigma', type=float, default=2.0)
1117
+
1118
+ # Loss
1119
+ parser.add_argument('--mse_weight', type=float, default=0.0, help='MSE loss weight')
1120
+ parser.add_argument('--consistency_weight', type=float, default=0.1, help='Consistency loss weight')
1121
+
1122
+ # CV
1123
+ parser.add_argument('--n_folds', type=int, default=5, help='Number of CV folds')
1124
+ parser.add_argument('--fold', type=int, default=None, help='Train single fold (None=all)')
1125
+
1126
+ # Misc
1127
+ parser.add_argument('--seed', type=int, default=42)
1128
+ parser.add_argument('--num_workers', type=int, default=4)
1129
+ parser.add_argument('--mixed_precision', action='store_true', default=True)
1130
+
1131
+ args = parser.parse_args()
1132
+
1133
+ if args.no_log_transform:
1134
+ args.log_transform = False
1135
+
1136
+ return args
1137
+
1138
+
1139
+ def main():
1140
+ args = get_args()
1141
+ set_seed(args.seed)
1142
+
1143
+ # Load data
1144
+ train_df, test_df, sample_sub, train_img_dir, test_img_dir = load_competition_data(args.data_dir)
1145
+
1146
+ # Separate features and targets
1147
+ targets = train_df[TARGET_COLS].copy()
1148
+
1149
+ # Create output directory
1150
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
1151
+
1152
+ # Save args
1153
+ with open(Path(args.output_dir) / 'args.json', 'w') as f:
1154
+ json.dump(vars(args), f, indent=2)
1155
+
1156
+ # Train
1157
+ trainer = Trainer(args)
1158
+
1159
+ if args.fold is not None:
1160
+ # Single fold training
1161
+ from sklearn.model_selection import StratifiedKFold
1162
+ bins = pd.qcut(targets['Dry_Total_g'], q=min(10, args.n_folds), labels=False, duplicates='drop')
1163
+ kf = StratifiedKFold(n_splits=args.n_folds, shuffle=True, random_state=args.seed)
1164
+
1165
+ for fold_idx, (train_idx, val_idx) in enumerate(kf.split(train_df, bins)):
1166
+ if fold_idx == args.fold:
1167
+ train_fold_df = train_df.iloc[train_idx]
1168
+ val_fold_df = train_df.iloc[val_idx]
1169
+ train_targets = targets.iloc[train_idx]
1170
+ val_targets = targets.iloc[val_idx]
1171
+
1172
+ model, val_metrics = trainer.train_fold(
1173
+ args.fold, train_fold_df, val_fold_df,
1174
+ train_targets, val_targets, train_img_dir
1175
+ )
1176
+ break
1177
+ else:
1178
+ # Full K-fold training
1179
+ overall_r2, fold_scores = trainer.train_kfold(train_df, targets, train_img_dir)
1180
+
1181
+ logger.info("Training complete!")
1182
+
1183
+
1184
+ if __name__ == '__main__':
1185
+ main()