|
import torch |
|
import io |
|
from fireworks.flumina import FluminaModule, main as flumina_main |
|
from fireworks.flumina.route import post |
|
import pydantic |
|
from pydantic import BaseModel |
|
from fastapi import File, Form, Header, UploadFile, HTTPException |
|
from fastapi.responses import Response |
|
import math |
|
import os |
|
import re |
|
import PIL.Image as Image |
|
from typing import Dict, Optional, Set, Tuple |
|
|
|
from diffusers import FluxPipeline, FluxControlNetPipeline, FluxControlNetModel |
|
from diffusers.models import FluxMultiControlNetModel |
|
|
|
|
|
|
|
def _aspect_ratio_to_width_height(aspect_ratio: str) -> Tuple[int, int]: |
|
""" |
|
Convert specified aspect ratio to a height/width pair. |
|
""" |
|
if ":" not in aspect_ratio: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" |
|
) |
|
|
|
w, h = aspect_ratio.split(":") |
|
try: |
|
w, h = int(w), int(h) |
|
except ValueError: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" |
|
) |
|
|
|
valid_aspect_ratios = [ |
|
(1, 1), |
|
(21, 9), |
|
(16, 9), |
|
(3, 2), |
|
(4, 3), |
|
(5, 4), |
|
(4, 5), |
|
(3, 4), |
|
(2, 3), |
|
(9, 16), |
|
(9, 21), |
|
] |
|
if (w, h) not in valid_aspect_ratios: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be one of {valid_aspect_ratios}" |
|
) |
|
|
|
|
|
TARGET_SIZE_MP = 1 |
|
target_size = TARGET_SIZE_MP * 2**20 |
|
|
|
width = math.sqrt(target_size / (w * h)) * w |
|
height = math.sqrt(target_size / (w * h)) * h |
|
|
|
PAD_MULTIPLE = 64 |
|
|
|
if PAD_MULTIPLE: |
|
width = width // PAD_MULTIPLE * PAD_MULTIPLE |
|
height = height // PAD_MULTIPLE * PAD_MULTIPLE |
|
|
|
return int(width), int(height) |
|
|
|
|
|
def encode_image( |
|
image: Image.Image, mime_type: str, jpeg_quality: int = 95 |
|
) -> bytes: |
|
buffered = io.BytesIO() |
|
if mime_type == "image/jpeg": |
|
if jpeg_quality < 0 or jpeg_quality > 100: |
|
raise ValueError( |
|
f"jpeg_quality must be between 0 and 100, not {jpeg_quality}" |
|
) |
|
image.save(buffered, format="JPEG", quality=jpeg_quality) |
|
elif mime_type == "image/png": |
|
image.save(buffered, format="PNG") |
|
else: |
|
raise ValueError(f"invalid mime_type {mime_type}") |
|
return buffered.getvalue() |
|
|
|
|
|
def parse_accept_header(accept: str) -> str: |
|
|
|
parts = accept.split(",") |
|
weighted_types = [] |
|
|
|
for part in parts: |
|
|
|
match = re.match( |
|
r"(?P<media_type>[^;]+)(;q=(?P<q_factor>\d+(\.\d+)?))?", part.strip() |
|
) |
|
if match: |
|
media_type = match.group("media_type").strip() |
|
q_factor = ( |
|
float(match.group("q_factor")) if match.group("q_factor") else 1.0 |
|
) |
|
weighted_types.append((media_type, q_factor)) |
|
else: |
|
raise ValueError(f"Malformed Accept header value: {part.strip()}") |
|
|
|
|
|
sorted_types = sorted(weighted_types, key=lambda x: x[1], reverse=True) |
|
|
|
|
|
supported_types = ["image/jpeg", "image/png"] |
|
|
|
for media_type, _ in sorted_types: |
|
if media_type in supported_types: |
|
return media_type |
|
elif media_type == "*/*": |
|
return supported_types[0] |
|
elif media_type == "image/*": |
|
|
|
return supported_types[0] |
|
|
|
raise ValueError(f"Accept header did not include any supported MIME types: {supported_types}") |
|
|
|
|
|
|
|
class Text2ImageRequest(BaseModel): |
|
prompt: str |
|
aspect_ratio: str = "16:9" |
|
guidance_scale: float = 3.5 |
|
num_inference_steps: int = 4 |
|
seed: int = 0 |
|
|
|
|
|
class Error(BaseModel): |
|
object: str = "error" |
|
type: str = "invalid_request_error" |
|
message: str |
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
error: Error = pydantic.Field(default_factory=Error) |
|
|
|
|
|
class BillingInfo(BaseModel): |
|
steps: int |
|
height: int |
|
width: int |
|
is_control_net: bool |
|
|
|
|
|
class FluminaModule(FluminaModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.hf_model = FluxPipeline.from_pretrained('./data', torch_dtype=torch.bfloat16) |
|
self.hf_model.to(device='cuda', dtype=torch.bfloat16) |
|
|
|
|
|
self.cnet_union_pipes: Dict[str, FluxControlNetPipeline] = {} |
|
|
|
|
|
self.active_cnet_union: Optional[str] = None |
|
|
|
self.lora_adapters: Set[str] = set() |
|
self.active_lora_adapter: Optional[str] = None |
|
|
|
self._test_return_sync_response = False |
|
|
|
def _error_response(self, code: int, message: str) -> Response: |
|
response_json = ErrorResponse( |
|
error=Error(message=message), |
|
).json() |
|
if self._test_return_sync_response: |
|
return response_json |
|
else: |
|
return Response( |
|
response_json, |
|
status_code=code, |
|
media_type="application/json", |
|
) |
|
|
|
def _image_response(self, img: Image.Image, mime_type: str, billing_info: BillingInfo): |
|
image_bytes = encode_image(img, mime_type) |
|
if self._test_return_sync_response: |
|
return image_bytes |
|
else: |
|
headers = {'Fireworks-Billing-Properties': billing_info.json()} |
|
return Response(image_bytes, status_code=200, media_type=mime_type, headers=headers) |
|
|
|
@post('/text_to_image') |
|
async def text_to_image( |
|
self, |
|
body: Text2ImageRequest, |
|
accept: str = Header("image/jpeg"), |
|
): |
|
mime_type = parse_accept_header(accept) |
|
width, height = _aspect_ratio_to_width_height(body.aspect_ratio) |
|
img = self.hf_model( |
|
prompt=body.prompt, |
|
height=height, |
|
width=width, |
|
guidance_scale=body.guidance_scale, |
|
num_inference_steps=body.num_inference_steps, |
|
generator=torch.Generator('cuda').manual_seed(body.seed), |
|
) |
|
assert len(img.images) == 1, len(img.images) |
|
|
|
billing_info = BillingInfo( |
|
steps=body.num_inference_steps, |
|
height=height, |
|
width=width, |
|
is_control_net=False, |
|
) |
|
return self._image_response(img.images[0], mime_type, billing_info) |
|
|
|
@post('/control_net') |
|
async def control_net( |
|
self, |
|
prompt: str = Form(...), |
|
control_image: UploadFile = File(...), |
|
control_mode: int = Form(...), |
|
aspect_ratio: str = Form("16:9"), |
|
guidance_scale: float = Form(3.5), |
|
num_inference_steps: int = Form(4), |
|
seed: int = Form(0), |
|
|
|
controlnet_conditioning_scale: Optional[float] = Form(1.0), |
|
accept: str = Header("image/jpeg"), |
|
): |
|
mime_type = parse_accept_header(accept) |
|
if self.active_cnet_union is None: |
|
return self._error_response(400, f"Must call `/control_net` endpoint with a ControlNet model specified in the URI") |
|
|
|
if control_mode is None: |
|
return self._error_response(400, f"control_mode must be specified when calling a ControlNet model") |
|
|
|
|
|
try: |
|
image_data = await control_image.read() |
|
pil_image = Image.open(io.BytesIO(image_data)) |
|
except Exception as e: |
|
return self._error_response(400, f"Invalid image format: {e}") |
|
|
|
width, height = _aspect_ratio_to_width_height(aspect_ratio) |
|
img = self.cnet_union_pipes[self.active_cnet_union]( |
|
prompt=prompt, |
|
control_image=[pil_image], |
|
control_mode=[control_mode], |
|
height=height, |
|
width=width, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
controlnet_conditioning_scale=[controlnet_conditioning_scale], |
|
generator=torch.Generator('cuda').manual_seed(seed), |
|
) |
|
assert len(img.images) == 1, len(img.images) |
|
|
|
billing_info = BillingInfo( |
|
steps=num_inference_steps, |
|
height=height, |
|
width=width, |
|
is_control_net=True, |
|
) |
|
return self._image_response(img.images[0], mime_type, billing_info) |
|
|
|
|
|
@property |
|
def supported_addon_types(self): |
|
return ['controlnet_union', 'lora'] |
|
|
|
|
|
def load_addon( |
|
self, addon_account_id: str, addon_model_id: str, addon_type: str, addon_data_path: os.PathLike |
|
): |
|
if addon_type not in self.supported_addon_types: |
|
raise ValueError(f"Invalid addon type {addon_type}. Supported types: {self.supported_addon_types}") |
|
|
|
qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" |
|
|
|
if addon_type == 'controlnet_union': |
|
cnet_model = FluxControlNetModel.from_pretrained(addon_data_path) |
|
multi_cnet_model = FluxMultiControlNetModel([cnet_model]) |
|
multi_cnet_model.to(device='cuda', dtype=torch.bfloat16) |
|
self.cnet_union_pipes[qualname] = FluxControlNetPipeline( |
|
scheduler=self.hf_model.scheduler, |
|
vae=self.hf_model.vae, |
|
text_encoder=self.hf_model.text_encoder, |
|
tokenizer=self.hf_model.tokenizer, |
|
text_encoder_2=self.hf_model.text_encoder_2, |
|
tokenizer_2=self.hf_model.tokenizer_2, |
|
transformer=self.hf_model.transformer, |
|
controlnet=multi_cnet_model, |
|
) |
|
elif addon_type == 'lora': |
|
self.hf_model.load_lora_weights(addon_data_path, adapter_name=qualname) |
|
self.lora_adapters.add(qualname) |
|
else: |
|
raise NotImplementedError(f'Addon support for type {addon_type} not implemented') |
|
|
|
def unload_addon( |
|
self, addon_account_id: str, addon_model_id: str, addon_type: str |
|
): |
|
qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" |
|
|
|
if addon_type == 'controlnet_union': |
|
assert qualname in self.cnet_union_pipes |
|
self.cnet_union_pipes.pop(qualname) |
|
elif addon_type == 'lora': |
|
assert qualname in self.lora_adapters |
|
self.hf_model.delete_adapters([qualname]) |
|
self.lora_adapters.remove(qualname) |
|
else: |
|
raise NotImplementedError(f'Addon support for type {addon_type} not implemented') |
|
|
|
def activate_addon(self, addon_account_id: str, addon_model_id: str): |
|
qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" |
|
|
|
if qualname in self.cnet_union_pipes: |
|
if self.active_cnet_union is not None: |
|
raise ValueError(f"ControlNet Union {self.active_cnet_union} already active. Multi-controlnet union not supported!") |
|
|
|
self.active_cnet_union = qualname |
|
return |
|
|
|
if qualname in self.lora_adapters: |
|
if self.active_lora_adapter is not None: |
|
raise ValueError(f"LoRA adapter {self.active_lora_adapter} already active. Multi-LoRA not yet supported") |
|
|
|
self.active_lora_adapter = qualname |
|
return |
|
|
|
raise ValueError(f"Unknown addon {qualname}") |
|
|
|
|
|
def deactivate_addon(self, addon_account_id: str, addon_model_id: str): |
|
qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" |
|
|
|
if self.active_cnet_union == qualname: |
|
self.active_cnet_union = None |
|
elif self.active_lora_adapter == qualname: |
|
self.active_lora_adapter = None |
|
else: |
|
raise AssertionError(f'Addon {qualname} not loaded!') |
|
|
|
|
|
if __name__ == "__flumina_main__": |
|
f = FluminaModule() |
|
flumina_main(f) |
|
|
|
if __name__ == "__main__": |
|
f = FluminaModule() |
|
f._test_return_sync_response = True |
|
|
|
import asyncio |
|
|
|
|
|
t2i_out = asyncio.run(f.text_to_image( |
|
Text2ImageRequest( |
|
prompt="A quick brown fox", |
|
aspect_ratio="4:3", |
|
guidance_scale=3.5, |
|
num_inference_steps=4, |
|
seed=0, |
|
), |
|
accept="*/*", |
|
)) |
|
assert isinstance(t2i_out, bytes), t2i_out |
|
with open('output.png', 'wb') as out_file: |
|
out_file.write(t2i_out) |
|
|
|
|
|
cn_adapter_path = os.environ.get('CONTROLNET_ADAPTER_PATH', None) |
|
if cn_adapter_path is not None: |
|
addon_acct_id, addon_model_id = "fireworks", "test_controlnet" |
|
f.load_addon(addon_acct_id, addon_model_id, "controlnet_union", cn_adapter_path) |
|
f.activate_addon(addon_acct_id, addon_model_id) |
|
|
|
import cv2 |
|
class FakeFile: |
|
def __init__(self, filename): |
|
self.filename = filename |
|
|
|
async def read(self): |
|
image = cv2.imread(self.filename) |
|
|
|
|
|
if image is None: |
|
raise ValueError("Image not found or unable to open.") |
|
|
|
|
|
blurred_image = cv2.GaussianBlur(image, (5, 5), 1.4) |
|
|
|
|
|
edges = cv2.Canny(blurred_image, threshold1=0, threshold2=50) |
|
control_image = Image.fromarray(edges).convert("RGB") |
|
bio = io.BytesIO() |
|
control_image.save(bio, format="PNG") |
|
bio.seek(0) |
|
return bio.getvalue() |
|
|
|
cn_out = asyncio.run(f.control_net( |
|
prompt="Cyberpunk future fox nighttime purple and green", |
|
control_image=FakeFile('output.png'), |
|
control_mode=0, |
|
aspect_ratio="4:3", |
|
guidance_scale=3.5, |
|
num_inference_steps=4, |
|
seed=0, |
|
controlnet_conditioning_scale=1.0, |
|
accept="image/png", |
|
)) |
|
assert isinstance(cn_out, bytes), cn_out |
|
f.deactivate_addon(addon_acct_id, addon_model_id) |
|
f.unload_addon(addon_acct_id, addon_model_id, "controlnet_union") |
|
|
|
with open('output_cn.png', 'wb') as cn_out_file: |
|
cn_out_file.write(cn_out) |
|
else: |
|
print('Skipping ControlNet test. Set CONTROLNET_ADAPTER_PATH to enable it.') |
|
|
|
lora_adapter_path = os.environ.get('LORA_ADAPTER_PATH', None) |
|
if lora_adapter_path is not None: |
|
addon_acct_id, addon_model_id = "fireworks", "test_lora" |
|
f.load_addon(addon_acct_id, addon_model_id, "lora", lora_adapter_path) |
|
f.activate_addon(addon_acct_id, addon_model_id) |
|
|
|
lora_out = asyncio.run(f.text_to_image( |
|
Text2ImageRequest( |
|
prompt="A quick brown fox", |
|
aspect_ratio="4:3", |
|
guidance_scale=3.5, |
|
num_inference_steps=4, |
|
seed=0, |
|
), |
|
accept="image/jpeg;image/png", |
|
)) |
|
assert isinstance(lora_out, bytes), lora_out |
|
|
|
f.deactivate_addon(addon_acct_id, addon_model_id) |
|
f.unload_addon(addon_acct_id, addon_model_id, "lora") |
|
|
|
with open('output_lora.png', 'wb') as lora_out_file: |
|
lora_out_file.write(lora_out) |
|
else: |
|
print('Skipping ControlNet test. Set LORA_ADAPTER_PATH to enable it.') |
|
|