File size: 3,798 Bytes
944dd2b
 
 
 
 
94bd22c
944dd2b
94bd22c
09a289b
06726d0
e368cec
 
 
 
944dd2b
 
e368cec
 
94bd22c
 
 
 
 
 
 
 
 
 
 
 
 
 
e368cec
 
944dd2b
94bd22c
e368cec
 
 
 
944dd2b
 
e368cec
 
 
94bd22c
944dd2b
07e4294
94bd22c
 
 
 
 
e368cec
 
944dd2b
 
 
26dad4e
 
 
 
765fb5e
7d60da4
e548ada
06726d0
26dad4e
 
944dd2b
26dad4e
e368cec
944dd2b
e368cec
 
944dd2b
 
 
e368cec
 
 
 
944dd2b
 
 
 
e368cec
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import fal_client
from PIL import Image
import requests
import io
import os
import base64

FAL_MODEl_NAME_MAP = {"SDXL": "fast-sdxl", "SDXLTurbo": "fast-turbo-diffusion", "SDXLLightning": "fast-lightning-sdxl",
                      "LCM(v1.5/XL)": "fast-lcm-diffusion", "PixArtSigma": "pixart-sigma", "StableCascade": "stable-cascade",
                      "AuraFlow": "aura-flow", "FLUX1schnell": "flux/schnell", "FLUX1dev": "flux/dev", "T2VTurbo": "t2v-turbo"}

class FalModel():
    def __init__(self, model_name, model_type):
        self.model_name = model_name
        self.model_type = model_type
        os.environ['FAL_KEY'] = os.environ['FalAPI']
    
    def __call__(self, *args, **kwargs):
        def decode_data_url(data_url):
            # Find the start of the Base64 encoded data
            base64_start = data_url.find(",") + 1
            if base64_start == 0:
                raise ValueError("Invalid data URL provided")

            # Extract the Base64 encoded data
            base64_string = data_url[base64_start:]

            # Decode the Base64 string
            decoded_bytes = base64.b64decode(base64_string)

            return decoded_bytes

        if self.model_type == "text2image":
            assert "prompt" in kwargs, "prompt is required for text2image model"
            handler = fal_client.submit(
                f"fal-ai/{FAL_MODEl_NAME_MAP[self.model_name]}",
                arguments={
                    "prompt": kwargs["prompt"]
                },
            )
            for event in handler.iter_events(with_logs=True):
                if isinstance(event, fal_client.InProgress):
                    print('Request in progress')
                    print(event.logs)
            result = handler.get()
            print(result)
            result_url = result['images'][0]['url']
            if self.model_name in ["SDXLTurbo", "LCM(v1.5/XL)"]:
                result_url = io.BytesIO(decode_data_url(result_url))
                result = Image.open(result_url)
            else:
                response = requests.get(result_url)
                result = Image.open(io.BytesIO(response.content))
            return result
        elif self.model_type == "image2image":
            raise NotImplementedError("image2image model is not implemented yet")
        elif self.model_type == "text2video":
            assert "prompt" in kwargs, "prompt is required for text2video model"
            if self.model_name == 'AnimateDiff':
                fal_model_name = 'fast-animatediff/text-to-video'
            elif self.model_name == 'AnimateDiffTurbo':
                fal_model_name = 'fast-animatediff/turbo/text-to-video'
            elif self.model_name == 'StableVideoDiffusion':
                fal_model_name = 'fast-svd/text-to-video'
            elif self.model_name == 'CogVideoX-5B':
                fal_model_name = 'cogvideox-5b'
            else:
                raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet")
            handler = fal_client.submit(
                f"fal-ai/{fal_model_name}",
                arguments={
                    "prompt": kwargs["prompt"]
                },
            )

            for event in handler.iter_events(with_logs=True):
                if isinstance(event, fal_client.InProgress):
                    print('Request in progress')
                    print(event.logs)

            result = handler.get()
            print("result video: ====")
            print(result)
            result_url = result['video']['url']
            return result_url
        else:
            raise ValueError("model_type must be text2image or image2image")

def load_fal_model(model_name, model_type):
    return FalModel(model_name, model_type)