File size: 3,967 Bytes
a3f8f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy
from PIL import Image
import pytest
from pytest import fixture
import torch
from typing import Tuple

from sgm.inference.api import (
    model_specs,
    SamplingParams,
    SamplingPipeline,
    Sampler,
    ModelArchitecture,
)
import sgm.inference.helpers as helpers


@pytest.mark.inference
class TestInference:
    @fixture(scope="class", params=model_specs.keys())
    def pipeline(self, request) -> SamplingPipeline:
        pipeline = SamplingPipeline(request.param)
        yield pipeline
        del pipeline
        torch.cuda.empty_cache()

    @fixture(
        scope="class",
        params=[
            [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
            [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
        ],
        ids=["SDXL_V1", "SDXL_V0_9"],
    )
    def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]:
        base_pipeline = SamplingPipeline(request.param[0])
        refiner_pipeline = SamplingPipeline(request.param[1])
        yield base_pipeline, refiner_pipeline
        del base_pipeline
        del refiner_pipeline
        torch.cuda.empty_cache()

    def create_init_image(self, h, w):
        image_array = numpy.random.rand(h, w, 3) * 255
        image = Image.fromarray(image_array.astype("uint8")).convert("RGB")
        return helpers.get_input_image_tensor(image)

    @pytest.mark.parametrize("sampler_enum", Sampler)
    def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum):
        output = pipeline.text_to_image(
            params=SamplingParams(sampler=sampler_enum.value, steps=10),
            prompt="A professional photograph of an astronaut riding a pig",
            negative_prompt="",
            samples=1,
        )

        assert output is not None

    @pytest.mark.parametrize("sampler_enum", Sampler)
    def test_img2img(self, pipeline: SamplingPipeline, sampler_enum):
        output = pipeline.image_to_image(
            params=SamplingParams(sampler=sampler_enum.value, steps=10),
            image=self.create_init_image(pipeline.specs.height, pipeline.specs.width),
            prompt="A professional photograph of an astronaut riding a pig",
            negative_prompt="",
            samples=1,
        )
        assert output is not None

    @pytest.mark.parametrize("sampler_enum", Sampler)
    @pytest.mark.parametrize(
        "use_init_image", [True, False], ids=["img2img", "txt2img"]
    )
    def test_sdxl_with_refiner(
        self,
        sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
        sampler_enum,
        use_init_image,
    ):
        base_pipeline, refiner_pipeline = sdxl_pipelines
        if use_init_image:
            output = base_pipeline.image_to_image(
                params=SamplingParams(sampler=sampler_enum.value, steps=10),
                image=self.create_init_image(
                    base_pipeline.specs.height, base_pipeline.specs.width
                ),
                prompt="A professional photograph of an astronaut riding a pig",
                negative_prompt="",
                samples=1,
                return_latents=True,
            )
        else:
            output = base_pipeline.text_to_image(
                params=SamplingParams(sampler=sampler_enum.value, steps=10),
                prompt="A professional photograph of an astronaut riding a pig",
                negative_prompt="",
                samples=1,
                return_latents=True,
            )

        assert isinstance(output, (tuple, list))
        samples, samples_z = output
        assert samples is not None
        assert samples_z is not None
        refiner_pipeline.refiner(
            params=SamplingParams(sampler=sampler_enum.value, steps=10),
            image=samples_z,
            prompt="A professional photograph of an astronaut riding a pig",
            negative_prompt="",
            samples=1,
        )