Spaces:
Runtime error
Runtime error
File size: 2,076 Bytes
a38e25f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from datetime import datetime
from tqdm import tqdm
# torch
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
# internal
from dataset import VoiceDataset
from cnn import CNNetwork
BATCH_SIZE = 128
EPOCHS = 100
LEARNING_RATE = 0.001
TRAIN_FILE="data/train"
SAMPLE_RATE=16000
def train(model, dataloader, loss_fn, optimizer, device, epochs):
for i in tqdm(range(epochs), "Training model..."):
print(f"Epoch {i + 1}")
train_epoch(model, dataloader, loss_fn, optimizer, device)
print (f"----------------------------------- \n")
print("---- Finished Training ----")
def train_epoch(model, dataloader, loss_fn, optimizer, device):
for x, y in dataloader:
x, y = x.to(device), y.to(device)
# calculate loss
pred = model(x)
loss = loss_fn(pred, y)
# backprop and update weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Loss: {loss.item()}")
if __name__ == "__main__":
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using {device} device.")
# instantiating our dataset object and create data loader
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=SAMPLE_RATE,
n_fft=1024,
hop_length=512,
n_mels=64
)
train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, SAMPLE_RATE, device)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# construct model
model = CNNetwork().to(device)
print(model)
# init loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# train model
train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)
# save model
now = datetime.now()
now = now.strftime("%Y%m%d_%H%M%S")
model_filename = f"models/void_{now}.pth"
torch.save(model.state_dict(), model_filename)
print(f"Trained feed forward net saved at {model_filename}") |