Fifa_avatar_gen / app.py
Scezui's picture
updated the system
c2e4746
raw
history blame
2.02 kB
from flask import Flask, render_template, request, jsonify
from tensorflow.keras.models import load_model
from numpy.random import randn
import matplotlib.pyplot as plt
import numpy as np
import base64
from io import BytesIO
app = Flask(__name__)
# Load your GAN model from the H5 file
model = load_model('gan.h5')
def generate_latent_points(latent_dim, n_samples):
x_input = randn(latent_dim * n_samples)
z_input = x_input.reshape(n_samples, latent_dim)
return z_input
def generate_images(model, latent_points):
generated_images = model.predict(latent_points)
return generated_images
def plot_generated(examples, n_rows, n_cols, image_size=(80, 80)):
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10))
for i in range(n_rows):
for j in range(n_cols):
index = i * n_cols + j
if index < len(examples):
axes[i, j].axis('off')
axes[i, j].imshow(examples[index, :, :])
else:
axes[i, j].axis('off')
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
plt.close(fig)
return base64.b64encode(buf.read()).decode('utf-8')
@app.route('/')
def index():
return render_template('index.html')
import math
@app.route('/generate', methods=['POST'])
def generate():
latent_dim = 100
n_samples = max(int(request.form.get('n_samples', 4)), 1)
# Calculate the number of rows dynamically based on the square root of n_samples
n_rows = max(int(math.sqrt(n_samples)), 1)
# Calculate the number of columns based on the number of rows
n_cols = (n_samples + n_rows - 1) // n_rows
latent_points = generate_latent_points(latent_dim, n_samples)
generated_images = generate_images(model, latent_points)
generated_images = (generated_images + 1) / 2.0
img_data = plot_generated(generated_images, n_rows, n_cols)
return jsonify({'success': True, 'generated_image': img_data})
if __name__ == '__main__':
app.run(debug=True)