Spaces:
Sleeping
Sleeping
import argparse, subprocess, sys, time | |
def setup(): | |
install_cmds = [ | |
['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'stability-sdk', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'], | |
['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'], | |
['pip', 'install', '-e', | |
'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'], | |
['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git'] | |
] | |
for cmd in install_cmds: | |
print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8')) | |
setup() | |
sys.path.append('src/blip') | |
sys.path.append('src/clip') | |
sys.path.append('clip-interrogator') | |
import clip | |
import torch | |
import gradio as gr | |
from clip_interrogator import Interrogator, Config | |
ci = Interrogator(Config()) | |
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation | |
from stability_sdk import client | |
import os | |
stability_api = client.StabilityInference( | |
key=os.environ['STABILITY_KEY'], | |
verbose=True | |
) | |
import clip | |
import torch | |
from PIL import Image | |
import warnings | |
import random | |
from io import BytesIO | |
def inferAndRebuild(image, mode): | |
image = image.convert('RGB') | |
output = '' | |
if (mode == 'best'): | |
output = ci.interrogate(image) | |
elif (mode == 'classic'): | |
output = ci.interrogate_classic(image) | |
else: | |
output = ci.interrogate_fast(image) | |
answers = stability_api.generate( | |
prompt=str(output), | |
seed=34567, | |
steps=30, | |
samples=5 | |
) | |
imglist = [] | |
for resp in answers: | |
for artifact in resp.artifacts: | |
if artifact.finish_reason == generation.FILTER: | |
warnings.warn( | |
"Your request activated the API's safety filters and could not be processed. Please modify the prompt and try again.") | |
if artifact.type == generation.ARTIFACT_IMAGE: | |
img = Image.open(BytesIO(artifact.binary)) | |
imglist.append(img) | |
return [imglist, output] | |
inputs = [ | |
gr.inputs.Image(type='pil'), | |
gr.Radio(['best', 'classic', 'fast'], label='Models', value='fast') | |
] | |
outputs = [ | |
gr.Gallery(), | |
gr.outputs.Textbox(label='Prompt') | |
] | |
io = gr.Interface( | |
inferAndRebuild, | |
inputs, | |
outputs, | |
allow_flagging=False, | |
) | |
io.launch(debug=True) | |