from typing import List import torch from PIL import Image from surya.input.processing import convert_if_not_rgb from surya.postprocessing.math.latex import fix_math, contains_math from surya.postprocessing.text import truncate_repetitions from surya.settings import settings from tqdm import tqdm import numpy as np import torch.nn.functional as F def get_batch_size(): batch_size = settings.RECOGNITION_BATCH_SIZE if batch_size is None: batch_size = 32 if settings.TORCH_DEVICE_MODEL == "mps": batch_size = 64 # 12GB RAM max if settings.TORCH_DEVICE_MODEL == "cuda": batch_size = 256 return batch_size def batch_recognition(images: List, languages: List[List[str]], model, processor, batch_size=None): import inspect print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&",inspect.getargspec(processor).args) assert all([isinstance(image, Image.Image) for image in images]) assert len(images) == len(languages) for l in languages: assert len(l) <= settings.RECOGNITION_MAX_LANGS, f"OCR only supports up to {settings.RECOGNITION_MAX_LANGS} languages per image, you passed {l}." images = [image.convert("RGB") for image in images] # also copies the images if batch_size is None: batch_size = get_batch_size() output_text = [] confidences = [] dec_config = model.config.decoder layer_count = dec_config.decoder_layers kv_heads = dec_config.kv_heads head_dim = int(dec_config.d_model / dec_config.decoder_attention_heads) min_val = torch.finfo(model.dtype).min if settings.RECOGNITION_STATIC_CACHE: # We'll re-use these for all batches to avoid recopying kv_mask = torch.full((batch_size, 1, 1, settings.RECOGNITION_MAX_TOKENS + 1), min_val, dtype=model.dtype, device=model.device) # The +1 accounts for start token initial_attn_mask = torch.full((batch_size, 1, settings.RECOGNITION_MAX_LANGS + 1, settings.RECOGNITION_MAX_LANGS + 1), min_val, dtype=model.dtype, device=model.device) # Decoder kv cache # 7 (layers) x 2 (kv) x bs x 4 (heads) x max tokens x 64 (head dim) decoder_cache = [torch.zeros((2, batch_size, kv_heads, settings.RECOGNITION_MAX_TOKENS, head_dim), dtype=model.dtype, device=model.device) for _ in range(layer_count)] # Prefill decoder_input = torch.zeros((batch_size, settings.RECOGNITION_MAX_LANGS + 1), dtype=torch.long, device=model.device) else: initial_kv_mask = torch.zeros((batch_size, 1, 1, 1), dtype=model.dtype, device=model.device) initial_attn_mask = torch.zeros((batch_size, 1, settings.RECOGNITION_MAX_LANGS + 1, settings.RECOGNITION_MAX_LANGS + 1), dtype=model.dtype, device=model.device) processed_batches = processor(text=[""] * len(images), images=images, lang=languages) for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"): batch_langs = languages[i:i+batch_size] has_math = ["_math" in lang for lang in batch_langs] batch_pixel_values = processed_batches["pixel_values"][i:i+batch_size] batch_langs = processed_batches["langs"][i:i+batch_size] max_lang_len = max([len(lang) for lang in batch_langs]) # Pad languages to max length if needed, to ensure we can convert to a tensor for lang_idx in range(len(batch_langs)): lang_len = len(batch_langs[lang_idx]) if lang_len < max_lang_len: batch_langs[lang_idx] = [processor.tokenizer.pad_id] * (max_lang_len - lang_len) + batch_langs[lang_idx] batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs] current_batch_size = len(batch_pixel_values) batch_langs = torch.tensor(np.stack(batch_langs, axis=0), dtype=torch.long, device=model.device) batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device) batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device) token_count = 0 inference_token_count = batch_decoder_input.shape[-1] batch_predictions = [[] for _ in range(current_batch_size)] decoder_input_pad = torch.zeros((batch_size - current_batch_size, 1), dtype=torch.long, device=model.device) if settings.RECOGNITION_STATIC_CACHE: # Reset shared tensors if i > 0: # Decoder cache for layer_cache in decoder_cache: layer_cache.fill_(0) # KV mask kv_mask.fill_(min_val) kv_mask[:, :, :, -1] = 0 kv_mask[:, :, :, :inference_token_count] = 0 # Attention mask initial_attn_mask.fill_(min_val) # Prefill decoder_input.fill_(0) # Prefill attention mask attention_mask = initial_attn_mask attention_mask[:, :, -inference_token_count:, -inference_token_count:] = 0 # Prefill input decoder_input[:current_batch_size, -inference_token_count:] = batch_decoder_input batch_decoder_input = decoder_input # Pad to max batch size batch_langs = torch.cat([batch_langs, torch.zeros((batch_size - current_batch_size, batch_langs.shape[-1]), dtype=torch.long, device=model.device)], dim=0) batch_pixel_values = torch.cat([batch_pixel_values, torch.zeros((batch_size - current_batch_size,) + batch_pixel_values.shape[1:], dtype=model.dtype, device=model.device)], dim=0) else: # Select seed attention mask kv_mask = initial_kv_mask[:current_batch_size] kv_mask.fill_(0) # Select prefill attention mask attention_mask = initial_attn_mask[:current_batch_size, :, :inference_token_count, :inference_token_count] decoder_cache = [None] * layer_count encoder_outputs = None sequence_scores = None encoder_cache = [None] * layer_count all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) with torch.no_grad(): # inference_mode doesn't work with torch.compile # Run post-prefill tokens while token_count < settings.RECOGNITION_MAX_TOKENS: is_prefill = token_count == 0 return_dict = model( decoder_input_ids=batch_decoder_input, decoder_attention_mask=attention_mask, decoder_self_kv_cache=None if is_prefill else decoder_cache, decoder_cross_kv_cache=None if is_prefill else encoder_cache, decoder_past_token_count=token_count, decoder_langs=batch_langs, pixel_values=batch_pixel_values, encoder_outputs=encoder_outputs, return_dict=True, ) logits = return_dict["logits"][:current_batch_size] # Ignore batch padding preds = torch.argmax(logits[:, -1], dim=-1) scores = torch.max(F.softmax(logits, dim=-1), dim=-1).values done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id) done = done all_done = all_done | done scores[all_done == 1] = 0 if is_prefill: sequence_scores = scores encoder_outputs = (return_dict["encoder_last_hidden_state"],) else: sequence_scores = torch.cat([sequence_scores, scores], dim=1) if all_done.all(): break past_key_values = return_dict["past_key_values"] token_range = torch.arange(token_count, token_count + inference_token_count, device=model.device) for layer_idx, layer in enumerate(past_key_values): if is_prefill: encoder_cache[layer_idx] = layer[1] if settings.RECOGNITION_STATIC_CACHE: # Fill in entries in static kv cache decoder_cache[layer_idx][:, :, :, token_range, :] = layer[0][:, :, :, -inference_token_count:, :] else: # Cat to generate new kv cache including current tokens if is_prefill: decoder_cache[layer_idx] = layer[0] else: decoder_cache[layer_idx] = torch.cat([decoder_cache[layer_idx], layer[0]], dim=3) batch_decoder_input = preds.unsqueeze(1) if settings.RECOGNITION_STATIC_CACHE: # Setup new attention mask and input token kv_mask[:, :, :, token_count:(token_count + inference_token_count)] = 0 batch_decoder_input = torch.cat([batch_decoder_input, decoder_input_pad], dim=0) # Pad to full batch else: kv_mask = torch.cat([kv_mask, torch.zeros((current_batch_size, 1, 1, inference_token_count), dtype=model.dtype, device=model.device)], dim=-1) attention_mask = kv_mask for j, (pred, status) in enumerate(zip(preds, all_done)): if not status: batch_predictions[j].append(int(pred)) token_count += inference_token_count inference_token_count = batch_decoder_input.shape[-1] sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) detected_text = processor.tokenizer.batch_decode(batch_predictions) detected_text = [truncate_repetitions(dt) for dt in detected_text] # Postprocess to fix LaTeX output (add $$ signs, etc) detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)] output_text.extend(detected_text) confidences.extend(sequence_scores.tolist()) return output_text, confidences