|
from typing import Dict, List, Any |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
import diffusers |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.path = path |
|
self.model = "remg1997/dynabench-sdxl10" |
|
self.pipeline = DiffusionPipeline.from_pretrained(self.model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16") |
|
self.pipeline = self.pipeline.to("cuda", torch.float16) |
|
|
|
def __call__(self, data: Dict[str, Any])-> List[Dict[str, Any]]: |
|
print("Torch version is", torch.__version__) |
|
print("Diffusers version is", diffusers.__version__) |
|
inputs = data.pop("inputs", data) |
|
print("inputs", inputs) |
|
steps = data.pop("steps", 30) |
|
image = self.pipeline(inputs, num_inference_steps = steps) |
|
return [{"image": image}] |
|
|
|
|
|
|