File size: 1,567 Bytes
2720487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \
AutoModel
from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig
from surya.model.ordering.decoder import MBartOrder
from surya.model.ordering.encoder import VariableDonutSwinModel
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
from surya.model.ordering.processor import OrderImageProcessor
from surya.settings import settings
def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
decoder_config = vars(config.decoder)
decoder = MBartOrderConfig(**decoder_config)
config.decoder = decoder
encoder_config = vars(config.encoder)
encoder = VariableDonutSwinConfig(**encoder_config)
config.encoder = encoder
# Get transformers to load custom model
AutoModel.register(MBartOrderConfig, MBartOrder)
AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder)
AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)
model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
assert isinstance(model.decoder, MBartOrder)
assert isinstance(model.encoder, VariableDonutSwinModel)
model = model.to(device)
model = model.eval()
print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}")
return model |