Spaces:
Running
Running
# Remi Serra 202407 | |
from env_utils import load_credentials | |
import gradio as gr | |
from random import randrange | |
from svg_utils import decode_b64_string_to_pretty_xml, encode_svg_xml_to_b64_string | |
from watsonx_utils import wxEngine | |
from prompts import ( | |
list_prompts, | |
get_prompt_template, | |
get_prompt_example, | |
get_prompt_primer, | |
get_prompt_uploadmode, | |
) | |
from data_images import svg_three_dots | |
from ibm_watsonx_ai.wml_client_error import WMLClientError | |
# Functions - input | |
def read_file(uploaded_file): | |
if uploaded_file: | |
svg_xml = open(uploaded_file, "r").read() | |
encoded_data_string = xml_string_to_data_string(svg_xml) | |
return ( | |
encoded_data_string, | |
svg_xml, | |
html_img_preview(encoded_data_string), | |
) | |
def input_encoded_string_box_change(data_string: str): | |
# print(f"encoded_string_box_change:image_data:{data_string}") | |
svg_xml = decode_b64_string_to_pretty_xml(data_string) | |
return svg_xml, html_img_preview(data_string) | |
def input_xml_string_box_change(svg_xml: str): | |
data_string = xml_string_to_data_string(svg_xml) | |
return data_string, html_img_preview(data_string) | |
def xml_string_to_data_string(svg_xml: str): | |
b64 = encode_svg_xml_to_b64_string(svg_xml) | |
data_string = "data:image/svg+xml;base64," + b64 | |
return data_string | |
def html_img_preview(data_string): | |
return f'<img src="{data_string}" width="100px" style="display: block; margin-left: auto; margin-right: auto;"/>' | |
image_placeholder = html_img_preview(svg_three_dots) | |
# Functions - watsonx | |
def wx_prompt_drop_change(prompt_template_name): | |
show_upload = get_prompt_uploadmode(prompt_template_name) | |
return { | |
wx_prompt_box: get_prompt_template(prompt_template_name), | |
wx_instructions_box: get_prompt_example(prompt_template_name), | |
wx_primer_box: get_prompt_primer(prompt_template_name), | |
upload_row: gr.Row(visible=show_upload), | |
upload_accordeon: gr.Accordion(visible=show_upload), | |
} | |
def wx_models_dropdown(wx_engine_state): | |
wx_engine = wx_engine_state | |
model_list = [] | |
default_value = None | |
recommended_model = "ibm/granite-20b-code-instruct" | |
if wx_engine is not None: | |
model_list = wx_engine.list_models() | |
default_value = ( | |
recommended_model if recommended_model in model_list else model_list[0] | |
) | |
# print(f"wx_models_dropdown:model_list:{model_list}") | |
# print(f"wx_models_dropdown:default_value:{default_value}") | |
return gr.Dropdown( | |
label="Model", | |
info=recommended_model + " recommended", | |
choices=model_list, | |
value=default_value, | |
) | |
def wx_connect_click(wx_engine_state, apiendpoint, apikey, projectid): | |
# if apiendpoint is not None and apikey is not None and projectid is not None: | |
try: | |
wx_engine_state = wxEngine(apiendpoint, apikey, projectid) | |
msg = "watsonx.ai sucessfully activated" | |
print(msg) | |
return ( | |
wx_engine_state, | |
wx_models_dropdown(wx_engine_state), | |
gr.Accordion(open=False), | |
gr.Button(interactive=True), | |
gr.Textbox(msg), | |
) | |
except WMLClientError as ex: | |
template = "Exception {0} occurred: {1!r}" | |
msg = template.format(type(ex).__name__, ex.args) | |
print(msg) | |
return ( | |
wx_engine_state, | |
[], | |
gr.Accordion(open=True), | |
gr.Button(interactive=False), | |
gr.Textbox(msg), | |
) | |
def prepare_prompt( | |
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" | |
): | |
wx_status = "Done." | |
wx_engine = wx_engine_state | |
# get model specs | |
model_max_tokens = wx_engine.get_model_max_tokens(wx_model) | |
# Add "primer" at the end of the prompt | |
prompt = wx_prompt.format(svg=xml_string, instructions=wx_instructions) + wx_primer | |
# Test and alert if prompt is too long | |
prompt_nb_tokens = wx_engine.get_prompt_nb_tokens(prompt, wx_model) | |
if prompt_nb_tokens > model_max_tokens: | |
wx_status = f"Warning: prompt length ({prompt_nb_tokens}) is more than the model max tokens ({model_max_tokens}), and will be truncated. Please review your instructions." | |
print(wx_status) | |
# calculate max new token based on xml_string - or 500 when original string is too small | |
# note: prompt will be truncated if too long with GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS in generate() | |
max_new_tokens = max(500, len(xml_string)) | |
return wx_status, max_new_tokens, prompt | |
def wx_generate( | |
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" | |
): | |
wx_engine = wx_engine_state | |
wx_status, max_new_tokens, prompt = prepare_prompt( | |
wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string | |
) | |
wx_result = wx_primer + wx_engine.generate_text( | |
modelid=wx_model, | |
prompt=prompt, | |
max_new_tokens=max_new_tokens, | |
stop_sequences=["</svg>"], | |
) | |
print(f"wx_generate:wx_result:{wx_result}") | |
data_string = xml_string_to_data_string(wx_result) | |
return wx_status, data_string, wx_result, html_img_preview(data_string) | |
def wx_stream( | |
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" | |
): | |
wx_engine = wx_engine_state | |
wx_status, max_new_tokens, prompt = prepare_prompt( | |
wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string | |
) | |
wx_result = wx_primer | |
# https://www.gradio.app/guides/streaming-outputs | |
wx_result_generator = wx_engine.generate_text( | |
modelid=wx_model, | |
prompt=prompt, | |
max_new_tokens=max_new_tokens, | |
stop_sequences=["</svg>"], | |
stream=True, | |
) | |
for chunk in wx_result_generator: | |
wx_result += chunk | |
status = f"Processing.{'.'*int(randrange(3))}" | |
yield status, status, wx_result, None | |
print(f"wx_stream:wx_result:{wx_result}") | |
data_string = xml_string_to_data_string(wx_result) | |
yield wx_status, data_string, wx_result, html_img_preview(data_string) | |
# Functions - output | |
def output_xml_string_box_change(xml_string): | |
data_string = xml_string_to_data_string(xml_string) | |
return ( | |
data_string, | |
xml_string, | |
html_img_preview(data_string), | |
) | |
# APP layout | |
with gr.Blocks(theme = "Zarkel/IBM_Carbon_Theme") as demo: | |
gr.Markdown("# SVG editor") | |
gr.Markdown( | |
"""### Get started: | |
- Create a new SVG: Enter a description in the 'Instructions' box and click 'Submit' | |
- Modify an existing SVG: Upload an SVG file, or paste an image string or SVG XML, then Select the prompt template 'Modify SVG', enter a change instruction in the 'Instructions' box and click 'Submit'""" | |
) | |
# init state - note gr.State() initial value must be deep-copyable - my wx_engine class is not | |
wx_engine_state = gr.State(None) | |
with gr.Row(): # main UI | |
with gr.Column(scale=0): # watsonx setup | |
# prompt template selection | |
prompt_template_names = list_prompts() | |
default_prompt_template_name = prompt_template_names[0] | |
wx_prompt_drop = gr.Dropdown( | |
label="Action", | |
choices=prompt_template_names, | |
value=default_prompt_template_name, | |
) | |
# credentials | |
# load env variables | |
status_unused, env_apiendpoint, env_apikey, env_projectid = ( | |
load_credentials() | |
) | |
with gr.Accordion("Credentials", open=True) as credentials_accordeon: | |
wx_creds_endpoint = gr.Textbox( | |
label="Endpoint", | |
value=env_apiendpoint or "https://us-south.ml.cloud.ibm.com", | |
max_lines=1, | |
) | |
wx_creds_apikey = gr.Textbox( | |
label="API key", value=env_apikey, max_lines=1 | |
) | |
wx_creds_projectid = gr.Textbox( | |
label="Project id", value=env_projectid, max_lines=1 | |
) | |
wx_connect_btn = gr.Button("Connect") | |
# model | |
wx_models_drop = wx_models_dropdown(None) | |
# prompt text and primer | |
wx_prompt_box = gr.Textbox( | |
info="Text", | |
show_label=False, | |
max_lines=5, | |
value=get_prompt_template(prompt_template_names[0]), | |
) | |
wx_primer_box = gr.Textbox( | |
info="Primer", | |
show_label=False, | |
max_lines=2, | |
value=get_prompt_primer(prompt_template_names[0]), | |
) | |
with gr.Column(): # main pane | |
# Upload | |
with gr.Row( | |
visible=get_prompt_uploadmode(default_prompt_template_name) | |
) as upload_row: | |
# Upload an .svg file | |
input_file = gr.File(scale=0, label="Upload an SVG file") | |
# original preview | |
input_svg_preview = gr.HTML(image_placeholder) | |
# decoded SVG XML | |
input_xml_string_box = gr.Textbox( | |
label="Input SVG XML", | |
lines=7, | |
max_lines=7, | |
show_copy_button=True, | |
scale=3, | |
) | |
with gr.Accordion( | |
label="Input encoded string", | |
open=False, | |
visible=get_prompt_uploadmode(default_prompt_template_name), | |
) as upload_accordeon: | |
# Encoded image string | |
input_encoded_string_box = gr.Textbox( | |
label="Image string", | |
info="data:image/svg+xml;base64,...", | |
lines=7, | |
max_lines=7, | |
show_copy_button=True, | |
scale=3, | |
container=False, | |
) | |
# modification | |
with gr.Row(): | |
wx_instructions_box = gr.Textbox( | |
label="Instructions", | |
scale=3, | |
value=get_prompt_example(prompt_template_names[0]), | |
show_copy_button=True, | |
) | |
wx_generate_btn = gr.Button("↓Generate↓", scale=0, interactive=False) | |
output_svg_preview = gr.HTML(image_placeholder) | |
output_xml_string_box = gr.Textbox( | |
label="Result SVG XML", | |
lines=7, | |
max_lines=7, | |
scale=3, | |
show_copy_button=True, | |
) | |
with gr.Accordion(label="Result encoded string", open=False): | |
output_encoded_string_box = gr.Textbox( | |
label="Image string", | |
info="data:image/svg+xml;base64,...", | |
lines=7, | |
max_lines=7, | |
show_copy_button=True, | |
scale=3, | |
container=False, | |
) | |
wx_status_box = gr.Markdown("Status") | |
# Map controls to functions | |
wx_prompt_drop.input( | |
fn=wx_prompt_drop_change, | |
inputs=wx_prompt_drop, | |
outputs=[ | |
wx_prompt_box, | |
wx_instructions_box, | |
wx_primer_box, | |
upload_row, | |
upload_accordeon, | |
], | |
) | |
wx_connect_btn.click( | |
fn=wx_connect_click, | |
inputs=[ | |
wx_engine_state, | |
wx_creds_endpoint, | |
wx_creds_apikey, | |
wx_creds_projectid, | |
], | |
outputs=[ | |
wx_engine_state, | |
wx_models_drop, | |
credentials_accordeon, | |
wx_generate_btn, | |
wx_status_box, | |
], | |
) | |
input_file.upload( | |
fn=read_file, | |
inputs=input_file, | |
outputs=[ | |
input_encoded_string_box, | |
input_xml_string_box, | |
input_svg_preview, | |
], | |
) | |
input_encoded_string_box.input( | |
fn=input_encoded_string_box_change, | |
inputs=[input_encoded_string_box], | |
outputs=[input_xml_string_box, input_svg_preview], | |
) | |
input_xml_string_box.input( | |
fn=input_xml_string_box_change, | |
inputs=[input_xml_string_box], | |
outputs=[input_encoded_string_box, input_svg_preview], | |
) | |
wx_generate_btn.click( | |
fn=wx_stream, | |
inputs=[ | |
wx_engine_state, | |
wx_models_drop, | |
wx_prompt_box, | |
wx_instructions_box, | |
wx_primer_box, | |
input_xml_string_box, | |
], | |
outputs=[ | |
wx_status_box, | |
output_encoded_string_box, | |
output_xml_string_box, | |
output_svg_preview, | |
], | |
api_name="wx_generate", | |
) | |
output_xml_string_box.input( | |
fn=output_xml_string_box_change, | |
inputs=[output_xml_string_box], | |
outputs=[ | |
output_encoded_string_box, | |
output_xml_string_box, | |
output_svg_preview, | |
], | |
) | |
# Main | |
if __name__ == "__main__": | |
demo.launch() | |