|
from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \ |
|
AutoModel |
|
from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig |
|
from surya.model.ordering.decoder import MBartOrder |
|
from surya.model.ordering.encoder import VariableDonutSwinModel |
|
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel |
|
from surya.model.ordering.processor import OrderImageProcessor |
|
from surya.settings import settings |
|
|
|
|
|
def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): |
|
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint) |
|
|
|
decoder_config = vars(config.decoder) |
|
decoder = MBartOrderConfig(**decoder_config) |
|
config.decoder = decoder |
|
|
|
encoder_config = vars(config.encoder) |
|
encoder = VariableDonutSwinConfig(**encoder_config) |
|
config.encoder = encoder |
|
|
|
|
|
AutoModel.register(MBartOrderConfig, MBartOrder) |
|
AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder) |
|
AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel) |
|
|
|
model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) |
|
assert isinstance(model.decoder, MBartOrder) |
|
assert isinstance(model.encoder, VariableDonutSwinModel) |
|
|
|
model = model.to(device) |
|
model = model.eval() |
|
print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}") |
|
return model |