|
import argparse |
|
import time |
|
import logging |
|
import requests |
|
import os |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
from PIL import Image |
|
import torch |
|
from transformers import AutoTokenizer |
|
|
|
from modeling_tinyllava_elm import TinyLlavaForConditionalGeneration |
|
from configuration import * |
|
from conversion import * |
|
from utils import * |
|
|
|
|
|
|
|
def load_image(image_file): |
|
if image_file.startswith("http") or image_file.startswith("https"): |
|
response = requests.get(image_file) |
|
image = Image.open(BytesIO(response.content)).convert("RGB") |
|
else: |
|
image = Image.open(image_file).convert("RGB") |
|
return image |
|
|
|
|
|
def generate( |
|
prompt: str, |
|
model: str, |
|
tokenizer = None, |
|
image: str = None, |
|
device: str = None, |
|
max_new_tokens: int = 1024, |
|
num_beams = 1, |
|
top_p=None, |
|
temperature=0.2 |
|
): |
|
if not device: |
|
if torch.cuda.is_available() and torch.cuda.device_count(): |
|
device = "cuda:0" |
|
logging.warning( |
|
'inference device is not set, using cuda:0, %s', |
|
torch.cuda.get_device_name(0) |
|
) |
|
else: |
|
device = 'cpu' |
|
logging.warning( |
|
( |
|
'No CUDA device detected, using cpu, ' |
|
'expect slower speeds.' |
|
) |
|
) |
|
|
|
if 'cuda' in device and not torch.cuda.is_available(): |
|
raise ValueError('CUDA device requested but no CUDA device detected.') |
|
|
|
if isinstance(model, str): |
|
checkpoint_path = model |
|
|
|
model = TinyLlavaForConditionalGeneration.from_pretrained( |
|
checkpoint_path, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
config = model.config |
|
if tokenizer is None: |
|
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, model_max_length = config.tokenizer_model_max_length, |
|
padding_side = config.tokenizer_padding_side) |
|
image_processor = model.vision_tower._image_processor |
|
context_len = getattr(config, 'max_sequence_length', 2048) |
|
model.to(device).eval() |
|
|
|
|
|
if image is not None: |
|
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt |
|
conv = conv_phi_v0.copy() |
|
conv.append_message(conv.roles[0], prompt) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
if image is not None: |
|
|
|
image = load_image(image) |
|
|
|
image_tensor = process_images(image, image_processor, config).to(model.device, dtype=torch.float16) |
|
|
|
input_ids = ( |
|
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") |
|
.unsqueeze(0) |
|
.cuda() |
|
) |
|
|
|
stime = time.time() |
|
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=True if temperature > 0 else False, |
|
temperature=temperature, |
|
top_p=top_p, |
|
num_beams=num_beams, |
|
pad_token_id=tokenizer.pad_token_id, |
|
max_new_tokens=max_new_tokens, |
|
use_cache=True, |
|
|
|
) |
|
|
|
|
|
generation_time = time.time() - stime |
|
outputs = tokenizer.batch_decode( |
|
output_ids, skip_special_tokens=True |
|
)[0] |
|
|
|
|
|
|
|
outputs = outputs.strip() |
|
|
|
return outputs, generation_time |
|
def tinyllava_elm_generate_parser(): |
|
"""Argument Parser""" |
|
|
|
class KwargsParser(argparse.Action): |
|
"""Parser action class to parse kwargs of form key=value""" |
|
def __call__(self, parser, namespace, values, option_string=None): |
|
setattr(namespace, self.dest, dict()) |
|
for val in values: |
|
if '=' not in val: |
|
raise ValueError( |
|
( |
|
'Argument parsing error, kwargs are expected in' |
|
' the form of key=value.' |
|
) |
|
) |
|
kwarg_k, kwarg_v = val.split('=') |
|
try: |
|
converted_v = int(kwarg_v) |
|
except ValueError: |
|
try: |
|
converted_v = float(kwarg_v) |
|
except ValueError: |
|
converted_v = kwarg_v |
|
getattr(namespace, self.dest)[kwarg_k] = converted_v |
|
|
|
parser = argparse.ArgumentParser('TinyLLaVA-OpenELM Generate Module') |
|
parser.add_argument( |
|
'--model', |
|
dest='model', |
|
help='Path to the hf converted model.', |
|
required=True, |
|
type=str, |
|
) |
|
parser.add_argument( |
|
'--prompt', |
|
dest='prompt', |
|
help='Prompt for LLM call.', |
|
default='', |
|
type=str, |
|
) |
|
parser.add_argument( |
|
'--device', |
|
dest='device', |
|
help='Device used for inference.', |
|
type=str, |
|
) |
|
parser.add_argument("--image", type=str, default=None) |
|
parser.add_argument("--temperature", type=float, default=0) |
|
parser.add_argument("--top_p", type=float, default=None) |
|
parser.add_argument("--num_beams", type=int, default=1) |
|
parser.add_argument("--max_new_tokens", type=int, default=512) |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == '__main__': |
|
args = tinyllava_elm_generate_parser() |
|
prompt = args.prompt |
|
model = TinyLlavaForConditionalGeneration.from_pretrained(args.model) |
|
|
|
output_text, genertaion_time = generate( |
|
prompt=prompt, |
|
image=args.image, |
|
model=args.model, |
|
device=args.device, |
|
max_new_tokens = args.max_new_tokens, |
|
num_beams = args.num_beams, |
|
top_p=args.top_p, |
|
temperature=args.temperature |
|
) |
|
|
|
print_txt = ( |
|
f'\r\n{"=" * os.get_terminal_size().columns}\r\n' |
|
'\033[1m Prompt + Generated Output\033[0m\r\n' |
|
f'{"-" * os.get_terminal_size().columns}\r\n' |
|
f'{output_text}\r\n' |
|
f'{"-" * os.get_terminal_size().columns}\r\n' |
|
'\r\nGeneration took' |
|
f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' |
|
'seconds.\r\n' |
|
) |
|
print(print_txt) |