svg-editor / svg_editor_gradio.py
remiserra's picture
set image width
86da7c8
# Remi Serra 202407
from env_utils import load_credentials
import gradio as gr
from random import randrange
from svg_utils import decode_b64_string_to_pretty_xml, encode_svg_xml_to_b64_string
from watsonx_utils import wxEngine
from prompts import (
list_prompts,
get_prompt_template,
get_prompt_example,
get_prompt_primer,
get_prompt_uploadmode,
)
from data_images import svg_three_dots
from ibm_watsonx_ai.wml_client_error import WMLClientError
# Functions - input
def read_file(uploaded_file):
if uploaded_file:
svg_xml = open(uploaded_file, "r").read()
encoded_data_string = xml_string_to_data_string(svg_xml)
return (
encoded_data_string,
svg_xml,
html_img_preview(encoded_data_string),
)
def input_encoded_string_box_change(data_string: str):
# print(f"encoded_string_box_change:image_data:{data_string}")
svg_xml = decode_b64_string_to_pretty_xml(data_string)
return svg_xml, html_img_preview(data_string)
def input_xml_string_box_change(svg_xml: str):
data_string = xml_string_to_data_string(svg_xml)
return data_string, html_img_preview(data_string)
def xml_string_to_data_string(svg_xml: str):
b64 = encode_svg_xml_to_b64_string(svg_xml)
data_string = "data:image/svg+xml;base64," + b64
return data_string
def html_img_preview(data_string):
return f'<img src="{data_string}" width="100px" style="display: block; margin-left: auto; margin-right: auto;"/>'
image_placeholder = html_img_preview(svg_three_dots)
# Functions - watsonx
def wx_prompt_drop_change(prompt_template_name):
show_upload = get_prompt_uploadmode(prompt_template_name)
return {
wx_prompt_box: get_prompt_template(prompt_template_name),
wx_instructions_box: get_prompt_example(prompt_template_name),
wx_primer_box: get_prompt_primer(prompt_template_name),
upload_row: gr.Row(visible=show_upload),
upload_accordeon: gr.Accordion(visible=show_upload),
}
def wx_models_dropdown(wx_engine_state):
wx_engine = wx_engine_state
model_list = []
default_value = None
recommended_model = "ibm/granite-20b-code-instruct"
if wx_engine is not None:
model_list = wx_engine.list_models()
default_value = (
recommended_model if recommended_model in model_list else model_list[0]
)
# print(f"wx_models_dropdown:model_list:{model_list}")
# print(f"wx_models_dropdown:default_value:{default_value}")
return gr.Dropdown(
label="Model",
info=recommended_model + " recommended",
choices=model_list,
value=default_value,
)
def wx_connect_click(wx_engine_state, apiendpoint, apikey, projectid):
# if apiendpoint is not None and apikey is not None and projectid is not None:
try:
wx_engine_state = wxEngine(apiendpoint, apikey, projectid)
msg = "watsonx.ai sucessfully activated"
print(msg)
return (
wx_engine_state,
wx_models_dropdown(wx_engine_state),
gr.Accordion(open=False),
gr.Button(interactive=True),
gr.Textbox(msg),
)
except WMLClientError as ex:
template = "Exception {0} occurred: {1!r}"
msg = template.format(type(ex).__name__, ex.args)
print(msg)
return (
wx_engine_state,
[],
gr.Accordion(open=True),
gr.Button(interactive=False),
gr.Textbox(msg),
)
def prepare_prompt(
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
wx_status = "Done."
wx_engine = wx_engine_state
# get model specs
model_max_tokens = wx_engine.get_model_max_tokens(wx_model)
# Add "primer" at the end of the prompt
prompt = wx_prompt.format(svg=xml_string, instructions=wx_instructions) + wx_primer
# Test and alert if prompt is too long
prompt_nb_tokens = wx_engine.get_prompt_nb_tokens(prompt, wx_model)
if prompt_nb_tokens > model_max_tokens:
wx_status = f"Warning: prompt length ({prompt_nb_tokens}) is more than the model max tokens ({model_max_tokens}), and will be truncated. Please review your instructions."
print(wx_status)
# calculate max new token based on xml_string - or 500 when original string is too small
# note: prompt will be truncated if too long with GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS in generate()
max_new_tokens = max(500, len(xml_string))
return wx_status, max_new_tokens, prompt
def wx_generate(
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
wx_engine = wx_engine_state
wx_status, max_new_tokens, prompt = prepare_prompt(
wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string
)
wx_result = wx_primer + wx_engine.generate_text(
modelid=wx_model,
prompt=prompt,
max_new_tokens=max_new_tokens,
stop_sequences=["</svg>"],
)
print(f"wx_generate:wx_result:{wx_result}")
data_string = xml_string_to_data_string(wx_result)
return wx_status, data_string, wx_result, html_img_preview(data_string)
def wx_stream(
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
wx_engine = wx_engine_state
wx_status, max_new_tokens, prompt = prepare_prompt(
wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string
)
wx_result = wx_primer
# https://www.gradio.app/guides/streaming-outputs
wx_result_generator = wx_engine.generate_text(
modelid=wx_model,
prompt=prompt,
max_new_tokens=max_new_tokens,
stop_sequences=["</svg>"],
stream=True,
)
for chunk in wx_result_generator:
wx_result += chunk
status = f"Processing.{'.'*int(randrange(3))}"
yield status, status, wx_result, None
print(f"wx_stream:wx_result:{wx_result}")
data_string = xml_string_to_data_string(wx_result)
yield wx_status, data_string, wx_result, html_img_preview(data_string)
# Functions - output
def output_xml_string_box_change(xml_string):
data_string = xml_string_to_data_string(xml_string)
return (
data_string,
xml_string,
html_img_preview(data_string),
)
# APP layout
with gr.Blocks(theme = "Zarkel/IBM_Carbon_Theme") as demo:
gr.Markdown("# SVG editor")
gr.Markdown(
"""### Get started:
- Create a new SVG: Enter a description in the 'Instructions' box and click 'Submit'
- Modify an existing SVG: Upload an SVG file, or paste an image string or SVG XML, then Select the prompt template 'Modify SVG', enter a change instruction in the 'Instructions' box and click 'Submit'"""
)
# init state - note gr.State() initial value must be deep-copyable - my wx_engine class is not
wx_engine_state = gr.State(None)
with gr.Row(): # main UI
with gr.Column(scale=0): # watsonx setup
# prompt template selection
prompt_template_names = list_prompts()
default_prompt_template_name = prompt_template_names[0]
wx_prompt_drop = gr.Dropdown(
label="Action",
choices=prompt_template_names,
value=default_prompt_template_name,
)
# credentials
# load env variables
status_unused, env_apiendpoint, env_apikey, env_projectid = (
load_credentials()
)
with gr.Accordion("Credentials", open=True) as credentials_accordeon:
wx_creds_endpoint = gr.Textbox(
label="Endpoint",
value=env_apiendpoint or "https://us-south.ml.cloud.ibm.com",
max_lines=1,
)
wx_creds_apikey = gr.Textbox(
label="API key", value=env_apikey, max_lines=1
)
wx_creds_projectid = gr.Textbox(
label="Project id", value=env_projectid, max_lines=1
)
wx_connect_btn = gr.Button("Connect")
# model
wx_models_drop = wx_models_dropdown(None)
# prompt text and primer
wx_prompt_box = gr.Textbox(
info="Text",
show_label=False,
max_lines=5,
value=get_prompt_template(prompt_template_names[0]),
)
wx_primer_box = gr.Textbox(
info="Primer",
show_label=False,
max_lines=2,
value=get_prompt_primer(prompt_template_names[0]),
)
with gr.Column(): # main pane
# Upload
with gr.Row(
visible=get_prompt_uploadmode(default_prompt_template_name)
) as upload_row:
# Upload an .svg file
input_file = gr.File(scale=0, label="Upload an SVG file")
# original preview
input_svg_preview = gr.HTML(image_placeholder)
# decoded SVG XML
input_xml_string_box = gr.Textbox(
label="Input SVG XML",
lines=7,
max_lines=7,
show_copy_button=True,
scale=3,
)
with gr.Accordion(
label="Input encoded string",
open=False,
visible=get_prompt_uploadmode(default_prompt_template_name),
) as upload_accordeon:
# Encoded image string
input_encoded_string_box = gr.Textbox(
label="Image string",
info="data:image/svg+xml;base64,...",
lines=7,
max_lines=7,
show_copy_button=True,
scale=3,
container=False,
)
# modification
with gr.Row():
wx_instructions_box = gr.Textbox(
label="Instructions",
scale=3,
value=get_prompt_example(prompt_template_names[0]),
show_copy_button=True,
)
wx_generate_btn = gr.Button("↓Generate↓", scale=0, interactive=False)
output_svg_preview = gr.HTML(image_placeholder)
output_xml_string_box = gr.Textbox(
label="Result SVG XML",
lines=7,
max_lines=7,
scale=3,
show_copy_button=True,
)
with gr.Accordion(label="Result encoded string", open=False):
output_encoded_string_box = gr.Textbox(
label="Image string",
info="data:image/svg+xml;base64,...",
lines=7,
max_lines=7,
show_copy_button=True,
scale=3,
container=False,
)
wx_status_box = gr.Markdown("Status")
# Map controls to functions
wx_prompt_drop.input(
fn=wx_prompt_drop_change,
inputs=wx_prompt_drop,
outputs=[
wx_prompt_box,
wx_instructions_box,
wx_primer_box,
upload_row,
upload_accordeon,
],
)
wx_connect_btn.click(
fn=wx_connect_click,
inputs=[
wx_engine_state,
wx_creds_endpoint,
wx_creds_apikey,
wx_creds_projectid,
],
outputs=[
wx_engine_state,
wx_models_drop,
credentials_accordeon,
wx_generate_btn,
wx_status_box,
],
)
input_file.upload(
fn=read_file,
inputs=input_file,
outputs=[
input_encoded_string_box,
input_xml_string_box,
input_svg_preview,
],
)
input_encoded_string_box.input(
fn=input_encoded_string_box_change,
inputs=[input_encoded_string_box],
outputs=[input_xml_string_box, input_svg_preview],
)
input_xml_string_box.input(
fn=input_xml_string_box_change,
inputs=[input_xml_string_box],
outputs=[input_encoded_string_box, input_svg_preview],
)
wx_generate_btn.click(
fn=wx_stream,
inputs=[
wx_engine_state,
wx_models_drop,
wx_prompt_box,
wx_instructions_box,
wx_primer_box,
input_xml_string_box,
],
outputs=[
wx_status_box,
output_encoded_string_box,
output_xml_string_box,
output_svg_preview,
],
api_name="wx_generate",
)
output_xml_string_box.input(
fn=output_xml_string_box_change,
inputs=[output_xml_string_box],
outputs=[
output_encoded_string_box,
output_xml_string_box,
output_svg_preview,
],
)
# Main
if __name__ == "__main__":
demo.launch()