Spaces:
Runtime error
Runtime error
File size: 4,542 Bytes
31fd16b a4d7b31 31fd16b a4d7b31 31fd16b a4d7b31 31fd16b a4d7b31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
import shutil
import zipfile
import tensorflow as tf
import pandas as pd
import pathlib
import PIL.Image
import os
import subprocess
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
w, h = image.size
if w == h:
return image
elif w > h:
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
new_image.paste(image, (0, (w - h) // 2))
return new_image
else:
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
new_image.paste(image, ((h - w) // 2, 0))
return new_image
class ModelTrainer:
def __init__(self):
self.training_pictures = []
self.training_model = None
def unzip_file(self, zip_file_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
extracted_path = zip_file_path.replace('.zip', '')
zip_ref.extractall(extracted_path)
file_names = zip_ref.namelist()
for file_name in file_names:
if file_name.endswith(('.jpeg', '.jpg', '.png')):
self.training_pictures.append(f'{extracted_path}/{file_name}')
def train(self, pretrained_model_name_or_path: str, instance_images: list | None):
output_model_name = 'a-xyz-model'
resolution = 512
repo_dir = pathlib.Path(__file__).parent
subdirs = ['train-instance', 'train-class', 'experiments']
dir_paths = []
for subdir in subdirs:
dir_path = repo_dir / subdir / output_model_name
dir_paths.append(dir_path)
shutil.rmtree(dir_path, ignore_errors=True)
os.makedirs(dir_path, exist_ok=True)
instance_data_dir, class_data_dir, output_dir = dir_paths
for i, temp_path in enumerate(instance_images):
image = PIL.Image.open(temp_path.name)
image = pad_image(image)
image = image.resize((resolution, resolution))
image = image.convert('RGB')
out_path = instance_data_dir / f'{i:03d}.jpg'
image.save(out_path, format='JPEG', quality=100)
command = [
'python', '-u',
'train_dreambooth_cloneofsimo_lora.py',
'--pretrained_model_name_or_path', pretrained_model_name_or_path,
'--instance_data_dir', instance_data_dir,
'--class_data_dir', class_data_dir,
'--resolution', '768',
'--output_dir', output_dir,
'--instance_prompt', 'a photo of a pwsm dog',
'--with_prior_preservation',
'--class_prompt', 'a dog',
'--prior_loss_weight', '1.0',
'--num_class_images', '100',
'--learning_rate', '0.0004',
'--train_batch_size', '1',
'--sample_batch_size', '1',
'--max_train_steps', '400',
'--gradient_accumulation_steps', '1',
'--gradient_checkpointing',
'--train_text_encoder',
'--learning_rate_text', '5e-6',
'--save_steps', '100',
'--seed', '1337',
'--lr_scheduler', 'constant',
'--lr_warmup_steps', '0'
]
result = subprocess.run(command)
return result
def generate_picture(self, row):
num_of_training_steps, learning_rate, checkpoint_steps, abc = row
return f'Picture generated for num_of_training_steps: {num_of_training_steps}, learning_rate: {learning_rate}, checkpoint_steps: {checkpoint_steps}'
def generate_pictures(self, csv_input):
csv = pd.read_csv(csv_input.name)
result = []
for index, row in csv.iterrows():
result.append(self.generate_picture(row))
return "\n".join(str(item) for item in result)
loader = ModelTrainer()
with gr.Blocks() as demo:
with gr.Box():
instance_images = gr.Files(label='Instance images')
pretrained_model_name_or_path = gr.inputs.Textbox(lines=1, label='pretrained_model_name_or_path', default='stabilityai/stable-diffusion-2-1')
output_message = gr.Markdown()
train_button = gr.Button('Train')
train_button.click(fn=loader.train, inputs=[pretrained_model_name_or_path, instance_images], outputs=[output_message])
with gr.Box():
csv_input = gr.inputs.File(label='CSV File')
output_message2 = gr.Markdown()
generate_button = gr.Button('Generate Pictures from CSV')
generate_button.click(fn=loader.generate_pictures, inputs=[csv_input], outputs=[output_message2])
demo.launch()
|