|
[21:01:12] - INFO - absl - A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`. |
|
/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3132: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https: |
|
lax._check_user_dtype_supported(dtype, "zeros") |
|
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. |
|
warnings.warn( |
|
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. |
|
warnings.warn( |
|
Epoch ... (1/3): 0%| | 0/3 [00:00<?, ?it/s][21:01:13] - INFO - __main__ - Skipping to epoch 0 step 0 |
|
Epoch ... (1/3): 0%| | 0/3 [01:17<?, ?it/s] |
|
Traceback (most recent call last): |
|
File "./run_mlm_flax.py", line 790, in <module> |
|
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback |
|
return fun(*args, **kwargs) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1669, in f_pmapped |
|
out = pxla.xla_pmap( |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind |
|
return call_bind(self, fun, *args, **params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind |
|
outs = primitive.process(top_trace, fun, tracers, params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process |
|
return trace.process_map(self, fun, tracers, params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call |
|
return primitive.impl(f, *tracers, **params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 624, in xla_pmap_impl |
|
compiled_fun, fingerprint = parallel_callable(fun, backend, axis_name, axis_size, |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun |
|
ans = call(fun, *args) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 906, in parallel_callable |
|
compiled = xla.backend_compile(backend, built, compile_options) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 360, in backend_compile |
|
return backend.compile(built_c, compile_options=options) |
|
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 17.79G of 15.48G hbm. Exceeded hbm capacity by 2.31G. |
|
Total hbm usage >= 18.31G: |
|
reserved 530.00M |
|
program 17.79G |
|
arguments 0B |
|
Output size 0B; shares 0B with arguments. |
|
Program hbm requirement 17.79G: |
|
global 884.0K |
|
scoped 253.0K |
|
HLO temp 17.79G (97.6% utilization: Unpadded (17.27G) Padded (17.68G), 0.6% fragmentation (106.34M)) |
|
Largest program allocations in hbm: |
|
1. Size: 3.07G |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/pino/lib/python3.8/site-packages/flax/linen/linear.py" source_line=175 |
|
Shape: f32[4,4096,50358]{1,2,0:T(8,128)} |
|
Unpadded size: 3.07G |
|
Extra memory due to padding: 128.0K (1.0x expansion) |
|
XLA label: %fusion.1233.remat4 = f32[4,4096,50358]{1,2,0:T(8,128)} fusion(f32[50358]{0:T(1024)} %get-tuple-element.21733, f32[768,50358,1]{0,1,2:T(8,128)} %bitcast.4927, f32[768]{0:T(1024)} %get-tuple-element.21734, f32[768]{0:T(1024)} %get-tuple-element.21735, f32[4... |
|
Allocation type: HLO temp |
|
========================== |
|
2. Size: 336.00M |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.12188 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1904, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8899, f32[4,12,28,128,128]{3,4,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
3. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1304.remat6 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1906, f32[4,12,28,128]{3,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
4. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1304.remat6 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1906, f32[4,12,28,128]{3,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
5. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1306.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1908, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8903, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
6. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1307.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1909, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8904, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
7. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1308.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1910, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8905, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
8. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1309.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1911, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8906, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
9. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1310.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1912, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8907, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
10. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1311.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1913, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8908, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
11. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1312.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1914, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8909, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
12. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1305 = bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1907, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8902, f32[4,12,28,128,128]{3,4,2,1,0:T(8,128)} %get-tuple-element.19534, f32[4,12,28,128,384]{... |
|
Allocation type: HLO temp |
|
========================== |
|
13. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1301.remat6 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1903, f32[4,12,28,128]{3,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
14. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1301.remat6 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1903, f32[4,12,28,128]{3,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
15. Size: 336.00M |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.12187 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1905, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8900, f32[4,12,28,128,128]{3,4,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
16. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.10998 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.23248, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28551, f32[4,12,32,128,64]{3,2,4,1... |
|
Allocation type: HLO temp |
|
========================== |
|
17. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.11022 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.23245, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28656.remat_uncompressed, f32[4,12... |
|
Allocation type: HLO temp |
|
========================== |
|
18. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.11014 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.23246, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28621.remat_uncompressed, f32[4,12... |
|
Allocation type: HLO temp |
|
========================== |
|
19. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.11006 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.23247, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28586.remat_uncompressed, f32[4,12... |
|
Allocation type: HLO temp |
|
========================== |
|
20. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.10934 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.19864, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28270, f32[4,12,32,128,64]{3,2,4,1... |
|
Allocation type: HLO temp |
|
========================== |
|
The stack trace below excludes JAX-internal frames. |
|
The preceding is the original exception that occurred, unmodified. |
|
-------------------- |
|
The above exception was the direct cause of the following exception: |
|
Traceback (most recent call last): |
|
File "./run_mlm_flax.py", line 790, in <module> |
|
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 360, in backend_compile |
|
return backend.compile(built_c, compile_options=options) |
|
RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 17.79G of 15.48G hbm. Exceeded hbm capacity by 2.31G. |
|
Total hbm usage >= 18.31G: |
|
reserved 530.00M |
|
program 17.79G |
|
arguments 0B |
|
Output size 0B; shares 0B with arguments. |
|
Program hbm requirement 17.79G: |
|
global 884.0K |
|
scoped 253.0K |
|
HLO temp 17.79G (97.6% utilization: Unpadded (17.27G) Padded (17.68G), 0.6% fragmentation (106.34M)) |
|
Largest program allocations in hbm: |
|
1. Size: 3.07G |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/pino/lib/python3.8/site-packages/flax/linen/linear.py" source_line=175 |
|
Shape: f32[4,4096,50358]{1,2,0:T(8,128)} |
|
Unpadded size: 3.07G |
|
Extra memory due to padding: 128.0K (1.0x expansion) |
|
XLA label: %fusion.1233.remat4 = f32[4,4096,50358]{1,2,0:T(8,128)} fusion(f32[50358]{0:T(1024)} %get-tuple-element.21733, f32[768,50358,1]{0,1,2:T(8,128)} %bitcast.4927, f32[768]{0:T(1024)} %get-tuple-element.21734, f32[768]{0:T(1024)} %get-tuple-element.21735, f32[4... |
|
Allocation type: HLO temp |
|
========================== |
|
2. Size: 336.00M |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.12188 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1904, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8899, f32[4,12,28,128,128]{3,4,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
3. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1304.remat6 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1906, f32[4,12,28,128]{3,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
4. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1304.remat6 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1906, f32[4,12,28,128]{3,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
5. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1306.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1908, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8903, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
6. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1307.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1909, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8904, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
7. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1308.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1910, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8905, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
8. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1309.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1911, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8906, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
9. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1310.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1912, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8907, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
10. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1311.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1913, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8908, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
11. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1312.remat = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1914, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8909, f32[4,12,28,128,128]{3,4,2,1,0:... |
|
Allocation type: HLO temp |
|
========================== |
|
12. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1305 = bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1907, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8902, f32[4,12,28,128,128]{3,4,2,1,0:T(8,128)} %get-tuple-element.19534, f32[4,12,28,128,384]{... |
|
Allocation type: HLO temp |
|
========================== |
|
13. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1301.remat6 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1903, f32[4,12,28,128]{3,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
14. Size: 336.00M |
|
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.1301.remat6 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1903, f32[4,12,28,128]{3,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
15. Size: 336.00M |
|
Shape: bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} |
|
Unpadded size: 336.00M |
|
XLA label: %fusion.12187 = (bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.1905, f32[4,12,28,128]{3,2,1,0:T(8,128)} %fusion.8900, f32[4,12,28,128,128]{3,4,2,1,0:T(8,1... |
|
Allocation type: HLO temp |
|
========================== |
|
16. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.10998 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.23248, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28551, f32[4,12,32,128,64]{3,2,4,1... |
|
Allocation type: HLO temp |
|
========================== |
|
17. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.11022 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.23245, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28656.remat_uncompressed, f32[4,12... |
|
Allocation type: HLO temp |
|
========================== |
|
18. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.11014 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.23246, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28621.remat_uncompressed, f32[4,12... |
|
Allocation type: HLO temp |
|
========================== |
|
19. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.11006 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.23247, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28586.remat_uncompressed, f32[4,12... |
|
Allocation type: HLO temp |
|
========================== |
|
20. Size: 252.00M |
|
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 |
|
Shape: f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)} |
|
Unpadded size: 252.00M |
|
XLA label: %fusion.10934 = (f32[4,12,28,128]{3,2,1,0:T(8,128)}, f32[4,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.19864, bf16[4,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28270, f32[4,12,32,128,64]{3,2,4,1... |
|
Allocation type: HLO temp |
|
========================== |