from OpenSSL.crypto import load_certificate as _load_certificate, load_privatekey as _load_privatekey, \
        dump_certificate as _dump_certificate, dump_privatekey as _dump_privatekey, \
        dump_publickey as _dump_publickey, PKey as _PKey, X509 as _X509, X509Extension as _X509Extension, \
        TYPE_RSA as _TYPE_RSA, FILETYPE_PEM as _FILETYPE_PEM
from os import makedirs as _makedirs, symlink as _symlink, unlink as _unlink, listdir as _listdir, environ as _environ
from os.path import join as _join, isdir as _isdir, isfile as _isfile, exists as _exists
from datetime import datetime as _datetime
from risotto.utils import RISOTTO_CONFIG as _RISOTTO_CONFIG, multi_function as _multi_function


_PKI_DIR = 'pki/x509'
_HERE = _environ['PWD']


@_multi_function
def get_chain(cn: str,
              authority_cn: str,
              authority_name: str,
              hide: bool,
              ):
    if hide:
        return "XXXXX"
    if not authority_cn or not authority_name:
        return
    return _gen_ca(cn,
                   authority_cn,
                   authority_name,
                   _HERE,
                   )


@_multi_function
def get_certificate(cn,
                    authority_name: str,
                    hide: bool,
                    authority_cn: str=None,
                    extra_domainnames: list=[],
                    type: str='server',
                    ):
    if hide:
        return "XXXXX"
    if isinstance(cn, list) and extra_domainnames:
        raise Exception('cn cannot be a list with extra_domainnames set')
    if not cn or authority_name is None:
        if isinstance(cn, list):
            return []
        return
    return _gen_cert(cn,
                     extra_domainnames,
                     authority_cn,
                     authority_name,
                     type,
                     'crt',
                     _HERE,
                     )


@_multi_function
def get_private_key(cn: str,
                    hide: bool,
                    authority_name: str=None,
                    authority_cn: str=None,
                    type: str='server',
                    ):
    if hide:
        return "XXXXX"
    if not cn:
        if isinstance(cn, list):
            return []
        return
    if authority_name is None:
        if _has_pub(cn, _HERE):
            return _gen_pub(cn,
                            'key',
                            _HERE,
                            )
        if isinstance(cn, list):
            return []
        return
    return _gen_cert(cn,
                     [],
                     authority_cn,
                     authority_name,
                     type,
                     'key',
                     _HERE,
                     )


def get_public_key(cn: str,
                   hide: bool,
                   ):
    if hide:
        return "XXXXX"
    if not cn:
        return
    return _gen_pub(cn,
                    'pub',
                    _HERE,
                    )


def _gen_key_pair():
    key = _PKey()
    key.generate_key(_TYPE_RSA, 4096)
    return key


