|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <math.h> |
|
#include <pybind11/numpy.h> |
|
#include <pybind11/pybind11.h> |
|
#include <algorithm> |
|
#include <iostream> |
|
#include <limits> |
|
#include <random> |
|
#include <stdexcept> |
|
|
|
namespace py = pybind11; |
|
using namespace std; |
|
|
|
const int32_t LONG_SENTENCE_LEN = 512; |
|
|
|
void build_blending_indices(py::array_t<uint8_t>& dataset_index, |
|
py::array_t<int64_t>& dataset_sample_index, |
|
const py::array_t<double>& weights, |
|
const int32_t num_datasets, |
|
const int64_t size, |
|
const bool verbose) |
|
{ |
|
|
|
|
|
|
|
if (verbose) { std::cout << "> building indices for blendable datasets ..." << std::endl; } |
|
|
|
|
|
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); |
|
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); |
|
auto weights_ptr = weights.unchecked<1>(); |
|
|
|
|
|
int64_t current_samples[num_datasets]; |
|
for (int64_t i = 0; i < num_datasets; ++i) { current_samples[i] = 0; } |
|
|
|
|
|
for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { |
|
|
|
double sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0); |
|
int64_t max_error_index = 0; |
|
double max_error = |
|
weights_ptr[0] * sample_idx_double - static_cast<double>(current_samples[0]); |
|
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { |
|
double error = weights_ptr[dataset_idx] * sample_idx_double - |
|
static_cast<double>(current_samples[dataset_idx]); |
|
if (error > max_error) { |
|
max_error = error; |
|
max_error_index = dataset_idx; |
|
} |
|
} |
|
|
|
|
|
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index); |
|
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; |
|
|
|
|
|
current_samples[max_error_index] += 1; |
|
} |
|
|
|
|
|
if (verbose) { |
|
std::cout << " > sample ratios:" << std::endl; |
|
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { |
|
auto ratio = |
|
static_cast<double>(current_samples[dataset_idx]) / static_cast<double>(size); |
|
std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] |
|
<< ", achieved: " << ratio << std::endl; |
|
} |
|
} |
|
} |
|
|
|
py::array build_sample_idx_int32(const py::array_t<int32_t>& sizes_, |
|
const py::array_t<int32_t>& doc_idx_, |
|
const int32_t seq_length, |
|
const int32_t num_epochs, |
|
const int64_t tokens_per_epoch) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert(seq_length > 1); |
|
assert(num_epochs > 0); |
|
assert(tokens_per_epoch > 1); |
|
|
|
|
|
auto sizes = sizes_.unchecked<1>(); |
|
auto doc_idx = doc_idx_.unchecked<1>(); |
|
|
|
|
|
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; |
|
int32_t* sample_idx = new int32_t[2 * (num_samples + 1)]; |
|
|
|
cout << " using:" << endl << std::flush; |
|
cout << " number of documents: " << doc_idx_.shape(0) / num_epochs << endl |
|
<< std::flush; |
|
cout << " number of epochs: " << num_epochs << endl << std::flush; |
|
cout << " sequence length: " << seq_length << endl << std::flush; |
|
cout << " total number of samples: " << num_samples << endl << std::flush; |
|
|
|
|
|
int64_t sample_index = 0; |
|
|
|
int64_t doc_idx_index = 0; |
|
|
|
int32_t doc_offset = 0; |
|
|
|
sample_idx[2 * sample_index] = doc_idx_index; |
|
sample_idx[2 * sample_index + 1] = doc_offset; |
|
++sample_index; |
|
|
|
while (sample_index <= num_samples) { |
|
|
|
int32_t remaining_seq_length = seq_length + 1; |
|
while (remaining_seq_length != 0) { |
|
|
|
auto doc_id = doc_idx[doc_idx_index]; |
|
auto doc_length = sizes[doc_id] - doc_offset; |
|
|
|
remaining_seq_length -= doc_length; |
|
|
|
|
|
|
|
|
|
if (remaining_seq_length <= 0) { |
|
doc_offset += (remaining_seq_length + doc_length - 1); |
|
remaining_seq_length = 0; |
|
} else { |
|
|
|
++doc_idx_index; |
|
doc_offset = 0; |
|
} |
|
} |
|
|
|
sample_idx[2 * sample_index] = doc_idx_index; |
|
sample_idx[2 * sample_index + 1] = doc_offset; |
|
++sample_index; |
|
} |
|
|
|
|
|
py::capsule free_when_done(sample_idx, [](void* mem_) { |
|
int32_t* mem = reinterpret_cast<int32_t*>(mem_); |
|
delete[] mem; |
|
}); |
|
|
|
|
|
const auto byte_size = sizeof(int32_t); |
|
return py::array(std::vector<int64_t>{num_samples + 1, 2}, |
|
{2 * byte_size, byte_size}, |
|
sample_idx, |
|
free_when_done); |
|
} |
|
|
|
py::array build_sample_idx_int64(const py::array_t<int32_t>& sizes_, |
|
const py::array_t<int32_t>& doc_idx_, |
|
const int32_t seq_length, |
|
const int32_t num_epochs, |
|
const int64_t tokens_per_epoch) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert(seq_length > 1); |
|
assert(num_epochs > 0); |
|
assert(tokens_per_epoch > 1); |
|
|
|
|
|
auto sizes = sizes_.unchecked<1>(); |
|
auto doc_idx = doc_idx_.unchecked<1>(); |
|
|
|
|
|
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; |
|
int64_t* sample_idx = new int64_t[2 * (num_samples + 1)]; |
|
|
|
cout << " using:" << endl << std::flush; |
|
cout << " number of documents: " << doc_idx_.shape(0) / num_epochs << endl |
|
<< std::flush; |
|
cout << " number of epochs: " << num_epochs << endl << std::flush; |
|
cout << " sequence length: " << seq_length << endl << std::flush; |
|
cout << " total number of samples: " << num_samples << endl << std::flush; |
|
|
|
|
|
int64_t sample_index = 0; |
|
|
|
int64_t doc_idx_index = 0; |
|
|
|
int32_t doc_offset = 0; |
|
|
|
sample_idx[2 * sample_index] = doc_idx_index; |
|
sample_idx[2 * sample_index + 1] = doc_offset; |
|
++sample_index; |
|
|
|
while (sample_index <= num_samples) { |
|
|
|
int32_t remaining_seq_length = seq_length + 1; |
|
while (remaining_seq_length != 0) { |
|
|
|
auto doc_id = doc_idx[doc_idx_index]; |
|
auto doc_length = sizes[doc_id] - doc_offset; |
|
|
|
remaining_seq_length -= doc_length; |
|
|
|
|
|
|
|
|
|
if (remaining_seq_length <= 0) { |
|
doc_offset += (remaining_seq_length + doc_length - 1); |
|
remaining_seq_length = 0; |
|
} else { |
|
|
|
++doc_idx_index; |
|
doc_offset = 0; |
|
} |
|
} |
|
|
|
sample_idx[2 * sample_index] = doc_idx_index; |
|
sample_idx[2 * sample_index + 1] = doc_offset; |
|
++sample_index; |
|
} |
|
|
|
|
|
py::capsule free_when_done(sample_idx, [](void* mem_) { |
|
int64_t* mem = reinterpret_cast<int64_t*>(mem_); |
|
delete[] mem; |
|
}); |
|
|
|
|
|
const auto byte_size = sizeof(int64_t); |
|
return py::array(std::vector<int64_t>{num_samples + 1, 2}, |
|
{2 * byte_size, byte_size}, |
|
sample_idx, |
|
free_when_done); |
|
} |
|
|
|
inline int32_t get_target_sample_len(const int32_t short_seq_ratio, |
|
const int32_t max_length, |
|
std::mt19937& rand32_gen) |
|
{ |
|
|
|
const auto random_number = rand32_gen(); |
|
if ((random_number % short_seq_ratio) == 0) { return 2 + random_number % (max_length - 1); } |
|
return max_length; |
|
} |
|
|
|
template <typename DocIdx> |
|
py::array build_mapping_impl(const py::array_t<int64_t>& docs_, |
|
const py::array_t<int32_t>& sizes_, |
|
const int32_t num_epochs, |
|
const uint64_t max_num_samples, |
|
const int32_t max_seq_length, |
|
const double short_seq_prob, |
|
const int32_t seed, |
|
const bool verbose) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert(num_epochs > 0); |
|
assert(max_seq_length > 1); |
|
assert(short_seq_prob > 0.0); |
|
assert(short_seq_prob <= 1.0); |
|
assert(seed > 0); |
|
|
|
|
|
auto docs = docs_.unchecked<1>(); |
|
auto sizes = sizes_.unchecked<1>(); |
|
|
|
|
|
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob)); |
|
|
|
if (verbose) { |
|
const auto sent_start_index = docs[0]; |
|
const auto sent_end_index = docs[docs_.shape(0) - 1]; |
|
const auto num_sentences = sent_end_index - sent_start_index; |
|
cout << " using:" << endl << std::flush; |
|
cout << " number of documents: " << docs_.shape(0) - 1 << endl << std::flush; |
|
cout << " sentences range: [" << sent_start_index << ", " |
|
<< sent_end_index << ")" << endl |
|
<< std::flush; |
|
cout << " total number of sentences: " << num_sentences << endl << std::flush; |
|
cout << " number of epochs: " << num_epochs << endl << std::flush; |
|
cout << " maximum number of samples: " << max_num_samples << endl << std::flush; |
|
cout << " maximum sequence length: " << max_seq_length << endl << std::flush; |
|
cout << " short sequence probability: " << short_seq_prob << endl << std::flush; |
|
cout << " short sequence ration (1/prob): " << short_seq_ratio << endl << std::flush; |
|
cout << " seed: " << seed << endl << std::flush; |
|
} |
|
|
|
|
|
int64_t num_samples = -1; |
|
DocIdx* maps = NULL; |
|
|
|
|
|
|
|
bool second = false; |
|
for (int32_t iteration = 0; iteration < 2; ++iteration) { |
|
|
|
std::mt19937 rand32_gen(seed); |
|
|
|
|
|
second = (iteration == 1); |
|
|
|
|
|
uint64_t empty_docs = 0; |
|
uint64_t one_sent_docs = 0; |
|
uint64_t long_sent_docs = 0; |
|
|
|
|
|
uint64_t map_index = 0; |
|
|
|
|
|
for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { |
|
if (map_index >= max_num_samples) { |
|
if (verbose && (!second)) { |
|
cout << " reached " << max_num_samples << " samples after " << epoch |
|
<< " epochs ..." << endl |
|
<< std::flush; |
|
} |
|
break; |
|
} |
|
|
|
for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { |
|
|
|
const auto sent_index_first = docs[doc]; |
|
const auto sent_index_last = docs[doc + 1]; |
|
|
|
|
|
|
|
auto prev_start_index = sent_index_first; |
|
|
|
|
|
auto num_remain_sent = sent_index_last - sent_index_first; |
|
|
|
|
|
if ((epoch == 0) && (!second)) { |
|
if (num_remain_sent == 0) { ++empty_docs; } |
|
if (num_remain_sent == 1) { ++one_sent_docs; } |
|
} |
|
|
|
|
|
bool contains_long_sentence = false; |
|
if (num_remain_sent > 1) { |
|
for (auto sent_index = sent_index_first; sent_index < sent_index_last; |
|
++sent_index) { |
|
if (sizes[sent_index] > LONG_SENTENCE_LEN) { |
|
if ((epoch == 0) && (!second)) { ++long_sent_docs; } |
|
contains_long_sentence = true; |
|
break; |
|
} |
|
} |
|
} |
|
|
|
|
|
if ((num_remain_sent > 1) && (!contains_long_sentence)) { |
|
|
|
auto seq_len = int32_t{0}; |
|
auto num_sent = int32_t{0}; |
|
auto target_seq_len = |
|
get_target_sample_len(short_seq_ratio, max_seq_length, rand32_gen); |
|
|
|
|
|
for (auto sent_index = sent_index_first; sent_index < sent_index_last; |
|
++sent_index) { |
|
|
|
seq_len += sizes[sent_index]; |
|
++num_sent; |
|
--num_remain_sent; |
|
|
|
|
|
|
|
|
|
|
|
if (((seq_len >= target_seq_len) && (num_remain_sent > 1) && |
|
(num_sent > 1)) || |
|
(num_remain_sent == 0)) { |
|
|
|
if ((3 * map_index + 2) > std::numeric_limits<int64_t>::max()) { |
|
cout << "number of samples exceeded maximum " |
|
<< "allowed by type int64: " |
|
<< std::numeric_limits<int64_t>::max() << endl; |
|
throw std::overflow_error("Number of samples"); |
|
} |
|
|
|
|
|
if (second) { |
|
const auto map_index_0 = 3 * map_index; |
|
maps[map_index_0] = static_cast<DocIdx>(prev_start_index); |
|
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1); |
|
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len); |
|
} |
|
|
|
|
|
++map_index; |
|
prev_start_index = sent_index + 1; |
|
target_seq_len = |
|
get_target_sample_len(short_seq_ratio, max_seq_length, rand32_gen); |
|
seq_len = 0; |
|
num_sent = 0; |
|
} |
|
|
|
} |
|
} |
|
} |
|
} |
|
|
|
if (!second) { |
|
if (verbose) { |
|
cout << " number of empty documents: " << empty_docs << endl << std::flush; |
|
cout << " number of documents with one sentence: " << one_sent_docs << endl |
|
<< std::flush; |
|
cout << " number of documents with long sentences: " << long_sent_docs << endl |
|
<< std::flush; |
|
cout << " will create mapping for " << map_index << " samples" << endl |
|
<< std::flush; |
|
} |
|
assert(maps == NULL); |
|
assert(num_samples < 0); |
|
maps = new DocIdx[3 * map_index]; |
|
num_samples = static_cast<int64_t>(map_index); |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
std::mt19937_64 rand64_gen(seed + 1); |
|
for (auto i = (num_samples - 1); i > 0; --i) { |
|
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1)); |
|
const auto i0 = 3 * i; |
|
const auto j0 = 3 * j; |
|
|
|
swap(maps[i0], maps[j0]); |
|
swap(maps[i0 + 1], maps[j0 + 1]); |
|
swap(maps[i0 + 2], maps[j0 + 2]); |
|
} |
|
|
|
|
|
py::capsule free_when_done(maps, [](void* mem_) { |
|
DocIdx* mem = reinterpret_cast<DocIdx*>(mem_); |
|
delete[] mem; |
|
}); |
|
|
|
|
|
const auto byte_size = sizeof(DocIdx); |
|
return py::array(std::vector<int64_t>{num_samples, 3}, |
|
{3 * byte_size, byte_size}, |
|
maps, |
|
free_when_done); |
|
} |
|
|
|
py::array build_mapping(const py::array_t<int64_t>& docs_, |
|
const py::array_t<int>& sizes_, |
|
const int num_epochs, |
|
const uint64_t max_num_samples, |
|
const int max_seq_length, |
|
const double short_seq_prob, |
|
const int seed, |
|
const bool verbose) |
|
{ |
|
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { |
|
if (verbose) { cout << " using uint64 for data mapping..." << endl << std::flush; } |
|
return build_mapping_impl<uint64_t>(docs_, |
|
sizes_, |
|
num_epochs, |
|
max_num_samples, |
|
max_seq_length, |
|
short_seq_prob, |
|
seed, |
|
verbose); |
|
} else { |
|
if (verbose) { cout << " using uint32 for data mapping..." << endl << std::flush; } |
|
return build_mapping_impl<uint32_t>(docs_, |
|
sizes_, |
|
num_epochs, |
|
max_num_samples, |
|
max_seq_length, |
|
short_seq_prob, |
|
seed, |
|
verbose); |
|
} |
|
} |
|
|
|
template <typename DocIdx> |
|
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, |
|
const py::array_t<int32_t>& sizes_, |
|
const py::array_t<int32_t>& titles_sizes_, |
|
const int32_t num_epochs, |
|
const uint64_t max_num_samples, |
|
const int32_t max_seq_length, |
|
const int32_t seed, |
|
const bool verbose, |
|
const bool use_one_sent_blocks) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert(num_epochs > 0); |
|
assert(max_seq_length > 1); |
|
assert(seed > 0); |
|
|
|
|
|
auto docs = docs_.unchecked<1>(); |
|
auto sizes = sizes_.unchecked<1>(); |
|
auto titles_sizes = titles_sizes_.unchecked<1>(); |
|
|
|
if (verbose) { |
|
const auto sent_start_index = docs[0]; |
|
const auto sent_end_index = docs[docs_.shape(0) - 1]; |
|
const auto num_sentences = sent_end_index - sent_start_index; |
|
cout << " using:" << endl << std::flush; |
|
cout << " number of documents: " << docs_.shape(0) - 1 << endl << std::flush; |
|
cout << " sentences range: [" << sent_start_index << ", " |
|
<< sent_end_index << ")" << endl |
|
<< std::flush; |
|
cout << " total number of sentences: " << num_sentences << endl << std::flush; |
|
cout << " number of epochs: " << num_epochs << endl << std::flush; |
|
cout << " maximum number of samples: " << max_num_samples << endl << std::flush; |
|
cout << " maximum sequence length: " << max_seq_length << endl << std::flush; |
|
cout << " seed: " << seed << endl << std::flush; |
|
} |
|
|
|
|
|
int64_t num_samples = -1; |
|
DocIdx* maps = NULL; |
|
|
|
|
|
int min_num_sent = 2; |
|
if (use_one_sent_blocks) { min_num_sent = 1; } |
|
|
|
|
|
|
|
bool second = false; |
|
for (int32_t iteration = 0; iteration < 2; ++iteration) { |
|
|
|
second = (iteration == 1); |
|
|
|
|
|
uint64_t map_index = 0; |
|
|
|
uint64_t empty_docs = 0; |
|
uint64_t one_sent_docs = 0; |
|
uint64_t long_sent_docs = 0; |
|
|
|
for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { |
|
|
|
int32_t block_id = 0; |
|
|
|
if (map_index >= max_num_samples) { |
|
if (verbose && (!second)) { |
|
cout << " reached " << max_num_samples << " samples after " << epoch |
|
<< " epochs ..." << endl |
|
<< std::flush; |
|
} |
|
break; |
|
} |
|
|
|
for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { |
|
|
|
const auto sent_index_first = docs[doc]; |
|
const auto sent_index_last = docs[doc + 1]; |
|
const auto target_seq_len = max_seq_length - titles_sizes[doc]; |
|
|
|
|
|
|
|
auto prev_start_index = sent_index_first; |
|
|
|
|
|
auto num_remain_sent = sent_index_last - sent_index_first; |
|
|
|
|
|
if ((epoch == 0) && (!second)) { |
|
if (num_remain_sent == 0) { ++empty_docs; } |
|
if (num_remain_sent == 1) { ++one_sent_docs; } |
|
} |
|
|
|
bool contains_long_sentence = false; |
|
if (num_remain_sent >= min_num_sent) { |
|
for (auto sent_index = sent_index_first; sent_index < sent_index_last; |
|
++sent_index) { |
|
if (sizes[sent_index] > LONG_SENTENCE_LEN) { |
|
if ((epoch == 0) && (!second)) { ++long_sent_docs; } |
|
contains_long_sentence = true; |
|
break; |
|
} |
|
} |
|
} |
|
|
|
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { |
|
|
|
auto seq_len = int32_t{0}; |
|
auto num_sent = int32_t{0}; |
|
|
|
|
|
for (auto sent_index = sent_index_first; sent_index < sent_index_last; |
|
++sent_index) { |
|
|
|
seq_len += sizes[sent_index]; |
|
++num_sent; |
|
--num_remain_sent; |
|
|
|
|
|
|
|
|
|
|
|
if (((seq_len >= target_seq_len) && (num_remain_sent >= min_num_sent) && |
|
(num_sent >= min_num_sent)) || |
|
(num_remain_sent == 0)) { |
|
|
|
if (second) { |
|
const auto map_index_0 = 4 * map_index; |
|
|
|
|
|
|
|
|
|
|
|
maps[map_index_0] = static_cast<DocIdx>(prev_start_index); |
|
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1); |
|
maps[map_index_0 + 2] = static_cast<DocIdx>(doc); |
|
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id); |
|
} |
|
|
|
|
|
++map_index; |
|
++block_id; |
|
prev_start_index = sent_index + 1; |
|
seq_len = 0; |
|
num_sent = 0; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
if (!second) { |
|
if (verbose) { |
|
cout << " number of empty documents: " << empty_docs << endl << std::flush; |
|
cout << " number of documents with one sentence: " << one_sent_docs << endl |
|
<< std::flush; |
|
cout << " number of documents with long sentences: " << long_sent_docs << endl |
|
<< std::flush; |
|
cout << " will create mapping for " << map_index << " samples" << endl |
|
<< std::flush; |
|
} |
|
assert(maps == NULL); |
|
assert(num_samples < 0); |
|
maps = new DocIdx[4 * map_index]; |
|
num_samples = static_cast<int64_t>(map_index); |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
std::mt19937_64 rand64_gen(seed + 1); |
|
for (auto i = (num_samples - 1); i > 0; --i) { |
|
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1)); |
|
const auto i0 = 4 * i; |
|
const auto j0 = 4 * j; |
|
|
|
swap(maps[i0], maps[j0]); |
|
swap(maps[i0 + 1], maps[j0 + 1]); |
|
swap(maps[i0 + 2], maps[j0 + 2]); |
|
swap(maps[i0 + 3], maps[j0 + 3]); |
|
} |
|
|
|
|
|
py::capsule free_when_done(maps, [](void* mem_) { |
|
DocIdx* mem = reinterpret_cast<DocIdx*>(mem_); |
|
delete[] mem; |
|
}); |
|
|
|
|
|
const auto byte_size = sizeof(DocIdx); |
|
return py::array(std::vector<int64_t>{num_samples, 4}, |
|
{4 * byte_size, byte_size}, |
|
maps, |
|
free_when_done); |
|
} |
|
|
|
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_, |
|
const py::array_t<int>& sizes_, |
|
const py::array_t<int>& titles_sizes_, |
|
const int num_epochs, |
|
const uint64_t max_num_samples, |
|
const int max_seq_length, |
|
const int seed, |
|
const bool verbose, |
|
const bool use_one_sent_blocks) |
|
{ |
|
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { |
|
if (verbose) { cout << " using uint64 for data mapping..." << endl << std::flush; } |
|
return build_blocks_mapping_impl<uint64_t>(docs_, |
|
sizes_, |
|
titles_sizes_, |
|
num_epochs, |
|
max_num_samples, |
|
max_seq_length, |
|
seed, |
|
verbose, |
|
use_one_sent_blocks); |
|
} else { |
|
if (verbose) { cout << " using uint32 for data mapping..." << endl << std::flush; } |
|
return build_blocks_mapping_impl<uint32_t>(docs_, |
|
sizes_, |
|
titles_sizes_, |
|
num_epochs, |
|
max_num_samples, |
|
max_seq_length, |
|
seed, |
|
verbose, |
|
use_one_sent_blocks); |
|
} |
|
} |
|
|
|
PYBIND11_MODULE(helpers, m) |
|
{ |
|
m.def("build_mapping", &build_mapping); |
|
m.def("build_blocks_mapping", &build_blocks_mapping); |
|
m.def("build_sample_idx_int32", &build_sample_idx_int32); |
|
m.def("build_sample_idx_int64", &build_sample_idx_int64); |
|
m.def("build_blending_indices", &build_blending_indices); |
|
} |
|
|