#!/usr/bin/env python3
"""systemctl: subset of systemctl used for image construction

Mask/preset systemd units
"""

import argparse
import fnmatch
import os
import re
import sys

from collections import namedtuple
from pathlib import Path

version = 1.0

ROOT = Path("/")
SYSCONFDIR = Path("etc")
BASE_LIBDIR = Path("lib")
LIBDIR = Path("usr", "lib")


class SystemdFile():
    """Class representing a single systemd configuration file"""
    def __init__(self, root, path):
        self.sections = dict()
        self._parse(root, path)

    def _parse(self, root, path):
        """Parse a systemd syntax configuration file

        Args:
            path: A pathlib.Path object pointing to the file

        """
        skip_re = re.compile(r"^\s*([#;]|$)")
        section_re = re.compile(r"^\s*\[(?P<section>.*)\]")
        kv_re = re.compile(r"^\s*(?P<key>[^\s]+)\s*=\s*(?P<value>.*)")
        section = None

        if path.is_symlink():
            try:
                path.resolve()
            except FileNotFoundError:
                # broken symlink, try relative to root
                path = root / Path(os.readlink(str(path))).relative_to(ROOT)

        with path.open() as f:
            for line in f:
                if skip_re.match(line):
                    continue

                line = line.rstrip("\n")
                m = section_re.match(line)
                if m:
                    section = dict()
                    self.sections[m.group('section')] = section
                    continue

                while line.endswith("\\"):
                    line += f.readline().rstrip("\n")

                m = kv_re.match(line)
                k = m.group('key')
                v = m.group('value')
                if k not in section:
                    section[k] = list()
                section[k].extend(v.split())

    def get(self, section, prop):
        """Get a property from section

        Args:
            section: Section to retrieve property from
            prop: Property to retrieve

        Returns:
            List representing all properties of type prop in section.

        Raises:
            KeyError: if ``section`` or ``prop`` not found
        """
        return self.sections[section][prop]


class Presets():
    """Class representing all systemd presets"""
    def __init__(self, scope, root):
        self.directives = list()
        self._collect_presets(scope, root)

    def _parse_presets(self, presets):
        """Parse presets out of a set of preset files"""
        skip_re = re.compile(r"^\s*([#;]|$)")
        directive_re = re.compile(r"^\s*(?P<action>enable|disable)\s+(?P<unit_name>(.+))")

        Directive = namedtuple("Directive", "action unit_name")
        for preset in presets:
            with preset.open() as f:
                for line in f:
                    m = directive_re.match(line)
                    if m:
                        directive = Directive(action=m.group('action'),
                                              unit_name=m.group('unit_name'))
                        self.directives.append(directive)
                    elif skip_re.match(line):
                        pass
                    else:
                        sys.exit("Unparsed preset line in {}".format(preset))

    def _collect_presets(self, scope, root):
        """Collect list of preset files"""
        locations = [SYSCONFDIR / "systemd"]
        # Handle the usrmerge case by ignoring /lib when it's a symlink
        if not BASE_LIBDIR.is_symlink():
            locations.append(BASE_LIBDIR / "systemd")
        locations.append(LIBDIR / "systemd")

        presets = dict()
        for location in locations:
            paths = (root / location / scope).glob("*.preset")
            for path in paths:
                # earlier names override later ones
                if path.name not in presets:
                    presets[path.name] = path

        self._parse_presets([v for k, v in sorted(presets.items())])

    def state(self, unit_name):
        """Return state of preset for unit_name

        Args:
            presets: set of presets
            unit_name: name of the unit

        Returns:
            None: no matching preset
            `enable`: unit_name is enabled
            `disable`: unit_name is disabled
        """
        for directive in self.directives:
            if fnmatch.fnmatch(unit_name, directive.unit_name):
                return directive.action

        return None


def collect_services(root):
    """Collect list of service files"""
    locations = [SYSCONFDIR / "systemd"]
    # Handle the usrmerge case by ignoring /lib when it's a symlink
    if not BASE_LIBDIR.is_symlink():
        locations.append(BASE_LIBDIR / "systemd")
    locations.append(LIBDIR / "systemd")

    services = dict()
    for location in locations:
        paths = (root / location / "system").glob("*")
        for path in paths:
            if path.is_dir():
                continue
            # implement earlier names override later ones
            if path.name not in services:
                services[path.name] = path

    return services


def add_link(path, target):
    try:
        path.parent.mkdir(parents=True)
    except FileExistsError:
        pass
    if not path.is_symlink():
        print("ln -s {} {}".format(target, path))
        path.symlink_to(target)


def process_deps(root, config, service, location, prop, dirstem):
    systemdir = SYSCONFDIR / "systemd" / "system"

    target = ROOT / location.relative_to(root)
    try:
        for dependent in config.get('Install', prop):
            wants = root / systemdir / "{}.{}".format(dependent, dirstem) / service
            add_link(wants, target)

    except KeyError:
        pass


def enable(root, service, location, services):
    if location.is_symlink():
        # ignore aliases
        return

    config = SystemdFile(root, location)
    template = re.match(r"[^@]+@(?P<instance>[^\.]*)\.", service)
    if template:
        instance = template.group('instance')
        if not instance:
            try:
                instance = config.get('Install', 'DefaultInstance')[0]
                service = service.replace("@.", "@{}.".format(instance))
            except KeyError:
                pass
        if instance is None:
            return
    else:
        instance = None

    process_deps(root, config, service, location, 'WantedBy', 'wants')
    process_deps(root, config, service, location, 'RequiredBy', 'requires')

    try:
        for also in config.get('Install', 'Also'):
            enable(root, also, services[also], services)

    except KeyError:
        pass

    systemdir = root / SYSCONFDIR / "systemd" / "system"
    target = ROOT / location.relative_to(root)
    try:
        for dest in config.get('Install', 'Alias'):
            alias = systemdir / dest
            add_link(alias, target)

    except KeyError:
        pass


def preset_all(root):
    presets = Presets('system-preset', root)
    services = collect_services(root)

    for service, location in services.items():
        state = presets.state(service)

        if state == "enable" or state is None:
            enable(root, service, location, services)


def mask(root, *services):
    systemdir = root / SYSCONFDIR / "systemd" / "system"
    for service in services:
        add_link(systemdir / service, "/dev/null")


def main():
    if sys.version_info < (3, 4, 0):
        sys.exit("Python 3.4 or greater is required")

    parser = argparse.ArgumentParser()
    parser.add_argument('command', nargs=1, choices=['mask', 'preset-all'])
    parser.add_argument('service', nargs=argparse.REMAINDER)
    parser.add_argument('--root')
    parser.add_argument('--preset-mode',
                        choices=['full', 'enable-only', 'disable-only'],
                        default='full')

    args = parser.parse_args()

    root = Path(args.root) if args.root else ROOT
    command = args.command[0]
    if command == "mask":
        mask(root, *args.service)
    elif command == "preset-all":
        if len(args.service) != 0:
            sys.exit("Too many arguments.")
        if args.preset_mode != "enable-only":
            sys.exit("Only enable-only is supported as preset-mode.")
        preset_all(root)
    else:
        raise RuntimeError()


if __name__ == '__main__':
    main()
