File size: 3,129 Bytes
731244f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
  warnings.warn(
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
  warnings.warn(
Epoch ... (1/5):   0%|                                                                                                                                                                | 0/5 [00:00<?, ?it/s]2021-07-14 23:26:04.701487: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 0 failed: Resource exhausted: Attempting to allocate 17.0K. That was not possible. There are 48.0K free. Due to fragmentation, the largest contiguous region of free memory is 16.0K.; (0x0x0_HBM0)
Epoch ... (1/5):   0%|                                                                                                                                                                | 0/5 [14:02<?, ?it/s]
Traceback (most recent call last):
  File "./run_mlm_flax.py", line 806, in <module>
    train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size // grad_accum_steps)
  File "./run_mlm_flax.py", line 263, in generate_batch_splits
    batch_idx = np.split(samples_idx, sections_split)
  File "<__array_function__ internals>", line 5, in split
  File "/home/dat/pino/lib/python3.8/site-packages/numpy/lib/shape_base.py", line 874, in split
    return array_split(ary, indices_or_sections, axis)
  File "<__array_function__ internals>", line 5, in array_split
  File "/home/dat/pino/lib/python3.8/site-packages/numpy/lib/shape_base.py", line 790, in array_split
    sub_arys.append(_nx.swapaxes(sary[st:end], axis, 0))
  File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5009, in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5028, in _gather
    y = lax.gather(
  File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 984, in gather
    return gather_p.bind(
  File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 603, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 249, in apply_primitive
    return compiled_fun(*args)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 365, in _execute_compiled_primitive
    out_bufs = compiled.execute(input_bufs)
RuntimeError: Resource exhausted: Attempting to allocate 17.0K. That was not possible. There are 48.0K free. Due to fragmentation, the largest contiguous region of free memory is 16.0K.; (0x0x0_HBM0)