import subprocess import re import pandas as pd import plotly.express as px from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from prettytable import PrettyTable import streamlit as st #st.title('Code Generation on the CoNaLa Dataset') import subprocess import re import pandas as pd import plotly.express as px from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from prettytable import PrettyTable #browser.gatherUsageStats=False class CodeGenerator: def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained("AhmedSSoliman/MarianCG-CoNaLa-Large") self.model = AutoModelForSeq2SeqLM.from_pretrained("AhmedSSoliman/MarianCG-CoNaLa-Large") def generate_code(self, nl_input): input_ids = self.tokenizer.encode(nl_input, return_tensors="pt") output_ids = self.model.generate(input_ids) output_code = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) return output_code def check_code(self, code): with open("temp.py", "w") as f: f.write(code) result = subprocess.run(["flake8", "--count", "temp.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) output = result.stdout.decode() error = result.stderr.decode() return output #return self._process_output(output, error) def check_code_list(self, code_list): output = "" error = "" for code in code_list: with open("temp.py", "w") as f: f.write(code) result = subprocess.run(["flake8", "--count", "temp.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) output += result.stdout.decode() error += result.stderr.decode() return self._process_output(output, error) def _process_output(self, output, error): if output: output_counts = self._get_error_counts(output) self.show_variables_in_table(output_counts, output) self.visualize_all_errors(output_counts) self.visualize_error_types(output_counts) return self._format_error_counts(output_counts) else: error_counts = self._get_error_counts(error) self.show_variables_in_table(output_counts, output) self.visualize_all_errors(error_counts) self.visualize_error_types(error_counts) return self._format_error_counts(error_counts) def _get_error_counts(self, output): error_counts = {} error_messages = re.findall(r"temp.py:(\d+):\d+: (\w\d+)", output) for message in error_messages: error_type = message[1] if error_type in error_counts: error_counts[error_type] += 1 else: error_counts[error_type] = 1 return error_counts def _format_error_counts(self, error_counts): error_message = "\n".join([f"{error_type}: {count}" for error_type, count in error_counts.items()]) return error_message def visualize_all_errors(self, error_counts): for error_type, count in error_counts.items(): print(f"{error_type}: {count}\n") def visualize_error_types(self, error_counts): df = pd.DataFrame({'Error Type': list(error_counts.keys()), 'Count': list(error_counts.values())}) fig = px.bar(df, x='Count', y='Error Type', title='Error Occurrences in The Generated Code') fig.update_layout( title={ 'text': "Error Occurrences in The Generated Code", 'x': 0.5, 'y': 0.96, 'xanchor': 'center', 'yanchor': 'top' }, xaxis_title="Error Counts", yaxis_title="Error Codes" ) fig.show() def show_variables_in_table(self, output_counts, output): table = PrettyTable() table.field_names = ["Error Code", "Message"] table.add_row([output_counts, output]) #table.add_row(["Error", error]) print(table) def display_variables(self, output, error): output_df = pd.DataFrame({"Output": [output]}) error_df = pd.DataFrame({"Error": [error]}) display(pd.concat([output_df, error_df], axis=1)) import autopep8 import black import isort import pylint.lint import autoimport from yapf.yapflib.yapf_api import FormatCode # reformat a string of code class PythonCodeFormatter: def __init__(self, code): self.code = code.replace('▁', ' ').strip() def load_code_from_file(self, filename): # Load the code to be fixed with open(filename, 'r') as f: self.code = f.read() def format(self): try: # Use isort to sort and organize the imports formatted_code = isort.code(self.code) # Use black to format the code formatted_code = black.format_str(formatted_code, mode=black.Mode()) # Use autoimport to add a missing import statement formatted_code = autoimport.fix_code(formatted_code) # Use autopep8 to fix any remaining issues formatted_code = autopep8.fix_code(formatted_code) formatted_code, changed = FormatCode(formatted_code) return formatted_code except RuntimeError as error: if str(error) == 'Project root not found.': return formatted_code else: raise # re-raise the error if it's not the one we're looking for except ValueError as error: return formatted_code return formatted_code def save(self, filename): # Save the fixed code to a file with open(filename, 'w') as f: f.write(self.code) code_generator = CodeGenerator() # Streamlit app def main(): st.title('Code Generator and Error Checker') nl_input = st.text_area('Enter natural language input for code generation') if st.button('Generate Code'): # Generate code output_code = code_generator.generate_code(nl_input) st.subheader('Generated Code') st.code(output_code, language='python') # Check code for errors st.subheader('Error Check') error_message = code_generator.check_code(output_code) st.write('Error Counts:') st.write(error_message) st.subheader('Error Correction') formatter = PythonCodeFormatter(output_code) formatted_code = formatter.format() st.write('Code after correction:') st.write(formatted_code) #st.subheader('Code after correction:') #st.code(formatted_code, language='python') if __name__ == '__main__': main()