# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import tempfile
import unittest
from pathlib import Path

import transformers
from transformers.commands.add_new_model_like import (
    ModelPatterns,
    _re_class_func,
    add_content_to_file,
    add_content_to_text,
    clean_frameworks_in_init,
    duplicate_doc_file,
    duplicate_module,
    filter_framework_files,
    find_base_model_checkpoint,
    get_model_files,
    get_module_from_file,
    parse_module_content,
    replace_model_patterns,
    retrieve_info_for_model,
    retrieve_model_classes,
    simplify_replacements,
)
from transformers.testing_utils import require_flax, require_tf, require_torch


BERT_MODEL_FILES = {
    "src/transformers/models/bert/__init__.py",
    "src/transformers/models/bert/configuration_bert.py",
    "src/transformers/models/bert/tokenization_bert.py",
    "src/transformers/models/bert/tokenization_bert_fast.py",
    "src/transformers/models/bert/tokenization_bert_tf.py",
    "src/transformers/models/bert/modeling_bert.py",
    "src/transformers/models/bert/modeling_flax_bert.py",
    "src/transformers/models/bert/modeling_tf_bert.py",
    "src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py",
    "src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py",
    "src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py",
    "src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py",
}

VIT_MODEL_FILES = {
    "src/transformers/models/vit/__init__.py",
    "src/transformers/models/vit/configuration_vit.py",
    "src/transformers/models/vit/convert_dino_to_pytorch.py",
    "src/transformers/models/vit/convert_vit_timm_to_pytorch.py",
    "src/transformers/models/vit/feature_extraction_vit.py",
    "src/transformers/models/vit/image_processing_vit.py",
    "src/transformers/models/vit/modeling_vit.py",
    "src/transformers/models/vit/modeling_tf_vit.py",
    "src/transformers/models/vit/modeling_flax_vit.py",
}

WAV2VEC2_MODEL_FILES = {
    "src/transformers/models/wav2vec2/__init__.py",
    "src/transformers/models/wav2vec2/configuration_wav2vec2.py",
    "src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py",
    "src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py",
    "src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py",
    "src/transformers/models/wav2vec2/modeling_wav2vec2.py",
    "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py",
    "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py",
    "src/transformers/models/wav2vec2/processing_wav2vec2.py",
    "src/transformers/models/wav2vec2/tokenization_wav2vec2.py",
}

REPO_PATH = Path(transformers.__path__[0]).parent.parent


