t5-mini-nl8-finnish / flax_model_to_pytorch.py
aapot
Add 500k step pytorch model
69bb3e4
raw
history blame
884 Bytes
from transformers import AutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM, AutoTokenizer
import torch
import numpy as np
import jax
import jax.numpy as jnp
def to_f32(t):
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
jax.config.update('jax_platform_name', 'cpu')
MODEL_PATH = "./"
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
model.params = to_f32(model.params)
model.save_pretrained(MODEL_PATH)
pt_model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_PATH, from_flax=True).to('cpu')
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = pt_model(input_ids=input_ids_pt, decoder_input_ids=input_ids_pt).logits
print(logits_pt)
logits_fx = model(input_ids=input_ids, decoder_input_ids=input_ids).logits
print(logits_fx)
pt_model.save_pretrained(MODEL_PATH)