#!/usr/bin/env python3 | |
import sys | |
filenames = sys.argv[1:] | |
MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart._make_causal_mask" | |
MATCH_PATTERN_2 = "def _make_causal_mask(" | |
MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart.prepare_4d_attention_mask" | |
MATCH_PATTERN_2 = "def prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):" | |
END_MATCH_PATTERN_2 = "" | |
# MATCH_PATTERN_1 = "def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):" | |
#MATCH_PATTERN_2 = "# create causal mask" | |
# END_MATCH_PATTERN_2 = "def forward(" | |
for filename in filenames: | |
with open(filename, "r") as f: | |
lines = f.readlines() | |
new_lines = [] | |
is_in_del = False | |
for i, line in enumerate(lines): | |
if line.strip().lstrip() == MATCH_PATTERN_1 and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == MATCH_PATTERN_2: | |
print("suh") | |
is_in_del = True | |
elif line.strip().lstrip() == "" and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == END_MATCH_PATTERN_2: | |
is_in_del = False | |
if not is_in_del: | |
new_lines.append(line) | |
with open(filename, "w") as f: | |
f.writelines(new_lines) | |