File size: 2,076 Bytes
a38e25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from datetime import datetime
from tqdm import tqdm

# torch
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader

# internal
from dataset import VoiceDataset
from cnn import CNNetwork

BATCH_SIZE = 128
EPOCHS = 100
LEARNING_RATE = 0.001

TRAIN_FILE="data/train"
SAMPLE_RATE=16000

def train(model, dataloader, loss_fn, optimizer, device, epochs):
  for i in tqdm(range(epochs), "Training model..."):
    print(f"Epoch {i + 1}")

    train_epoch(model, dataloader, loss_fn, optimizer, device)

    print (f"----------------------------------- \n")
  
  print("---- Finished Training ----")
  

def train_epoch(model, dataloader, loss_fn, optimizer, device):
  for x, y in dataloader:
    x, y = x.to(device), y.to(device)

    # calculate loss
    pred = model(x)
    loss = loss_fn(pred, y)

    # backprop and update weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  print(f"Loss: {loss.item()}") 

if __name__ == "__main__":
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device} device.")

    # instantiating our dataset object and create data loader
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=1024,
        hop_length=512,
        n_mels=64
    )

    train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, SAMPLE_RATE, device)
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # construct model
    model = CNNetwork().to(device)
    print(model)

    # init loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


    # train model
    train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)

    # save model
    now = datetime.now()
    now = now.strftime("%Y%m%d_%H%M%S")
    model_filename = f"models/void_{now}.pth"
    torch.save(model.state_dict(), model_filename)
    print(f"Trained feed forward net saved at {model_filename}")