import __main__
from os import urandom as _urandom, environ as _environ
from hashlib import sha1 as _sha1
from base64 import encodebytes as _encodebytes, b64encode as _b64encode
from json import load as _load, dump as _dump
from os.path import join as _join, isfile as _isfile


_HERE = _environ['PWD']
_SSHA_PASSWORD_DIR = _join(_HERE, 'password', 'ssha.json')


# unproudly borrowed from
# http://www.openldap.org/faq/data/cache/347.html
def ssha_encode(password):
    # do not regenerate SSHA
    if _isfile(_SSHA_PASSWORD_DIR):
        with open(_SSHA_PASSWORD_DIR, 'r') as fh:
            passwords = _load(fh)
    else:
        passwords = {}
    if password in passwords:
        return passwords[password]
    salt = _urandom(4)
    h = _sha1(password.encode())
    h.update(salt)
    ret = _b64encode(b"{SSHA}" + _encodebytes(h.digest() + salt)[:-1]).decode()
    passwords[password] = ret
    with open(_SSHA_PASSWORD_DIR, 'w') as fh:
        _dump(passwords, fh)
    return ret


def calc_ldapclient_base_dn(ldap_base_dn: str,
                            family_name: str=None,
                            base: bool=False,
                            group: bool=False,
                            ) -> str:
    # copied from ldap-client
    if ldap_base_dn is None:
        return
    if family_name == 'all':
        family_name = None
        base = True
    if group:
        return f'ou=groups,{ldap_base_dn}'
    if not ldap_base_dn.startswith('ou=accounts,'):
        base_name = f'ou=accounts,{ldap_base_dn}'
    else:
        base_name = ldap_base_dn
    if base:
        return base_name
    if not family_name:
        return f'ou=users,{base_name}'
    base_name = f'ou=families,{base_name}'
    if family_name != '-':
        base_name = f'ou={family_name},{base_name}'
    return base_name


def get_default_base_dn(prefix: str) -> str:
    # copied from ldap-client
    if not prefix or '.' not in prefix:
        return None
    values = prefix.split('.')
    # cannot calculated base dn should be subdomain.domain.tld
    # remove 'server' in dn
    if len(values) < 3:
        return None
    domain = ['ou=' + domain for domain in values[0:-2]]
    domain.append(f'o={values[-2]},o={values[-1]}')
    return ','.join(domain)


def valid_base_dn(base_dn: str) -> None:
    # copied from ldap-client
    for att in ['o', 'dc', 'ou']:
        if base_dn.startswith(att + '='):
            break
    else:
        raise ValueError('La racine doit débuter par une organisation (o=), une composante du domaine (dc=) ou une unité organisationnelle (ou=)')