#!/usr/bin/env python import configargparse from flask import Flask, jsonify, request from waitress import serve from onmt.translate import TranslationServer, ServerModelError import logging from logging.handlers import RotatingFileHandler STATUS_OK = "ok" STATUS_ERROR = "error" def start(config_file, url_root="./translator", host="0.0.0.0", port=5000, debug=False): def prefix_route(route_function, prefix="", mask="{0}{1}"): def newroute(route, *args, **kwargs): return route_function(mask.format(prefix, route), *args, **kwargs) return newroute if debug: logger = logging.getLogger("main") log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") file_handler = RotatingFileHandler( "debug_requests.log", maxBytes=1000000, backupCount=10 ) file_handler.setFormatter(log_format) logger.addHandler(file_handler) app = Flask(__name__) app.route = prefix_route(app.route, url_root) translation_server = TranslationServer() translation_server.start(config_file) @app.route("/models", methods=["GET"]) def get_models(): out = translation_server.list_models() return jsonify(out) @app.route("/health", methods=["GET"]) def health(): out = {} out["status"] = STATUS_OK return jsonify(out) @app.route("/clone_model/", methods=["POST"]) def clone_model(model_id): out = {} data = request.get_json(force=True) timeout = -1 if "timeout" in data: timeout = data["timeout"] del data["timeout"] opt = data.get("opt", None) try: model_id, load_time = translation_server.clone_model(model_id, opt, timeout) except ServerModelError as e: out["status"] = STATUS_ERROR out["error"] = str(e) else: out["status"] = STATUS_OK out["model_id"] = model_id out["load_time"] = load_time return jsonify(out) @app.route("/unload_model/", methods=["GET"]) def unload_model(model_id): out = {"model_id": model_id} try: translation_server.unload_model(model_id) out["status"] = STATUS_OK except Exception as e: out["status"] = STATUS_ERROR out["error"] = str(e) return jsonify(out) @app.route("/translate", methods=["POST"]) def translate(): inputs = request.get_json(force=True) if debug: logger.info(inputs) out = {} try: trans, scores, n_best, _, aligns, align_scores = translation_server.run( inputs ) assert len(trans) == len(inputs) * n_best assert len(scores) == len(inputs) * n_best assert len(aligns) == len(inputs) * n_best out = [[] for _ in range(n_best)] for i in range(len(trans)): response = { "src": inputs[i // n_best]["src"], "tgt": trans[i], "n_best": n_best, "pred_score": scores[i], } if len(aligns[i]) > 0 and aligns[i][0] is not None: response["align"] = aligns[i] response["align_score"] = align_scores[i] out[i % n_best].append(response) except ServerModelError as e: model_id = inputs[0].get("id") if debug: logger.warning( "Unload model #{} " "because of an error".format(model_id) ) translation_server.models[model_id].unload() out["error"] = str(e) out["status"] = STATUS_ERROR if debug: logger.info(out) return jsonify(out) @app.route("/to_cpu/", methods=["GET"]) def to_cpu(model_id): out = {"model_id": model_id} translation_server.models[model_id].to_cpu() out["status"] = STATUS_OK return jsonify(out) @app.route("/to_gpu/", methods=["GET"]) def to_gpu(model_id): out = {"model_id": model_id} translation_server.models[model_id].to_gpu() out["status"] = STATUS_OK return jsonify(out) serve(app, host=host, port=port) def _get_parser(): parser = configargparse.ArgumentParser( config_file_parser_class=configargparse.YAMLConfigFileParser, description="OpenNMT-py REST Server", ) parser.add_argument("--ip", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default="5000") parser.add_argument("--url_root", type=str, default="/translator") parser.add_argument("--debug", "-d", action="store_true") parser.add_argument( "--config", "-c", type=str, default="./available_models/conf.json" ) return parser def main(): parser = _get_parser() args = parser.parse_args() start( args.config, url_root=args.url_root, host=args.ip, port=args.port, debug=args.debug, ) if __name__ == "__main__": main()