image-captioning / model.py
ydshieh
update UI and samples
6f0178d
raw
history blame
2.04 kB
import json
import os, shutil
import random
from PIL import Image
import jax
from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from huggingface_hub import hf_hub_download
# create target model directory
model_dir = './models/'
os.makedirs(model_dir, exist_ok=True)
files_to_download = [
"config.json",
"flax_model.msgpack",
"merges.txt",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
"vocab.json",
"preprocessor_config.json",
]
# copy files from checkpoint hub:
for fn in files_to_download:
file_path = hf_hub_download("ydshieh/vit-gpt2-coco-en", f"ckpt_epoch_3_step_6900/{fn}")
shutil.copyfile(file_path, os.path.join(model_dir, fn))
model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
@jax.jit
def generate(pixel_values):
output_ids = model.generate(pixel_values, **gen_kwargs).sequences
return output_ids
def predict(image):
pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
output_ids = generate(pixel_values)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds[0]
def _compile():
image_path = 'samples/val_000000039769.jpg'
image = Image.open(image_path)
caption = predict(image)
image.close()
_compile()
sample_dir = './samples/'
sample_image_ids = tuple([int(f.replace('COCO_val2017_', '').replace('.jpg', '')) for f in os.listdir(sample_dir) if f.startswith('COCO_val2017_')])
with open(os.path.join(sample_dir, "coco-val2017-img-ids.json"), "r", encoding="UTF-8") as fp:
coco_2017_val_image_ids = json.load(fp)
def get_random_image_id():
image_id = random.sample(coco_2017_val_image_ids, k=1)[0]
return image_id