# Remi Serra 202407 from env_utils import load_credentials import gradio as gr from random import randrange from svg_utils import decode_b64_string_to_pretty_xml, encode_svg_xml_to_b64_string from watsonx_utils import wxEngine from prompts import ( list_prompts, get_prompt_template, get_prompt_example, get_prompt_primer, get_prompt_uploadmode, ) from data_images import svg_three_dots from ibm_watsonx_ai.wml_client_error import WMLClientError # Functions - input def read_file(uploaded_file): if uploaded_file: svg_xml = open(uploaded_file, "r").read() encoded_data_string = xml_string_to_data_string(svg_xml) return ( encoded_data_string, svg_xml, html_img_preview(encoded_data_string), ) def input_encoded_string_box_change(data_string: str): # print(f"encoded_string_box_change:image_data:{data_string}") svg_xml = decode_b64_string_to_pretty_xml(data_string) return svg_xml, html_img_preview(data_string) def input_xml_string_box_change(svg_xml: str): data_string = xml_string_to_data_string(svg_xml) return data_string, html_img_preview(data_string) def xml_string_to_data_string(svg_xml: str): b64 = encode_svg_xml_to_b64_string(svg_xml) data_string = "data:image/svg+xml;base64," + b64 return data_string def html_img_preview(data_string): return f'' image_placeholder = html_img_preview(svg_three_dots) # Functions - watsonx def wx_prompt_drop_change(prompt_template_name): show_upload = get_prompt_uploadmode(prompt_template_name) return { wx_prompt_box: get_prompt_template(prompt_template_name), wx_instructions_box: get_prompt_example(prompt_template_name), wx_primer_box: get_prompt_primer(prompt_template_name), upload_row: gr.Row(visible=show_upload), upload_accordeon: gr.Accordion(visible=show_upload), } def wx_models_dropdown(wx_engine_state): wx_engine = wx_engine_state model_list = [] default_value = None recommended_model = "ibm/granite-20b-code-instruct" if wx_engine is not None: model_list = wx_engine.list_models() default_value = ( recommended_model if recommended_model in model_list else model_list[0] ) # print(f"wx_models_dropdown:model_list:{model_list}") # print(f"wx_models_dropdown:default_value:{default_value}") return gr.Dropdown( label="Model", info=recommended_model + " recommended", choices=model_list, value=default_value, ) def wx_connect_click(wx_engine_state, apiendpoint, apikey, projectid): # if apiendpoint is not None and apikey is not None and projectid is not None: try: wx_engine_state = wxEngine(apiendpoint, apikey, projectid) msg = "watsonx.ai sucessfully activated" print(msg) return ( wx_engine_state, wx_models_dropdown(wx_engine_state), gr.Accordion(open=False), gr.Button(interactive=True), gr.Textbox(msg), ) except WMLClientError as ex: template = "Exception {0} occurred: {1!r}" msg = template.format(type(ex).__name__, ex.args) print(msg) return ( wx_engine_state, [], gr.Accordion(open=True), gr.Button(interactive=False), gr.Textbox(msg), ) def prepare_prompt( wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" ): wx_status = "Done." wx_engine = wx_engine_state # get model specs model_max_tokens = wx_engine.get_model_max_tokens(wx_model) # Add "primer" at the end of the prompt prompt = wx_prompt.format(svg=xml_string, instructions=wx_instructions) + wx_primer # Test and alert if prompt is too long prompt_nb_tokens = wx_engine.get_prompt_nb_tokens(prompt, wx_model) if prompt_nb_tokens > model_max_tokens: wx_status = f"Warning: prompt length ({prompt_nb_tokens}) is more than the model max tokens ({model_max_tokens}), and will be truncated. Please review your instructions." print(wx_status) # calculate max new token based on xml_string - or 500 when original string is too small # note: prompt will be truncated if too long with GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS in generate() max_new_tokens = max(500, len(xml_string)) return wx_status, max_new_tokens, prompt def wx_generate( wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" ): wx_engine = wx_engine_state wx_status, max_new_tokens, prompt = prepare_prompt( wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string ) wx_result = wx_primer + wx_engine.generate_text( modelid=wx_model, prompt=prompt, max_new_tokens=max_new_tokens, stop_sequences=[""], ) print(f"wx_generate:wx_result:{wx_result}") data_string = xml_string_to_data_string(wx_result) return wx_status, data_string, wx_result, html_img_preview(data_string) def wx_stream( wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" ): wx_engine = wx_engine_state wx_status, max_new_tokens, prompt = prepare_prompt( wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string ) wx_result = wx_primer # https://www.gradio.app/guides/streaming-outputs wx_result_generator = wx_engine.generate_text( modelid=wx_model, prompt=prompt, max_new_tokens=max_new_tokens, stop_sequences=[""], stream=True, ) for chunk in wx_result_generator: wx_result += chunk status = f"Processing.{'.'*int(randrange(3))}" yield status, status, wx_result, None print(f"wx_stream:wx_result:{wx_result}") data_string = xml_string_to_data_string(wx_result) yield wx_status, data_string, wx_result, html_img_preview(data_string) # Functions - output def output_xml_string_box_change(xml_string): data_string = xml_string_to_data_string(xml_string) return ( data_string, xml_string, html_img_preview(data_string), ) # APP layout with gr.Blocks(theme = "Zarkel/IBM_Carbon_Theme") as demo: gr.Markdown("# SVG editor") gr.Markdown( """### Get started: - Create a new SVG: Enter a description in the 'Instructions' box and click 'Submit' - Modify an existing SVG: Upload an SVG file, or paste an image string or SVG XML, then Select the prompt template 'Modify SVG', enter a change instruction in the 'Instructions' box and click 'Submit'""" ) # init state - note gr.State() initial value must be deep-copyable - my wx_engine class is not wx_engine_state = gr.State(None) with gr.Row(): # main UI with gr.Column(scale=0): # watsonx setup # prompt template selection prompt_template_names = list_prompts() default_prompt_template_name = prompt_template_names[0] wx_prompt_drop = gr.Dropdown( label="Action", choices=prompt_template_names, value=default_prompt_template_name, ) # credentials # load env variables status_unused, env_apiendpoint, env_apikey, env_projectid = ( load_credentials() ) with gr.Accordion("Credentials", open=True) as credentials_accordeon: wx_creds_endpoint = gr.Textbox( label="Endpoint", value=env_apiendpoint or "https://us-south.ml.cloud.ibm.com", max_lines=1, ) wx_creds_apikey = gr.Textbox( label="API key", value=env_apikey, max_lines=1 ) wx_creds_projectid = gr.Textbox( label="Project id", value=env_projectid, max_lines=1 ) wx_connect_btn = gr.Button("Connect") # model wx_models_drop = wx_models_dropdown(None) # prompt text and primer wx_prompt_box = gr.Textbox( info="Text", show_label=False, max_lines=5, value=get_prompt_template(prompt_template_names[0]), ) wx_primer_box = gr.Textbox( info="Primer", show_label=False, max_lines=2, value=get_prompt_primer(prompt_template_names[0]), ) with gr.Column(): # main pane # Upload with gr.Row( visible=get_prompt_uploadmode(default_prompt_template_name) ) as upload_row: # Upload an .svg file input_file = gr.File(scale=0, label="Upload an SVG file") # original preview input_svg_preview = gr.HTML(image_placeholder) # decoded SVG XML input_xml_string_box = gr.Textbox( label="Input SVG XML", lines=7, max_lines=7, show_copy_button=True, scale=3, ) with gr.Accordion( label="Input encoded string", open=False, visible=get_prompt_uploadmode(default_prompt_template_name), ) as upload_accordeon: # Encoded image string input_encoded_string_box = gr.Textbox( label="Image string", info="data:image/svg+xml;base64,...", lines=7, max_lines=7, show_copy_button=True, scale=3, container=False, ) # modification with gr.Row(): wx_instructions_box = gr.Textbox( label="Instructions", scale=3, value=get_prompt_example(prompt_template_names[0]), show_copy_button=True, ) wx_generate_btn = gr.Button("↓Generate↓", scale=0, interactive=False) output_svg_preview = gr.HTML(image_placeholder) output_xml_string_box = gr.Textbox( label="Result SVG XML", lines=7, max_lines=7, scale=3, show_copy_button=True, ) with gr.Accordion(label="Result encoded string", open=False): output_encoded_string_box = gr.Textbox( label="Image string", info="data:image/svg+xml;base64,...", lines=7, max_lines=7, show_copy_button=True, scale=3, container=False, ) wx_status_box = gr.Markdown("Status") # Map controls to functions wx_prompt_drop.input( fn=wx_prompt_drop_change, inputs=wx_prompt_drop, outputs=[ wx_prompt_box, wx_instructions_box, wx_primer_box, upload_row, upload_accordeon, ], ) wx_connect_btn.click( fn=wx_connect_click, inputs=[ wx_engine_state, wx_creds_endpoint, wx_creds_apikey, wx_creds_projectid, ], outputs=[ wx_engine_state, wx_models_drop, credentials_accordeon, wx_generate_btn, wx_status_box, ], ) input_file.upload( fn=read_file, inputs=input_file, outputs=[ input_encoded_string_box, input_xml_string_box, input_svg_preview, ], ) input_encoded_string_box.input( fn=input_encoded_string_box_change, inputs=[input_encoded_string_box], outputs=[input_xml_string_box, input_svg_preview], ) input_xml_string_box.input( fn=input_xml_string_box_change, inputs=[input_xml_string_box], outputs=[input_encoded_string_box, input_svg_preview], ) wx_generate_btn.click( fn=wx_stream, inputs=[ wx_engine_state, wx_models_drop, wx_prompt_box, wx_instructions_box, wx_primer_box, input_xml_string_box, ], outputs=[ wx_status_box, output_encoded_string_box, output_xml_string_box, output_svg_preview, ], api_name="wx_generate", ) output_xml_string_box.input( fn=output_xml_string_box_change, inputs=[output_xml_string_box], outputs=[ output_encoded_string_box, output_xml_string_box, output_svg_preview, ], ) # Main if __name__ == "__main__": demo.launch()