basel-weather / train.py
jonwiese
add train.py
c1f29cc
raw
history blame
4.94 kB
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from datetime import datetime
class WeatherPredictor:
def __init__(self, data_path):
# Load and preprocess data
self.df = pd.read_csv(data_path, parse_dates=['datetime'],
date_parser=lambda x: datetime.strptime(x, '%d/%m/%y'))
self.df['day'] = self.df['datetime'].dt.day
self.df['month'] = self.df['datetime'].dt.month
self.df['year'] = self.df['datetime'].dt.year
self.df['day_sin'] = np.sin(2 * np.pi * self.df['day'] / 31)
self.df['day_cos'] = np.cos(2 * np.pi * self.df['day'] / 31)
self.df['month_sin'] = np.sin(2 * np.pi * self.df['month'] / 12)
self.df['month_cos'] = np.cos(2 * np.pi * self.df['month'] / 12)
self.df['year'] = self.df['datetime'].dt.year
features = ['day_sin', 'day_cos', 'month_sin', 'month_cos', 'year']
target_columns = ['temp', 'precip', 'snow', 'windspeed']
# Scale features and targets
self.feature_scaler = MinMaxScaler()
self.target_scaler = MinMaxScaler()
X = self.feature_scaler.fit_transform(self.df[features])
Y = self.target_scaler.fit_transform(self.df[target_columns])
self.X_tensor = torch.FloatTensor(X)
self.Y_tensor = torch.FloatTensor(Y)
# Single model for all targets
input_dim = len(features)
self.model = nn.Sequential(
nn.Linear(input_dim, 16),
nn.ReLU(),
nn.Linear(16, 8),
nn.ReLU(),
nn.Linear(8, 4)
)
def train(self, epochs=1000):
# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(self.model.parameters(), lr=0.01)
for epoch in range(epochs):
# Forward pass
outputs = self.model(self.X_tensor) # Multi-output predictions
loss = criterion(outputs, self.Y_tensor)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f'Epoch [{epoch}/{epochs}], Loss: {loss.item():.4f}')
# Save the model after training
self.save_model('weather_predictor.pth')
def predict(self, input_date):
# Convert input date to features
date = datetime.strptime(input_date, '%d/%m/%y')
features = [
np.sin(2 * np.pi * date.day / 31),
np.cos(2 * np.pi * date.day / 31),
np.sin(2 * np.pi * date.month / 12),
np.cos(2 * np.pi * date.month / 12),
date.year
]
# Transform features to match training scale
scaled_features = self.feature_scaler.transform([features])
input_tensor = torch.FloatTensor(scaled_features)
# Predict outputs
with torch.no_grad():
scaled_predictions = self.model(input_tensor).numpy() # Outputs: [temp, precip, snow, windspeed]
predictions = self.target_scaler.inverse_transform(scaled_predictions.reshape(1, -1)).flatten()
# Map predictions to target columns
target_columns = ['temp', 'precip', 'snow', 'windspeed']
return dict(zip(target_columns, predictions))
def predict(self, input_date):
# Convert input date to features
date = datetime.strptime(input_date, '%d/%m/%y')
features = [
np.sin(2 * np.pi * date.day / 31),
np.cos(2 * np.pi * date.day / 31),
np.sin(2 * np.pi * date.month / 12),
np.cos(2 * np.pi * date.month / 12),
date.year
]
# Transform features to match training scale
scaled_features = self.feature_scaler.transform([features])
input_tensor = torch.FloatTensor(scaled_features)
# Load the model before making predictions
self.load_model('weather_predictor.pth')
# Predict outputs
with torch.no_grad():
scaled_predictions = self.model(input_tensor).numpy() # Outputs: [temp, precip, snow, windspeed]
predictions = self.target_scaler.inverse_transform(scaled_predictions.reshape(1, -1)).flatten()
# Map predictions to target columns
target_columns = ['temp', 'precip', 'snow', 'windspeed']
return dict(zip(target_columns, predictions))
def save_model(self, file_path):
torch.save(self.model.state_dict(), file_path)
def load_model(self, file_path):
self.model.load_state_dict(torch.load(file_path))
self.model.eval()
def main():
predictor = WeatherPredictor('basel-weather.csv')
predictor.train()
# Predict for a specific date
result = predictor.predict('01/02/23')
print("Predictions:", result)
if __name__ == '__main__':
main()