File size: 6,685 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import json
import glob
import torch
import random
from tqdm import tqdm

# from deepafx_st.plugins.channel import Channel
from deepafx_st.processors.processor import Processor
from deepafx_st.data.audio import AudioFile
import deepafx_st.utils as utils


class DSPProxyDataset(torch.utils.data.Dataset):
    """Class for generating input-output audio from Python DSP effects.

    Args:
        input_dir (List[str]): List of paths to the directories containing input audio files.
        processor (Processor): Processor object to create proxy of.
        processor_type (str): Processor name.
        subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
        buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0
            Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers
        buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000
        length (int, optional): Number of samples to load for each example. Default: 65536
        num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
        ext (str, optional): Expected audio file extension. Default: "wav"
        hard_clip (bool, optional): Hard clip outputs between -1 and 1. Default: True
    """

    def __init__(
        self,
        input_dir: str,
        processor: Processor,
        processor_type: str,
        subset="train",
        length=65536,
        buffer_size_gb=1.0,
        buffer_reload_rate=1000,
        half=False,
        num_examples_per_epoch=10000,
        ext="wav",
        soft_clip=True,
    ):
        super().__init__()
        self.input_dir = input_dir
        self.processor = processor
        self.processor_type = processor_type
        self.subset = subset
        self.length = length
        self.buffer_size_gb = buffer_size_gb
        self.buffer_reload_rate = buffer_reload_rate
        self.half = half
        self.num_examples_per_epoch = num_examples_per_epoch
        self.ext = ext
        self.soft_clip = soft_clip

        search_path = os.path.join(input_dir, f"*.{ext}")
        self.input_filepaths = glob.glob(search_path)
        self.input_filepaths = sorted(self.input_filepaths)

        if len(self.input_filepaths) < 1:
            raise RuntimeError(f"No files found in {input_dir}.")

        # get training split
        self.input_filepaths = utils.split_dataset(
            self.input_filepaths, self.subset, 0.9
        )

        # get details about audio files
        cnt = 0
        self.input_files = {}
        for input_filepath in tqdm(self.input_filepaths, ncols=80):
            file_id = os.path.basename(input_filepath)
            audio_file = AudioFile(
                input_filepath,
                preload=False,
                half=half,
            )
            if audio_file.num_frames < self.length:
                continue
            self.input_files[file_id] = audio_file
            self.sample_rate = self.input_files[file_id].sample_rate
            cnt += 1
            if cnt > 1000:
                break

        # some setup for iteratble loading of the dataset into RAM
        self.items_since_load = self.buffer_reload_rate

    def __len__(self):
        return self.num_examples_per_epoch

    def load_audio_buffer(self):
        self.input_files_loaded = {}  # clear audio buffer
        self.items_since_load = 0  # reset iteration counter
        nbytes_loaded = 0  # counter for data in RAM

        # different subset in each
        random.shuffle(self.input_filepaths)

        # load files into RAM
        for input_filepath in self.input_filepaths:
            file_id = os.path.basename(input_filepath)
            audio_file = AudioFile(
                input_filepath,
                preload=True,
                half=self.half,
            )

            if audio_file.num_frames < self.length:
                continue

            self.input_files_loaded[file_id] = audio_file

            nbytes = audio_file.audio.element_size() * audio_file.audio.nelement()
            nbytes_loaded += nbytes

            if nbytes_loaded > self.buffer_size_gb * 1e9:
                break

    def __getitem__(self, _):
        """ """

        # increment counter
        self.items_since_load += 1

        # load next chunk into buffer if needed
        if self.items_since_load > self.buffer_reload_rate:
            self.load_audio_buffer()

        rand_input_file_id = utils.get_random_file_id(self.input_files_loaded.keys())
        # use this random key to retrieve an input file
        input_file = self.input_files_loaded[rand_input_file_id]

        # load the audio data if needed
        if not input_file.loaded:
            input_file.load()

        # get a random patch of size `self.length`
        # start_idx, stop_idx = utils.get_random_patch(input_file, self.sample_rate, self.length)
        start_idx, stop_idx = utils.get_random_patch(input_file, self.length)
        input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach()

        # random scaling
        input_audio /= input_audio.abs().max()
        scale_dB = (torch.rand(1).squeeze().numpy() * 12) + 12
        input_audio *= 10 ** (-scale_dB / 20.0)

        # generate random parameters (uniform) over 0 to 1
        params = torch.rand(self.processor.num_control_params)

        # expects batch dim
        # apply plugins with random parameters
        if self.processor_type == "channel":
            params[-1] = 0.5  # set makeup gain to 0dB
            target_audio = self.processor(
                input_audio.view(1, 1, -1),
                params.view(1, -1),
            )
            target_audio = target_audio.view(1, -1)
        elif self.processor_type == "peq":
            target_audio = self.processor(
                input_audio.view(1, 1, -1).numpy(),
                params.view(1, -1).numpy(),
            )
            target_audio = torch.tensor(target_audio).view(1, -1)
        elif self.processor_type == "comp":
            params[-1] = 0.5  # set makeup gain to 0dB
            target_audio = self.processor(
                input_audio.view(1, 1, -1).numpy(),
                params.view(1, -1).numpy(),
            )
            target_audio = torch.tensor(target_audio).view(1, -1)

        # clip
        if self.soft_clip:
            # target_audio = target_audio.clamp(-2.0, 2.0)
            target_audio = torch.tanh(target_audio / 2.0) * 2.0

        return input_audio, target_audio, params