import argparse import random from statistics import mean, stdev from typing import List import torch import torchmetrics from datasets import load_dataset from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model_name_or_path", type=str, default="/scratch/project_465000144/dasamuel/normistral/normistral-11b-masked-post-hf-60000", help="Path to the pre-trained model", ) args = parser.parse_args() return args def load_model(model_path: str): # Load the pre-trained model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".", token="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", torch_dtype=torch.bfloat16) model = AutoModelForCausalLM.from_pretrained(model_path, cache_dir=".", token="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", torch_dtype=torch.bfloat16).cuda().eval() eos_token_ids = [ token_id for token_id in range(tokenizer.vocab_size) if "\n" in tokenizer.decode([token_id]) ] if hasattr(model.config, "n_positions"): max_length = model.config.n_positions elif hasattr(model.config, "max_position_embeddings"): max_length = model.config.max_position_embeddings elif hasattr(model.config, "max_length"): max_length = model.config.max_length elif hasattr(model.config, "n_ctx"): max_length = model.config.n_ctx else: max_length = 4096 # Default value return { "name": model_path.split("/")[-1], "tokenizer": tokenizer, "model": model, "eos_token_ids": eos_token_ids, "max_length": max_length, } def main(): args = parse_args() model = load_model(args.model_name_or_path) model["model"].save_pretrained( args.model_name_or_path, max_shard_size="4.7GB" ) if __name__ == "__main__": main()