forked from stove/dataset
210 lines
6.3 KiB
Python
210 lines
6.3 KiB
Python
from typing import List as _List
|
|
from os.path import join as _join, isfile as _isfile, isdir as _isdir, abspath as _abspath, basename as _basename
|
|
from datetime import datetime as _datetime
|
|
from ipaddress import ip_network, ip_address
|
|
from subprocess import run as _run
|
|
from os import makedirs as _makedirs
|
|
from shutil import rmtree as _rmtree, copy2 as _copy2
|
|
from glob import glob as _glob
|
|
from filecmp import cmp as _cmp
|
|
|
|
from tiramisu import DomainnameOption
|
|
from risotto.utils import multi_function as _multi_function
|
|
|
|
|
|
_PKI_DIR = _abspath('pki/dnssec')
|
|
_ALGO = 'ECDSAP256SHA256'
|
|
_ZSK_LEN = 512
|
|
_KSK_LEN = _ZSK_LEN
|
|
|
|
|
|
def nsd_serial() -> str:
|
|
return _datetime.now().strftime('%m%d%H%M%S')
|
|
|
|
|
|
def value_in(value: str,
|
|
values: _List[str],
|
|
) -> bool:
|
|
for val in values:
|
|
if value == val:
|
|
return True
|
|
return False
|
|
|
|
|
|
def nsd_concat_lists(*args,
|
|
ip: str=None,
|
|
cidr: bool=False
|
|
) -> _List[str]:
|
|
ret = set()
|
|
for lst in args:
|
|
if cidr:
|
|
for l in lst:
|
|
if '/' not in l:
|
|
l = l + '/32'
|
|
ret.add(l)
|
|
else:
|
|
ret.update(lst)
|
|
if ip:
|
|
if cidr:
|
|
ip = f'{ip}/32'
|
|
ret.add(ip)
|
|
ret = list(ret)
|
|
ret.sort()
|
|
return ret
|
|
|
|
|
|
def get_reverse_name(network: str) -> str:
|
|
if not network:
|
|
return
|
|
network_obj = ip_network(network)
|
|
if network_obj.prefixlen < 24:
|
|
raise ValueError('only netmask greater than 24 is supported for DNS reverse name')
|
|
o1, o2, o3, o4 = network.split('.')
|
|
return f'{o3}.{o2}.{o1}.in-addr.arpa.'
|
|
|
|
|
|
def _gen_key(cn:str,
|
|
authority_cn: str,
|
|
type: str,
|
|
) -> str:
|
|
dir_name = _join(_PKI_DIR, cn, authority_cn, type)
|
|
filename = None
|
|
if _isdir(dir_name):
|
|
filenames = _glob(_join(dir_name, f'K{authority_cn}.+*.key'))
|
|
if filenames:
|
|
filename = filenames[0].rsplit('.', 1)[0]
|
|
if filename is None:
|
|
if _isdir(dir_name):
|
|
_rmtree(dir_name)
|
|
_makedirs(dir_name)
|
|
if type == 'zsk':
|
|
cmd = ['ldns-keygen', '-a', _ALGO, '-b', str(_ZSK_LEN), authority_cn]
|
|
else:
|
|
cmd = ['ldns-keygen', '-a', _ALGO, '-b', str(_KSK_LEN), '-k', authority_cn]
|
|
proc = _run(cmd,
|
|
cwd=dir_name,
|
|
capture_output=True,
|
|
)
|
|
if proc.returncode != 0:
|
|
raise Exception(f'cannot generate {type}: {proc.stdout.decode()}, {proc.stderr.decode()}')
|
|
filename = _join(dir_name, proc.stdout.decode().strip())
|
|
return filename
|
|
|
|
|
|
def _gen_keys(cn,
|
|
authority_cn,
|
|
) -> str:
|
|
zsk = _gen_key(cn, authority_cn, 'zsk')
|
|
ksk = _gen_key(cn, authority_cn, 'ksk')
|
|
return zsk, ksk
|
|
|
|
|
|
def gen_cert(cn: str,
|
|
authority_cn: str,
|
|
) -> str:
|
|
zsk, ksk = _gen_keys(cn, authority_cn)
|
|
with open(f'{ksk}.key') as fh:
|
|
content = fh.read().strip()
|
|
scontent = content.split()
|
|
infos = ' '.join(scontent[3:6])
|
|
return f'"{authority_cn}." {infos} "{scontent[6]}";'
|
|
|
|
|
|
def sign(zone_filename: str,
|
|
cn: str,
|
|
) -> str:
|
|
authority_cn = zone_filename.rsplit('/', 1)[-1].rsplit('.', 1)[0]
|
|
copy_file = _join(_PKI_DIR, cn, authority_cn, _basename(zone_filename))
|
|
signed_filename = f'{copy_file}.signed'
|
|
if not _isfile(copy_file) or not _isfile(signed_filename) or not _cmp(zone_filename, copy_file):
|
|
zsk, ksk = _gen_keys(cn, authority_cn)
|
|
_copy2(zone_filename, copy_file)
|
|
cmd = ['ldns-signzone', '-n', zone_filename, zsk, ksk]
|
|
proc = _run(cmd, capture_output=True)
|
|
if proc.returncode != 0:
|
|
raise Exception(f'cannot sign {zone_filename}: {proc.stdout.decode()}, {proc.stderr.decode()}')
|
|
new_signed_filename = f'{zone_filename}.signed'
|
|
with open(new_signed_filename) as fh:
|
|
content = fh.read().strip()
|
|
content.replace('0000000000', nsd_serial())
|
|
with open(signed_filename, 'w') as fh:
|
|
fh.write(content)
|
|
with open(signed_filename) as fh:
|
|
content = fh.read().strip()
|
|
return content
|
|
|
|
|
|
def get_internal_info_in_zone(zones: list,
|
|
domain_name: str,
|
|
type: str,
|
|
index: int=None,
|
|
) -> _List[str]:
|
|
for zone in zones.values():
|
|
if domain_name == zone['domain_name']:
|
|
break
|
|
else:
|
|
return []
|
|
if type == 'host':
|
|
return list(['host'] + list(zone['hosts']))
|
|
if not index:
|
|
return zone['host_ip']
|
|
return list(zone['hosts'].values())[index - 1]
|
|
|
|
|
|
def get_internal_zones(zones) -> _List[str]:
|
|
return [zone['domain_name'] for zone_name, zone in zones.items()]
|
|
|
|
|
|
@_multi_function
|
|
def calc_reverse_names(names):
|
|
ret = []
|
|
for name in names:
|
|
if name in ret:
|
|
continue
|
|
ret.append(name)
|
|
return ret
|
|
|
|
|
|
@_multi_function
|
|
def calc_reverse_networks(names):
|
|
ret = []
|
|
for name in names:
|
|
name = name.rsplit('.', 1)[0]
|
|
if name in ret:
|
|
continue
|
|
ret.append(name)
|
|
return ret
|
|
|
|
|
|
def valid_dns_hostname(hostname,
|
|
domainname,
|
|
zone_names,
|
|
):
|
|
if hostname != '*':
|
|
try:
|
|
DomainnameOption('a', '', hostname, type='hostname', allow_ip=False)
|
|
except ValueError as err:
|
|
err.prefix = ''
|
|
raise err from err
|
|
if hostname + '.' + domainname in zone_names:
|
|
raise ValueError(f'"{hostname}.{domainname}" is also a zone name')
|
|
|
|
|
|
@_multi_function
|
|
def get_dnssec_ds(cn: str,
|
|
names: str,
|
|
) -> str:
|
|
dnssec = []
|
|
for name in names:
|
|
if name.endswith('arpa.'):
|
|
name = name[:-1]
|
|
dir_name = _join(_PKI_DIR, cn, name, 'ksk')
|
|
filename = None
|
|
if _isdir(dir_name):
|
|
filenames = _glob(_join(dir_name, f'K{name}.+*.ds'))
|
|
if filenames:
|
|
filename = filenames[0]
|
|
if filename:
|
|
with open(filename) as fh:
|
|
dnssec.append(fh.read().strip())
|
|
return dnssec
|