magi / modelling_magi.py
ragavsachdeva's picture
Update modelling_magi.py
9ebec84 verified
from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel
from transformers.models.conditional_detr.modeling_conditional_detr import (
ConditionalDetrMLPPredictionHead,
ConditionalDetrModelOutput,
ConditionalDetrHungarianMatcher,
inverse_sigmoid,
)
from .configuration_magi import MagiConfig
from .processing_magi import MagiProcessor
from torch import nn
from typing import Optional, List
import torch
from einops import rearrange, repeat, einsum
from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
class MagiModel(PreTrainedModel):
config_class = MagiConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.processor = MagiProcessor(config)
if not config.disable_ocr:
self.ocr_model = VisionEncoderDecoderModel(config.ocr_model_config)
if not config.disable_crop_embeddings:
self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config)
if not config.disable_detections:
self.num_non_obj_tokens = 5
self.detection_transformer = ConditionalDetrModel(config.detection_model_config)
self.bbox_predictor = ConditionalDetrMLPPredictionHead(
input_dim=config.detection_model_config.d_model,
hidden_dim=config.detection_model_config.d_model,
output_dim=4, num_layers=3
)
self.is_this_text_a_dialogue = ConditionalDetrMLPPredictionHead(
input_dim=config.detection_model_config.d_model,
hidden_dim=config.detection_model_config.d_model,
output_dim=1,
num_layers=3
)
self.character_character_matching_head = ConditionalDetrMLPPredictionHead(
input_dim = 3 * config.detection_model_config.d_model + (2 * config.crop_embedding_model_config.hidden_size if not config.disable_crop_embeddings else 0),
hidden_dim=config.detection_model_config.d_model,
output_dim=1, num_layers=3
)
self.text_character_matching_head = ConditionalDetrMLPPredictionHead(
input_dim = 3 * config.detection_model_config.d_model,
hidden_dim=config.detection_model_config.d_model,
output_dim=1, num_layers=3
)
self.class_labels_classifier = nn.Linear(
config.detection_model_config.d_model, config.detection_model_config.num_labels
)
self.matcher = ConditionalDetrHungarianMatcher(
class_cost=config.detection_model_config.class_cost,
bbox_cost=config.detection_model_config.bbox_cost,
giou_cost=config.detection_model_config.giou_cost
)
def move_to_device(self, input):
return move_to_device(input, self.device)
def predict_detections_and_associations(
self,
images,
move_to_device_fn=None,
character_detection_threshold=0.3,
panel_detection_threshold=0.2,
text_detection_threshold=0.25,
character_character_matching_threshold=0.65,
text_character_matching_threshold=0.4,
):
assert not self.config.disable_detections
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images)
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
# create callback fn
def get_character_character_matching_scores(batch_character_indices, batch_bboxes):
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
crop_bboxes = [batch_bboxes[i][batch_character_indices[i]] for i in range(len(batch_character_indices))]
crop_embeddings_for_batch = self.predict_crop_embeddings(images, crop_bboxes, move_to_device_fn)
character_obj_tokens_for_batch = []
c2c_tokens_for_batch = []
for predicted_obj_tokens, predicted_c2c_tokens, character_indices in zip(predicted_obj_tokens_for_batch, predicted_c2c_tokens_for_batch, batch_character_indices):
character_obj_tokens_for_batch.append(predicted_obj_tokens[character_indices])
c2c_tokens_for_batch.append(predicted_c2c_tokens)
return self._get_character_character_affinity_matrices(
character_obj_tokens_for_batch=character_obj_tokens_for_batch,
crop_embeddings_for_batch=crop_embeddings_for_batch,
c2c_tokens_for_batch=c2c_tokens_for_batch,
apply_sigmoid=True,
)
# create callback fn
def get_text_character_matching_scores(batch_text_indices, batch_character_indices):
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
text_obj_tokens_for_batch = []
character_obj_tokens_for_batch = []
t2c_tokens_for_batch = []
for predicted_obj_tokens, predicted_t2c_tokens, text_indices, character_indices in zip(predicted_obj_tokens_for_batch, predicted_t2c_tokens_for_batch, batch_text_indices, batch_character_indices):
text_obj_tokens_for_batch.append(predicted_obj_tokens[text_indices])
character_obj_tokens_for_batch.append(predicted_obj_tokens[character_indices])
t2c_tokens_for_batch.append(predicted_t2c_tokens)
return self._get_text_character_affinity_matrices(
character_obj_tokens_for_batch=character_obj_tokens_for_batch,
text_obj_tokens_for_this_batch=text_obj_tokens_for_batch,
t2c_tokens_for_batch=t2c_tokens_for_batch,
apply_sigmoid=True,
)
# create callback fn
def get_dialog_confidence_scores(batch_text_indices):
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
dialog_confidence = []
for predicted_obj_tokens, text_indices in zip(predicted_obj_tokens_for_batch, batch_text_indices):
confidence = self.is_this_text_a_dialogue(predicted_obj_tokens[text_indices]).sigmoid()
dialog_confidence.append(rearrange(confidence, "i 1 -> i"))
return dialog_confidence
return self.processor.postprocess_detections_and_associations(
predicted_bboxes=predicted_bboxes,
predicted_class_scores=predicted_class_scores,
original_image_sizes=torch.stack([torch.tensor(img.shape[:2]) for img in images], dim=0).to(predicted_bboxes.device),
get_character_character_matching_scores=get_character_character_matching_scores,
get_text_character_matching_scores=get_text_character_matching_scores,
get_dialog_confidence_scores=get_dialog_confidence_scores,
character_detection_threshold=character_detection_threshold,
panel_detection_threshold=panel_detection_threshold,
text_detection_threshold=text_detection_threshold,
character_character_matching_threshold=character_character_matching_threshold,
text_character_matching_threshold=text_character_matching_threshold,
)
def predict_crop_embeddings(self, images, crop_bboxes, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
if self.config.disable_crop_embeddings:
return None
assert isinstance(crop_bboxes, List), "please provide a list of bboxes for each image to get embeddings for"
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
# temporarily change the mask ratio from default to the one specified
old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
crops_per_image = []
num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
crops = self.processor.crop_image(image, bboxes)
assert len(crops) == num_crops
crops_per_image.extend(crops)
if len(crops_per_image) == 0:
return [[] for _ in crop_bboxes]
crops_per_image = self.processor.preprocess_inputs_for_crop_embeddings(crops_per_image)
crops_per_image = move_to_device_fn(crops_per_image)
# process the crops in batches to avoid OOM
embeddings = []
for i in range(0, len(crops_per_image), batch_size):
crops = crops_per_image[i:i+batch_size]
embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0]
embeddings.append(embeddings_per_batch)
embeddings = torch.cat(embeddings, dim=0)
crop_embeddings_for_batch = []
for num_crops in num_crops_per_batch:
crop_embeddings_for_batch.append(embeddings[:num_crops])
embeddings = embeddings[num_crops:]
# restore the mask ratio to the default
self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio
return crop_embeddings_for_batch
def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32, max_new_tokens=64):
assert not self.config.disable_ocr
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
crops_per_image = []
num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
crops = self.processor.crop_image(image, bboxes)
assert len(crops) == num_crops
crops_per_image.extend(crops)
if len(crops_per_image) == 0:
return [[] for _ in crop_bboxes]
crops_per_image = self.processor.preprocess_inputs_for_ocr(crops_per_image)
crops_per_image = move_to_device_fn(crops_per_image)
# process the crops in batches to avoid OOM
all_generated_texts = []
if use_tqdm:
from tqdm import tqdm
pbar = tqdm(range(0, len(crops_per_image), batch_size))
else:
pbar = range(0, len(crops_per_image), batch_size)
for i in pbar:
crops = crops_per_image[i:i+batch_size]
generated_ids = self.ocr_model.generate(crops, max_new_tokens=max_new_tokens)
generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
all_generated_texts.extend(generated_texts)
texts_for_images = []
for num_crops in num_crops_per_batch:
texts_for_images.append([x.replace("\n", "") for x in all_generated_texts[:num_crops]])
all_generated_texts = all_generated_texts[num_crops:]
return texts_for_images
def visualise_single_image_prediction(
self, image_as_np_array, predictions, filename=None
):
return visualise_single_image_prediction(image_as_np_array, predictions, filename)
def generate_transcript_for_single_image(
self, predictions, ocr_results, filename=None
):
character_clusters = predictions["character_cluster_labels"]
text_to_character = predictions["text_character_associations"]
text_to_character = {k: v for k, v in text_to_character}
transript = " ### Transcript ###\n"
for index, text in enumerate(ocr_results):
if index in text_to_character:
speaker = character_clusters[text_to_character[index]]
speaker = f"<{speaker}>"
else:
speaker = "<?>"
transript += f"{speaker}: {text}\n"
if filename is not None:
with open(filename, "w") as file:
file.write(transript)
return transript
def get_affinity_matrices_given_annotations(
self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
):
assert not self.config.disable_detections
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
character_bboxes_in_batch = [[bbox for bbox, label in zip(a["bboxes_as_x1y1x2y2"], a["labels"]) if label == 0] for a in annotations]
crop_embeddings_for_batch = self.predict_crop_embeddings(images, character_bboxes_in_batch, move_to_device_fn)
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
processed_targets = inputs_to_detection_transformer.pop("labels")
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
matching_dict = {
"logits": predicted_class_scores,
"pred_boxes": predicted_bboxes,
}
indices = self.matcher(matching_dict, processed_targets)
matched_char_obj_tokens_for_batch = []
matched_text_obj_tokens_for_batch = []
t2c_tokens_for_batch = []
c2c_tokens_for_batch = []
text_bboxes_for_batch = []
character_bboxes_for_batch = []
for j, (pred_idx, tgt_idx) in enumerate(indices):
target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
targets_for_this_image = processed_targets[j]
indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
text_bboxes_for_batch.append(
[annotations[j]["bboxes_as_x1y1x2y2"][k] for k in indices_of_text_boxes_in_annotation]
)
character_bboxes_for_batch.append(
[annotations[j]["bboxes_as_x1y1x2y2"][k] for k in indices_of_char_boxes_in_annotation]
)
matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
t2c_tokens_for_batch=t2c_tokens_for_batch,
apply_sigmoid=apply_sigmoid,
)
character_character_affinity_matrices = self._get_character_character_affinity_matrices(
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
crop_embeddings_for_batch=crop_embeddings_for_batch,
c2c_tokens_for_batch=c2c_tokens_for_batch,
apply_sigmoid=apply_sigmoid,
)
return {
"text_character_affinity_matrices": text_character_affinity_matrices,
"character_character_affinity_matrices": character_character_affinity_matrices,
"text_bboxes_for_batch": text_bboxes_for_batch,
"character_bboxes_for_batch": character_bboxes_for_batch,
}
def get_obj_embeddings_corresponding_to_given_annotations(
self, images, annotations, move_to_device_fn=None
):
assert not self.config.disable_detections
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
processed_targets = inputs_to_detection_transformer.pop("labels")
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
matching_dict = {
"logits": predicted_class_scores,
"pred_boxes": predicted_bboxes,
}
indices = self.matcher(matching_dict, processed_targets)
matched_char_obj_tokens_for_batch = []
matched_text_obj_tokens_for_batch = []
matched_panel_obj_tokens_for_batch = []
t2c_tokens_for_batch = []
c2c_tokens_for_batch = []
for j, (pred_idx, tgt_idx) in enumerate(indices):
target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
targets_for_this_image = processed_targets[j]
indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
indices_of_panel_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 2]
predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
predicted_panel_indices = [target_idx_to_pred_idx[i] for i in indices_of_panel_boxes_in_annotation]
matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
matched_panel_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_panel_indices])
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
return {
"character": matched_char_obj_tokens_for_batch,
"text": matched_text_obj_tokens_for_batch,
"panel": matched_panel_obj_tokens_for_batch,
"t2c": t2c_tokens_for_batch,
"c2c": c2c_tokens_for_batch,
}
def sort_panels_and_text_bboxes_in_reading_order(
self,
batch_panel_bboxes,
batch_text_bboxes,
):
batch_sorted_panel_indices = []
batch_sorted_text_indices = []
for batch_index in range(len(batch_text_bboxes)):
panel_bboxes = batch_panel_bboxes[batch_index]
text_bboxes = batch_text_bboxes[batch_index]
sorted_panel_indices = sort_panels(panel_bboxes)
sorted_panels = [panel_bboxes[i] for i in sorted_panel_indices]
sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, sorted_panels)
batch_sorted_panel_indices.append(sorted_panel_indices)
batch_sorted_text_indices.append(sorted_text_indices)
return batch_sorted_panel_indices, batch_sorted_text_indices
def _get_detection_transformer_output(
self,
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.LongTensor] = None
):
if self.config.disable_detections:
raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
return self.detection_transformer(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
return_dict=True
)
def _get_predicted_obj_tokens(
self,
detection_transformer_output: ConditionalDetrModelOutput
):
return detection_transformer_output.last_hidden_state[:, :-self.num_non_obj_tokens]
def _get_predicted_c2c_tokens(
self,
detection_transformer_output: ConditionalDetrModelOutput
):
return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens]
def _get_predicted_t2c_tokens(
self,
detection_transformer_output: ConditionalDetrModelOutput
):
return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens+1]
def _get_predicted_bboxes_and_classes(
self,
detection_transformer_output: ConditionalDetrModelOutput,
):
if self.config.disable_detections:
raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
obj = self._get_predicted_obj_tokens(detection_transformer_output)
predicted_class_scores = self.class_labels_classifier(obj)
reference = detection_transformer_output.reference_points[:-self.num_non_obj_tokens]
reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
predicted_boxes = self.bbox_predictor(obj)
predicted_boxes[..., :2] += reference_before_sigmoid
predicted_boxes = predicted_boxes.sigmoid()
return predicted_class_scores, predicted_boxes
def _get_character_character_affinity_matrices(
self,
character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
crop_embeddings_for_batch: List[torch.FloatTensor] = None,
c2c_tokens_for_batch: List[torch.FloatTensor] = None,
apply_sigmoid=True,
):
assert self.config.disable_detections or (character_obj_tokens_for_batch is not None and c2c_tokens_for_batch is not None)
assert self.config.disable_crop_embeddings or crop_embeddings_for_batch is not None
assert not self.config.disable_detections or not self.config.disable_crop_embeddings
if self.config.disable_detections:
affinity_matrices = []
for crop_embeddings in crop_embeddings_for_batch:
crop_embeddings = crop_embeddings / crop_embeddings.norm(dim=-1, keepdim=True)
affinity_matrix = crop_embeddings @ crop_embeddings.T
affinity_matrices.append(affinity_matrix)
return affinity_matrices
affinity_matrices = []
for batch_index, (character_obj_tokens, c2c) in enumerate(zip(character_obj_tokens_for_batch, c2c_tokens_for_batch)):
if character_obj_tokens.shape[0] == 0:
affinity_matrices.append(torch.zeros(0, 0).type_as(character_obj_tokens))
continue
if not self.config.disable_crop_embeddings:
crop_embeddings = crop_embeddings_for_batch[batch_index]
assert character_obj_tokens.shape[0] == crop_embeddings.shape[0]
character_obj_tokens = torch.cat([character_obj_tokens, crop_embeddings], dim=-1)
char_i = repeat(character_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=character_obj_tokens.shape[0])
char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
c2c = repeat(c2c, "d -> repeat d", repeat = char_ij.shape[0])
char_ij_c2c = torch.cat([char_ij, c2c], dim=-1)
character_character_affinities = self.character_character_matching_head(char_ij_c2c)
character_character_affinities = rearrange(character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
character_character_affinities = (character_character_affinities + character_character_affinities.T) / 2
if apply_sigmoid:
character_character_affinities = character_character_affinities.sigmoid()
affinity_matrices.append(character_character_affinities)
return affinity_matrices
def _get_text_character_affinity_matrices(
self,
character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
t2c_tokens_for_batch: List[torch.FloatTensor] = None,
apply_sigmoid=True,
):
assert not self.config.disable_detections
assert character_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None and t2c_tokens_for_batch is not None
affinity_matrices = []
for character_obj_tokens, text_obj_tokens, t2c in zip(character_obj_tokens_for_batch, text_obj_tokens_for_this_batch, t2c_tokens_for_batch):
if character_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
affinity_matrices.append(torch.zeros(text_obj_tokens.shape[0], character_obj_tokens.shape[0]).type_as(character_obj_tokens))
continue
text_i = repeat(text_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=text_obj_tokens.shape[0])
text_char = rearrange([text_i, char_j], "two i j d -> (i j) (two d)")
t2c = repeat(t2c, "d -> repeat d", repeat = text_char.shape[0])
text_char_t2c = torch.cat([text_char, t2c], dim=-1)
text_character_affinities = self.text_character_matching_head(text_char_t2c)
text_character_affinities = rearrange(text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
if apply_sigmoid:
text_character_affinities = text_character_affinities.sigmoid()
affinity_matrices.append(text_character_affinities)
return affinity_matrices