ketanmore's picture
Upload folder using huggingface_hub
2720487 verified
raw
history blame
1.57 kB
from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \
AutoModel
from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig
from surya.model.ordering.decoder import MBartOrder
from surya.model.ordering.encoder import VariableDonutSwinModel
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
from surya.model.ordering.processor import OrderImageProcessor
from surya.settings import settings
def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
decoder_config = vars(config.decoder)
decoder = MBartOrderConfig(**decoder_config)
config.decoder = decoder
encoder_config = vars(config.encoder)
encoder = VariableDonutSwinConfig(**encoder_config)
config.encoder = encoder
# Get transformers to load custom model
AutoModel.register(MBartOrderConfig, MBartOrder)
AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder)
AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)
model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
assert isinstance(model.decoder, MBartOrder)
assert isinstance(model.encoder, VariableDonutSwinModel)
model = model.to(device)
model = model.eval()
print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}")
return model