File size: 3,293 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import math
import numbers
from typing import Optional
import numpy as np


class SpecAugment():
    """SpecAugment (https://arxiv.org/abs/1904.08779)"""
    def __init__(self, args):
        self.time_warp_w, self.freq_mask_n, self.freq_mask_f = 0, 0, 0
        self.time_mask_n, self.time_mask_t = 0, 0
        self.time_mask_p, self.mask_value = 0.0, None
        if "specaugment" in args:
            self.time_warp_w = args.specaugment["time_warp_W"]
            self.freq_mask_n = args.specaugment["freq_mask_N"]
            self.freq_mask_f = args.specaugment["freq_mask_F"]
            self.time_mask_n = args.specaugment["time_mask_N"]
            self.time_mask_t = args.specaugment["time_mask_T"]
            self.time_mask_p = args.specaugment["time_mask_p"]
            if "mask_value" in args.specaugment:
                self.mask_value = args.specaugment["mask_value"]

        if self.freq_mask_n > 0:
            assert self.freq_mask_f > 0, (
                f"freq_mask_F ({self.freq_mask_f}) must be larger than 0 when doing freq masking."
            )
        if self.time_mask_n > 0:
            assert self.time_mask_t > 0, (
                f"time_mask_T ({self.time_mask_t}) must be larger than 0 when doing time masking."
            )

    def __call__(self, spectrogram):
        assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."

        distorted = spectrogram.copy()  # make a copy of input spectrogram.
        num_frames = spectrogram.shape[0]  # or 'tau' in the paper.
        num_freqs = spectrogram.shape[1]  # or 'miu' in the paper.
        mask_value = self.mask_value

        if mask_value is None:  # if no value was specified, use local mean.
            mask_value = spectrogram.mean()

        if num_frames == 0:
            return spectrogram

        if num_freqs < self.freq_mask_f:
            return spectrogram

        if self.time_warp_w > 0:
            if 2 * self.time_warp_w < num_frames:
                import cv2

                w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
                w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
                upper, lower = distorted[:w0, :], distorted[w0:, :]
                upper = cv2.resize(
                    upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
                )
                lower = cv2.resize(
                    lower,
                    dsize=(num_freqs, num_frames - w0 - w),
                    interpolation=cv2.INTER_LINEAR,
                )
                distorted = np.concatenate((upper, lower), axis=0)

        for _i in range(self.freq_mask_n):
            f = np.random.randint(0, self.freq_mask_f)
            f0 = np.random.randint(0, num_freqs - f)
            if f != 0:
                distorted[:, f0 : f0 + f] = mask_value

        max_time_mask_t = min(
            self.time_mask_t, math.floor(num_frames * self.time_mask_p)
        )
        if max_time_mask_t < 1:
            return distorted

        for _i in range(self.time_mask_n):
            t = np.random.randint(0, max_time_mask_t)
            t0 = np.random.randint(0, num_frames - t)
            if t != 0:
                distorted[t0 : t0 + t, :] = mask_value

        return distorted