File size: 4,542 Bytes
31fd16b
a4d7b31
 
 
31fd16b
a4d7b31
 
 
 
31fd16b
a4d7b31
 
 
 
 
 
 
 
 
 
 
 
31fd16b
 
a4d7b31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import shutil
import zipfile
import tensorflow as tf
import pandas as pd
import pathlib
import PIL.Image
import os
import subprocess

def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
    w, h = image.size
    if w == h:
        return image
    elif w > h:
        new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
        new_image.paste(image, (0, (w - h) // 2))
        return new_image
    else:
        new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
        new_image.paste(image, ((h - w) // 2, 0))
        return new_image


class ModelTrainer:
    def __init__(self):
        self.training_pictures = []
        self.training_model = None

    def unzip_file(self, zip_file_path):
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            extracted_path = zip_file_path.replace('.zip', '')
            zip_ref.extractall(extracted_path)
            file_names = zip_ref.namelist()
            for file_name in file_names:
                if file_name.endswith(('.jpeg', '.jpg', '.png')):
                    self.training_pictures.append(f'{extracted_path}/{file_name}')

    def train(self, pretrained_model_name_or_path: str, instance_images: list | None):
        output_model_name = 'a-xyz-model'
        resolution = 512
        repo_dir = pathlib.Path(__file__).parent
        subdirs = ['train-instance', 'train-class', 'experiments']
        dir_paths = []

        for subdir in subdirs:
            dir_path = repo_dir / subdir / output_model_name
            dir_paths.append(dir_path)
            shutil.rmtree(dir_path, ignore_errors=True)
            os.makedirs(dir_path, exist_ok=True)

        instance_data_dir, class_data_dir, output_dir = dir_paths

        for i, temp_path in enumerate(instance_images):
            image = PIL.Image.open(temp_path.name)
            image = pad_image(image)
            image = image.resize((resolution, resolution))
            image = image.convert('RGB')
            out_path = instance_data_dir / f'{i:03d}.jpg'
            image.save(out_path, format='JPEG', quality=100)

        command = [
            'python', '-u',
            'train_dreambooth_cloneofsimo_lora.py',
            '--pretrained_model_name_or_path', pretrained_model_name_or_path,
            '--instance_data_dir', instance_data_dir,
            '--class_data_dir', class_data_dir,
            '--resolution', '768',
            '--output_dir', output_dir,
            '--instance_prompt', 'a photo of a pwsm dog',
            '--with_prior_preservation',
            '--class_prompt', 'a dog',
            '--prior_loss_weight', '1.0',
            '--num_class_images', '100',
            '--learning_rate', '0.0004',
            '--train_batch_size', '1',
            '--sample_batch_size', '1',
            '--max_train_steps', '400',
            '--gradient_accumulation_steps', '1',
            '--gradient_checkpointing',
            '--train_text_encoder',
            '--learning_rate_text', '5e-6',
            '--save_steps', '100',
            '--seed', '1337',
            '--lr_scheduler', 'constant',
            '--lr_warmup_steps', '0'
        ]

        result = subprocess.run(command)
        return result

    def generate_picture(self, row):
        num_of_training_steps, learning_rate, checkpoint_steps, abc = row
        return f'Picture generated for num_of_training_steps: {num_of_training_steps}, learning_rate: {learning_rate}, checkpoint_steps: {checkpoint_steps}'

    def generate_pictures(self, csv_input):
        csv = pd.read_csv(csv_input.name)
        result = []
        for index, row in csv.iterrows():
            result.append(self.generate_picture(row))
        return "\n".join(str(item) for item in result)

loader = ModelTrainer()

with gr.Blocks() as demo:
    with gr.Box():
        instance_images = gr.Files(label='Instance images')
        pretrained_model_name_or_path = gr.inputs.Textbox(lines=1, label='pretrained_model_name_or_path', default='stabilityai/stable-diffusion-2-1')
        output_message = gr.Markdown()
        train_button = gr.Button('Train')
        train_button.click(fn=loader.train, inputs=[pretrained_model_name_or_path, instance_images], outputs=[output_message])
    with gr.Box():
        csv_input = gr.inputs.File(label='CSV File')
        output_message2 = gr.Markdown()
        generate_button = gr.Button('Generate Pictures from CSV')
        generate_button.click(fn=loader.generate_pictures, inputs=[csv_input], outputs=[output_message2])

demo.launch()