from typing import List as _List
from os.path import join as _join, isfile as _isfile, isdir as _isdir, abspath as _abspath, basename as _basename
from datetime import datetime as _datetime
from ipaddress import ip_network, ip_address
from subprocess import run as _run
from os import makedirs as _makedirs
from shutil import rmtree as _rmtree, copy2 as _copy2
from glob import glob as _glob
from filecmp import cmp as _cmp

from tiramisu import DomainnameOption
from risotto.utils import multi_function as _multi_function


_PKI_DIR = _abspath('pki/dnssec')
_ALGO = 'ECDSAP256SHA256'
_ZSK_LEN = 512
_KSK_LEN = _ZSK_LEN


def nsd_serial() -> str:
    return _datetime.now().strftime('%m%d%H%M%S')


def value_in(value: str,
             values: _List[str],
             ) -> bool:
    for val in values:
        if value == val:
            return True
    return False


def nsd_concat_lists(*args,
                     ip: str=None,
                     cidr: bool=False
                     ) -> _List[str]:
    ret = set()
    for lst in args:
        if cidr:
            for l in lst:
                if '/' not in l:
                    l = l + '/32'
                ret.add(l)
        else:
            ret.update(lst)
    if ip:
        if cidr:
            ip = f'{ip}/32'
        ret.add(ip)
    ret = list(ret)
    ret.sort()
    return ret


def get_reverse_name(network: str) -> str:
    if not network:
        return
    network_obj = ip_network(network)
    if network_obj.prefixlen < 24:
        raise ValueError('only netmask greater than 24 is supported for DNS reverse name')
    o1, o2, o3, o4 = network.split('.')
    return f'{o3}.{o2}.{o1}.in-addr.arpa.'


def _gen_key(cn:str,
             authority_cn: str,
             type: str,
             ) -> str:
    dir_name = _join(_PKI_DIR, cn, authority_cn, type)
    filename = None
    if _isdir(dir_name):
        filenames = _glob(_join(dir_name, f'K{authority_cn}.+*.key'))
        if filenames:
            filename = filenames[0].rsplit('.', 1)[0]
    if filename is None:
        if _isdir(dir_name):
            _rmtree(dir_name)
        _makedirs(dir_name)
        if type == 'zsk':
            cmd = ['ldns-keygen', '-a', _ALGO, '-b', str(_ZSK_LEN), authority_cn]
        else:
            cmd = ['ldns-keygen', '-a', _ALGO, '-b', str(_KSK_LEN), '-k', authority_cn]
        proc = _run(cmd,
                    cwd=dir_name,
                    capture_output=True,
                    )
        if proc.returncode != 0:
            raise Exception(f'cannot generate {type}: {proc.stdout.decode()}, {proc.stderr.decode()}')
        filename = _join(dir_name, proc.stdout.decode().strip())
    return filename


def _gen_keys(cn,
              authority_cn,
              ) -> str:
    zsk = _gen_key(cn, authority_cn, 'zsk')
    ksk = _gen_key(cn, authority_cn, 'ksk')
    return zsk, ksk


def gen_cert(cn: str,
             authority_cn: str,
             ) -> str:
    zsk, ksk = _gen_keys(cn, authority_cn)
    with open(f'{ksk}.key') as fh:
        content = fh.read().strip()
    scontent = content.split()
    infos = ' '.join(scontent[3:6])
    return f'"{authority_cn}." {infos} "{scontent[6]}";'


def sign(zone_filename: str,
         cn: str,
         ) -> str:
    authority_cn = zone_filename.rsplit('/', 1)[-1].rsplit('.', 1)[0]
    copy_file = _join(_PKI_DIR, cn, authority_cn, _basename(zone_filename))
    signed_filename = f'{copy_file}.signed'
    if not _isfile(copy_file) or not _isfile(signed_filename) or not _cmp(zone_filename, copy_file):
        zsk, ksk = _gen_keys(cn, authority_cn)
        _copy2(zone_filename, copy_file)
        cmd = ['ldns-signzone', '-n', zone_filename, zsk, ksk]
        proc = _run(cmd, capture_output=True)
        if proc.returncode != 0:
            raise Exception(f'cannot sign {zone_filename}: {proc.stdout.decode()}, {proc.stderr.decode()}')
        new_signed_filename = f'{zone_filename}.signed'
        with open(new_signed_filename) as fh:
            content = fh.read().strip()
            content.replace('0000000000', nsd_serial())
        with open(signed_filename, 'w') as fh:
            fh.write(content)
    with open(signed_filename) as fh:
        content = fh.read().strip()
    return content


def get_internal_info_in_zone(zones: list,
                              domain_name: str,
                              type: str,
                              index: int=None,
                              ) -> _List[str]:
    for zone in zones.values():
        if domain_name == zone['domain_name']:
            break
    else:
        return []
    if type == 'host':
        return list(['host'] + list(zone['hosts']))
    if not index:
        return zone['host_ip']
    return list(zone['hosts'].values())[index - 1]
    
                                                                                  
def get_internal_zones(zones) -> _List[str]: 
    return [zone['domain_name'] for zone_name, zone in zones.items()]


@_multi_function
def calc_reverse_names(names):
    ret = []
    for name in names:
        if name in ret:
            continue
        ret.append(name)
    return ret


@_multi_function
def calc_reverse_networks(names):
    ret = []
    for name in names:
        name = name.rsplit('.', 1)[0]
        if name in ret:
            continue
        ret.append(name)
    return ret


def valid_dns_hostname(hostname,
                       domainname,
                       zone_names,
                      ):
    if hostname != '*':
        try:
            DomainnameOption('a', '', hostname, type='hostname', allow_ip=False)
        except ValueError as err:
            err.prefix = ''
            raise err from err
        if hostname + '.' + domainname in zone_names:
            raise ValueError(f'"{hostname}.{domainname}" is also a zone name')


@_multi_function
def get_dnssec_ds(cn: str,
                  names: str,
                  ) -> str:
    dnssec = []
    for name in names:
        if name.endswith('arpa.'):
            name = name[:-1]
        dir_name = _join(_PKI_DIR, cn, name, 'ksk')
        filename = None
        if _isdir(dir_name):
            filenames = _glob(_join(dir_name, f'K{name}.+*.ds'))
            if filenames:
                filename = filenames[0]
        if filename:
            with open(filename) as fh:
                dnssec.append(fh.read().strip())
    return dnssec