Doron Adler
* Updated model card
6e1c9c6
raw
history blame
714 Bytes
import argparse
import logging
import numpy as np
import torch
import os
from transformers import AutoConfig, FlaxAutoModelForCausalLM
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
model_path = "./distilgpt2-base-pretrained-he"
save_directory = "./tmp/flax/"
config_path = os.path.join(model_path, 'config.json')
# Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
config = AutoConfig.from_pretrained(config_path)
model = FlaxAutoModelForCausalLM.from_pretrained(model_path, from_pt=True, config=config)
model.save_pretrained(save_directory)