File size: 2,078 Bytes
c3a1897
 
 
 
 
 
 
 
 
 
 
 
eb902b3
 
c3a1897
 
 
44a0c32
 
 
 
c3a1897
 
44a0c32
 
 
c3a1897
 
 
 
44a0c32
 
c3a1897
 
 
 
eb902b3
44a0c32
 
c3a1897
 
 
 
 
 
 
 
 
 
 
 
 
eb902b3
 
c3a1897
 
eb902b3
 
c3a1897
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import cv2
import torch
import numpy as np
from PIL import Image
from diffusers import (
    StableDiffusionControlNetPipeline,
    ControlNetModel,
    UniPCMultistepScheduler,
)


class TextToImage:
    def __init__(self, device):
        self.device = device
        self.model = self.initialize_model()

    def initialize_model(self):
        if self.device == 'cpu':
            self.data_type = torch.float32
        else:
            self.data_type = torch.float16
        controlnet = ControlNetModel.from_pretrained(
            "fusing/stable-diffusion-v1-5-controlnet-canny",
            torch_dtype=self.data_type,
            map_location=self.device,  # Add this line
        ).to(self.device)
        pipeline = StableDiffusionControlNetPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            controlnet=controlnet,
            safety_checker=None,
            torch_dtype=self.data_type,
            map_location=self.device,  # Add this line
        )
        pipeline.scheduler = UniPCMultistepScheduler.from_config(
            pipeline.scheduler.config
        )
        pipeline.to(self.device)
        if self.device != 'cpu':
            pipeline.enable_model_cpu_offload()
        return pipeline

    @staticmethod
    def preprocess_image(image):
        image = np.array(image)
        low_threshold = 100
        high_threshold = 200
        image = cv2.Canny(image, low_threshold, high_threshold)
        image = np.stack([image, image, image], axis=2)
        image = Image.fromarray(image)
        return image

    def text_to_image(self, text, image):
        print('\033[1;35m' + '*' * 100 + '\033[0m')
        print('\nStep5, Text to Image:')
        image = self.preprocess_image(image)
        generated_image = self.model(text, image, num_inference_steps=20).images[0]
        print("Generated image has been svaed.")
        print('\033[1;35m' + '*' * 100 + '\033[0m')
        return generated_image
    
    def text_to_image_debug(self, text, image):
        print("text_to_image_debug")
        return image