@require_torch
@require_tf
@require_flax
class TestAddNewModelLike(unittest.TestCase):
    def init_file(self, file_name, content):
        with open(file_name, "w", encoding="utf-8") as f:
            f.write(content)

    def check_result(self, file_name, expected_result):
        with open(file_name, "r", encoding="utf-8") as f:
            result = f.read()
            self.assertEqual(result, expected_result)

    def test_re_class_func(self):
        self.assertEqual(_re_class_func.search("def my_function(x, y):").groups()[0], "my_function")
        self.assertEqual(_re_class_func.search("class MyClass:").groups()[0], "MyClass")
        self.assertEqual(_re_class_func.search("class MyClass(SuperClass):").groups()[0], "MyClass")

    def test_model_patterns_defaults(self):
        model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base")

        self.assertEqual(model_patterns.model_type, "gpt-new-new")
        self.assertEqual(model_patterns.model_lower_cased, "gpt_new_new")
        self.assertEqual(model_patterns.model_camel_cased, "GPTNewNew")
        self.assertEqual(model_patterns.model_upper_cased, "GPT_NEW_NEW")
        self.assertEqual(model_patterns.config_class, "GPTNewNewConfig")
        self.assertIsNone(model_patterns.tokenizer_class)
        self.assertIsNone(model_patterns.feature_extractor_class)
        self.assertIsNone(model_patterns.processor_class)

    def test_parse_module_content(self):
        test_code = """SOME_CONSTANT = a constant

CONSTANT_DEFINED_ON_SEVERAL_LINES = [
    first_item,
    second_item
]

def function(args):
    some code

# Copied from transformers.some_module
class SomeClass:
    some code
"""

        expected_parts = [
            "SOME_CONSTANT = a constant\n",
            "CONSTANT_DEFINED_ON_SEVERAL_LINES = [\n    first_item,\n    second_item\n]",
            "",
            "def function(args):\n    some code\n",
            "# Copied from transformers.some_module\nclass SomeClass:\n    some code\n",
        ]
        self.assertEqual(parse_module_content(test_code), expected_parts)

    def test_add_content_to_text(self):
        test_text = """all_configs = {
    "gpt": "GPTConfig",
    "bert": "BertConfig",
    "t5": "T5Config",
}"""

        expected = """all_configs = {
    "gpt": "GPTConfig",
    "gpt2": "GPT2Config",
    "bert": "BertConfig",
    "t5": "T5Config",
}"""
        line = '    "gpt2": "GPT2Config",'

        self.assertEqual(add_content_to_text(test_text, line, add_before="bert"), expected)
        self.assertEqual(add_content_to_text(test_text, line, add_before="bert", exact_match=True), test_text)
        self.assertEqual(
            add_content_to_text(test_text, line, add_before='    "bert": "BertConfig",', exact_match=True), expected
        )
        self.assertEqual(add_content_to_text(test_text, line, add_before=re.compile('^\s*"bert":')), expected)

        self.assertEqual(add_content_to_text(test_text, line, add_after="gpt"), expected)
        self.assertEqual(add_content_to_text(test_text, line, add_after="gpt", exact_match=True), test_text)
        self.assertEqual(
            add_content_to_text(test_text, line, add_after='    "gpt": "GPTConfig",', exact_match=True), expected
        )
        self.assertEqual(add_content_to_text(test_text, line, add_after=re.compile('^\s*"gpt":')), expected)

    def test_add_content_to_file(self):
        test_text = """all_configs = {
    "gpt": "GPTConfig",
    "bert": "BertConfig",
    "t5": "T5Config",
}"""

        expected = """all_configs = {
    "gpt": "GPTConfig",
    "gpt2": "GPT2Config",
    "bert": "BertConfig",
    "t5": "T5Config",
}"""
        line = '    "gpt2": "GPT2Config",'

        with tempfile.TemporaryDirectory() as tmp_dir:
            file_name = os.path.join(tmp_dir, "code.py")

            self.init_file(file_name, test_text)
            add_content_to_file(file_name, line, add_before="bert")
            self.check_result(file_name, expected)

            self.init_file(file_name, test_text)
            add_content_to_file(file_name, line, add_before="bert", exact_match=True)
            self.check_result(file_name, test_text)

            self.init_file(file_name, test_text)
            add_content_to_file(file_name, line, add_before='    "bert": "BertConfig",', exact_match=True)
            self.check_result(file_name, expected)

            self.init_file(file_name, test_text)
            add_content_to_file(file_name, line, add_before=re.compile('^\s*"bert":'))
            self.check_result(file_name, expected)

            self.init_file(file_name, test_text)
            add_content_to_file(file_name, line, add_after="gpt")
            self.check_result(file_name, expected)

            self.init_file(file_name, test_text)
            add_content_to_file(file_name, line, add_after="gpt", exact_match=True)
            self.check_result(file_name, test_text)

            self.init_file(file_name, test_text)
            add_content_to_file(file_name, line, add_after='    "gpt": "GPTConfig",', exact_match=True)
            self.check_result(file_name, expected)

            self.init_file(file_name, test_text)
            add_content_to_file(file_name, line, add_after=re.compile('^\s*"gpt":'))
            self.check_result(file_name, expected)

    def test_simplify_replacements(self):
        self.assertEqual(simplify_replacements([("Bert", "NewBert")]), [("Bert", "NewBert")])
        self.assertEqual(
            simplify_replacements([("Bert", "NewBert"), ("bert", "new-bert")]),
            [("Bert", "NewBert"), ("bert", "new-bert")],
        )
        self.assertEqual(
            simplify_replacements([("BertConfig", "NewBertConfig"), ("Bert", "NewBert"), ("bert", "new-bert")]),
            [("Bert", "NewBert"), ("bert", "new-bert")],
        )

    def test_replace_model_patterns(self):
        bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
        new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
        bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"
    is_parallelizable = True
    supports_gradient_checkpointing = True
    model_type = "bert"

