Spaces:
Running
on
Zero
Running
on
Zero
import numpy | |
from PIL import Image | |
import pytest | |
from pytest import fixture | |
import torch | |
from typing import Tuple | |
from sgm.inference.api import ( | |
model_specs, | |
SamplingParams, | |
SamplingPipeline, | |
Sampler, | |
ModelArchitecture, | |
) | |
import sgm.inference.helpers as helpers | |
class TestInference: | |
def pipeline(self, request) -> SamplingPipeline: | |
pipeline = SamplingPipeline(request.param) | |
yield pipeline | |
del pipeline | |
torch.cuda.empty_cache() | |
def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]: | |
base_pipeline = SamplingPipeline(request.param[0]) | |
refiner_pipeline = SamplingPipeline(request.param[1]) | |
yield base_pipeline, refiner_pipeline | |
del base_pipeline | |
del refiner_pipeline | |
torch.cuda.empty_cache() | |
def create_init_image(self, h, w): | |
image_array = numpy.random.rand(h, w, 3) * 255 | |
image = Image.fromarray(image_array.astype("uint8")).convert("RGB") | |
return helpers.get_input_image_tensor(image) | |
def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum): | |
output = pipeline.text_to_image( | |
params=SamplingParams(sampler=sampler_enum.value, steps=10), | |
prompt="A professional photograph of an astronaut riding a pig", | |
negative_prompt="", | |
samples=1, | |
) | |
assert output is not None | |
def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): | |
output = pipeline.image_to_image( | |
params=SamplingParams(sampler=sampler_enum.value, steps=10), | |
image=self.create_init_image(pipeline.specs.height, pipeline.specs.width), | |
prompt="A professional photograph of an astronaut riding a pig", | |
negative_prompt="", | |
samples=1, | |
) | |
assert output is not None | |
def test_sdxl_with_refiner( | |
self, | |
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], | |
sampler_enum, | |
use_init_image, | |
): | |
base_pipeline, refiner_pipeline = sdxl_pipelines | |
if use_init_image: | |
output = base_pipeline.image_to_image( | |
params=SamplingParams(sampler=sampler_enum.value, steps=10), | |
image=self.create_init_image( | |
base_pipeline.specs.height, base_pipeline.specs.width | |
), | |
prompt="A professional photograph of an astronaut riding a pig", | |
negative_prompt="", | |
samples=1, | |
return_latents=True, | |
) | |
else: | |
output = base_pipeline.text_to_image( | |
params=SamplingParams(sampler=sampler_enum.value, steps=10), | |
prompt="A professional photograph of an astronaut riding a pig", | |
negative_prompt="", | |
samples=1, | |
return_latents=True, | |
) | |
assert isinstance(output, (tuple, list)) | |
samples, samples_z = output | |
assert samples is not None | |
assert samples_z is not None | |
refiner_pipeline.refiner( | |
params=SamplingParams(sampler=sampler_enum.value, steps=10), | |
image=samples_z, | |
prompt="A professional photograph of an astronaut riding a pig", | |
negative_prompt="", | |
samples=1, | |
) | |