File size: 6,363 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np

cimport cython
cimport numpy as np

from libc.stdint cimport int32_t, int64_t
from libcpp cimport bool as bool_t

ctypedef int64_t DTYPE_t

@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_vec(
    np.ndarray[int64_t, ndim=1] indices,
    np.ndarray[int64_t, ndim=1] num_tokens_vec,
    int64_t max_tokens,
    int64_t max_sentences,
    int32_t bsz_mult,
):
    if indices.shape[0] == 0:
        return []

    assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
        f"Sentences lengths should not exceed max_tokens={max_tokens}"
    )

    cdef int32_t indices_len = indices.shape[0]
    cdef np.ndarray[int32_t, ndim=1] batches_ends = \
            np.zeros(indices_len, dtype=np.int32)
    cdef int32_t[:] batches_ends_view = batches_ends
    cdef int64_t[:] num_tokens_view = num_tokens_vec

    cdef int32_t pos = 0
    cdef int32_t new_batch_end = 0

    cdef int64_t new_batch_max_tokens = 0
    cdef int32_t new_batch_sentences = 0
    cdef int64_t new_batch_num_tokens = 0

    cdef bool_t overflow = False
    cdef bool_t size_matches_with_bsz_mult = False

    cdef int32_t batches_count = 0
    cdef int32_t batch_start = 0
    cdef int64_t tail_max_tokens = 0
    cdef int64_t batch_max_tokens = 0

    for pos in range(indices_len):
        # At every pos we keep stats about the last complete batch [batch_start:batch_end),
        #      and tail [batch_end:pos].
        # 1) Every time when (batch + tail) forms a valid batch
        #      (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
        # 2) When (batch+tail) violates max_tokens or max_sentences constraints
        #      we finalize running batch, and tail becomes a new batch.
        # 3) There is a corner case when tail also violates constraints.
        #      In that situation [batch_end:pos-1] (tail without the current pos)
        #      gets added to the finalized batches, while [pos:pos] becomes a new tail.
        #
        # Important: For the sake of performance try to avoid using function calls within this loop.

        tail_max_tokens = tail_max_tokens \
                            if tail_max_tokens > num_tokens_view[pos] \
                            else num_tokens_view[pos]
        new_batch_end = pos + 1
        new_batch_max_tokens = batch_max_tokens \
                                if batch_max_tokens > tail_max_tokens \
                                else tail_max_tokens
        new_batch_sentences = new_batch_end - batch_start
        new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens

        overflow = (new_batch_sentences > max_sentences > 0 or
                    new_batch_num_tokens > max_tokens > 0)
        size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
                                      new_batch_sentences % bsz_mult == 0)

        if overflow:
            tail_num_tokens = tail_max_tokens * \
                    (new_batch_end - batches_ends_view[batches_count])
            tail_overflow = tail_num_tokens > max_tokens > 0
            # In case of a tail overflow finalize two batches
            if tail_overflow:
                batches_count += 1
                batches_ends_view[batches_count] = pos
                tail_max_tokens = num_tokens_view[pos]
            batch_start = batches_ends_view[batches_count]
            batches_count += 1
            new_batch_max_tokens = tail_max_tokens

        if overflow or size_matches_with_bsz_mult:
            batches_ends_view[batches_count] = new_batch_end
            batch_max_tokens = new_batch_max_tokens
            tail_max_tokens = 0
    if batches_ends_view[batches_count] != indices_len:
        batches_count += 1
    # Memory and time-efficient split
    return np.split(indices, batches_ends[:batches_count])


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_fn(
    np.ndarray[DTYPE_t, ndim=1] indices,
    num_tokens_fn,
    int64_t max_tokens,
    int64_t max_sentences,
    int32_t bsz_mult,
):
    cdef int32_t indices_len = indices.shape[0]
    cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
                                                               dtype=np.int64)
    cdef DTYPE_t[:] indices_view = indices
    cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
    cdef int64_t pos
    for pos in range(indices_len):
        num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
    return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
        max_sentences, bsz_mult,)


cdef _find_valid_shape(
    DTYPE_t[:, :] shapes_view,
    int64_t num_sentences,
    int64_t num_tokens,
):
    """Return index of first valid shape of -1 if none is found."""
    for i in range(shapes_view.shape[0]):
        if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
            return i
    return -1


@cython.cdivision(True)
cpdef list batch_fixed_shapes_fast(
    np.ndarray[DTYPE_t, ndim=1] indices,
    num_tokens_fn,
    np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
):
    cdef int64_t sample_len = 0
    cdef list sample_lens = []
    cdef list batch = []
    cdef list batches = []
    cdef int64_t mod_len
    cdef int64_t i
    cdef int64_t idx
    cdef int64_t num_tokens
    cdef DTYPE_t[:] indices_view = indices
    cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted

    for i in range(len(indices_view)):
        idx = indices_view[i]
        num_tokens = num_tokens_fn(idx)
        sample_lens.append(num_tokens)
        sample_len = max(sample_len, num_tokens)

        shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
        if shape_idx == -1:
            batches.append(batch)
            batch = []
            sample_lens = []
            sample_len = 0
            shapes_view = fixed_shapes_sorted
        elif shape_idx > 0:
            # small optimization for the next call to _find_valid_shape
            shapes_view = shapes_view[shape_idx:]

        batch.append(idx)

    if len(batch) > 0:
        batches.append(batch)

    return batches