Spaces:
Running
Running
import fal | |
class FalModel(): | |
def __init__(self, model_name, model_type): | |
self.model_name = model_name | |
self.modle_type = model_type | |
def __call__(self, *args, **kwargs): | |
if self.model_type == "text2image": | |
assert "prompt" in kwargs, "prompt is required for text2image model" | |
handler = fal.apps.submit( | |
f"fal-ai/{self.model_name}", | |
arguments={ | |
"prompt": kwargs["prompt"] | |
}, | |
) | |
for event in handler.iter_events(): | |
if isinstance(event, fal.apps.InProgress): | |
print('Request in progress') | |
print(event.logs) | |
result = handler.get() | |
return result | |
elif self.model_type == "image2image": | |
assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model" | |
if "image" in kwargs: | |
image_url = None | |
pass | |
handler = fal.apps.submit( | |
f"fal-ai/{self.model_name}", | |
arguments={ | |
"image_url": image_url | |
}, | |
) | |
for event in handler.iter_events(): | |
if isinstance(event, fal.apps.InProgress): | |
print('Request in progress') | |
print(event.logs) | |
result = handler.get() | |
return result | |
else: | |
raise ValueError("model_type must be text2image or image2image") | |
def load_fal_model(model_name, model_type): | |
return FalModel(model_name, model_type) |