Spaces:
Runtime error
Runtime error
File size: 3,967 Bytes
a3f8f46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import numpy
from PIL import Image
import pytest
from pytest import fixture
import torch
from typing import Tuple
from sgm.inference.api import (
model_specs,
SamplingParams,
SamplingPipeline,
Sampler,
ModelArchitecture,
)
import sgm.inference.helpers as helpers
@pytest.mark.inference
class TestInference:
@fixture(scope="class", params=model_specs.keys())
def pipeline(self, request) -> SamplingPipeline:
pipeline = SamplingPipeline(request.param)
yield pipeline
del pipeline
torch.cuda.empty_cache()
@fixture(
scope="class",
params=[
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
],
ids=["SDXL_V1", "SDXL_V0_9"],
)
def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]:
base_pipeline = SamplingPipeline(request.param[0])
refiner_pipeline = SamplingPipeline(request.param[1])
yield base_pipeline, refiner_pipeline
del base_pipeline
del refiner_pipeline
torch.cuda.empty_cache()
def create_init_image(self, h, w):
image_array = numpy.random.rand(h, w, 3) * 255
image = Image.fromarray(image_array.astype("uint8")).convert("RGB")
return helpers.get_input_image_tensor(image)
@pytest.mark.parametrize("sampler_enum", Sampler)
def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum):
output = pipeline.text_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
)
assert output is not None
@pytest.mark.parametrize("sampler_enum", Sampler)
def test_img2img(self, pipeline: SamplingPipeline, sampler_enum):
output = pipeline.image_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=self.create_init_image(pipeline.specs.height, pipeline.specs.width),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
)
assert output is not None
@pytest.mark.parametrize("sampler_enum", Sampler)
@pytest.mark.parametrize(
"use_init_image", [True, False], ids=["img2img", "txt2img"]
)
def test_sdxl_with_refiner(
self,
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
sampler_enum,
use_init_image,
):
base_pipeline, refiner_pipeline = sdxl_pipelines
if use_init_image:
output = base_pipeline.image_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=self.create_init_image(
base_pipeline.specs.height, base_pipeline.specs.width
),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
return_latents=True,
)
else:
output = base_pipeline.text_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
return_latents=True,
)
assert isinstance(output, (tuple, list))
samples, samples_z = output
assert samples is not None
assert samples_z is not None
refiner_pipeline.refiner(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=samples_z,
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
)
|