File size: 836 Bytes
76b4cd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
import numpy as np
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers import FlaxGPTNeoForCausalLM
from transformers import GPTNeoForCausalLM
tokenizer = AutoTokenizer.from_pretrained(".")
tokenizer.pad_token = tokenizer.eos_token
model_fx = FlaxGPTNeoForCausalLM.from_pretrained(".")
# def to_f32(t):
# return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
# model_fx.params = to_f32(model_fx.params)
# model_fx.save_pretrained("./fx")
model_pt = GPTNeoForCausalLM.from_pretrained(".", from_flax=True)
model_pt.save_pretrained(".")
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx) |