aifeifei798's picture
Update app.py
a62e14b verified
from PIL import Image
import base64
from io import BytesIO
import os
from mistralai import Mistral
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
from huggingface_hub import hf_hub_download
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
from openai import OpenAI
import config
from extras.expansion import FooocusExpansion
import re
expansion = FooocusExpansion()
api_key = os.getenv("MISTRAL_API_KEY")
client = Mistral(api_key=api_key)
client_open_ai = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key=os.getenv('HF_TOKEN')
)
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained("shuttleai/shuttle-3-diffusion", subfolder="vae", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained("shuttleai/shuttle-3-diffusion", torch_dtype=dtype, vae=taef1).to(device)
pipe.load_lora_weights(hf_hub_download("aifeifei798/feifei-flux-lora-v1", "feifei.safetensors"), adapter_name = "feifei")
pipe.load_lora_weights(hf_hub_download("aifeifei798/feifei-flux-lora-v1", "FLUX-dev-lora-add_details.safetensors"), adapter_name = "FLUX-dev-lora-add_details")
pipe.load_lora_weights(hf_hub_download("aifeifei798/feifei-flux-lora-v1", "Shadow-Projection.safetensors"), adapter_name = "Shadow-Projection")
pipe.set_adapters(["feifei","FLUX-dev-lora-add_details","Shadow-Projection"], adapter_weights=[0.65,0.35,0.35])
pipe.fuse_lora(adapter_name=["feifei","FLUX-dev-lora-add_details","Shadow-Projection"], lora_scale=1.0)
pipe.unload_lora_weights()
torch.cuda.empty_cache()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 4096
css="""
#col-container {
width: auto;
height: 750px;
}
"""
@spaces.GPU()
def infer(prompt, quality_select, styles_Radio, FooocusExpansion_select, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True), guidance_scale=3.5):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
if not prompt:
prompt = "the photo is a 18 yo jpop girl is looking absolutely adorable and gorgeous, with a playful and mischievous grin, her eyes twinkling with joy."
if quality_select:
prompt += ", masterpiece, best quality, very aesthetic, absurdres"
if styles_Radio:
for style_name in styles_Radio:
for style in config.style_list:
if style["name"] == style_name:
prompt += style["prompt"].replace("{prompt}", "the ")
if FooocusExpansion_select:
prompt = expansion(prompt, seed)
image = pipe(
prompt = "",
prompt_2 = prompt,
width = width,
height = height,
num_inference_steps = num_inference_steps,
generator = generator,
guidance_scale=guidance_scale,
output_type="pil",
).images[0]
return image, seed
def encode_image(image_path):
"""Encode the image to base64."""
try:
# 打开图片文件
image = Image.open(image_path).convert("RGB")
# 将图片转换为字节流
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
except FileNotFoundError:
print(f"Error: The file {image_path} was not found.")
return None
except Exception as e: # 添加通用异常处理
print(f"Error: {e}")
return None
def predict(message, history, additional_dropdown):
message_text = message.get("text", "")
message_files = message.get("files", [])
if message_files:
# Getting the base64 string
message_file = message_files[0]
base64_image = encode_image(message_file)
if base64_image is None:
yield "Error: Failed to encode the image."
return
# Specify model
model = "pixtral-large-2411"
# Define the messages for the chat
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": message_text},
{
"type": "image_url",
"image_url": f"data:image/jpeg;base64,{base64_image}",
},
],
}
]
partial_message = ""
for chunk in client.chat.stream(model=model, messages=messages):
if chunk.data.choices[0].delta.content is not None:
partial_message = partial_message + chunk.data.choices[0].delta.content
yield partial_message
else:
stream = client_open_ai.chat.completions.create(
model=additional_dropdown,
messages=[{"role": "user", "content": str(message_text)}],
temperature=0.5,
max_tokens=1024,
top_p=0.7,
stream=True
)
partial_message = ""
temp = ""
for chunk in stream:
if chunk.choices[0].delta.content is not None:
temp += chunk.choices[0].delta.content
yield temp
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Generator"):
prompt = gr.Text(
label="Prompt",
show_label=False,
placeholder="Enter your prompt",
max_lines = 12,
container=False
)
run_button = gr.Button("Run")
result = gr.Image(label="Result", show_label=False, interactive=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=64,
value=896,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=64,
value=1152,
)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=4,
)
guidancescale = gr.Slider(
label="Guidance scale",
minimum=0,
maximum=10,
step=0.1,
value=3.5,
)
with gr.Tab("Styles"):
quality_select = gr.Checkbox(label="high quality")
FooocusExpansion_select = gr.Checkbox(label="FooocusExpansion",value=True)
styles_name = [style["name"] for style in config.style_list]
styles_Radio = gr.Dropdown(styles_name,label="Styles",multiselect=True)
with gr.Column(scale=3,elem_id="col-container"):
gr.ChatInterface(
predict,
type="messages",
multimodal=True,
additional_inputs =[gr.Dropdown(
["CohereForAI/c4ai-command-r-plus-08-2024",
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"Qwen/Qwen2.5-72B-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"NousResearch/Hermes-3-Llama-3.1-8B",
"mistralai/Mistral-Nemo-Instruct-2407",
"microsoft/Phi-3.5-mini-instruct"],
value="meta-llama/Meta-Llama-3.1-70B-Instruct",
show_label=False,
)]
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [prompt, quality_select, styles_Radio, FooocusExpansion_select, seed, randomize_seed, width, height, num_inference_steps, guidancescale],
outputs = [result, seed]
)
if __name__ == "__main__":
demo.queue().launch()