NEOX / tests /neox_args /test_neoxargs_load.py
akswelh's picture
Upload 251 files
d90b3a8 verified
# Copyright (c) 2024, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
load all confings in neox/configs in order to perform validations implemented in NeoXArgs
"""
import pytest
import yaml
from ..common import get_configs_with_path
def run_neox_args_load_test(yaml_files):
from megatron.neox_arguments import NeoXArgs
yaml_list = get_configs_with_path(yaml_files)
args_loaded = NeoXArgs.from_ymls(yaml_list)
assert isinstance(args_loaded, NeoXArgs)
# initialize an empty config dictionary to be filled by yamls
config = dict()
# iterate of all to be loaded yaml files
for conf_file_name in yaml_list:
# load file
with open(conf_file_name) as conf_file:
conf = yaml.load(conf_file, Loader=yaml.FullLoader)
# check for key duplicates and load values
for conf_key, conf_value in conf.items():
if conf_key in config:
raise ValueError(
f"Conf file {conf_file_name} has the following duplicate keys with previously loaded file: {conf_key}"
)
conf_key_converted = conf_key.replace(
"-", "_"
) # TODO remove replace and update configuration files?
config[conf_key_converted] = conf_value
# validate that neox args has the same value as specified in the config (if specified in the config)
for k, v in config.items():
neox_args_value = getattr(args_loaded, k)
assert v == neox_args_value, (
"loaded neox args value "
+ str(k)
+ " == "
+ str(neox_args_value)
+ " different from config file "
+ str(v)
)
@pytest.mark.cpu
def test_neoxargs_load_arguments_125M_local_setup():
"""
verify 125M.yml can be loaded without raising validation errors
"""
run_neox_args_load_test(["125M.yml", "local_setup.yml", "cpu_mock_config.yml"])
@pytest.mark.cpu
def test_neoxargs_load_arguments_125M_local_setup_text_generation():
"""
verify 125M.yml can be loaded together with text generation without raising validation errors
"""
run_neox_args_load_test(
["125M.yml", "local_setup.yml", "text_generation.yml", "cpu_mock_config.yml"]
)
@pytest.mark.cpu
def test_neoxargs_load_arguments_350M_local_setup():
"""
verify 350M.yml can be loaded without raising validation errors
"""
run_neox_args_load_test(["350M.yml", "local_setup.yml", "cpu_mock_config.yml"])
@pytest.mark.cpu
def test_neoxargs_load_arguments_760M_local_setup():
"""
verify 760M.yml can be loaded without raising validation errors
"""
run_neox_args_load_test(["760M.yml", "local_setup.yml", "cpu_mock_config.yml"])
@pytest.mark.cpu
def test_neoxargs_load_arguments_2_7B_local_setup():
"""
verify 2-7B.yml can be loaded without raising validation errors
"""
run_neox_args_load_test(["2-7B.yml", "local_setup.yml", "cpu_mock_config.yml"])
@pytest.mark.cpu
def test_neoxargs_load_arguments_6_7B_local_setup():
"""
verify 6-7B.yml can be loaded without raising validation errors
"""
run_neox_args_load_test(["6-7B.yml", "local_setup.yml", "cpu_mock_config.yml"])
@pytest.mark.cpu
def test_neoxargs_load_arguments_13B_local_setup():
"""
verify 13B.yml can be loaded without raising validation errors
"""
run_neox_args_load_test(["13B.yml", "local_setup.yml", "cpu_mock_config.yml"])
@pytest.mark.cpu
def test_neoxargs_load_arguments_1_3B_local_setup():
"""
verify 1-3B.yml can be loaded without raising validation errors
"""
run_neox_args_load_test(["1-3B.yml", "local_setup.yml", "cpu_mock_config.yml"])
@pytest.mark.cpu
def test_neoxargs_load_arguments_175B_local_setup():
"""
verify 13B.yml can be loaded without raising validation errors
"""
run_neox_args_load_test(["175B.yml", "local_setup.yml", "cpu_mock_config.yml"])
@pytest.mark.cpu
def test_neoxargs_fail_instantiate_without_required_params():
"""
verify assertion error if required arguments are not provided
"""
try:
run_neox_args_load_test(["local_setup.yml"])
assert False
except Exception as e:
assert True
@pytest.mark.cpu
def test_neoxargs_fail_instantiate_without_any_params():
"""
verify assertion error if required arguments are not provided
"""
from megatron.neox_arguments import NeoXArgs
try:
args_loaded = NeoXArgs()
assert False
except Exception as e:
assert True