|
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. |
|
warnings.warn( |
|
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. |
|
warnings.warn( |
|
Epoch ... (1/3): 0%| | 0/3 [00:00<?, ?it/s]2021-07-14 19:10:58.993859: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 0 failed: Resource exhausted: Attempting to allocate 17.0K. That was not possible. There are 48.0K free. Due to fragmentation, the largest contiguous region of free memory is 16.0K.; (0x0x0_HBM0) |
|
Epoch ... (1/3): 0%| | 0/3 [15:02<?, ?it/s] |
|
Traceback (most recent call last): |
|
File "./run_mlm_flax_no_accum.py", line 686, in <module> |
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) |
|
File "./run_mlm_flax_no_accum.py", line 255, in generate_batch_splits |
|
batch_idx = np.split(samples_idx, sections_split) |
|
File "<__array_function__ internals>", line 5, in split |
|
File "/home/dat/pino/lib/python3.8/site-packages/numpy/lib/shape_base.py", line 874, in split |
|
return array_split(ary, indices_or_sections, axis) |
|
File "<__array_function__ internals>", line 5, in array_split |
|
File "/home/dat/pino/lib/python3.8/site-packages/numpy/lib/shape_base.py", line 790, in array_split |
|
sub_arys.append(_nx.swapaxes(sary[st:end], axis, 0)) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5009, in _rewriting_take |
|
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5028, in _gather |
|
y = lax.gather( |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 984, in gather |
|
return gather_p.bind( |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 264, in bind |
|
out = top_trace.process_primitive(self, tracers, params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 603, in process_primitive |
|
return primitive.impl(*tracers, **params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 249, in apply_primitive |
|
return compiled_fun(*args) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 365, in _execute_compiled_primitive |
|
out_bufs = compiled.execute(input_bufs) |
|
RuntimeError: Resource exhausted: Attempting to allocate 17.0K. That was not possible. There are 48.0K free. Due to fragmentation, the largest contiguous region of free memory is 16.0K.; (0x0x0_HBM0) |