from typing import List as _List from os.path import join as _join, isfile as _isfile, isdir as _isdir, abspath as _abspath, basename as _basename from datetime import datetime as _datetime from ipaddress import ip_network, ip_address from subprocess import run as _run from os import makedirs as _makedirs from shutil import rmtree as _rmtree, copy2 as _copy2 from glob import glob as _glob from filecmp import cmp as _cmp from tiramisu import DomainnameOption from risotto.utils import multi_function as _multi_function _PKI_DIR = _abspath('pki/dnssec') _ALGO = 'ECDSAP256SHA256' _ZSK_LEN = 512 _KSK_LEN = _ZSK_LEN def nsd_serial() -> str: return _datetime.now().strftime('%m%d%H%M%S') def value_in(value: str, values: _List[str], ) -> bool: for val in values: if value == val: return True return False def nsd_concat_lists(*args, ip: str=None, cidr: bool=False ) -> _List[str]: ret = set() for lst in args: if lst is None: continue if cidr: for l in lst: if '/' not in l: l = l + '/32' ret.add(l) else: ret.update(lst) if ip: if cidr: ip = f'{ip}/32' ret.add(ip) ret = list(ret) ret.sort() return ret def get_reverse_name(network: str) -> str: if not network: return network_obj = ip_network(network) if network_obj.prefixlen < 24: raise ValueError('only netmask greater than 24 is supported for DNS reverse name') o1, o2, o3, o4 = network.split('.') return f'{o3}.{o2}.{o1}.in-addr.arpa.' def _gen_key(cn:str, authority_cn: str, type: str, ) -> str: dir_name = _join(_PKI_DIR, cn, authority_cn, type) filename = None if _isdir(dir_name): filenames = _glob(_join(dir_name, f'K{authority_cn}.+*.key')) if filenames: filename = filenames[0].rsplit('.', 1)[0] if filename is None: if _isdir(dir_name): _rmtree(dir_name) _makedirs(dir_name) if type == 'zsk': cmd = ['ldns-keygen', '-a', _ALGO, '-b', str(_ZSK_LEN), authority_cn] else: cmd = ['ldns-keygen', '-a', _ALGO, '-b', str(_KSK_LEN), '-k', authority_cn] proc = _run(cmd, cwd=dir_name, capture_output=True, ) if proc.returncode != 0: raise Exception(f'cannot generate {type}: {proc.stdout.decode()}, {proc.stderr.decode()}') filename = _join(dir_name, proc.stdout.decode().strip()) return filename def _gen_keys(cn, authority_cn, ) -> str: zsk = _gen_key(cn, authority_cn, 'zsk') ksk = _gen_key(cn, authority_cn, 'ksk') return zsk, ksk def gen_cert(cn: str, authority_cn: str, ) -> str: zsk, ksk = _gen_keys(cn, authority_cn) with open(f'{ksk}.key') as fh: content = fh.read().strip() scontent = content.split() infos = ' '.join(scontent[3:6]) return f'"{authority_cn}." {infos} "{scontent[6]}";' def sign(zone_filename: str, cn: str, ) -> str: authority_cn = zone_filename.rsplit('/', 1)[-1].rsplit('.', 1)[0] copy_file = _join(_PKI_DIR, cn, authority_cn, _basename(zone_filename)) signed_filename = f'{copy_file}.signed' if not _isfile(copy_file) or not _isfile(signed_filename) or not _cmp(zone_filename, copy_file): zsk, ksk = _gen_keys(cn, authority_cn) _copy2(zone_filename, copy_file) cmd = ['ldns-signzone', '-n', zone_filename, zsk, ksk] proc = _run(cmd, capture_output=True) if proc.returncode != 0: raise Exception(f'cannot sign {zone_filename}: {proc.stdout.decode()}, {proc.stderr.decode()}') new_signed_filename = f'{zone_filename}.signed' with open(new_signed_filename) as fh: content = fh.read().strip() content.replace('0000000000', nsd_serial()) with open(signed_filename, 'w') as fh: fh.write(content) with open(signed_filename) as fh: content = fh.read().strip() return content def get_internal_info_in_zone(zones: list, domain_name: str, type: str, index: int=None, ) -> _List[str]: for zone in zones.values(): if domain_name == zone['domain_name']: break else: return [] if type == 'host': return list(['host'] + list(zone['hosts'])) if not index: return zone['host_ip'] return list(zone['hosts'].values())[index - 1] def get_internal_zones(zones) -> _List[str]: return [zone['domain_name'] for zone_name, zone in zones.items()] @_multi_function def calc_reverse_names(names): ret = [] for name in names: if name in ret: continue ret.append(name) return ret @_multi_function def calc_reverse_networks(names): ret = [] for name in names: name = name.rsplit('.', 1)[0] if name in ret: continue ret.append(name) return ret def valid_dns_hostname(hostname, domainname, zone_names, ): if hostname != '*': try: DomainnameOption('a', '', hostname, type='hostname', allow_ip=False) except ValueError as err: err.prefix = '' return err if hostname + '.' + domainname in zone_names: return f'"{hostname}.{domainname}" is also a zone name' @_multi_function def get_dnssec_ds(cn: str, names: str, ) -> str: dnssec = [] for name in names: if name.endswith('arpa.'): name = name[:-1] dir_name = _join(_PKI_DIR, cn, name, 'ksk') filename = None if _isdir(dir_name): filenames = _glob(_join(dir_name, f'K{name}.+*.ds')) if filenames: filename = filenames[0] if filename: with open(filename) as fh: dnssec.append(fh.read().strip()) return dnssec