Spaces:
Runtime error
Runtime error
import os | |
import re | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel | |
# Pattern to ignore all the text after 2 or more full stops | |
regex_pattern = "[.]{2,}" | |
def post_process(text): | |
try: | |
text = text.strip() | |
text = re.split(regex_pattern, text)[0] | |
except Exception as e: | |
print(e) | |
pass | |
return text | |
def set_example_image(example: list) -> dict: | |
return gr.Image.update(value=example[0]) | |
def predict(image, max_length=64, num_beams=4): | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device) | |
with torch.no_grad(): | |
output_ids = model.generate( | |
pixel_values, | |
max_length=max_length, | |
num_beams=num_beams, | |
return_dict_in_generate=True, | |
).sequences | |
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
pred = post_process(preds[0]) | |
return pred | |
model_name_or_path = "deepklarity/poster2plot" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load model. | |
model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path) | |
model.to(device) | |
print("Loaded model") | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path) | |
print("Loaded feature_extractor") | |
tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True) | |
if model.decoder.name_or_path == "gpt2": | |
tokenizer.pad_token = tokenizer.eos_token | |
print("Loaded tokenizer") | |
examples = [[f"examples/{filename}"] for filename in next(os.walk('examples'), (None, None, []))[2]] | |
print(f"Loaded {len(examples)} example images") | |
with gr.Blocks(css="#title { margin: 0 auto; padding: 25px 25px 25px 25px }") as poster2plot: | |
with gr.Column(): | |
with gr.Row(): | |
gr.Markdown("# Poster2Plot: Upload a Movie/T.V show poster to generate a plot", elem_id='title') | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_image = gr.Image(label='Input Image', type='numpy') | |
with gr.Row(): | |
submit_button = gr.Button(value="Submit", variant='primary') | |
with gr.Column(): | |
plot = gr.Textbox(label="Plot") | |
with gr.Row(): | |
example_images = gr.Dataset(components=[input_image], samples=examples) | |
with gr.Row(): | |
gr.Markdown("Made by: [dk-crazydiv](https://twitter.com/kartik_godawat) and [dsr](https://twitter.com/dsr_ai)") | |
submit_button.click(fn=predict, inputs=[input_image], outputs=[plot]) | |
example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components) | |
poster2plot.launch() | |