BERT_CONSTANT = "value"
'''
        bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = NewBertConfig
    load_tf_weights = load_tf_weights_in_new_bert
    base_model_prefix = "new_bert"
    is_parallelizable = True
    supports_gradient_checkpointing = True
    model_type = "new-bert"

NEW_BERT_CONSTANT = "value"
'''

        bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns)
        self.assertEqual(bert_converted, bert_expected)
        # Replacements are empty here since bert as been replaced by bert_new in some instances and bert-new
        # in others.
        self.assertEqual(replacements, "")

        # If we remove the model type, we will get replacements
        bert_test = bert_test.replace('    model_type = "bert"\n', "")
        bert_expected = bert_expected.replace('    model_type = "new-bert"\n', "")
        bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns)
        self.assertEqual(bert_converted, bert_expected)
        self.assertEqual(replacements, "BERT->NEW_BERT,Bert->NewBert,bert->new_bert")

        gpt_model_patterns = ModelPatterns("GPT2", "gpt2")
        new_gpt_model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base")
        gpt_test = '''class GPT2PreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = GPT2Config
    load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
    is_parallelizable = True
    supports_gradient_checkpointing = True

GPT2_CONSTANT = "value"
'''

        gpt_expected = '''class GPTNewNewPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = GPTNewNewConfig
    load_tf_weights = load_tf_weights_in_gpt_new_new
    base_model_prefix = "transformer"
    is_parallelizable = True
    supports_gradient_checkpointing = True

GPT_NEW_NEW_CONSTANT = "value"
'''

        gpt_converted, replacements = replace_model_patterns(gpt_test, gpt_model_patterns, new_gpt_model_patterns)
        self.assertEqual(gpt_converted, gpt_expected)
        # Replacements are empty here since GPT2 as been replaced by GPTNewNew in some instances and GPT_NEW_NEW
        # in others.
        self.assertEqual(replacements, "")

        roberta_model_patterns = ModelPatterns("RoBERTa", "roberta-base", model_camel_cased="Roberta")
        new_roberta_model_patterns = ModelPatterns(
            "RoBERTa-New", "huggingface/roberta-new-base", model_camel_cased="RobertaNew"
        )
        roberta_test = '''# Copied from transformers.models.bert.BertModel with Bert->Roberta
class RobertaModel(RobertaPreTrainedModel):
    """ The base RoBERTa model. """
    checkpoint = roberta-base
    base_model_prefix = "roberta"
        '''
        roberta_expected = '''# Copied from transformers.models.bert.BertModel with Bert->RobertaNew
class RobertaNewModel(RobertaNewPreTrainedModel):
    """ The base RoBERTa-New model. """
    checkpoint = huggingface/roberta-new-base
    base_model_prefix = "roberta_new"
        '''
        roberta_converted, replacements = replace_model_patterns(
            roberta_test, roberta_model_patterns, new_roberta_model_patterns
        )
        self.assertEqual(roberta_converted, roberta_expected)

    def test_get_module_from_file(self):
        self.assertEqual(
            get_module_from_file("/git/transformers/src/transformers/models/bert/modeling_tf_bert.py"),
            "transformers.models.bert.modeling_tf_bert",
        )
        self.assertEqual(
            get_module_from_file("/transformers/models/gpt2/modeling_gpt2.py"),
            "transformers.models.gpt2.modeling_gpt2",
        )
        with self.assertRaises(ValueError):
            get_module_from_file("/models/gpt2/modeling_gpt2.py")

    def test_duplicate_module(self):
        bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
        new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
        bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"
    is_parallelizable = True
    supports_gradient_checkpointing = True

