#!/usr/bin/python3
# SPDX-FileCopyrightText: 2004-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

"""|UDM| command line server"""


import array
import errno
import json
import os
import signal
import socket
import socketserver
import sys
import traceback
from argparse import ArgumentParser, Namespace
from io import StringIO
from logging import getLogger
from select import select
from typing import IO

import univention.admincli.adduser
import univention.admincli.admin
import univention.admincli.license_check
import univention.admincli.passwd
import univention.debug as ud
import univention.logging
from univention.config_registry import ucr


log = getLogger('ADMIN')
logfile = ''
loglevel = 1


def recv_fds(sock: socket.socket, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
    fds = array.array("i")   # Array of ints
    msg, ancdata, _flags, _addr = sock.recvmsg(msglen, socket.CMSG_LEN(maxfds * fds.itemsize))
    for cmsg_level, cmsg_type, cmsg_data in ancdata:
        if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
            # Append data, ignoring any truncated integers at the end.
            fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
    return msg, list(fds)


class MyRequestHandler(socketserver.BaseRequestHandler):
    """Handle request on listening socket to open new connection."""

    def handle(self) -> None:
        log.debug('daemon [%s] new connection [%s]', os.getppid(), os.getpid())
        sarglist = b''
        fds = []
        while True:
            buf, _fds = recv_fds(self.request, 1024, 2)
            fds.extend(_fds)
            if buf[-1:] == b'\0':
                buf = buf[:-1]
                sarglist += buf
                break
            else:
                sarglist += buf
        try:
            stdout, stderr = (os.fdopen(fd, 'w') for fd in fds)
        except ValueError:
            stdout, stderr = None, None
        doit(sarglist, self.request, stdout, stderr)
        log.debug('daemon [%s] connection closed [%s]', os.getppid(), os.getpid())


class ForkingServer(socketserver.ForkingMixIn, socketserver.UnixStreamServer):
    """UDM server listening on UNIX socket."""


def server_main(args: Namespace) -> None:
    """UDM command line server."""
    socket_path = args.socket
    socket_dir = os.path.dirname(socket_path)

    global logfile, loglevel
    logfile = args.logfile
    loglevel = args.debug_level
    univention.logging.basicConfig(filename=logfile, univention_debug_level=args.debug_level, use_structured_logging=ucr.is_true('directory/manager/cmd/debug/structured-logging', False))

    runfilename = '%s.run' % socket_path
    if os.path.isfile(runfilename):
        try:
            with open(runfilename) as runfile:
                line = runfile.readline().strip()
                pid = int(line)
                os.kill(pid, signal.SIGCONT)
        except (ValueError, OSError):
            pid = 0
        if not pid:  # no pid found or no server running
            os.unlink(socket_path)
            os.unlink(runfilename)
            os.rmdir(socket_dir)
        else:
            sys.exit('E: Server already running [Pid: %s]' % pid)

    log.debug('server is running as process %d', os.getpid())

    try:
        os.mkdir(socket_dir)
        os.chmod(socket_dir, 0o700)
    except OSError as ex:
        if ex.errno != errno.EEXIST:
            sys.exit('E: %s %s' % (socket_dir, ex))
        else:
            print('E: socket directory exists (%s)' % socket_dir, file=sys.stderr)

    timeout = max(0, min(args.timeout, (1 << 31) - 1)) or 300

    try:
        sock = ForkingServer(socket_path, MyRequestHandler)
        os.chmod(socket_path, 0o600)
    except OSError:
        log.error('daemon [%s] Failed creating socket (%s). Daemon stopped.', os.getpid(), socket_path)
        sys.exit('E: Failed creating socket (%s). Daemon stopped.' % socket_path)

    try:
        with open(runfilename, 'w') as runfile:
            runfile.write(str(os.getpid()))
    except OSError:
        print('E: Can`t write runfile', file=sys.stderr)

    signal.signal(signal.SIGCHLD, lambda signo, frame: None)
    (read_fd, write_fd) = os.pipe2(os.O_NONBLOCK | os.O_CLOEXEC)
    signal.set_wakeup_fd(write_fd)
    try:
        while True:
            rlist, _wlist, _xlist = select([sock, read_fd], [], [], float(timeout))
            if sock in rlist:
                sock.handle_request()
            if read_fd in rlist:
                os.read(read_fd, 1)
            sock.service_actions()
            if not rlist:
                log.debug('daemon [%s] stopped after %s seconds idle', os.getpid(), timeout)
                break
    finally:
        os.unlink(socket_path)
        os.unlink(runfilename)
        os.rmdir(socket_dir)
        ud.exit()


def doit(sarglist: bytes, conn: socket.socket, stdout: IO[str] | None = None, stderr: IO[str] | None = None) -> None:
    """Process single UDM request."""

    def send_message(output) -> None:
        """Send answer back."""
        back = json.dumps(output, ensure_ascii=True).encode('ASCII')
        conn.send(back + b'\0')
        conn.close()

    global logfile
    arglist = json.loads(sarglist.decode('ASCII'))

    next_is_logfile = False
    secret = False
    show_help = False
    oldlogfile = logfile
    for arg in arglist:
        if next_is_logfile:
            logfile = arg
            next_is_logfile = False
            continue
        if arg.startswith('--logfile='):
            logfile = arg[len('--logfile='):]
        elif arg == '--logfile':
            next_is_logfile = True
        secret |= arg == '--binddn'
        show_help |= arg in ('--help', '-h', '-?', '--version')

    if not secret:
        for filename in ('/etc/ldap.secret', '/etc/machine.secret'):
            try:
                open(filename).close()
                secret = True
                break
            except OSError:
                continue
        else:
            if not show_help:
                send_message(["E: Permission denied, try --logfile, --binddn and --bindpwd", "OPERATION FAILED"])
                sys.exit(1)

    if logfile != oldlogfile:
        ud.exit()
        univention.logging.basicConfig(filename=logfile, univention_debug_level=loglevel, use_structured_logging=ucr.is_true('directory/manager/cmd/debug/structured-logging', False))

    if not stdout:
        stdout = StringIO()
    if not stderr:
        stderr = StringIO()

    output = []
    cmdfile = os.path.basename(arglist[0])
    try:
        if cmdfile in ('univention-directory-manager', 'udm'):
            log.info('daemon [%s] [%s] Calling univention-directory-manager', os.getppid(), os.getpid())
            log.debug('daemon [%s] [%s] arglist: %s', os.getppid(), os.getpid(), arglist)
            univention.admincli.admin.main(arglist, stdout, stderr)
        elif cmdfile == 'univention-passwd':
            log.info('daemon [%s] [%s] Calling univention-passwd', os.getppid(), os.getpid())
            log.debug('daemon [%s] [%s] arglist: %s', os.getppid(), os.getpid(), arglist)
            output = univention.admincli.passwd.doit(arglist)
        elif cmdfile == 'univention-license-check':
            log.info('daemon [%s] [%s] Calling univention-license-check', os.getppid(), os.getpid())
            log.debug('daemon [%s] [%s] arglist: %s', os.getppid(), os.getpid(), arglist)
            output = univention.admincli.license_check.doit(arglist)
        else:
            log.info('daemon [%s] [%s] Calling univention-adduser', os.getppid(), os.getpid())
            log.debug('daemon [%s] [%s] arglist: %s', os.getppid(), os.getpid(), arglist)
            output = univention.admincli.adduser.doit(arglist)
    except univention.admincli.admin.OperationFailed as exc:
        print(str(exc), file=stderr)
        output.append("OPERATION FAILED")
    except Exception:
        log.exception('Fatal error')
        print(traceback.format_exc(), file=stderr)
        output.append("OPERATION FAILED")
    finally:
        stdout.flush()
        stderr.flush()

    if isinstance(stdout, StringIO):
        output += stdout.getvalue().splitlines()
    if isinstance(stderr, StringIO):
        output += stderr.getvalue().splitlines()
    send_message(output)

    if show_help and not secret:
        log.debug('daemon [%s] [%s] stopped, because User has no read/write permissions', os.getppid(), os.getpid())
        sys.exit(0)


def daemonize(logfile: str) -> None:
    null = os.open(os.path.devnull, os.O_RDONLY)
    try:
        log = os.open(logfile, os.O_WRONLY | os.O_APPEND)
    except OSError:
        log = os.open(os.path.devnull, os.O_WRONLY)

    for dst in range(3):
        src = log if dst else null
        if dst != src:
            os.dup2(src, dst)

    os.closerange(3, 1 + max(int(fd) for fd in os.listdir("/proc/self/fd")))

    # Prepare to call `setsid()`
    pid = os.fork()
    if pid:
        sys.exit(0)

    # Become session leader
    os.setsid()

    # Double fork against CTTY
    pid = os.fork()
    if pid:
        sys.exit(0)


def parse_args() -> Namespace:
    debug_level = ucr.get_int('directory/manager/cmd/debug/level', 1)
    timeout = ucr.get_int('directory/manager/cmd/timeout', 300)
    default_socket_path = '/tmp/admincli_%d/sock' % os.getuid()

    argparser = ArgumentParser(description=__doc__)
    argparser.add_argument('-n', '--daemonize', action='store_false', default=True, help='Run in foreground without daemonizing')
    argparser.add_argument('-L', '--logfile', default='/var/log/univention/directory-manager-cmd.log', help='logfile: %(default)s')
    argparser.add_argument('-d', '--debug-level', type=int, default=debug_level, help='debug level: %(default)s')
    argparser.add_argument('-t', '--timeout', type=int, default=timeout, help='timeout: %(default)s')
    argparser.add_argument('-s', '--socket', default=default_socket_path, help='socket: %(default)s')
    args = argparser.parse_args()
    return args


def main() -> None:
    args = parse_args()
    if args.daemonize:
        daemonize(args.logfile or os.path.devnull)

    server_main(args)


if __name__ == "__main__":
    main()
