File size: 840 Bytes
52a5033
 
be0c166
42aabcc
52a5033
 
 
 
 
 
9a1c01d
52a5033
 
42aabcc
 
52a5033
42aabcc
52a5033
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from typing import Dict, List, Any
from diffusers import DiffusionPipeline
import torch
import diffusers

class EndpointHandler():
    def __init__(self, path=""):
        self.path = path
        self.model = "remg1997/dynabench-sdxl10"
        self.pipeline = DiffusionPipeline.from_pretrained(self.model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
        self.pipeline = self.pipeline.to("cuda", torch.float16)
    
    def __call__(self, data: Dict[str, Any])-> List[Dict[str, Any]]:
        print("Torch version is", torch.__version__)
        print("Diffusers version is", diffusers.__version__)
        inputs = data.pop("inputs", data)
        print("inputs", inputs)
        steps = data.pop("steps", 30)
        image = self.pipeline(inputs, num_inference_steps = steps)
        return [{"image": image}]