File size: 2,703 Bytes
d90b3a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# Copyright (c) 2024, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
plausibility check for the usage of neox_args in the megatron codebase
"""
import pytest
import re
from ..common import get_root_directory
@pytest.mark.cpu
def test_neoxargs_usage():
""" "
checks for code pieces of the pattern "args.*" and verifies that such used arg is defined in NeoXArgs
"""
from megatron.neox_arguments import NeoXArgs
declared_all = True
neox_args_attributes = set(NeoXArgs.__dataclass_fields__.keys())
# we exclude a number of properties (implemented with the @property decorator) or functions that we know exists
exclude = set(
[
"params_dtype",
"deepspeed_config",
"get",
"pop",
"get_deepspeed_main_args",
'optimizer["params"]',
"attention_config[layer_number]",
"adlr_autoresume_object",
"update_value",
"all_config",
"tensorboard_writer",
"tokenizer",
"train_batch_size]",
"items",
"configure_distributed_args",
"build_tokenizer",
"attention_config[i]",
"print",
"update",
]
)
# test file by file
for filename in (get_root_directory() / "megatron").glob("**/*.py"):
if filename.name in ["text_generation_utils.py", "train_tokenizer.py"]:
continue
# load file
with open(filename, "r") as f:
file_contents = f.read()
# find args matches
matches = list(
re.findall(
r"(?<=neox_args\.).{2,}?(?=[\s\n(){}+-/*;:,=,[,\]])", file_contents
)
)
if len(matches) == 0:
continue
# compare
for match in matches:
if match not in neox_args_attributes and match not in exclude:
print(
f"(arguments used not found in neox args): {filename.name}: {match}",
flush=True,
)
declared_all = False
assert declared_all, "all arguments used in code defined in NeoXArgs"
|