kohya_ss / kohya_gui /blip2_caption_gui.py
zengxi123's picture
Upload folder using huggingface_hub
fb83c5b verified
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import gradio as gr
import os
from .common_gui import get_folder_path, scriptdir, list_dirs
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
def load_model():
# Set the device to GPU if available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the BLIP2 processor
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
# Initialize the BLIP2 model
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
)
# Move the model to the specified device
model.to(device)
return processor, model, device
def get_images_in_directory(directory_path):
"""
Returns a list of image file paths found in the provided directory path.
Parameters:
- directory_path: A string representing the path to the directory to search for images.
Returns:
- A list of strings, where each string is the full path to an image file found in the specified directory.
"""
import os
# List of common image file extensions to look for
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]
# Generate a list of image file paths in the directory
image_files = [
# constructs the full path to the file
os.path.join(directory_path, file)
# lists all files and directories in the given path
for file in os.listdir(directory_path)
# gets the file extension in lowercase
if os.path.splitext(file)[1].lower() in image_extensions
]
# Return the list of image file paths
return image_files
def generate_caption(
file_list,
processor,
model,
device,
caption_file_ext=".txt",
num_beams=5,
repetition_penalty=1.5,
length_penalty=1.2,
max_new_tokens=40,
min_new_tokens=20,
do_sample=True,
temperature=1.0,
top_p=0.0,
):
"""
Fetches and processes each image in file_list, generates captions based on the image, and writes the generated captions to a file.
Parameters:
- file_list: A list of file paths pointing to the images to be captioned.
- processor: The preprocessor for the BLIP2 model.
- model: The BLIP2 model to be used for generating captions.
- device: The device on which the computation is performed.
- extension: The extension for the output text files.
- num_beams: Number of beams for beam search. Default: 5.
- repetition_penalty: Penalty for repeating tokens. Default: 1.5.
- length_penalty: Penalty for sentence length. Default: 1.2.
- max_new_tokens: Maximum number of new tokens to generate. Default: 40.
- min_new_tokens: Minimum number of new tokens to generate. Default: 20.
"""
for file_path in file_list:
image = Image.open(file_path)
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
if top_p == 0.0:
generated_ids = model.generate(
**inputs,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
)
else:
generated_ids = model.generate(
**inputs,
do_sample=do_sample,
top_p=top_p,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
temperature=temperature,
)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True
)[0].strip()
# Construct the output file path by replacing the original file extension with the specified extension
output_file_path = os.path.splitext(file_path)[0] + caption_file_ext
# Write the generated text to the output file
with open(output_file_path, "w", encoding="utf-8") as output_file:
output_file.write(generated_text)
# Log the image file path with a message about the fact that the caption was generated
log.info(f"{file_path} caption was generated")
def caption_images_beam_search(
directory_path,
num_beams,
repetition_penalty,
length_penalty,
min_new_tokens,
max_new_tokens,
caption_file_ext,
):
"""
Captions all images in the specified directory using the provided prompt.
Parameters:
- directory_path: A string representing the path to the directory containing the images to be captioned.
"""
log.info("BLIP2 captionning beam...")
if not os.path.isdir(directory_path):
log.error(f"Directory {directory_path} does not exist.")
return
processor, model, device = load_model()
image_files = get_images_in_directory(directory_path)
generate_caption(
file_list=image_files,
processor=processor,
model=model,
device=device,
num_beams=int(num_beams),
repetition_penalty=float(repetition_penalty),
length_penalty=length_penalty,
min_new_tokens=int(min_new_tokens),
max_new_tokens=int(max_new_tokens),
caption_file_ext=caption_file_ext,
)
def caption_images_nucleus(
directory_path,
do_sample,
temperature,
top_p,
min_new_tokens,
max_new_tokens,
caption_file_ext,
):
"""
Captions all images in the specified directory using the provided prompt.
Parameters:
- directory_path: A string representing the path to the directory containing the images to be captioned.
"""
log.info("BLIP2 captionning nucleus...")
if not os.path.isdir(directory_path):
log.error(f"Directory {directory_path} does not exist.")
return
processor, model, device = load_model()
image_files = get_images_in_directory(directory_path)
generate_caption(
file_list=image_files,
processor=processor,
model=model,
device=device,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
min_new_tokens=int(min_new_tokens),
max_new_tokens=int(max_new_tokens),
caption_file_ext=caption_file_ext,
)
def gradio_blip2_caption_gui_tab(headless=False, directory_path=None):
from .common_gui import create_refresh_button
directory_path = (
directory_path
if directory_path is not None
else os.path.join(scriptdir, "data")
)
current_train_dir = directory_path
def list_train_dirs(path):
nonlocal current_train_dir
current_train_dir = path
return list(list_dirs(path))
with gr.Tab("BLIP2 Captioning"):
gr.Markdown(
"This utility uses BLIP2 to caption files for each image in a folder."
)
with gr.Group(), gr.Row():
directory_path_dir = gr.Dropdown(
label="Image folder to caption (containing the images to caption)",
choices=[""] + list_train_dirs(directory_path),
value="",
interactive=True,
allow_custom_value=True,
)
create_refresh_button(
directory_path_dir,
lambda: None,
lambda: {"choices": list_train_dirs(current_train_dir)},
"open_folder_small",
)
button_directory_path_dir_input = gr.Button(
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_directory_path_dir_input.click(
get_folder_path,
outputs=directory_path_dir,
show_progress=False,
)
with gr.Group(), gr.Row():
min_new_tokens = gr.Number(
value=20,
label="Min new tokens",
interactive=True,
step=1,
minimum=5,
maximum=300,
)
max_new_tokens = gr.Number(
value=40,
label="Max new tokens",
interactive=True,
step=1,
minimum=5,
maximum=300,
)
caption_file_ext = gr.Textbox(
label="Caption file extension",
placeholder="Extension for caption file (e.g., .caption, .txt)",
value=".txt",
interactive=True,
)
with gr.Row():
with gr.Tab("Beam search"):
with gr.Row():
num_beams = gr.Slider(
minimum=1,
maximum=16,
value=16,
step=1,
interactive=True,
label="Number of beams",
)
len_penalty = gr.Slider(
minimum=-1.0,
maximum=2.0,
value=1.0,
step=0.2,
interactive=True,
label="Length Penalty",
info="increase for longer sequence",
)
rep_penalty = gr.Slider(
minimum=1.0,
maximum=5.0,
value=1.5,
step=0.5,
interactive=True,
label="Repeat Penalty",
info="larger value prevents repetition",
)
caption_button_beam = gr.Button(
value="Caption images", interactive=True, variant="primary"
)
caption_button_beam.click(
caption_images_beam_search,
inputs=[
directory_path_dir,
num_beams,
rep_penalty,
len_penalty,
min_new_tokens,
max_new_tokens,
caption_file_ext,
],
)
with gr.Tab("Nucleus sampling"):
with gr.Row():
do_sample = gr.Checkbox(label="Sample", value=True)
temperature = gr.Slider(
minimum=0.5,
maximum=1.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
info="used with nucleus sampling",
)
top_p = gr.Slider(
minimum=0,
maximum=1,
value=0.9,
step=0.1,
interactive=True,
label="Top_p",
)
caption_button_nucleus = gr.Button(
value="Caption images", interactive=True, variant="primary"
)
caption_button_nucleus.click(
caption_images_nucleus,
inputs=[
directory_path_dir,
do_sample,
temperature,
top_p,
min_new_tokens,
max_new_tokens,
caption_file_ext,
],
)