Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile, Form, Body | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import List | |
import io | |
from PIL import Image, ImageOps | |
import numpy as np | |
import compColors | |
import dominantColors | |
import recolorReinhardAlgo | |
import recolorOTAlgo | |
import recolorTransferAlgo | |
import recolorLumaConverterAlgo | |
import recolorPaletteBasedTransfer | |
import recolorReinhardV2Algo | |
import recolorLinearColorTransfer | |
import recolorStrictTransfer | |
import matchCollection | |
import ColorReplacer | |
import ColorMask | |
from typing import Optional | |
import random | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def dominant_color(file: UploadFile = File(...), num_colors: int = Form(...), ordered: bool = Form(False)): | |
""" | |
Receive an image file and an integer and return the dominant color(s). | |
""" | |
print('num_colors: ', num_colors) | |
file_content = await file.read() | |
image_bytes = io.BytesIO(file_content) | |
im = Image.open(image_bytes) | |
dominantColorsRGB = dominantColors.findDominantColors(image_bytes, num_colors, False) | |
dominantColorsHex = [dominantColors.rgb_to_hex(color) for color in dominantColorsRGB] | |
return {"dominantColors": dominantColorsHex} | |
async def color_palettes(colors: str = Form(...)): | |
""" | |
Receive an array of strings representing colors and return a color palette based on these colors. | |
""" | |
#maybe this isn't necessary. converting the string to an array of strings | |
colors = [color.strip() for color in colors.split(',')] | |
#generate the first pallete, which is the complementary colors of the given colors | |
complementaryColors = [] | |
for color in colors: | |
complementaryColors.append(compColors.complementary_colors(color)) | |
#generate the second palette using the adjacent colors algorithm: | |
adjacentColors = [] | |
for color in colors: | |
_adjcolors = compColors.adjacent_colors(color) | |
for _color in _adjcolors: | |
if _color not in adjacentColors: | |
adjacentColors.append(_color) | |
#generate the third palette using the gradient colors algorithm: | |
gradientColors = [] | |
for i in range(len(colors)-1): | |
gradientColors.append(compColors.gradient_colors(colors[i], colors[i+1])) | |
#Fixing size of palletes to 5 colors: | |
complementaryColors = [complementaryColors[i:i + 5] for i in range(0, len(complementaryColors), 5)] | |
adjacentColors = [adjacentColors[i:i + 5] for i in range(0, len(adjacentColors), 5)] | |
colors = [colors[i:i + 5] for i in range(0, len(colors), 5)] | |
return {"inputColor": colors, "complementaryColors": complementaryColors, "adjacentColors": adjacentColors, "gradientColors": gradientColors} | |
async def recolor(file: UploadFile = File(...), colors: str = Form(...), new_colors: Optional[str] = Form(None), random_colors: bool = Form(False), model: str = Form(...), mask: Optional[UploadFile] = File(None)): | |
""" | |
Receive an image file and an array of strings representing colors of a selected pallete and recolor an image. | |
""" | |
method = model | |
invertColors = False | |
colors = [color.strip() for color in colors.split(',')] | |
file_content = await file.read() | |
image_bytes = io.BytesIO(file_content) | |
image = Image.open(image_bytes) | |
if invertColors: | |
image = ImageOps.invert(image) | |
image_np = np.array(image) | |
if new_colors is not None: | |
new_colors = [new_color.strip() for new_color in new_colors.split(',')] | |
if method == "CCA": | |
print('CCA generated') | |
#Characteristic Color Analysis | |
recolorReinhardAlgo.recolor(image_np, colors) | |
elif method == "OTA": | |
print('OTA generated') | |
#Optimal Transport Algorithm transfer | |
recolorOTAlgo.recolor(image_np, colors) | |
elif method =="KMEANS": | |
print('KMEANS generated') | |
#K-means clustering transfer | |
recolorTransferAlgo.recolor(image_np, colors) | |
elif method == "LUMA": | |
print('Luma generated') | |
#Luma converter transfer | |
recolorLumaConverterAlgo.remap_image_colors(image_np, colors) | |
elif method == "palette": | |
#palette transfer | |
print('palette generated') | |
recolorPaletteBasedTransfer.recolor(image_np, colors) | |
elif method == "Reinhardv2": | |
print('Reinhardv2 generated') | |
recolorReinhardV2Algo.recolor(image_np, colors) | |
elif method == "LinearColorTransfer": | |
print('LinearColorTransfer generated') | |
recolorLinearColorTransfer.recolor(image_np, colors) | |
elif method.startswith("StrictTransfer"): | |
print('StrictTransfer started', colors, new_colors) | |
if random_colors: | |
random.shuffle(colors) | |
random.shuffle(new_colors) | |
print('StrictTransfer random', colors, new_colors) | |
recolorStrictTransfer.recolor(image_np, colors, new_colors) | |
elif method == "ColorMask": | |
print('ColorMask started') | |
ColorMask.create_mask(image_np, colors[0]) | |
#mask image: | |
if mask is not None: | |
mask_content = await mask.read() | |
mask_image = Image.open(io.BytesIO(mask_content)) | |
# Ensure mask_image is the same size as result_image | |
result_image = Image.open('./result.jpg') | |
result_np = np.array(result_image) | |
print('result_np', result_np.size) | |
print('image_np', image_np.size) | |
if mask_image.size != result_image.size: | |
mask_image = mask_image.resize(result_image.size) | |
mask_image = mask_image.convert('RGB') | |
mask_np = np.array(mask_image) | |
# Create a new image array based on the mask | |
new_image_np = np.where(mask_np == 0, result_np, image_np) | |
# Save the new image | |
new_image = Image.fromarray(new_image_np) | |
new_image.save('./result.jpg') | |
img_file = open("./result.jpg", "rb") | |
return StreamingResponse(img_file, media_type="image/jpeg") | |
# @app.post("/collection/") | |
# async def create_collection(collection: str = Body(...), colors: List[str] = Body(...)): | |
# """ | |
# Endpoint to create a collection with items. | |
# """ | |
# result = matchCollection.predict_palette(collection, colors[0]) | |
# print(result) | |
# #preparar o dado pra ser respondido | |
# return {"collection": result} | |
async def create_collection(collection: str = Body(...), colors: List[str] = Body(...)): | |
""" | |
Endpoint to create a collection with items. | |
""" | |
palettes = [matchCollection.predict_palette(collection, color) for color in colors] | |
return {"collection": collection, "palettes": palettes} | |
async def test(): | |
""" | |
Test endpoint to check if the server is running. | |
""" | |
return {"message": "Server is running!"} | |
if __name__ == "__main__": | |
import uvicorn | |
print("Server is running") | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |
#how to run: | |
#source env/bin/activate | |
#uvicorn server:app --reload | |
#curl -X POST http://0.0.0.0:4201/collection/ \ -H "Content-Type: application/json" \ -d '{"collection": "FLORAL", "colors": ["#1f3b4a", "#597375", "#7f623e", "#5c453c"]}' |