Spaces:
Runtime error
Runtime error
import re | |
import gradio as gr | |
from transformers import RobertaTokenizer, pipeline, AutoModelForMaskedLM | |
tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictor") | |
model = AutoModelForMaskedLM.from_pretrained("mamiksik/CommitPredictor") | |
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer) | |
def parse_files(patch) -> str: | |
accumulator = [] | |
lines = patch.splitlines() | |
filename_before = None | |
for line in lines: | |
if line.startswith("index") or line.startswith("diff"): | |
continue | |
if line.startswith("---"): | |
filename_before = line.split(" ", 1)[1][1:] | |
continue | |
if line.startswith("+++"): | |
filename_after = line.split(" ", 1)[1][1:] | |
if filename_before == filename_after: | |
accumulator.append(f"<ide><path>{filename_before}") | |
else: | |
accumulator.append(f"<add><path>{filename_after}") | |
accumulator.append(f"<del><path>{filename_before}") | |
continue | |
line = re.sub("@@[^@@]*@@", "", line) | |
if len(line) == 0: | |
continue | |
if line[0] == "+": | |
line = line.replace("+", "<add>", 1) | |
elif line[0] == "-": | |
line = line.replace("-", "<del>", 1) | |
else: | |
line = f"<ide>{line}" | |
accumulator.append(line) | |
return '\n'.join(accumulator) | |
def predict(patch, commit_message): | |
input_text = parse_files(patch) + "\n<msg> " + commit_message | |
token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1] | |
result = pipe.predict(input_text) | |
return token_count, input_text, {pred['token_str']: round(pred['score'], 3) for pred in result} | |
iface = gr.Interface(fn=predict, inputs=[ | |
gr.Textbox(label="Patch (as generated by git diff)"), | |
gr.Textbox(label="Commit message (with one <mask> token)"), | |
], outputs=[ | |
gr.Textbox(label="Token count"), | |
gr.Textbox(label="Parsed patch"), | |
gr.Label(label="Predictions") | |
], examples=[[""" | |
def main(): | |
- name = "John" | |
print("Hello World") | |
""", "Remove <mask> variable"] | |
]) | |
if __name__ == "__main__": | |
iface.launch() | |