|
import argparse |
|
import logging |
|
|
|
import numpy as np |
|
import torch |
|
import os |
|
from transformers import AutoConfig, FlaxAutoModelForCausalLM |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
model_path = "./distilgpt2-base-pretrained-he" |
|
save_directory = "./tmp/flax/" |
|
|
|
config_path = os.path.join(model_path, 'config.json') |
|
|
|
|
|
config = AutoConfig.from_pretrained(config_path) |
|
model = FlaxAutoModelForCausalLM.from_pretrained(model_path, from_pt=True, config=config) |
|
model.save_pretrained(save_directory) |