Safetensors
aredden commited on
Commit
fb3cdc4
1 Parent(s): 71170f1

Make lora loading api endpoint functional

Browse files
Files changed (2) hide show
  1. api.py +92 -3
  2. flux_pipeline.py +8 -5
api.py CHANGED
@@ -1,17 +1,38 @@
1
- from typing import Optional
2
 
3
  import numpy as np
4
  from fastapi import FastAPI
5
- from fastapi.responses import StreamingResponse
6
  from pydantic import BaseModel, Field
7
  from platform import system
8
 
 
 
 
9
  if system() == "Windows":
10
  MAX_RAND = 2**16 - 1
11
  else:
12
  MAX_RAND = 2**32 - 1
13
 
14
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  class GenerateArgs(BaseModel):
@@ -27,7 +48,75 @@ class GenerateArgs(BaseModel):
27
  init_image: Optional[str] = None
28
 
29
 
 
 
 
30
  @app.post("/generate")
31
  def generate(args: GenerateArgs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  result = app.state.model.generate(**args.model_dump())
33
  return StreamingResponse(result, media_type="image/jpeg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, TYPE_CHECKING
2
 
3
  import numpy as np
4
  from fastapi import FastAPI
5
+ from fastapi.responses import StreamingResponse, JSONResponse
6
  from pydantic import BaseModel, Field
7
  from platform import system
8
 
9
+ if TYPE_CHECKING:
10
+ from flux_pipeline import FluxPipeline
11
+
12
  if system() == "Windows":
13
  MAX_RAND = 2**16 - 1
14
  else:
15
  MAX_RAND = 2**32 - 1
16
 
17
+
18
+ class AppState:
19
+ model: "FluxPipeline"
20
+
21
+
22
+ class FastAPIApp(FastAPI):
23
+ state: AppState
24
+
25
+
26
+ class LoraArgs(BaseModel):
27
+ scale: Optional[float] = 1.0
28
+ path: Optional[str] = None
29
+ name: Optional[str] = None
30
+ action: Optional[Literal["load", "unload"]] = "load"
31
+
32
+
33
+ class LoraLoadResponse(BaseModel):
34
+ status: Literal["success", "error"]
35
+ message: Optional[str] = None
36
 
37
 
38
  class GenerateArgs(BaseModel):
 
48
  init_image: Optional[str] = None
49
 
50
 
51
+ app = FastAPIApp()
52
+
53
+
54
  @app.post("/generate")
55
  def generate(args: GenerateArgs):
56
+ """
57
+ Generates an image from the Flux flow transformer.
58
+
59
+ Args:
60
+ args (GenerateArgs): Arguments for image generation:
61
+
62
+ - `prompt`: The prompt used for image generation.
63
+
64
+ - `width`: The width of the image.
65
+
66
+ - `height`: The height of the image.
67
+
68
+ - `num_steps`: The number of steps for the image generation.
69
+
70
+ - `guidance`: The guidance for image generation, represents the
71
+ influence of the prompt on the image generation.
72
+
73
+ - `seed`: The seed for the image generation.
74
+
75
+ - `strength`: strength for image generation, 0.0 - 1.0.
76
+ Represents the percent of diffusion steps to run,
77
+ setting the init_image as the noised latent at the
78
+ given number of steps.
79
+
80
+ - `init_image`: Base64 encoded image or path to image to use as the init image.
81
+
82
+ Returns:
83
+ StreamingResponse: The generated image as streaming jpeg bytes.
84
+ """
85
  result = app.state.model.generate(**args.model_dump())
86
  return StreamingResponse(result, media_type="image/jpeg")
87
+
88
+
89
+ @app.post("/lora", response_model=LoraLoadResponse)
90
+ def lora_action(args: LoraArgs):
91
+ """
92
+ Loads or unloads a LoRA checkpoint into / from the Flux flow transformer.
93
+
94
+ Args:
95
+ args (LoraArgs): Arguments for the LoRA action:
96
+
97
+ - `scale`: The scaling factor for the LoRA weights.
98
+ - `path`: The path to the LoRA checkpoint.
99
+ - `name`: The name of the LoRA checkpoint.
100
+ - `action`: The action to perform, either "load" or "unload".
101
+
102
+ Returns:
103
+ LoraLoadResponse: The status of the LoRA action.
104
+ """
105
+ try:
106
+ if args.action == "load":
107
+ app.state.model.load_lora(args.path, args.scale, args.name)
108
+ elif args.action == "unload":
109
+ app.state.model.unload_lora(args.name if args.name else args.path)
110
+ else:
111
+ return JSONResponse(
112
+ content={
113
+ "status": "error",
114
+ "message": f"Invalid action, expected 'load' or 'unload', got {args.action}",
115
+ },
116
+ status_code=400,
117
+ )
118
+ except Exception as e:
119
+ return JSONResponse(
120
+ status_code=500, content={"status": "error", "message": str(e)}
121
+ )
122
+ return JSONResponse(status_code=200, content={"status": "success"})
flux_pipeline.py CHANGED
@@ -2,7 +2,7 @@ import io
2
  import math
3
  import random
4
  import warnings
5
- from typing import TYPE_CHECKING, Callable, List, OrderedDict, Union
6
 
7
  import numpy as np
8
  from PIL import Image
@@ -149,7 +149,10 @@ class FluxPipeline:
149
  return cuda_generator, seed
150
 
151
  def load_lora(
152
- self, lora_path: Union[str, OrderedDict[str, torch.Tensor]], scale: float
 
 
 
153
  ):
154
  """
155
  Loads a LoRA checkpoint into the Flux flow transformer.
@@ -160,9 +163,9 @@ class FluxPipeline:
160
  Args:
161
  lora_path (str | OrderedDict[str, torch.Tensor]): Path to the LoRA checkpoint or an ordered dictionary containing the LoRA weights.
162
  scale (float): Scaling factor for the LoRA weights.
163
-
164
  """
165
- self.model.load_lora(lora_path, scale)
166
 
167
  def unload_lora(self, path_or_identifier: str):
168
  """
@@ -171,7 +174,7 @@ class FluxPipeline:
171
  Args:
172
  path_or_identifier (str): Path to the LoRA checkpoint or the name given to the LoRA checkpoint when it was loaded.
173
  """
174
- self.model.unload_lora(path_or_identifier)
175
 
176
  @torch.inference_mode()
177
  def compile(self):
 
2
  import math
3
  import random
4
  import warnings
5
+ from typing import TYPE_CHECKING, Callable, List, Optional, OrderedDict, Union
6
 
7
  import numpy as np
8
  from PIL import Image
 
149
  return cuda_generator, seed
150
 
151
  def load_lora(
152
+ self,
153
+ lora_path: Union[str, OrderedDict[str, torch.Tensor]],
154
+ scale: float,
155
+ name: Optional[str] = None,
156
  ):
157
  """
158
  Loads a LoRA checkpoint into the Flux flow transformer.
 
163
  Args:
164
  lora_path (str | OrderedDict[str, torch.Tensor]): Path to the LoRA checkpoint or an ordered dictionary containing the LoRA weights.
165
  scale (float): Scaling factor for the LoRA weights.
166
+ name (str): Name of the LoRA checkpoint, optionally can be left as None, since it only acts as an identifier.
167
  """
168
+ self.model.load_lora(path=lora_path, scale=scale, name=name)
169
 
170
  def unload_lora(self, path_or_identifier: str):
171
  """
 
174
  Args:
175
  path_or_identifier (str): Path to the LoRA checkpoint or the name given to the LoRA checkpoint when it was loaded.
176
  """
177
+ self.model.unload_lora(path_or_identifier=path_or_identifier)
178
 
179
  @torch.inference_mode()
180
  def compile(self):