def __gen_cert(is_ca,
               common_names,
               root_dir_name,
               validity_end_in_seconds,
               key_file,
               cert_file,
               type=None,
               ca_cert=None,
               ca_key=None,
               email_address=None,
               country_name=None,
               locality_name=None,
               state_or_province_name=None,
               organization_name=None,
               organization_unit_name=None,
               ):
    #can look at generated file using openssl:
    #openssl x509 -inform pem -in selfsigned.crt -noout -text
    # create a key pair
    if _isfile(key_file):
        with open(key_file) as fh:
            filecontent = bytes(fh.read(), 'utf-8')
            key = _load_privatekey(_FILETYPE_PEM, filecontent)
    else:
        key = _gen_key_pair()
    cert = _X509()
    cert.set_version(2)
    cert.get_subject().C = country_name
    cert.get_subject().ST = state_or_province_name
    cert.get_subject().L = locality_name
    cert.get_subject().O = organization_name
    cert.get_subject().OU = organization_unit_name
    cert.get_subject().CN = common_names[0]
    cert.get_subject().emailAddress = email_address
    cert_ext = []
    if not is_ca:
        cert_ext.append(_X509Extension(b'basicConstraints', False, b'CA:FALSE'))
        cert_ext.append(_X509Extension(b'keyUsage', True, b'digitalSignature, keyEncipherment'))
        cert_ext.append(_X509Extension(b'subjectAltName', False, ", ".join([f'DNS:{common_name}' for common_name in common_names]).encode('ascii')))
        if type == 'server':
            cert_ext.append(_X509Extension(b'extendedKeyUsage', True, b'serverAuth'))
        else:
            cert_ext.append(_X509Extension(b'extendedKeyUsage', True, b'clientAuth'))
    else:
        cert_ext.append(_X509Extension(b'basicConstraints', False, b'CA:TRUE'))
        cert_ext.append(_X509Extension(b"keyUsage", True, b'keyCertSign, cRLSign'))
        cert_ext.append(_X509Extension(b'subjectAltName', False, f'email:{email_address}'.encode()))
    cert_ext.append(_X509Extension(b'subjectKeyIdentifier', False, b"hash", subject=cert))
    cert.add_extensions(cert_ext)
    sn_filename = _join(root_dir_name, 'serial_number')
    if _isfile(sn_filename):
        with open(sn_filename, 'r') as fh:
            serial_number = int(fh.read().strip()) + 1
    else:
        serial_number = 0
    cert.set_serial_number(serial_number)
    with open(sn_filename, 'w') as fh:
        fh.write(str(serial_number))
    cert.gmtime_adj_notBefore(0)
    cert.gmtime_adj_notAfter(validity_end_in_seconds)
    if is_ca:
        ca_cert = cert
        ca_key = key
    else:
        with open(ca_cert) as fh:
            filecontent = bytes(fh.read(), 'utf-8')
            ca_cert = _load_certificate(_FILETYPE_PEM, filecontent)
        with open(ca_key) as fh:
            filecontent = bytes(fh.read(), 'utf-8')
            ca_key = _load_privatekey(_FILETYPE_PEM, filecontent)
    cert.set_issuer(ca_cert.get_subject())
    cert.add_extensions([_X509Extension(b"authorityKeyIdentifier", False, b'keyid:always', issuer=ca_cert)])
    cert.set_pubkey(key)
    cert.sign(ca_key, "sha512")

    with open(cert_file, "wt") as f:
        f.write(_dump_certificate(_FILETYPE_PEM, cert).decode("utf-8"))
        if not is_ca:
            f.write(_dump_certificate(_FILETYPE_PEM, ca_cert).decode("utf-8"))
    with open(key_file, "wt") as f:
        f.write(_dump_privatekey(_FILETYPE_PEM, key).decode("utf-8"))


def _gen_ca(cn,
            authority_dns,
            authority_name,
            base_dir,
            ):
    authority_cn = authority_name + '+' + authority_dns
    week_number = _datetime.now().isocalendar().week
    root_dir_name = _join(base_dir, _PKI_DIR, authority_cn)
    ca_dir_name = _join(root_dir_name, 'ca')
    key_ca_name = _join(ca_dir_name, 'private.key')
    cert_ca_name = f'certificate_{week_number}.crt'
    cert_ca_filename = _join(ca_dir_name, cert_ca_name)
    local_ca_dir_name = _join(root_dir_name, 'certificats', cn, 'ca')
    if not _isfile(cert_ca_filename):
        if not _isdir(ca_dir_name):
            _makedirs(ca_dir_name)
        __gen_cert(True,
                   [authority_cn],
                   root_dir_name,
                   10*24*60*60,
                   key_ca_name,
                   cert_ca_filename,
                   email_address=_RISOTTO_CONFIG['cert_authority']['email'],
                   country_name=_RISOTTO_CONFIG['cert_authority']['country'],
                   locality_name=_RISOTTO_CONFIG['cert_authority']['locality'],
                   state_or_province_name=_RISOTTO_CONFIG['cert_authority']['state'],
                   organization_name=_RISOTTO_CONFIG['cert_authority']['org_name'],
                   organization_unit_name=_RISOTTO_CONFIG['cert_authority']['org_unit_name'],
                   )
        for filename in _listdir(ca_dir_name):
            if not filename.endswith('.crt') or filename == cert_ca_name:
                continue
            _unlink(_join(ca_dir_name, filename))
    with open(cert_ca_filename, 'r') as fh:
        return fh.read().strip()


