risotto/funcs.py
Emmanuel Garette 946506f27c first commit
2022-03-08 20:47:55 +01:00

233 lines
7.4 KiB
Python

from tiramisu import valid_network_netmask, valid_ip_netmask, valid_broadcast, valid_in_network, valid_not_equal as valid_differ, valid_not_equal, calc_value
from ipaddress import ip_address
from os.path import dirname, abspath, join as _join, isdir as _isdir, isfile as _isfile
from typing import List
from json import load
from secrets import token_urlsafe as _token_urlsafe
from rougail.utils import normalize_family
from utils import multi_function, CONFIGS
from x509 import gen_cert as _x509_gen_cert, gen_ca as _x509_gen_ca, gen_pub as _x509_gen_pub, has_pub as _x509_has_pub
# =============================================================
# fork of risotto-setting/src/risotto_setting/config/config.py
with open('servers.json', 'r') as server_fh:
ZONES_SERVER = load(server_fh)
ZONES = None
DOMAINS = None
HERE = dirname(abspath(__file__))
def load_zones():
global ZONES
if ZONES is not None:
return
ZONES = ZONES_SERVER['zones']
for server_name, server in ZONES_SERVER['servers'].items():
if 'informations' not in server:
continue
server_zones = server['informations']['zones_name']
server_extra_domainnames = server['informations'].get('extra_domainnames', [])
if len(server_zones) > 1 and len(server_zones) != len(server_extra_domainnames) + 1:
raise Exception(f'the server "{server_name}" has more that one zone, please set correct number of extra_domainnames ({len(server_zones) - 1} instead of {len(server_extra_domainnames)})')
for idx, zone_name in enumerate(server_zones):
zone_domain_name = ZONES[zone_name]['domain_name']
if idx == 0:
zone_server_name = server_name
else:
zone_server_name = server_extra_domainnames[idx - 1]
server_domain_name = zone_server_name.split('.', 1)[1]
if zone_domain_name and zone_domain_name != server_domain_name:
raise Exception(f'wrong server_name "{zone_server_name}" in zone "{zone_name}" should ends with "{zone_domain_name}"')
ZONES[zone_name].setdefault('hosts', []).append(server_name)
def load_domains():
load_zones()
global DOMAINS
if DOMAINS is not None:
return
DOMAINS = {}
for zone_name, zone in ZONES_SERVER['zones'].items():
if 'domain_name' in zone:
hosts = []
ips = []
for host in ZONES[zone_name].get('hosts', []):
hosts.append(host.split('.', 1)[0])
ips.append(get_ip(host, [zone_name], 0))
DOMAINS[zone['domain_name']] = (tuple(hosts), tuple(ips))
def get_ip(server_name: str,
zones_name: List[str],
index: str,
) -> str:
if server_name is None:
return
load_zones()
index = int(index)
zone_name = zones_name[index]
if zone_name not in ZONES:
raise ValueError(f"cannot set IP in unknown zone '{zone_name}'")
zone = ZONES[zone_name]
if server_name not in zone['hosts']:
raise ValueError(f"cannot set IP in unknown server '{server_name}'")
server_index = zone['hosts'].index(server_name)
# print(server_name, zones_name, index, str(ip_address(zone['start_ip']) + server_index))
return str(ip_address(zone['start_ip']) + server_index)
@multi_function
def get_chain(authority_cn,
authority_name,
):
if not authority_name or authority_name is None:
if isinstance(authority_name, list):
return []
return
if not isinstance(authority_cn, list):
is_list = False
authority_cn = [authority_cn]
else:
is_list = True
authorities = []
for auth_cn in authority_cn:
ret = _x509_gen_ca(auth_cn,
authority_name,
HERE,
)
if not is_list:
return ret
authorities.append(ret)
return authorities
@multi_function
def get_certificate(cn,
authority_name,
authority_cn=None,
extra_domainnames=[],
type='server',
):
if isinstance(cn, list) and extra_domainnames:
raise Exception('cn cannot be a list with extra_domainnames set')
if not cn or authority_name is None:
if isinstance(cn, list):
return []
return
return _x509_gen_cert(cn,
extra_domainnames,
authority_cn,
authority_name,
type,
'crt',
HERE,
)
@multi_function
def get_private_key(cn,
authority_name=None,
authority_cn=None,
type='server',
):
if not cn:
if isinstance(cn, list):
return []
return
if authority_name is None:
if _x509_has_pub(cn, HERE):
return _x509_gen_pub(cn,
'key',
HERE,
)
if isinstance(cn, list):
return []
return
return _x509_gen_cert(cn,
[],
authority_cn,
authority_name,
type,
'key',
HERE,
)
def get_public_key(cn):
if not cn:
return
return _x509_gen_pub(cn,
'pub',
HERE,
)
def zone_information(zone_name: str,
type: str,
multi: bool=False,
index: int=None,
) -> str:
if not zone_name:
return
if type == 'gateway' and index != 0:
return
load_zones()
if zone_name not in ZONES:
raise ValueError(f"cannot get zone informations in unknown zone '{zone_name}'")
zone = ZONES[zone_name]
if type not in zone:
raise ValueError(f"unknown type '{type}' in zone '{zone_name}'")
value = zone[type]
if multi:
value = [value]
return value
def get_internal_zones() -> List[str]:
load_domains()
return list(DOMAINS.keys())
@multi_function
def get_zones_info(type: str) -> str:
ret = []
for data in ZONES_SERVER['zones'].values():
ret.append(data[type])
return ret
@multi_function
def get_internal_zone_names() -> List[str]:
load_zones()
return list(ZONES.keys())
def get_internal_zone_information(zone: str,
info: str,
) -> str:
load_domains()
if info == 'cidr':
return ZONES[zone]['gateway'] + '/' + ZONES[zone]['network'].split('/')[-1]
return ZONES[zone][info]
def get_internal_info_in_zone(zone: str,
auto: bool,
type: str,
index: int=None,
) -> List[str]:
if not auto:
return
for domain_name, domain in DOMAINS.items():
if zone == domain_name:
if type == 'host':
return list(domain[0])
else:
return domain[1][index]
# =============================================================