[14:11:58] - INFO - absl - Restoring checkpoint from ./checkpoint_90000 tcmalloc: large alloc 1530273792 bytes == 0x9aa8a000 @ 0x7f287c292680 0x7f287c2b3824 0x5b9a14 0x50b2ae 0x50cb1b 0x5a6f17 0x5f3010 0x56fd36 0x568d9a 0x5f5b33 0x56aadf 0x568d9a 0x68cdc7 0x67e161 0x67e1df 0x67e281 0x67e627 0x6b6e62 0x6b71ed 0x7f287c0a70b3 0x5f96de /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. warnings.warn( /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. warnings.warn( Epoch ... (1/2): 0%| | 0/2 [00:00 state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1669, in f_pmapped out = pxla.xla_pmap( File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind return call_bind(self, fun, *args, **params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process return trace.process_map(self, fun, tracers, params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call return primitive.impl(f, *tracers, **params) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 624, in xla_pmap_impl compiled_fun, fingerprint = parallel_callable(fun, backend, axis_name, axis_size, File "/home/dat/pino/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun ans = call(fun, *args) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 906, in parallel_callable compiled = xla.backend_compile(backend, built, compile_options) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 360, in backend_compile return backend.compile(built_c, compile_options=options) KeyboardInterrupt