File size: 338 Bytes
51ba5d6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from registry import MODEL_REGISTRY


def get_model(model_configs):
    registered_model = MODEL_REGISTRY.get(model_configs["registered_model_name"])
    model_configs.pop("registered_model_name")
    if len(model_configs) > 0:
        model = registered_model(model_configs)
    else:
        model = registered_model()
    return model