File size: 714 Bytes
6e1c9c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import argparse
import logging
import numpy as np
import torch
import os
from transformers import AutoConfig, FlaxAutoModelForCausalLM
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
model_path = "./distilgpt2-base-pretrained-he"
save_directory = "./tmp/flax/"
config_path = os.path.join(model_path, 'config.json')
# Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
config = AutoConfig.from_pretrained(config_path)
model = FlaxAutoModelForCausalLM.from_pretrained(model_path, from_pt=True, config=config)
model.save_pretrained(save_directory) |