dat
Saving weights and logs at step 1252
f291f93
raw
history blame
29.9 kB
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
warnings.warn(
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
warnings.warn(
Epoch ... (1/5): 0%| | 0/5 [00:00<?, ?it/s]
Epoch ... (1/5): 0%| | 0/5 [02:13<?, ?it/s]
Traceback (most recent call last):
File "./run_mlm_flax.py", line 725, in <module>
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1647, in f_pmapped
out = pxla.xla_pmap(
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind
return call_bind(self, fun, *args, **params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process
return trace.process_map(self, fun, tracers, params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 624, in xla_pmap_impl
compiled_fun, fingerprint = parallel_callable(fun, backend, axis_name, axis_size,
File "/home/dat/pino/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 899, in parallel_callable
compiled = xla.backend_compile(backend, built, compile_options)
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 360, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 35.77G of 15.48G hbm. Exceeded hbm capacity by 20.29G.
Total hbm usage >= 36.29G:
reserved 530.00M
program 35.77G
arguments 0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 35.77G:
global 692.0K
scoped 253.0K
HLO temp 35.77G (97.6% utilization: Unpadded (34.82G) Padded (35.67G), 0.3% fragmentation (105.77M))
Largest program allocations in hbm:
1. Size: 6.15G
Operator: op_type="dot_general" op_name="pmap(train_step)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/pino/lib/python3.8/site-packages/flax/linen/linear.py" source_line=175
Shape: f32[8,4096,50358]{1,2,0:T(8,128)}
Unpadded size: 6.15G
Extra memory due to padding: 256.0K (1.0x expansion)
XLA label: %fusion.1737.remat4 = f32[8,4096,50358]{1,2,0:T(8,128)} fusion(f32[50358]{0:T(1024)} %get-tuple-element.23321, f32[768,50358,1]{0,1,2:T(8,128)} %bitcast.5512, f32[768]{0:T(1024)} %get-tuple-element.23322, f32[768]{0:T(1024)} %get-tuple-element.23323, f32[8...
Allocation type: HLO temp
==========================
2. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1805.remat6 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2407, f32[8,12,28,128]{3,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
3. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13201 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2412, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9402, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
4. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1805.remat6 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2407, f32[8,12,28,128]{3,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
5. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13202 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2411, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9401, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
6. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1814 = bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2416, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9406, f32[8,12,28,128,128]{3,4,2,1,0:T(8,128)} %get-tuple-element.20627, f32[8,12,28,128,384]{...
Allocation type: HLO temp
==========================
7. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13199 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2414, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9404, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
8. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13200 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2413, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9403, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
9. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1816.remat = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2418, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9408, f32[8,12,28,128,128]{3,4,2,1,0:...
Allocation type: HLO temp
==========================
10. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1815.remat = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2417, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9407, f32[8,12,28,128,128]{3,4,2,1,0:...
Allocation type: HLO temp
==========================
11. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13203 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2410, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9400, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
12. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13204 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2409, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9399, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
13. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13205 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2408, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9398, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
14. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11557 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.25239, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28505.remat_uncompressed, f32[8,12...
Allocation type: HLO temp
==========================
15. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11549 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.25240, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28470.remat_uncompressed.remat, f3...
Allocation type: HLO temp
==========================
16. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11469 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.20990, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28115, f32[8,12,32,128,64]{3,2,4,1...
Allocation type: HLO temp
==========================
17. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11477 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.20989, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28151, f32[8,12,32,128,64]{3,2,4,1...
Allocation type: HLO temp
==========================
18. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11541 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.25236, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28435.remat_uncompressed, f32[8,12...
Allocation type: HLO temp
==========================
19. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.2085.remat5.1.remat = f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)} fusion(f32[8,28,128,384]{2,3,1,0:T(8,128)} %get-tuple-element.20992, bf16[8,12,28,384,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2473.remat_uncompressed, f32[8,12,32,128,64]{3,2,4,1,0:T(8,128...
Allocation type: HLO temp
==========================
20. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11533 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.25238, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28400.remat_uncompressed, f32[8,12...
Allocation type: HLO temp
==========================
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "./run_mlm_flax.py", line 725, in <module>
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 360, in backend_compile
return backend.compile(built_c, compile_options=options)
RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 35.77G of 15.48G hbm. Exceeded hbm capacity by 20.29G.
Total hbm usage >= 36.29G:
reserved 530.00M
program 35.77G
arguments 0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 35.77G:
global 692.0K
scoped 253.0K
HLO temp 35.77G (97.6% utilization: Unpadded (34.82G) Padded (35.67G), 0.3% fragmentation (105.77M))
Largest program allocations in hbm:
1. Size: 6.15G
Operator: op_type="dot_general" op_name="pmap(train_step)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/pino/lib/python3.8/site-packages/flax/linen/linear.py" source_line=175
Shape: f32[8,4096,50358]{1,2,0:T(8,128)}
Unpadded size: 6.15G
Extra memory due to padding: 256.0K (1.0x expansion)
XLA label: %fusion.1737.remat4 = f32[8,4096,50358]{1,2,0:T(8,128)} fusion(f32[50358]{0:T(1024)} %get-tuple-element.23321, f32[768,50358,1]{0,1,2:T(8,128)} %bitcast.5512, f32[768]{0:T(1024)} %get-tuple-element.23322, f32[768]{0:T(1024)} %get-tuple-element.23323, f32[8...
Allocation type: HLO temp
==========================
2. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1805.remat6 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2407, f32[8,12,28,128]{3,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
3. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13201 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2412, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9402, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
4. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1805.remat6 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2407, f32[8,12,28,128]{3,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
5. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13202 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2411, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9401, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
6. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1814 = bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)} fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2416, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9406, f32[8,12,28,128,128]{3,4,2,1,0:T(8,128)} %get-tuple-element.20627, f32[8,12,28,128,384]{...
Allocation type: HLO temp
==========================
7. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13199 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2414, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9404, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
8. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13200 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2413, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9403, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
9. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1816.remat = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2418, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9408, f32[8,12,28,128,128]{3,4,2,1,0:...
Allocation type: HLO temp
==========================
10. Size: 672.00M
Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.1815.remat = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2417, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9407, f32[8,12,28,128,128]{3,4,2,1,0:...
Allocation type: HLO temp
==========================
11. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13203 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2410, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9400, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
12. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13204 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2409, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9399, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
13. Size: 672.00M
Shape: bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}
Unpadded size: 672.00M
XLA label: %fusion.13205 = (bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}, bf16[8,12,28,128,1024]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.2408, f32[8,12,28,128]{3,2,1,0:T(8,128)} %fusion.9398, f32[8,12,28,128,128]{3,4,2,1,0:T(8,1...
Allocation type: HLO temp
==========================
14. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11557 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.25239, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28505.remat_uncompressed, f32[8,12...
Allocation type: HLO temp
==========================
15. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11549 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.25240, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28470.remat_uncompressed.remat, f3...
Allocation type: HLO temp
==========================
16. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11469 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.20990, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28115, f32[8,12,32,128,64]{3,2,4,1...
Allocation type: HLO temp
==========================
17. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11477 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.20989, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28151, f32[8,12,32,128,64]{3,2,4,1...
Allocation type: HLO temp
==========================
18. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11541 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.25236, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28435.remat_uncompressed, f32[8,12...
Allocation type: HLO temp
==========================
19. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.2085.remat5.1.remat = f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)} fusion(f32[8,28,128,384]{2,3,1,0:T(8,128)} %get-tuple-element.20992, bf16[8,12,28,384,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2473.remat_uncompressed, f32[8,12,32,128,64]{3,2,4,1,0:T(8,128...
Allocation type: HLO temp
==========================
20. Size: 504.00M
Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591
Shape: f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}
Unpadded size: 504.00M
XLA label: %fusion.11533 = (f32[8,12,28,128]{3,2,1,0:T(8,128)}, f32[8,12,28,128,384]{3,4,2,1,0:T(8,128)}) fusion(s32[8,12,30,128,384]{3,4,2,1,0:T(8,128)} %get-tuple-element.25238, bf16[8,12,28,384,64]{3,2,1,0,4:T(8,128)(2,1)} %slice.28400.remat_uncompressed, f32[8,12...
Allocation type: HLO temp
==========================