BERT_CONSTANT = "value"
'''
        bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = NewBertConfig
    load_tf_weights = load_tf_weights_in_new_bert
    base_model_prefix = "new_bert"
    is_parallelizable = True
    supports_gradient_checkpointing = True

NEW_BERT_CONSTANT = "value"
'''
        bert_expected_with_copied_from = (
            "# Copied from transformers.bert_module.TFBertPreTrainedModel with Bert->NewBert,bert->new_bert\n"
            + bert_expected
        )
        with tempfile.TemporaryDirectory() as tmp_dir:
            work_dir = os.path.join(tmp_dir, "transformers")
            os.makedirs(work_dir)
            file_name = os.path.join(work_dir, "bert_module.py")
            dest_file_name = os.path.join(work_dir, "new_bert_module.py")

            self.init_file(file_name, bert_test)
            duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns)
            self.check_result(dest_file_name, bert_expected_with_copied_from)

            self.init_file(file_name, bert_test)
            duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False)
            self.check_result(dest_file_name, bert_expected)

    def test_duplicate_module_with_copied_from(self):
        bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
        new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
        bert_test = '''# Copied from transformers.models.xxx.XxxModel with Xxx->Bert
class TFBertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"
    is_parallelizable = True
    supports_gradient_checkpointing = True

BERT_CONSTANT = "value"
'''
        bert_expected = '''# Copied from transformers.models.xxx.XxxModel with Xxx->NewBert
class TFNewBertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = NewBertConfig
    load_tf_weights = load_tf_weights_in_new_bert
    base_model_prefix = "new_bert"
    is_parallelizable = True
    supports_gradient_checkpointing = True

NEW_BERT_CONSTANT = "value"
'''
        with tempfile.TemporaryDirectory() as tmp_dir:
            work_dir = os.path.join(tmp_dir, "transformers")
            os.makedirs(work_dir)
            file_name = os.path.join(work_dir, "bert_module.py")
            dest_file_name = os.path.join(work_dir, "new_bert_module.py")

            self.init_file(file_name, bert_test)
            duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns)
            # There should not be a new Copied from statement, the old one should be adapated.
            self.check_result(dest_file_name, bert_expected)

            self.init_file(file_name, bert_test)
            duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False)
            self.check_result(dest_file_name, bert_expected)

    def test_filter_framework_files(self):
        files = ["modeling_bert.py", "modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"]
        self.assertEqual(filter_framework_files(files), files)
        self.assertEqual(set(filter_framework_files(files, ["pt", "tf", "flax"])), set(files))

        self.assertEqual(set(filter_framework_files(files, ["pt"])), {"modeling_bert.py", "configuration_bert.py"})
        self.assertEqual(set(filter_framework_files(files, ["tf"])), {"modeling_tf_bert.py", "configuration_bert.py"})
        self.assertEqual(
            set(filter_framework_files(files, ["flax"])), {"modeling_flax_bert.py", "configuration_bert.py"}
        )

        self.assertEqual(
            set(filter_framework_files(files, ["pt", "tf"])),
            {"modeling_tf_bert.py", "modeling_bert.py", "configuration_bert.py"},
        )
        self.assertEqual(
            set(filter_framework_files(files, ["tf", "flax"])),
            {"modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"},
        )
        self.assertEqual(
            set(filter_framework_files(files, ["pt", "flax"])),
            {"modeling_bert.py", "modeling_flax_bert.py", "configuration_bert.py"},
        )

    def test_get_model_files(self):
        # BERT
        bert_files = get_model_files("bert")

        doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")

        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
        self.assertEqual(model_files, BERT_MODEL_FILES)

        self.assertEqual(bert_files["module_name"], "bert")

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
        bert_test_files = {
            "tests/models/bert/test_tokenization_bert.py",
            "tests/models/bert/test_modeling_bert.py",
            "tests/models/bert/test_modeling_tf_bert.py",
            "tests/models/bert/test_modeling_flax_bert.py",
        }
        self.assertEqual(test_files, bert_test_files)

        # VIT
        vit_files = get_model_files("vit")
        doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")

        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
        self.assertEqual(model_files, VIT_MODEL_FILES)

        self.assertEqual(vit_files["module_name"], "vit")

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
        vit_test_files = {
            "tests/models/vit/test_image_processing_vit.py",
            "tests/models/vit/test_modeling_vit.py",
            "tests/models/vit/test_modeling_tf_vit.py",
            "tests/models/vit/test_modeling_flax_vit.py",
        }
        self.assertEqual(test_files, vit_test_files)

        # Wav2Vec2
        wav2vec2_files = get_model_files("wav2vec2")
        doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")

        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
        self.assertEqual(model_files, WAV2VEC2_MODEL_FILES)

        self.assertEqual(wav2vec2_files["module_name"], "wav2vec2")

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
        wav2vec2_test_files = {
            "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
            "tests/models/wav2vec2/test_modeling_wav2vec2.py",
            "tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
            "tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
            "tests/models/wav2vec2/test_processor_wav2vec2.py",
            "tests/models/wav2vec2/test_tokenization_wav2vec2.py",
        }
        self.assertEqual(test_files, wav2vec2_test_files)

    def test_get_model_files_only_pt(self):
        # BERT
        bert_files = get_model_files("bert", frameworks=["pt"])

        doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")

        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
        bert_model_files = BERT_MODEL_FILES - {
            "src/transformers/models/bert/modeling_tf_bert.py",
            "src/transformers/models/bert/modeling_flax_bert.py",
        }
        self.assertEqual(model_files, bert_model_files)

        self.assertEqual(bert_files["module_name"], "bert")

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
        bert_test_files = {
            "tests/models/bert/test_tokenization_bert.py",
            "tests/models/bert/test_modeling_bert.py",
        }
        self.assertEqual(test_files, bert_test_files)

        # VIT
        vit_files = get_model_files("vit", frameworks=["pt"])
        doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")

        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
        vit_model_files = VIT_MODEL_FILES - {
            "src/transformers/models/vit/modeling_tf_vit.py",
            "src/transformers/models/vit/modeling_flax_vit.py",
        }
        self.assertEqual(model_files, vit_model_files)

        self.assertEqual(vit_files["module_name"], "vit")

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
        vit_test_files = {
            "tests/models/vit/test_image_processing_vit.py",
            "tests/models/vit/test_modeling_vit.py",
        }
        self.assertEqual(test_files, vit_test_files)

        # Wav2Vec2
        wav2vec2_files = get_model_files("wav2vec2", frameworks=["pt"])
        doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")

        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
        wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {
            "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py",
            "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py",
        }
        self.assertEqual(model_files, wav2vec2_model_files)

        self.assertEqual(wav2vec2_files["module_name"], "wav2vec2")

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
        wav2vec2_test_files = {
            "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
            "tests/models/wav2vec2/test_modeling_wav2vec2.py",
            "tests/models/wav2vec2/test_processor_wav2vec2.py",
            "tests/models/wav2vec2/test_tokenization_wav2vec2.py",
        }
        self.assertEqual(test_files, wav2vec2_test_files)

    def test_get_model_files_tf_and_flax(self):
        # BERT
        bert_files = get_model_files("bert", frameworks=["tf", "flax"])

        doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")

        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
        bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_bert.py"}
        self.assertEqual(model_files, bert_model_files)

        self.assertEqual(bert_files["module_name"], "bert")

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
        bert_test_files = {
            "tests/models/bert/test_tokenization_bert.py",
            "tests/models/bert/test_modeling_tf_bert.py",
            "tests/models/bert/test_modeling_flax_bert.py",
        }
        self.assertEqual(test_files, bert_test_files)

        # VIT
        vit_files = get_model_files("vit", frameworks=["tf", "flax"])
        doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")

        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
        vit_model_files = VIT_MODEL_FILES - {"src/transformers/models/vit/modeling_vit.py"}
        self.assertEqual(model_files, vit_model_files)

        self.assertEqual(vit_files["module_name"], "vit")

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
        vit_test_files = {
            "tests/models/vit/test_image_processing_vit.py",
            "tests/models/vit/test_modeling_tf_vit.py",
            "tests/models/vit/test_modeling_flax_vit.py",
        }
        self.assertEqual(test_files, vit_test_files)

        # Wav2Vec2
        wav2vec2_files = get_model_files("wav2vec2", frameworks=["tf", "flax"])
        doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")

        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
        wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {"src/transformers/models/wav2vec2/modeling_wav2vec2.py"}
        self.assertEqual(model_files, wav2vec2_model_files)

        self.assertEqual(wav2vec2_files["module_name"], "wav2vec2")

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
        wav2vec2_test_files = {
            "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
            "tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
            "tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
            "tests/models/wav2vec2/test_processor_wav2vec2.py",
            "tests/models/wav2vec2/test_tokenization_wav2vec2.py",
        }
        self.assertEqual(test_files, wav2vec2_test_files)

    def test_find_base_model_checkpoint(self):
        self.assertEqual(find_base_model_checkpoint("bert"), "bert-base-uncased")
        self.assertEqual(find_base_model_checkpoint("gpt2"), "gpt2")

    def test_retrieve_model_classes(self):
        gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2").items()}
        expected_gpt_classes = {
            "pt": {"GPT2ForTokenClassification", "GPT2Model", "GPT2LMHeadModel", "GPT2ForSequenceClassification"},
            "tf": {"TFGPT2Model", "TFGPT2ForSequenceClassification", "TFGPT2LMHeadModel"},
            "flax": {"FlaxGPT2Model", "FlaxGPT2LMHeadModel"},
        }
        self.assertEqual(gpt_classes, expected_gpt_classes)

        del expected_gpt_classes["flax"]
        gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["pt", "tf"]).items()}
        self.assertEqual(gpt_classes, expected_gpt_classes)

        del expected_gpt_classes["pt"]
        gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["tf"]).items()}
        self.assertEqual(gpt_classes, expected_gpt_classes)

    def test_retrieve_info_for_model_with_bert(self):
        bert_info = retrieve_info_for_model("bert")
        bert_classes = [
            "BertForTokenClassification",
            "BertForQuestionAnswering",
            "BertForNextSentencePrediction",
            "BertForSequenceClassification",
            "BertForMaskedLM",
            "BertForMultipleChoice",
            "BertModel",
            "BertForPreTraining",
            "BertLMHeadModel",
        ]
        expected_model_classes = {
            "pt": set(bert_classes),
            "tf": {f"TF{m}" for m in bert_classes},
            "flax": {f"Flax{m}" for m in bert_classes[:-1] + ["BertForCausalLM"]},
        }

        self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf", "flax"})
        model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()}
        self.assertEqual(model_classes, expected_model_classes)

        all_bert_files = bert_info["model_files"]
        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]}
        self.assertEqual(model_files, BERT_MODEL_FILES)

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
        bert_test_files = {
            "tests/models/bert/test_tokenization_bert.py",
            "tests/models/bert/test_modeling_bert.py",
            "tests/models/bert/test_modeling_tf_bert.py",
            "tests/models/bert/test_modeling_flax_bert.py",
        }
        self.assertEqual(test_files, bert_test_files)

        doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")

        self.assertEqual(all_bert_files["module_name"], "bert")

        bert_model_patterns = bert_info["model_patterns"]
        self.assertEqual(bert_model_patterns.model_name, "BERT")
        self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased")
        self.assertEqual(bert_model_patterns.model_type, "bert")
        self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
        self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
        self.assertEqual(bert_model_patterns.model_upper_cased, "BERT")
        self.assertEqual(bert_model_patterns.config_class, "BertConfig")
        self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer")
        self.assertIsNone(bert_model_patterns.feature_extractor_class)
        self.assertIsNone(bert_model_patterns.processor_class)

    def test_retrieve_info_for_model_pt_tf_with_bert(self):
        bert_info = retrieve_info_for_model("bert", frameworks=["pt", "tf"])
        bert_classes = [
            "BertForTokenClassification",
            "BertForQuestionAnswering",
            "BertForNextSentencePrediction",
            "BertForSequenceClassification",
            "BertForMaskedLM",
            "BertForMultipleChoice",
            "BertModel",
            "BertForPreTraining",
            "BertLMHeadModel",
        ]
        expected_model_classes = {"pt": set(bert_classes), "tf": {f"TF{m}" for m in bert_classes}}

        self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf"})
        model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()}
        self.assertEqual(model_classes, expected_model_classes)

        all_bert_files = bert_info["model_files"]
        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]}
        bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_flax_bert.py"}
        self.assertEqual(model_files, bert_model_files)

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
        bert_test_files = {
            "tests/models/bert/test_tokenization_bert.py",
            "tests/models/bert/test_modeling_bert.py",
            "tests/models/bert/test_modeling_tf_bert.py",
        }
        self.assertEqual(test_files, bert_test_files)

        doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")

        self.assertEqual(all_bert_files["module_name"], "bert")

        bert_model_patterns = bert_info["model_patterns"]
        self.assertEqual(bert_model_patterns.model_name, "BERT")
        self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased")
        self.assertEqual(bert_model_patterns.model_type, "bert")
        self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
        self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
        self.assertEqual(bert_model_patterns.model_upper_cased, "BERT")
        self.assertEqual(bert_model_patterns.config_class, "BertConfig")
        self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer")
        self.assertIsNone(bert_model_patterns.feature_extractor_class)
        self.assertIsNone(bert_model_patterns.processor_class)

    def test_retrieve_info_for_model_with_vit(self):
        vit_info = retrieve_info_for_model("vit")
        vit_classes = ["ViTForImageClassification", "ViTModel"]
        pt_only_classes = ["ViTForMaskedImageModeling"]
        expected_model_classes = {
            "pt": set(vit_classes + pt_only_classes),
            "tf": {f"TF{m}" for m in vit_classes},
            "flax": {f"Flax{m}" for m in vit_classes},
        }

        self.assertEqual(set(vit_info["frameworks"]), {"pt", "tf", "flax"})
        model_classes = {k: set(v) for k, v in vit_info["model_classes"].items()}
        self.assertEqual(model_classes, expected_model_classes)

        all_vit_files = vit_info["model_files"]
        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["model_files"]}
        self.assertEqual(model_files, VIT_MODEL_FILES)

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["test_files"]}
        vit_test_files = {
            "tests/models/vit/test_image_processing_vit.py",
            "tests/models/vit/test_modeling_vit.py",
            "tests/models/vit/test_modeling_tf_vit.py",
            "tests/models/vit/test_modeling_flax_vit.py",
        }
        self.assertEqual(test_files, vit_test_files)

        doc_file = str(Path(all_vit_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")

        self.assertEqual(all_vit_files["module_name"], "vit")

        vit_model_patterns = vit_info["model_patterns"]
        self.assertEqual(vit_model_patterns.model_name, "ViT")
        self.assertEqual(vit_model_patterns.checkpoint, "google/vit-base-patch16-224-in21k")
        self.assertEqual(vit_model_patterns.model_type, "vit")
        self.assertEqual(vit_model_patterns.model_lower_cased, "vit")
        self.assertEqual(vit_model_patterns.model_camel_cased, "ViT")
        self.assertEqual(vit_model_patterns.model_upper_cased, "VIT")
        self.assertEqual(vit_model_patterns.config_class, "ViTConfig")
        self.assertEqual(vit_model_patterns.feature_extractor_class, "ViTFeatureExtractor")
        self.assertEqual(vit_model_patterns.image_processor_class, "ViTImageProcessor")
        self.assertIsNone(vit_model_patterns.tokenizer_class)
        self.assertIsNone(vit_model_patterns.processor_class)

    def test_retrieve_info_for_model_with_wav2vec2(self):
        wav2vec2_info = retrieve_info_for_model("wav2vec2")
        wav2vec2_classes = [
            "Wav2Vec2Model",
            "Wav2Vec2ForPreTraining",
            "Wav2Vec2ForAudioFrameClassification",
            "Wav2Vec2ForCTC",
            "Wav2Vec2ForMaskedLM",
            "Wav2Vec2ForSequenceClassification",
            "Wav2Vec2ForXVector",
        ]
        expected_model_classes = {
            "pt": set(wav2vec2_classes),
            "tf": {f"TF{m}" for m in wav2vec2_classes[:1]},
            "flax": {f"Flax{m}" for m in wav2vec2_classes[:2]},
        }

        self.assertEqual(set(wav2vec2_info["frameworks"]), {"pt", "tf", "flax"})
        model_classes = {k: set(v) for k, v in wav2vec2_info["model_classes"].items()}
        self.assertEqual(model_classes, expected_model_classes)

        all_wav2vec2_files = wav2vec2_info["model_files"]
        model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["model_files"]}
        self.assertEqual(model_files, WAV2VEC2_MODEL_FILES)

        test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["test_files"]}
        wav2vec2_test_files = {
            "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
            "tests/models/wav2vec2/test_modeling_wav2vec2.py",
            "tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
            "tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
            "tests/models/wav2vec2/test_processor_wav2vec2.py",
            "tests/models/wav2vec2/test_tokenization_wav2vec2.py",
        }
        self.assertEqual(test_files, wav2vec2_test_files)

        doc_file = str(Path(all_wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
        self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")

        self.assertEqual(all_wav2vec2_files["module_name"], "wav2vec2")

        wav2vec2_model_patterns = wav2vec2_info["model_patterns"]
        self.assertEqual(wav2vec2_model_patterns.model_name, "Wav2Vec2")
        self.assertEqual(wav2vec2_model_patterns.checkpoint, "facebook/wav2vec2-base-960h")
        self.assertEqual(wav2vec2_model_patterns.model_type, "wav2vec2")
        self.assertEqual(wav2vec2_model_patterns.model_lower_cased, "wav2vec2")
        self.assertEqual(wav2vec2_model_patterns.model_camel_cased, "Wav2Vec2")
        self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV_2_VEC_2")
        self.assertEqual(wav2vec2_model_patterns.config_class, "Wav2Vec2Config")
        self.assertEqual(wav2vec2_model_patterns.feature_extractor_class, "Wav2Vec2FeatureExtractor")
        self.assertEqual(wav2vec2_model_patterns.processor_class, "Wav2Vec2Processor")
        self.assertEqual(wav2vec2_model_patterns.tokenizer_class, "Wav2Vec2CTCTokenizer")

    def test_clean_frameworks_in_init_with_gpt(self):
        test_init = """
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available

_import_structure = {
    "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
    "tokenization_gpt2": ["GPT2Tokenizer"],
}

try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]

try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_gpt2"] = ["GPT2Model"]

try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]

try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]

if TYPE_CHECKING:
    from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
    from .tokenization_gpt2 import GPT2Tokenizer

    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_gpt2_fast import GPT2TokenizerFast

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_gpt2 import GPT2Model

    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_tf_gpt2 import TFGPT2Model

    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_flax_gpt2 import FlaxGPT2Model

else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
"""

        init_no_tokenizer = """
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available

_import_structure = {
    "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
}

try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_gpt2"] = ["GPT2Model"]

try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]

try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]

if TYPE_CHECKING:
    from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_gpt2 import GPT2Model

    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_tf_gpt2 import TFGPT2Model

    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_flax_gpt2 import FlaxGPT2Model

else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
"""

        init_pt_only = """
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_tokenizers_available, is_torch_available

_import_structure = {
    "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
    "tokenization_gpt2": ["GPT2Tokenizer"],
}

try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]

try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_gpt2"] = ["GPT2Model"]

if TYPE_CHECKING:
    from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
    from .tokenization_gpt2 import GPT2Tokenizer

    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_gpt2_fast import GPT2TokenizerFast

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_gpt2 import GPT2Model

else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
"""

        init_pt_only_no_tokenizer = """
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_torch_available

_import_structure = {
    "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
}

try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_gpt2"] = ["GPT2Model"]

if TYPE_CHECKING:
    from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_gpt2 import GPT2Model

else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
"""

        with tempfile.TemporaryDirectory() as tmp_dir:
            file_name = os.path.join(tmp_dir, "../__init__.py")

            self.init_file(file_name, test_init)
            clean_frameworks_in_init(file_name, keep_processing=False)
            self.check_result(file_name, init_no_tokenizer)

            self.init_file(file_name, test_init)
            clean_frameworks_in_init(file_name, frameworks=["pt"])
            self.check_result(file_name, init_pt_only)

            self.init_file(file_name, test_init)
            clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False)
            self.check_result(file_name, init_pt_only_no_tokenizer)

    def test_clean_frameworks_in_init_with_vit(self):
        test_init = """
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available

_import_structure = {
    "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
}

try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["image_processing_vit"] = ["ViTImageProcessor"]

try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_vit"] = ["ViTModel"]

try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_tf_vit"] = ["TFViTModel"]

try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_flax_vit"] = ["FlaxViTModel"]

if TYPE_CHECKING:
    from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig

    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .image_processing_vit import ViTImageProcessor

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_vit import ViTModel

    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_tf_vit import TFViTModel

    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_flax_vit import FlaxViTModel

else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
"""

        init_no_feature_extractor = """
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available

_import_structure = {
    "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
}

try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_vit"] = ["ViTModel"]

try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_tf_vit"] = ["TFViTModel"]

try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_flax_vit"] = ["FlaxViTModel"]

if TYPE_CHECKING:
    from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_vit import ViTModel

    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_tf_vit import TFViTModel

    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_flax_vit import FlaxViTModel

else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
"""

        init_pt_only = """
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_torch_available, is_vision_available

_import_structure = {
    "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
}

try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["image_processing_vit"] = ["ViTImageProcessor"]

try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_vit"] = ["ViTModel"]

if TYPE_CHECKING:
    from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig

    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .image_processing_vit import ViTImageProcessor

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_vit import ViTModel

else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
"""

        init_pt_only_no_feature_extractor = """
from typing import TYPE_CHECKING

from ...utils import _LazyModule, is_torch_available

_import_structure = {
    "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
}

try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_vit"] = ["ViTModel"]

if TYPE_CHECKING:
    from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_vit import ViTModel

else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
"""

        with tempfile.TemporaryDirectory() as tmp_dir:
            file_name = os.path.join(tmp_dir, "../__init__.py")

            self.init_file(file_name, test_init)
            clean_frameworks_in_init(file_name, keep_processing=False)
            self.check_result(file_name, init_no_feature_extractor)

            self.init_file(file_name, test_init)
            clean_frameworks_in_init(file_name, frameworks=["pt"])
            self.check_result(file_name, init_pt_only)

            self.init_file(file_name, test_init)
            clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False)
            self.check_result(file_name, init_pt_only_no_feature_extractor)

    def test_duplicate_doc_file(self):
        test_doc = """
# GPT2

## Overview

Overview of the model.

## GPT2Config

[[autodoc]] GPT2Config

## GPT2Tokenizer

[[autodoc]] GPT2Tokenizer
    - save_vocabulary

## GPT2TokenizerFast

[[autodoc]] GPT2TokenizerFast

## GPT2 specific outputs

[[autodoc]] models.gpt2.modeling_gpt2.GPT2DoubleHeadsModelOutput

[[autodoc]] models.gpt2.modeling_tf_gpt2.TFGPT2DoubleHeadsModelOutput

## GPT2Model

[[autodoc]] GPT2Model
    - forward

## TFGPT2Model

[[autodoc]] TFGPT2Model
    - call

## FlaxGPT2Model

[[autodoc]] FlaxGPT2Model
    - __call__

"""
        test_new_doc = """
# GPT-New New

## Overview

The GPT-New New model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*

Tips:

<INSERT TIPS ABOUT MODEL HERE>

This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).


## GPTNewNewConfig

[[autodoc]] GPTNewNewConfig

## GPTNewNewTokenizer

[[autodoc]] GPTNewNewTokenizer
    - save_vocabulary

## GPTNewNewTokenizerFast

[[autodoc]] GPTNewNewTokenizerFast

## GPTNewNew specific outputs

[[autodoc]] models.gpt_new_new.modeling_gpt_new_new.GPTNewNewDoubleHeadsModelOutput

[[autodoc]] models.gpt_new_new.modeling_tf_gpt_new_new.TFGPTNewNewDoubleHeadsModelOutput

## GPTNewNewModel

[[autodoc]] GPTNewNewModel
    - forward

## TFGPTNewNewModel

[[autodoc]] TFGPTNewNewModel
    - call

## FlaxGPTNewNewModel

[[autodoc]] FlaxGPTNewNewModel
    - __call__

"""

        with tempfile.TemporaryDirectory() as tmp_dir:
            doc_file = os.path.join(tmp_dir, "gpt2.mdx")
            new_doc_file = os.path.join(tmp_dir, "gpt-new-new.mdx")

            gpt2_model_patterns = ModelPatterns("GPT2", "gpt2", tokenizer_class="GPT2Tokenizer")
            new_model_patterns = ModelPatterns(
                "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPTNewNewTokenizer"
            )

            self.init_file(doc_file, test_doc)
            duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns)
            self.check_result(new_doc_file, test_new_doc)

            test_new_doc_pt_only = test_new_doc.replace(
                """
## TFGPTNewNewModel

[[autodoc]] TFGPTNewNewModel
    - call

## FlaxGPTNewNewModel

[[autodoc]] FlaxGPTNewNewModel
    - __call__

""",
                "",
            )
            self.init_file(doc_file, test_doc)
            duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"])
            self.check_result(new_doc_file, test_new_doc_pt_only)

            test_new_doc_no_tok = test_new_doc.replace(
                """
## GPTNewNewTokenizer

[[autodoc]] GPTNewNewTokenizer
    - save_vocabulary

## GPTNewNewTokenizerFast

[[autodoc]] GPTNewNewTokenizerFast
""",
                "",
            )
            new_model_patterns = ModelPatterns(
                "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPT2Tokenizer"
            )
            self.init_file(doc_file, test_doc)
            duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns)
            print(test_new_doc_no_tok)
            self.check_result(new_doc_file, test_new_doc_no_tok)

            test_new_doc_pt_only_no_tok = test_new_doc_no_tok.replace(
                """
## TFGPTNewNewModel

[[autodoc]] TFGPTNewNewModel
    - call

## FlaxGPTNewNewModel

[[autodoc]] FlaxGPTNewNewModel
    - __call__

""",
                "",
            )
            self.init_file(doc_file, test_doc)
            duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"])
            self.check_result(new_doc_file, test_new_doc_pt_only_no_tok)