import base64 import functools import os import tempfile import time import types import uuid from functools import partial from io import BytesIO import numpy as np from PIL.Image import Resampling from gradio_utils.grclient import check_job from src.enums import valid_imagegen_models, valid_imagechange_models, valid_imagestyle_models, docs_joiner_default, \ llava16_model_max_length, llava16_image_tokens, llava16_image_fudge, VIDEO_EXTENSIONS, IMAGE_EXTENSIONS from src.image_utils import fix_image_file from src.utils import is_gradio_version4, get_docs_tokens, get_limited_text, makedirs, call_subprocess_onetask, \ have_fiftyone, sanitize_filename def is_animated_gif(file_path): if not file_path.endswith('.gif'): return False from PIL import Image, UnidentifiedImageError try: gif = Image.open(file_path) except (FileNotFoundError, UnidentifiedImageError): return False try: gif.seek(1) except EOFError: return False else: return True def gif_to_mp4(gif_path): from moviepy.editor import VideoFileClip """ Convert an animated GIF to an MP4 video. :param gif_path: Path to the input GIF file. :param mp4_path: Path to the output MP4 file. """ clip = VideoFileClip(gif_path) mp4_path = gif_path.replace('.gif', '.mp4') clip.write_videofile(mp4_path, codec='libx264') return mp4_path def is_video_file(file_path): """ Determine if the file is a video by checking its extension, frame count, and frame rate. :param file_path: Path to the file. :return: True if the file is a video, False otherwise. """ ext = os.path.splitext(file_path)[-1].lower() if ext not in VIDEO_EXTENSIONS: return False import cv2 video = cv2.VideoCapture(file_path) frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) frame_rate = video.get(cv2.CAP_PROP_FPS) video.release() # A valid video should have more than 0 frames and a positive frame rate return frame_count >= 1 and frame_rate > 0 def img_to_base64(image_file, resolution=None, output_format=None, str_bytes=True): # assert image_file.lower().endswith('jpg') or image_file.lower().endswith('jpeg') from PIL import Image from pathlib import Path ext = Path(image_file).suffix iformat = IMAGE_EXTENSIONS.get(ext) assert iformat is not None, "Invalid file extension %s for file %s" % (ext, image_file) image = Image.open(image_file) if resolution: image = image.resize(resolution, resample=Resampling.BICUBIC) if output_format: oformat = output_format.upper() elif iformat not in ['JPEG', 'PNG']: # use jpeg by default if nothing set, so most general format allowed oformat = 'JPEG' else: oformat = iformat buffered = BytesIO() image.save(buffered, format=oformat) img_str = base64.b64encode(buffered.getvalue()) # FIXME: unsure about below if str_bytes: img_str = str(bytes("data:image/%s;base64," % oformat.lower(), encoding='utf-8') + img_str) else: img_str = f"data:image/{oformat.lower()};base64,{img_str.decode('utf-8')}" return img_str def base64_to_img(img_str, output_path): """ Convert a base64 string to an image or video file. :param img_str: The base64 encoded string with the image or video data. :param output_path: The path (without extension) where the output file will be saved. :return: The path to the saved file. """ if img_str.startswith("b'"): # check if was a string of bytes joined like when str_bytes=True in above function img_str = img_str[2:-1] # This removes the first b' and the last ' # Split the string on "," to separate the metadata from the base64 data meta, base64_data = img_str.split(",", 1) # Extract the format from the metadata img_format = meta.split(';')[0].split('/')[-1] # Decode the base64 string to bytes img_bytes = base64.b64decode(base64_data) # Create output file path with the correct format extension output_file = f"{output_path}.{img_format}" # Write the bytes to a file with open(output_file, "wb") as f: f.write(img_bytes) print(f"Image saved to {output_file} with format {img_format}") return output_file def video_to_base64frames(video_path): import cv2 video = cv2.VideoCapture(video_path) base64Frames = [] while video.isOpened(): success, frame = video.read() if not success: break _, buffer = cv2.imencode(".jpg", frame) base64Frames.append(base64.b64encode(buffer).decode("utf-8")) video.release() print(len(base64Frames), "frames read.") return base64Frames @functools.lru_cache(maxsize=10000, typed=False) def video_to_frames(video_path, output_dir, resolution=None, image_format="jpg", video_frame_period=None, extract_frames=None, verbose=False): import cv2 """ Convert video to frames, save them as image files in the specified format, and return the list of file names. :param video_path: Path to the input video file. :param output_dir: Directory where the output frames will be saved. :param resolution: Tuple specifying the desired resolution (width, height) or None to keep the original resolution. :param image_format: String specifying the desired image format (e.g., "jpg", "png"). :param video_frame_period: How often to sample frames from the video. If None, every 20th frame is saved. e.g. if pass non-real-time video, can set to 1 to save all frames, to mimic passing actual frames separately otherwise :param extract_frames: Number of frames to extract from the video. If None, all frames are saved. :param verbose: Boolean to control whether to print progress messages. :return: List of file names for the saved frames. Example usage: file_names = video_to_frames("input_video.mp4", "output_frames", resolution=(640, 480), image_format="png", verbose=True) print(file_names) """ if output_dir is None: output_dir = os.path.join(tempfile.gettempdir(), 'image_path_%s' % sanitize_filename(video_path)) enable_fiftyone = True # optimal against issues if using function server if enable_fiftyone and \ have_fiftyone and \ (video_frame_period is not None and video_frame_period < 1 or not os.path.isfile(video_path)): # handles either automatic period or urls from src.vision.extract_movie import extract_unique_frames args = () urls = [video_path] if not os.path.isfile(video_path) else None file = video_path if os.path.isfile(video_path) else None kwargs = {'urls': urls, 'file': file, 'download_dir': None, 'export_dir': output_dir, 'extract_frames': extract_frames} # fifty one is complex program and leaves around processes if False: # NOTE: Assumes using function server to handle isolation if want production grade behavior func_new = partial(call_subprocess_onetask, extract_unique_frames, args, kwargs) else: func_new = functools.partial(extract_unique_frames, *args, **kwargs) export_dir = func_new() return [os.path.join(export_dir, x) for x in os.listdir(export_dir)] if video_frame_period and video_frame_period < 1: video_frame_period = None if video_frame_period in [None, 0]: # e.g. if no fiftyone and so can't do 0 case, then assume ok to do period based total_frames = count_frames(video_path) extract_frames = min(20, extract_frames or 20) # no more than 20 frames total for now video_frame_period = total_frames // extract_frames video = cv2.VideoCapture(video_path) makedirs(output_dir) image_format = image_format or '.jpg' frame_count = 0 file_names = [] while True: success, frame = video.read() if not success: break # keep first frame, then keep a frame every video_frame_resolution frames if frame_count % video_frame_period != 0: frame_count += 1 continue if resolution: frame = cv2.resize(frame, resolution) frame_filename = os.path.join(output_dir, f"frame_{frame_count:04d}.{image_format}") cv2.imwrite(frame_filename, frame) file_names.append(frame_filename) frame_count += 1 video.release() if verbose: print(f"{frame_count} frames saved to {output_dir}.") return file_names def count_frames(video_path): import cv2 # Open the video file video = cv2.VideoCapture(video_path) # Check if video opened successfully if not video.isOpened(): print("Error: Could not open video.") return -1 # Get the total number of frames total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) # Release the video capture object video.release() return total_frames def process_file_list(file_list, output_dir, resolution=None, image_format="jpg", rotate_align_resize_image=True, video_frame_period=None, extract_frames=None, verbose=False): # FIXME: resolution is not used unless video, could use for every case, but resolution is set later when byte encoding for LLMs """ Process a list of files, converting any videos to frames and updating the list to only contain image files. :param file_list: List of file paths to be processed. :param output_dir: Directory where the output frames will be saved. :param resolution: Tuple specifying the desired resolution (width, height) or None to keep the original resolution. Does not affect images as inputs, handled elsewhere when converting to base64 for LLM :param image_format: String specifying the desired image format (e.g., "jpg", "png"). :param rotate_align_resize_image: Whether to apply rotation, alignment, resize before giving to LLM :param video_frame_period: Period to save frames, if <1 then automatic :param extract_frames: how many frames to extract if automatic period mode :param verbose: Boolean to control whether to print progress messages. :return: Updated list of file names containing only image files. """ if file_list is None: file_list = [] if image_format is None: image_format = 'jpg' image_files = [] for file in file_list: # i.e. if not file, then maybe youtube url is_maybe_video = os.path.isfile(file) and is_video_file(file) or not os.path.isfile(file) or is_animated_gif( file) if is_animated_gif(file): # FIXME: could convert gif -> mp4 with gif_to_mp4(gif_path)() # fiftyone can't handle animated gifs extract_frames = None if video_frame_period is not None and video_frame_period < 1: video_frame_period = None if is_maybe_video: # If it's a valid video, extract frames if verbose: print(f"Processing video file: {file}") # output_dir is None means only use file for location frame_files = video_to_frames(file, None, resolution, image_format, video_frame_period, extract_frames, verbose) image_files.extend(frame_files) else: # If it's not a valid video, add it to the image file list if rotate_align_resize_image: file_fixed = fix_image_file(file, do_align=True, do_rotate=True, do_pad=False, relaxed_resize=True) else: file_fixed = file image_files.append(file_fixed) return image_files def fix_llava_prompt(file, prompt=None, allow_prompt_auto=True, ): if prompt in ['auto', None] and allow_prompt_auto: prompt = "Describe the image and what does the image say?" # prompt = "According to the image, describe the image in full details with a well-structured response." if file in ['', None]: # let model handle if no prompt and no file prompt = '' # allow prompt = '', will describe image by default if prompt is None: if os.environ.get('HARD_ASSERTS'): raise ValueError('prompt is None') else: prompt = '' return prompt def llava_prep(file_list, llava_model, image_model='llava-v1.6-vicuna-13b', client=None): assert client is not None or len(file_list) == 1 file_list_new = [] image_model_list_new = [] for file in file_list: image_model_new, client, file_new = _llava_prep(file, llava_model, image_model=image_model, client=client) file_list_new.append(file_new) image_model_list_new.append(image_model_new) assert len(image_model_list_new) >= 1 assert len(file_list_new) >= 1 return image_model_list_new[0], client, file_list_new def _llava_prep(file, llava_model, image_model='llava-v1.6-vicuna-13b', client=None): prefix = '' if llava_model.startswith('http://'): prefix = 'http://' if llava_model.startswith('https://'): prefix = 'https://' llava_model = llava_model[len(prefix):] llava_model_split = llava_model.split(':') assert len(llava_model_split) >= 2 # FIXME: Allow choose model in UI if len(llava_model_split) >= 2: pass # assume default model is ok # llava_ip = llava_model_split[0] # llava_port = llava_model_split[1] if len(llava_model_split) >= 3: image_model = llava_model_split[2] llava_model = ':'.join(llava_model_split[:2]) # add back prefix llava_model = prefix + llava_model if client is None: from gradio_utils.grclient import GradioClient client = GradioClient(llava_model, check_hash=False, serialize=is_gradio_version4) client.setup() if not is_gradio_version4 and file and os.path.isfile(file): file = img_to_base64(file) assert image_model, "No image model specified" if isinstance(file, np.ndarray): from PIL import Image im = Image.fromarray(file) file = "%s.jpeg" % str(uuid.uuid4()) im.save(file) return image_model, client, file server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" def get_prompt_with_texts(texts, prompt, max_new_tokens, min_max_new_tokens, tokenizer): if tokenizer is None: raise RuntimeError("Not setup for multi-image without tokenizer") # from transformers import AutoTokenizer # tokenizer = AutoTokenizer.from_pretrained(base_model) if hasattr(tokenizer, 'model_max_length'): model_max_length = tokenizer.model_max_length else: model_max_length = llava16_model_max_length user_part = '\n\nReduce the above information into single correct answer to the following question: ' + prompt user_part_tokens = len(tokenizer.encode(user_part)) text_context_list = ['Answer #%s:\n\n%s' % (ii, text) for ii, text in enumerate(texts)] # see if too many tokens text_tokens_trial = len(tokenizer.encode(docs_joiner_default.join(text_context_list))) if user_part_tokens + text_tokens_trial + max_new_tokens >= model_max_length: max_new_tokens = min_max_new_tokens fudge = llava16_image_fudge max_input_tokens = model_max_length - max_new_tokens - fudge # fudge for extra chars top_k_docs, one_doc_size, num_doc_tokens = \ get_docs_tokens(tokenizer, text_context_list=text_context_list, max_input_tokens=max_input_tokens) text_context_list_cut = text_context_list[:top_k_docs] texts_joined = docs_joiner_default.join(text_context_list_cut) prompt_with_texts = '\n"""\n' + texts_joined + '\n"""\n' prompt_with_texts += user_part return prompt_with_texts.replace('image', 'document').replace('Image', 'Document') def get_llava_response(file=None, llava_model=None, prompt=None, chat_conversation=[], allow_prompt_auto=False, image_model='llava-v1.6-vicuna-13b', temperature=0.2, top_p=0.7, max_new_tokens=512, min_max_new_tokens=512, tokenizer=None, image_process_mode="Default", include_image=False, client=None, max_time=None, force_stream=True, verbose=False, ): max_new_tokens = min(max_new_tokens, 1024) # for hard_cutoff to be easy to know kwargs = locals().copy() force_stream |= isinstance(file, list) and len(file) > 1 if isinstance(file, str): file_list = [file] elif isinstance(file, list): file_list = file if len(file_list) == 0: file_list = [None] else: file_list = [None] if force_stream: text = '' for res in get_llava_stream(**kwargs): text = res return text, prompt image_model = os.path.basename(image_model) # in case passed HF link prompt = fix_llava_prompt(file_list, prompt, allow_prompt_auto=allow_prompt_auto) max_new_tokens1 = max_new_tokens if len(file_list) <= 4 else min(max_new_tokens, min_max_new_tokens) if tokenizer: model_max_length = tokenizer.model_max_length else: model_max_length = llava16_model_max_length image_tokens = llava16_image_tokens if len(file_list) >= 1 and file_list[0] is not None else 0 fudge = llava16_image_fudge hard_limit_tokens = model_max_length - max_new_tokens1 - fudge - image_tokens prompt = get_limited_text(hard_limit_tokens, prompt, tokenizer, verbose=False) image_model, client, file_list = \ llava_prep(file_list, llava_model, image_model=image_model, client=client) reses = [] for file in file_list: res = client.predict(prompt, chat_conversation if len(file_list) == 1 else [], file, image_process_mode, include_image, image_model, temperature, top_p, max_new_tokens1, api_name='/textbox_api_submit') reses.append(res) if len(reses) > 1: reses = [x for x in reses if server_error_msg not in x] prompt_with_texts = get_prompt_with_texts(reses, prompt, max_new_tokens, min_max_new_tokens, tokenizer) res = client.predict(prompt_with_texts, chat_conversation, None, image_process_mode, include_image, image_model, temperature, top_p, max_new_tokens, api_name='/textbox_api_submit') else: res = reses[0] return res, prompt def get_llava_stream(file, llava_model, prompt=None, chat_conversation=[], allow_prompt_auto=False, image_model='llava-v1.6-vicuna-13b', temperature=0.2, top_p=0.7, max_new_tokens=512, min_max_new_tokens=512, tokenizer=None, image_process_mode="Default", include_image=False, client=None, verbose_level=0, max_time=None, force_stream=True, # dummy arg verbose=False, ): max_new_tokens = min(max_new_tokens, 1024) # for hard_cutoff to be easy to know if isinstance(file, str): file_list = [file] elif isinstance(file, list): file_list = file if len(file_list) == 0: file_list = [None] else: file_list = [None] image_model = os.path.basename(image_model) # in case passed HF link prompt = fix_llava_prompt(file_list, prompt, allow_prompt_auto=allow_prompt_auto) max_new_tokens1 = max_new_tokens if len(file_list) <= 4 else min(max_new_tokens, min_max_new_tokens) if tokenizer: model_max_length = tokenizer.model_max_length else: model_max_length = llava16_model_max_length image_tokens = llava16_image_tokens if len(file_list) >= 1 and file_list[0] is not None else 0 fudge = llava16_image_fudge hard_limit_tokens = model_max_length - max_new_tokens1 - fudge - image_tokens prompt = get_limited_text(hard_limit_tokens, prompt, tokenizer) image_model, client, file_list = \ llava_prep(file_list, llava_model, image_model=image_model, client=client) jobs = [] for file in file_list: job = client.submit(prompt, chat_conversation, file, image_process_mode, include_image, image_model, temperature, top_p, max_new_tokens1, api_name='/textbox_api_submit') jobs.append(job) t0 = time.time() job_outputs_nums = [0] * len(jobs) texts = [''] * len(jobs) done_all = False reses = [''] * len(jobs) while True: for ji, job in enumerate(jobs): if verbose_level == 2: print("Inside: %s" % llava_model, time.time() - t0, flush=True) e = check_job(job, timeout=0, raise_exception=False) if e is not None: continue if max_time is not None and time.time() - t0 > max_time: done_all = True break outputs_list = job.outputs().copy() job_outputs_num_new = len(outputs_list[job_outputs_nums[ji]:]) for num in range(job_outputs_num_new): reses[ji] = outputs_list[job_outputs_nums[ji] + num] if verbose_level == 2: print('Stream %d: %s' % (num, reses[ji]), flush=True) elif verbose_level == 1: print('Stream %d' % (job_outputs_nums[ji] + num), flush=True) if reses[ji]: texts[ji] = reses[ji] if len(jobs) == 1: yield texts[ji] job_outputs_nums[ji] += job_outputs_num_new time.sleep(0.005) if done_all or all([job.done() for job in jobs]): break for ji, job in enumerate(jobs): e = check_job(job, timeout=0, raise_exception=False) if e is not None: continue outputs_list = job.outputs().copy() job_outputs_num_new = len(outputs_list[job_outputs_nums[ji]:]) for num in range(job_outputs_num_new): reses[ji] = outputs_list[job_outputs_nums[ji] + num] if verbose_level == 2: print('Final Stream %d: %s' % (num, reses[ji]), flush=True) elif verbose_level == 1: print('Final Stream %d' % (job_outputs_nums[ji] + num), flush=True) if reses[ji]: texts[ji] = reses[ji] if len(jobs) == 1: yield texts[ji] job_outputs_nums[ji] += job_outputs_num_new if verbose_level == 1: print("total job_outputs_num=%d" % job_outputs_nums[ji], flush=True) if len(jobs) > 1: # recurse without image(s) ntexts_before = len(texts) texts = [x for x in texts if server_error_msg not in x] ntexts_after = len(texts) if ntexts_after != ntexts_before: print("texts: %s -> %s" % (ntexts_before, ntexts_after)) prompt_with_texts = get_prompt_with_texts(texts, prompt, max_new_tokens, min_max_new_tokens, tokenizer) text = '' max_new_tokens = max_new_tokens if len(jobs) > 4 else min(max_new_tokens, min_max_new_tokens) for res in get_llava_stream(None, llava_model, prompt=prompt_with_texts, chat_conversation=chat_conversation, allow_prompt_auto=allow_prompt_auto, image_model=image_model, temperature=temperature, top_p=top_p, # avoid long outputs max_new_tokens=max_new_tokens, min_max_new_tokens=min_max_new_tokens, tokenizer=tokenizer, image_process_mode=image_process_mode, include_image=include_image, client=client, verbose_level=verbose_level, max_time=max_time, force_stream=force_stream, # dummy arg verbose=verbose, ): text = res yield text else: assert len(texts) == 1 text = texts[0] return text def get_image_model_dict(enable_image, image_models, image_gpu_ids, ): image_dict = {} if not enable_image: return image_dict if image_gpu_ids is None: image_gpu_ids = ['auto'] * len(image_models) if not image_gpu_ids: image_gpu_ids = ['auto'] * len(image_models) for image_model_name in valid_imagegen_models + valid_imagechange_models + valid_imagestyle_models: if image_model_name in image_models: imagegen_index = image_models.index(image_model_name) if image_model_name == 'sdxl_turbo': from src.vision.sdxl_turbo import get_pipe_make_image, make_image elif image_model_name == 'playv2': from src.vision.playv2 import get_pipe_make_image, make_image elif image_model_name == 'sdxl': from src.vision.stable_diffusion_xl import get_pipe_make_image, make_image elif image_model_name == 'sd3': from src.vision.stable_diffusion_xl import get_pipe_make_image, make_image get_pipe_make_image = functools.partial(get_pipe_make_image, base_model='stabilityai/stable-diffusion-3-medium-diffusers', refiner_model=None) make_image = functools.partial(make_image, base_model='stabilityai/stable-diffusion-3-medium-diffusers', refiner_model=None) elif image_model_name == 'flux.1-dev': from src.vision.flux import get_pipe_make_image, make_image elif image_model_name == 'flux.1-schnell': from src.vision.flux import get_pipe_make_image_2 as get_pipe_make_image from src.vision.flux import make_image elif image_model_name == 'sdxl_change': from src.vision.sdxl_turbo import get_pipe_change_image as get_pipe_make_image, change_image make_image = change_image # FIXME: style else: raise ValueError("Invalid image_model_name=%s" % image_model_name) pipe = get_pipe_make_image(gpu_id=image_gpu_ids[imagegen_index]) image_dict[image_model_name] = dict(pipe=pipe, make_image=make_image) return image_dict def pdf_to_base64_pngs(pdf_path, quality=75, max_size=(1024, 1024), ext='png', pages=None): """ Define the function to convert a pdf slide deck to a list of images. Note that we need to ensure we resize images to keep them within Claude's size limits. """ # https://github.com/anthropics/anthropic-cookbook/blob/main/multimodal/reading_charts_graphs_powerpoints.ipynb from PIL import Image import io import fitz import tempfile # Open the PDF file doc = fitz.open(pdf_path) # Iterate through each page of the PDF images = [] if pages is None: pages = list(range(doc.page_count)) else: assert isinstance(pages, (list, tuple, types.GeneratorType)) for page_num in pages: # Load the page page = doc.load_page(page_num) # Render the page as a PNG image pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) # Save the PNG image output_path = f"{tempfile.mkdtemp()}/page_{page_num + 1}.{ext}" pix.save(output_path) images.append(output_path) # Close the PDF document doc.close() if ext == 'png': iformat = 'PNG' elif ext in ['jpeg', 'jpg']: iformat = 'JPEG' else: raise ValueError("No such ext=%s" % ext) images = [Image.open(image) for image in images] base64_encoded_pngs = [] for image in images: # Resize the image if it exceeds the maximum size if image.size[0] > max_size[0] or image.size[1] > max_size[1]: image.thumbnail(max_size, Image.Resampling.LANCZOS) image_data = io.BytesIO() image.save(image_data, format=iformat, optimize=True, quality=quality) image_data.seek(0) base64_encoded = base64.b64encode(image_data.getvalue()).decode('utf-8') base64_encoded_pngs.append(base64_encoded) return base64_encoded_pngs