def gen_cert_iter(cn,
                  extra_domainnames,
                  authority_cn,
                  authority_name,
                  type,
                  base_dir,
                  root_cert_dir_name,
                  ):
    week_number = _datetime.now().isocalendar().week
    root_dir_name = _join(base_dir, _PKI_DIR, authority_cn)
    ca_dir_name = _join(root_dir_name, 'ca')
    key_ca_name = _join(ca_dir_name, 'private.key')
    certificate_name = f'certificate_{week_number}.crt'
    cert_ca_name = _join(ca_dir_name, certificate_name)
    cert_ca_external_name = _join(root_cert_dir_name, 'ca', certificate_name)
    dir_name = _join(root_cert_dir_name, type)
    key_name = _join(dir_name, f'private.key')
    cert_name = _join(dir_name, certificate_name)
    external = False
    if _isfile(cert_ca_external_name):
        external = True
    elif not _isfile(cert_ca_name):
        raise Exception(f'cannot find CA file "{cert_ca_name}" for "{cn}"')
    if not _isfile(cert_name):
        if external:
            raise Exception(f"cannot find CA private key (\"{authority_cn}\") to sign certificat for \"{cn}\" ({key_ca_name}), is it sign with external authority (like Let's Encrypt certification)?")
        if not _isdir(dir_name):
            _makedirs(dir_name)
        common_names = [cn]
        common_names.extend(extra_domainnames)
        __gen_cert(False,
                   common_names,
                   root_dir_name,
                   10*24*60*60,
                   key_name,
                   cert_name,
                   ca_cert=cert_ca_name,
                   ca_key=key_ca_name,
                   type=type,
                   email_address=_RISOTTO_CONFIG['cert_authority']['email'],
                   country_name=_RISOTTO_CONFIG['cert_authority']['country'],
                   locality_name=_RISOTTO_CONFIG['cert_authority']['locality'],
                   state_or_province_name=_RISOTTO_CONFIG['cert_authority']['state'],
                   organization_name=_RISOTTO_CONFIG['cert_authority']['org_name'],
                   organization_unit_name=_RISOTTO_CONFIG['cert_authority']['org_unit_name'],
                   )
        for extra in extra_domainnames:
            extra_dir_name = _join(base_dir, _PKI_DIR, authority_name + '+' + extra)
            if not _exists(extra_dir_name):
                _symlink(root_dir_name, extra_dir_name)
        for filename in _listdir(dir_name):
            if not filename.endswith('.crt') or filename == certificate_name:
                continue
            _unlink(_join(dir_name, filename))
    for extra in extra_domainnames:
        extra_dir_name = _join(base_dir, _PKI_DIR, authority_name + '+' + extra)
        if not _exists(extra_dir_name):
            raise Exception(f'file {extra_dir_name} not already exists that means subjectAltName is not set in certificat, please remove {cert_name}')
    return cert_name


def _gen_cert(cn,
              extra_domainnames,
              authority_cn,
              authority_name,
              type,
              file_type,
              base_dir,
              ):
    if '.' in authority_name:
        raise Exception(f'dot is not allowed in authority_name "{authority_name}"')
    if type == 'server' and authority_cn is None:
        authority_cn = cn
    if authority_cn is None:
        raise Exception(f'authority_cn is mandatory when authority type is client')
    if extra_domainnames is None:
        extra_domainnames = []
    auth_cn = authority_name + '+' + authority_cn
    dir_name = _join(base_dir, _PKI_DIR, auth_cn, 'certificats', cn)
    if file_type == 'crt':
        filename = gen_cert_iter(cn,
                                 extra_domainnames,
                                 auth_cn,
                                 authority_name,
                                 type,
                                 base_dir,
                                 dir_name,
                                 )
    else:
        filename = _join(dir_name, type, f'private.key')
        if not _isfile(filename):
            raise Exception(f'cannot find {filename}, you must use get_certificate before get_private_key')
    with open(filename, 'r') as fh:
        return fh.read().strip()


def _has_pub(cn,
             base_dir,
             ):
    dir_name = _join(base_dir, _PKI_DIR, 'public', cn)
    cert_name = _join(dir_name, f'public.pub')
    return _isfile(cert_name)


def _gen_pub(cn,
            file_type,
            base_dir,
            ):
    dir_name = _join(base_dir, _PKI_DIR, 'public', cn)
    key_name = _join(dir_name, f'private.key')
    if file_type == 'pub':
        pub_name = _join(dir_name, f'public.pub')
        if not _isfile(pub_name):
            if not _isdir(dir_name):
                _makedirs(dir_name)
            key = _gen_key_pair()
            with open(pub_name, "wt") as f:
                f.write(_dump_publickey(_FILETYPE_PEM, key).decode("utf-8"))
            with open(key_name, "wt") as f:
                f.write(_dump_privatekey(_FILETYPE_PEM, key).decode("utf-8"))
        filename = pub_name
    else:
        filename = key_name
    with open(filename, 'r') as fh:
        return fh.read().strip()