| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| | import gradio as gr |
| | import torch |
| | import autopep8 |
| | import glob |
| | import re |
| | import os |
| | from huggingface_hub import hf_hub_download |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def normalize_indentation(code): |
| | """ |
| | Normalize indentation in example code by removing excessive tabs. |
| | Also removes any backslash characters. |
| | """ |
| | code = code.replace("\\", "") |
| |
|
| | lines = code.split("\n") |
| | if not lines: |
| | return "" |
| |
|
| | fixed_lines = [] |
| | indent_fix_mode = False |
| |
|
| | for i, line in enumerate(lines): |
| | if line.strip().startswith("def "): |
| | fixed_lines.append(line) |
| | indent_fix_mode = True |
| | elif indent_fix_mode and line.strip(): |
| | |
| | if line.startswith("\t\t"): |
| | fixed_lines.append("\t" + line[2:]) |
| | elif line.startswith(" "): |
| | fixed_lines.append(" " + line[8:]) |
| | else: |
| | fixed_lines.append(line) |
| | else: |
| | fixed_lines.append(line) |
| |
|
| | return "\n".join(fixed_lines) |
| |
|
| |
|
| | def clear_text(text): |
| | """ |
| | Cleans text from escape sequences while preserving original formatting. |
| | """ |
| | temp_newline = "TEMP_NEWLINE_PLACEHOLDER" |
| | temp_tab = "TEMP_TAB_PLACEHOLDER" |
| |
|
| | text = text.replace("\\n", temp_newline) |
| | text = text.replace("\\t", temp_tab) |
| |
|
| | text = text.replace("\\", "") |
| |
|
| | text = text.replace(temp_newline, "\n") |
| | text = text.replace(temp_tab, "\t") |
| |
|
| | return text |
| |
|
| |
|
| | def encode_text(text): |
| | """ |
| | Encodes control characters into escape sequences. |
| | """ |
| | text = text.replace("\n", "\\n") |
| | text = text.replace("\t", "\\t") |
| | return text |
| |
|
| |
|
| | def format_code(code): |
| | """ |
| | Format Python code using autopep8 with aggressive settings. |
| | """ |
| | try: |
| | formatted_code = autopep8.fix_code( |
| | code, |
| | options={ |
| | "aggressive": 2, |
| | "max_line_length": 88, |
| | "indent_size": 4, |
| | }, |
| | ) |
| |
|
| | |
| | formatted_code = formatted_code.replace("( ", "(").replace(" )", ")") |
| |
|
| | for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]: |
| | formatted_code = formatted_code.replace(f"{op} ", op + " ") |
| | formatted_code = formatted_code.replace(f" {op}", " " + op) |
| |
|
| | formatted_code = re.sub(r"(\w+)\s+\(", r"\1(", formatted_code) |
| |
|
| | return formatted_code |
| | except Exception as e: |
| | print(f"Error formatting code: {str(e)}") |
| | return code |
| |
|
| |
|
| | def fix_common_syntax_issues(code): |
| | """ |
| | Fix common syntax issues in generated code without modifying indentation. |
| | """ |
| | lines = code.split("\n") |
| | fixed_lines = [] |
| |
|
| | for line in lines: |
| | stripped = line.strip() |
| | if ( |
| | stripped.startswith("if ") |
| | or stripped.startswith("elif ") |
| | or stripped.startswith("else") |
| | or stripped.startswith("for ") |
| | or stripped.startswith("while ") |
| | or stripped.startswith("def ") |
| | or stripped.startswith("class ") |
| | ): |
| | if not stripped.endswith(":") and not stripped.endswith("\\"): |
| | line = line.rstrip() + ":" |
| |
|
| | fixed_lines.append(line) |
| |
|
| | code = "\n".join(fixed_lines) |
| |
|
| | |
| | quote_chars = ['"', "'"] |
| | for quote in quote_chars: |
| | if code.count(quote) % 2 != 0: |
| | lines = code.split("\n") |
| | for i, line in enumerate(lines): |
| | if line.count(quote) % 2 != 0: |
| | lines[i] = line.rstrip() + quote |
| | break |
| | code = "\n".join(lines) |
| |
|
| | |
| | pattern = r"(\w+)\s*\([^)]*$" |
| | if re.search(pattern, code): |
| | lines = code.split("\n") |
| | for i, line in enumerate(lines): |
| | if re.search(pattern, line) and not any( |
| | lines[j].strip().startswith(")") |
| | for j in range(i + 1, min(i + 3, len(lines))) |
| | ): |
| | lines[i] = line.rstrip() + ")" |
| | code = "\n".join(lines) |
| |
|
| | return code |
| |
|
| |
|
| | def load_example_from_file(example_path): |
| | """ |
| | Load example from a file with format: |
| | description_BREAK_code |
| | where 'code' uses \\n and \\t for formatting. |
| | """ |
| | try: |
| | with open(example_path, "r") as f: |
| | content = f.read() |
| |
|
| | parts = content.split("_BREAK_") |
| | if len(parts) == 2: |
| | description = parts[0].strip() |
| | code = parts[1].strip() |
| |
|
| | code = code.replace("\\n", "\n").replace("\\t", "\t") |
| | code = normalize_indentation(code) |
| |
|
| | return description, code |
| | else: |
| | print(f"Invalid format in example file: {example_path}") |
| | return "", "" |
| | except Exception as e: |
| | print(f"Error loading example file {example_path}: {str(e)}") |
| | return "", "" |
| |
|
| |
|
| | def find_example_files(): |
| | """ |
| | Find all raw.in example files in the examples directory. |
| | """ |
| | example_files = glob.glob("examples/*/raw.in") |
| | return example_files |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | BASE_MODEL_ID = "Salesforce/codet5p-770m" |
| | FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs" |
| | FINETUNED_FILENAME = "pytorch_model.bin" |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | print(f"Loading tokenizer from base model: {BASE_MODEL_ID}") |
| | tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
| |
|
| | print(f"Loading base model: {BASE_MODEL_ID}") |
| | model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID) |
| | model.to(device) |
| |
|
| | print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}") |
| | ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME) |
| |
|
| | print(f"Loading state_dict from: {ckpt_path}") |
| | state_dict = torch.load(ckpt_path, map_location="cpu") |
| |
|
| | if "model_state_dict" in state_dict: |
| | state_dict = state_dict["model_state_dict"] |
| |
|
| | missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| | print(f"Loaded fine-tuned weights. Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}") |
| |
|
| | model.eval() |
| |
|
| |
|
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | current_code = None |
| | bug_counter = 0 |
| |
|
| |
|
| | def generate_bugged_code(description, code, chat_history, is_first_time): |
| | global current_code, bug_counter |
| |
|
| | if chat_history is None: |
| | chat_history = [] |
| |
|
| | if is_first_time: |
| | bug_counter = 0 |
| | current_code = None |
| | chat_history = [] |
| |
|
| | bug_counter += 1 |
| |
|
| | if bug_counter == 1: |
| | input_for_model = code |
| | input_type = "original" |
| | else: |
| | if current_code is None: |
| | return chat_history, gr.update(value=""), False |
| | input_for_model = current_code |
| | input_type = "previous bugged code" |
| |
|
| | print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}") |
| |
|
| | encoded_code = encode_text(input_for_model) |
| | combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}" |
| |
|
| | inputs = tokenizer( |
| | combined_input, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=512, |
| | ).input_ids.to(device) |
| |
|
| | try: |
| | print("Starting generation...") |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | inputs, |
| | max_new_tokens=256, |
| | num_beams=1, |
| | do_sample=False, |
| | early_stopping=True, |
| | ) |
| | print("Generation done.") |
| | except Exception as e: |
| | print("Generation error:", repr(e)) |
| | raise e |
| |
|
| | bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | bugged_code = clear_text(bugged_code_escaped) |
| | bugged_code = fix_common_syntax_issues(bugged_code) |
| | bugged_code = format_code(bugged_code) |
| |
|
| | current_code = bugged_code |
| |
|
| | user_message = f"**Description**: {description}" |
| | if input_type == "original": |
| | user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```" |
| | else: |
| | user_message += ( |
| | f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```" |
| | ) |
| |
|
| | ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```" |
| |
|
| | chat_history = chat_history + [ |
| | {"role": "user", "content": user_message}, |
| | {"role": "assistant", "content": ai_message}, |
| | ] |
| |
|
| | return chat_history, gr.update(value=""), False |
| |
|
| |
|
| |
|
| |
|
| | def reset_interface(): |
| | global current_code, bug_counter |
| | current_code = None |
| | bug_counter = 0 |
| | return [], gr.update(value=""), True |
| |
|
| |
|
| | example_files = find_example_files() |
| | example_names = [ |
| | f"Example {i+1}: {os.path.basename(os.path.dirname(f))}" |
| | for i, f in enumerate(example_files) |
| | ] |
| |
|
| |
|
| | def load_example(example_index): |
| | if example_index < len(example_files): |
| | return load_example_from_file(example_files[example_index]) |
| | return "", "" |
| |
|
| |
|
| | with gr.Blocks(title="Software-Fault Injection from NL") as demo: |
| | gr.Markdown("# 🐞 Software-Fault Injection from Natural Language") |
| | gr.Markdown( |
| | "Generate Python code with specific bugs based on a description and original code. " |
| | "The model used is **BugGen (CodeT5+ 770M, PyResBugs)**." |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | description_input = gr.Textbox( |
| | label="Bug Description", |
| | placeholder="Describe the type of bug to introduce...", |
| | lines=3, |
| | ) |
| | code_input = gr.Code( |
| | label="Original Code", |
| | language="python", |
| | lines=12, |
| | ) |
| |
|
| | is_first = gr.State(True) |
| |
|
| | submit_btn = gr.Button("Generate Bugged Code") |
| | reset_btn = gr.Button("Start Over") |
| |
|
| | gr.Markdown("### Examples") |
| | example_buttons = [gr.Button(name) for name in example_names] |
| |
|
| | with gr.Column(scale=3): |
| | chat_output = gr.Chatbot( |
| | label="Conversation", |
| | height=500, |
| | ) |
| |
|
| | for i, btn in enumerate(example_buttons): |
| | btn.click( |
| | fn=lambda i=i: load_example(i), |
| | outputs=[description_input, code_input], |
| | ) |
| |
|
| | submit_btn.click( |
| | fn=generate_bugged_code, |
| | inputs=[description_input, code_input, chat_output, is_first], |
| | outputs=[chat_output, description_input, is_first], |
| | ) |
| |
|
| | reset_btn.click( |
| | fn=reset_interface, |
| | outputs=[chat_output, description_input, is_first], |
| | ) |
| |
|
| | print("Launching Gradio interface...") |
| | demo.queue(max_size=10).launch() |