kohya_ss / kohya_gui /manual_caption_gui.py
zengxi123's picture
Upload folder using huggingface_hub
fb83c5b verified
import gradio as gr
from easygui import msgbox, boolbox
from .common_gui import get_folder_path, scriptdir, list_dirs
from math import ceil
import os
import re
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
IMAGES_TO_SHOW = 5
IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".webp", ".bmp")
auto_save = True
def _get_caption_path(image_file, images_dir, caption_ext):
"""
Returns the expected path of a caption file for a given image path
"""
caption_file_name = os.path.splitext(image_file)[0] + caption_ext
caption_file_path = os.path.join(images_dir, caption_file_name)
return caption_file_path
def _get_quick_tags(quick_tags_text):
"""
Gets a list of tags from the quick tags text box
"""
quick_tags = [t.strip() for t in quick_tags_text.split(",") if t.strip()]
quick_tags_set = set(quick_tags)
return quick_tags, quick_tags_set
def _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set):
"""
Updates a list of caption checkboxes to show possible tags and tags
already included in the caption
"""
caption_tags_have = [c.strip() for c in caption.split(",") if c.strip()]
caption_tags_unique = [t for t in caption_tags_have if t not in quick_tags_set]
caption_tags_all = quick_tags + caption_tags_unique
return gr.CheckboxGroup(choices=caption_tags_all, value=caption_tags_have)
def paginate_go(page, max_page):
try:
page = float(page)
except:
msgbox(f"Invalid page num: {page}")
return
return paginate(page, max_page, 0)
def paginate(page, max_page, page_change):
return int(max(min(page + page_change, max_page), 1))
def save_caption(caption, caption_ext, image_file, images_dir):
caption_path = _get_caption_path(image_file, images_dir, caption_ext)
with open(caption_path, "w+", encoding="utf-8") as f:
f.write(caption)
log.info(f"Wrote captions to {caption_path}")
def update_quick_tags(quick_tags_text, *image_caption_texts):
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text)
return [
_get_tag_checkbox_updates(caption, quick_tags, quick_tags_set)
for caption in image_caption_texts
]
def update_image_caption(
quick_tags_text, caption, image_file, images_dir, caption_ext, auto_save
):
if auto_save:
save_caption(caption, caption_ext, image_file, images_dir)
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text)
return _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set)
def update_image_tags(
quick_tags_text,
selected_tags,
image_file,
images_dir,
caption_ext,
auto_save,
):
# Try to determine order by quick tags
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text)
selected_tags_set = set(selected_tags)
output_tags = [t for t in quick_tags if t in selected_tags_set] + [
t for t in selected_tags if t not in quick_tags_set
]
caption = ", ".join(output_tags)
if auto_save:
save_caption(caption, caption_ext, image_file, images_dir)
return caption
def import_tags_from_captions(
images_dir, caption_ext, quick_tags_text, ignore_load_tags_word_count
):
"""
Scans images directory for all available captions and loads all tags
under a specified word count into the quick tags box
"""
def empty_return():
return gr.Text()
# Check for images_dir
if not images_dir:
msgbox("Image folder is missing...")
return empty_return()
if not os.path.exists(images_dir):
msgbox("Image folder does not exist...")
return empty_return()
if not caption_ext:
msgbox("Please provide an extension for the caption files.")
return empty_return()
if quick_tags_text:
if not boolbox(
f"Are you sure you wish to overwrite the current quick tags?",
choices=("Yes", "No"),
):
return empty_return()
images_list = os.listdir(images_dir)
image_files = [f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)]
# Use a set for lookup but store order with list
tags = []
tags_set = set()
for image_file in image_files:
caption_file_path = _get_caption_path(image_file, images_dir, caption_ext)
if os.path.exists(caption_file_path):
with open(caption_file_path, "r", encoding="utf-8") as f:
caption = f.read()
for tag in caption.split(","):
tag = tag.strip()
tag_key = tag.lower()
if not tag_key in tags_set:
# Ignore extra spaces
total_words = len(re.findall(r"\s+", tag)) + 1
if total_words <= ignore_load_tags_word_count:
tags.append(tag)
tags_set.add(tag_key)
return ", ".join(tags)
def load_images(images_dir, caption_ext, loaded_images_dir, page, max_page):
"""
Triggered to load a new set of images from the folder to caption
This loads in the total expected image counts to be used by pagination
before running update_images
"""
def empty_return():
return [loaded_images_dir, page, max_page]
# Check for images_dir
if not images_dir:
msgbox("Image folder is missing...")
return empty_return()
if not os.path.exists(images_dir):
msgbox("Image folder does not exist...")
return empty_return()
if not caption_ext:
msgbox("Please provide an extension for the caption files.")
return empty_return()
# Load Images
images_list = os.listdir(images_dir)
total_images = len(
[True for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)]
)
return [images_dir, 1, ceil(total_images / IMAGES_TO_SHOW)]
def update_images(
images_dir,
caption_ext,
quick_tags_text,
page,
):
"""
Updates the displayed images and captions from the current page and
image directory
"""
# Load Images
images_list = os.listdir(images_dir)
image_files = [f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)]
# Quick tags
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text or "")
# Display Images
rows = []
image_paths = []
captions = []
tag_checkbox_groups = []
start_index = (int(page) - 1) * IMAGES_TO_SHOW
for i in range(IMAGES_TO_SHOW):
image_index = start_index + i
show_row = image_index < len(image_files)
image_path = None
caption = ""
tag_checkboxes = None
if show_row:
image_file = image_files[image_index]
image_path = os.path.join(images_dir, image_file)
caption_file_path = _get_caption_path(image_file, images_dir, caption_ext)
if os.path.exists(caption_file_path):
with open(caption_file_path, "r", encoding="utf-8") as f:
caption = f.read()
tag_checkboxes = _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set)
rows.append(gr.Row(visible=show_row))
image_paths.append(image_path)
captions.append(caption)
tag_checkbox_groups.append(tag_checkboxes)
return (
rows
+ image_paths
+ image_paths
+ captions
+ tag_checkbox_groups
+ [gr.Row(visible=True), gr.Row(visible=True)]
)
# Gradio UI
def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
from .common_gui import create_refresh_button
default_images_dir = (
default_images_dir
if default_images_dir is not None
else os.path.join(scriptdir, "data")
)
current_images_dir = default_images_dir
# Function to list directories
def list_images_dirs(path):
# Allows list_images_dirs to modify current_images_dir outside of this function
nonlocal current_images_dir
current_images_dir = path
return list(list_dirs(path))
with gr.Tab("Manual Captioning"):
gr.Markdown("This utility allows quick captioning and tagging of images.")
page = gr.Number(value=-1, visible=False)
max_page = gr.Number(value=1, visible=False)
loaded_images_dir = gr.Text(visible=False)
with gr.Group(), gr.Row():
images_dir = gr.Dropdown(
label="Image folder to caption (containing the images to caption)",
choices=[""] + list_images_dirs(default_images_dir),
value="",
interactive=True,
allow_custom_value=True,
)
create_refresh_button(
images_dir,
lambda: None,
lambda: {"choices": list_images_dirs(current_images_dir)},
"open_folder_small",
)
folder_button = gr.Button(
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
folder_button.click(
get_folder_path,
outputs=images_dir,
show_progress=False,
)
load_images_button = gr.Button("Load", elem_id="open_folder")
caption_ext = gr.Dropdown(
label="Caption file extension",
choices=[".cap", ".caption", ".txt"],
value=".txt",
interactive=True,
allow_custom_value=True,
)
auto_save = gr.Checkbox(
label="Autosave", info="Options", value=True, interactive=True
)
images_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_images_dirs(path)),
inputs=images_dir,
outputs=images_dir,
show_progress=False,
)
# Caption Section
with gr.Group(), gr.Row():
quick_tags_text = gr.Textbox(
label="Quick Tags",
placeholder="Comma separated list of tags",
interactive=True,
)
import_tags_button = gr.Button("Import", elem_id="open_folder")
ignore_load_tags_word_count = gr.Slider(
minimum=1,
maximum=100,
value=3,
step=1,
label="Ignore Imported Tags Above Word Count",
interactive=True,
)
# Next/Prev section generator
def render_pagination():
gr.Button("< Prev", elem_id="open_folder").click(
paginate,
inputs=[page, max_page, gr.Number(value=-1, visible=False)],
outputs=[page],
)
page_count = gr.Label("Page 1", label="Page")
page_goto_text = gr.Textbox(
label="Goto page",
placeholder="Page Number",
interactive=True,
)
gr.Button("Go >", elem_id="open_folder").click(
paginate_go,
inputs=[page_goto_text, max_page],
outputs=[page],
)
gr.Button("Next >", elem_id="open_folder").click(
paginate,
inputs=[page, max_page, gr.Number(value=1, visible=False)],
outputs=[page],
)
return page_count
with gr.Row(visible=False) as pagination_row1:
page_count1 = render_pagination()
# Images section
image_rows = []
image_files = []
image_images = []
image_caption_texts = []
image_tag_checks = []
save_buttons = []
for _ in range(IMAGES_TO_SHOW):
with gr.Row(visible=False) as row:
image_file = gr.Text(visible=False)
image_files.append(image_file)
image_image = gr.Image(type="filepath")
image_images.append(image_image)
image_caption_text = gr.TextArea(
label="Captions",
placeholder="Input captions",
interactive=True,
)
image_caption_texts.append(image_caption_text)
tag_checkboxes = gr.CheckboxGroup([], label="Tags", interactive=True)
save_button = gr.Button(
"💾",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=False,
)
save_buttons.append(save_button)
# Caption text change
image_caption_text.input(
update_image_caption,
inputs=[
quick_tags_text,
image_caption_text,
image_file,
loaded_images_dir,
caption_ext,
auto_save,
],
outputs=tag_checkboxes,
)
# Quick tag check
tag_checkboxes.input(
update_image_tags,
inputs=[
quick_tags_text,
tag_checkboxes,
image_file,
loaded_images_dir,
caption_ext,
auto_save,
],
outputs=[image_caption_text],
)
# Save Button
save_button.click(
save_caption,
inputs=[
image_caption_text,
caption_ext,
image_file,
images_dir,
],
)
image_tag_checks.append(tag_checkboxes)
image_rows.append(row)
# Next/Prev Section
with gr.Row(visible=False) as pagination_row2:
page_count2 = render_pagination()
# Quick tag text update
quick_tags_text.change(
update_quick_tags,
inputs=[quick_tags_text] + image_caption_texts,
outputs=image_tag_checks,
)
# Import tags button
import_tags_button.click(
import_tags_from_captions,
inputs=[
loaded_images_dir,
caption_ext,
quick_tags_text,
ignore_load_tags_word_count,
],
outputs=quick_tags_text,
)
# Load Images button
load_images_button.click(
load_images,
inputs=[
images_dir,
caption_ext,
loaded_images_dir,
page,
max_page,
],
outputs=[loaded_images_dir, page, max_page],
)
# Update images shown when the update key changes
# This allows us to trigger a change from multiple
# sources (page, image_dir)
image_update_key = gr.Text(visible=False)
image_update_key.change(
update_images,
inputs=[loaded_images_dir, caption_ext, quick_tags_text, page],
outputs=image_rows
+ image_files
+ image_images
+ image_caption_texts
+ image_tag_checks
+ [pagination_row1, pagination_row2],
show_progress=False,
)
# Update the key on page and image dir change
listener_kwargs = {
"fn": lambda p, i: f"{p}-{i}",
"inputs": [page, loaded_images_dir],
"outputs": image_update_key,
}
page.change(**listener_kwargs)
loaded_images_dir.change(**listener_kwargs)
# Save buttons visibility
# (on auto-save on/off)
auto_save.change(
lambda auto_save: [gr.Button(visible=not auto_save)] * IMAGES_TO_SHOW,
inputs=auto_save,
outputs=save_buttons,
)
# Page Count
page.change(
lambda page, max_page: [f"Page {int(page)} / {int(max_page)}"] * 2,
inputs=[page, max_page],
outputs=[page_count1, page_count2],
show_progress=False,
)