from typing import Union, Tuple, Optional
#python3-crypto
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import AES, PKCS1_OAEP
from hmac import new as hmac_new
from secrets import token_bytes

from time import time
from json import dumps
from hashlib import pbkdf2_hmac, sha256
#from aiohttp import ClientSession
from requests import session
from base64 import b64encode, b64decode
from hkdf import hkdf_expand
from collections import namedtuple
from os.path import isfile
from jwt import encode as jwt_encode, decode as jwt_decode



#BITWARDEN_PRIVATE_KEY = '/var/lib/vaultwarden_rs/rsa_key.der'

cipher_string_fields = {
    'enc_type': lambda enc_type,iv,mac,ct: int(enc_type),
    'iv':       lambda enc_type,iv,mac,ct: iv,
    'mac':      lambda enc_type,iv,mac,ct: mac,
    'ct':     lambda enc_type,iv,mac,ct: ct,
}
CipherString = namedtuple('CipherString', cipher_string_fields.keys())


# support pulling apart a VaultWarden 'CipherString' from the following
#     - cipher string: "<enc_type>.<iv>|<ct>|<mac>"
def cipher_string_from_str(cipher_string: str) -> CipherString:
    enc_type, data = cipher_string.split('.', 1)
    if enc_type == '2':
        iv, ct, mac = (b64decode(sdata) for sdata in data.split('|', 2))

        d = { k: fn(enc_type, iv, mac, ct) for k,fn in cipher_string_fields.items() }
    else:
        iv, mac = None, None
        ct = b64decode(data)
        d = { k: fn(enc_type, iv, mac, ct) for k, fn in cipher_string_fields.items() }
    return CipherString(**d)


class VaultWarden:
    def __init__(self,
                 url: str,
                 email: str,
                 uuid: str,
                 vaultwarden_key: str,
                 ) -> None:
        self.vaultwarden_url = url
        self.vaultwarden_email = email.lower()
        self.vaultwarden_uuid = uuid
        self.vaultwarden_login = None
        self.vaultwarden_organizations = None
        self.vaultwarden_key = vaultwarden_key

    def register(self,
                 password: str,
                 valid: bool=True,
                 ) -> None:
        iterations = self.get_iterations()
        master_key, hash_password = self.hash_password(password,
                                                       iterations,
                                                       )
        # generate symmetric key
        token = token_bytes(64)
        enc, mac = self._get_enc_mac(master_key)
        key = self.encrypt_symmetric(token,
                                     enc=enc,
                                     mac=mac,
                                     )
        # generate asymmetric key
        asym_key = RSA.generate(2048)
        enc_private_key = self.encrypt_symmetric(asym_key.exportKey('DER', pkcs=8),
                                                 enc=token[:32],
                                                 mac=token[32:],
                                                 )
        public_key = b64encode(asym_key.publickey().exportKey('DER')).decode()
        data = {'name': self.vaultwarden_email.split('@')[0],
                'email': self.vaultwarden_email,
                'masterPasswordHash': hash_password,
                'masterPasswordHint': None,
                'key': key,
                'kdf': 0,
                'kdfIterations': iterations,
                'referenceId': None,
                'keys': {
                    'publicKey': public_key,
                    'encryptedPrivateKey': enc_private_key
                    }
                }
        register = self._post('api/accounts/register',
                              dumps(data),
                              )
        if 'Object' in register and register['Object'] == 'error':
            if register["ErrorModel"]['Message'] == 'User already exists':
                return
            raise Exception(register["ErrorModel"]["Message"])
        if valid and isfile(self.vaultwarden_key):
            self.login(password)
            # values = self.get('/api/sync')
            # user_id = values['Profile']['Id']
            user_id = jwt_decode(self.vaultwarden_login['access_token'],
                                 algorithm="RS256",
                                 #pyjwt 1
                                 verify=False,
                                 #pyjwt 2
                                 options={"verify_signature": False},
                                 )['sub']
            now = int(time())
            url = self.vaultwarden_url
            if url[-1] == '/':
                url = url[:-1]
            data = {'nbf': now,
                    'exp': now + 432000,
                    'iss': f'{url}|verifyemail',
                    'sub': user_id,
                    }
            with open(self.vaultwarden_key, 'rb') as private_key_fh:
                private_key = RSA.importKey(private_key_fh.read()).exportKey('PEM')
            token = jwt_encode(data, private_key, algorithm="RS256")
            if isinstance(token, bytes):
                tocken = token.decode()
            data = {'userId': user_id,
                    'token': token,
                    }
            self._post('api/accounts/verify-email-token', dumps(data))

    def login(self,
              password: str,
              ) -> None:
        iterations = self.get_iterations()
        master_key, hash_password = self.hash_password(password,
                                                       iterations,
                                                       )
        data = {'grant_type': 'password',
                'username': self.vaultwarden_email,
                'password': hash_password,
                'scope': 'api offline_access',
                'client_id': 'desktop',
                'device_type': 7,
                'device_identifier': self.vaultwarden_uuid,
                'device_name': 'risotto',
               }
        vaultwarden_login = self._post('identity/connect/token', data)
        if 'Object' in vaultwarden_login and vaultwarden_login['Object'] == 'error':
            raise Exception(f'unable to log to VaultWarden: {vaultwarden_login["ErrorModel"]["Message"]}')
        self.vaultwarden_login = vaultwarden_login
        self.vaultwarden_login['master_key'] = master_key
        self.vaultwarden_login['hash_password'] = hash_password

    def get_iterations(self):
        data = self._post('api/accounts/prelogin', dumps({'email': self.vaultwarden_email}))
        return data['KdfIterations']

    def hash_password(self,
                      password: str,
                      iterations: int,
                      ) -> str:
        master_key = pbkdf2_hmac('sha256',
                                 password.encode(),
                                 self.vaultwarden_email.encode(),
                                 iterations,
                                 )
        passwd = pbkdf2_hmac('sha256',
                             master_key,
                             password.encode(),
                             1,
                             )
        return master_key, b64encode(passwd).decode()

    def decrypt(self,
                cipher_string: str,
                organization_id: str=None,
                ) -> None:
        cipher = cipher_string_from_str(cipher_string)
        if cipher.enc_type == 2:
            return self.decrypt_symmetric(cipher,
                                          organization_id,
                                          )
        elif cipher.enc_type == 4:
            if organization_id:
                raise Exception('cipher type {cipher.enc_type} cannot have organization_id')
            return self.decrypt_asymmetric(cipher)
        raise Exception(f'Unknown cipher type {cipher.enc_type}')

    def decrypt_symmetric(self,
                          cipher: str,
                          organization_id: str=None,
                          enc: str=None,
                          mac: str=None,
                          ) -> bytes:
        # i.e: AesCbc256_HmacSha256_B64 (jslib/src/enums/encryptionType.ts)
        assert cipher.enc_type == 2
        if enc is None:
            enc = self.vaultwarden_organizations[organization_id]['key'][:32]
            mac = self.vaultwarden_organizations[organization_id]['key'][32:]
        # verify the MAC
        cmac = hmac_new(mac,
                        cipher.iv + cipher.ct,
                        sha256,
                        )
        assert cipher.mac == cmac.digest()

        # decrypt the content
        c = AES.new(enc,
                    AES.MODE_CBC,
                    cipher.iv,
                    )
        plaintext = c.decrypt(cipher.ct)

        # remove PKCS#7 padding from payload, see RFC 5652
        # https://tools.ietf.org/html/rfc5652#section-6.3
        pad_len = plaintext[-1]
        padding = bytes([pad_len] * pad_len)
        if plaintext[-pad_len:] == padding:
            plaintext = plaintext[:-pad_len]
        return plaintext

    def decrypt_asymmetric(self,
                           cipher: str,
                           ) -> str:
        private_key = self.decrypt(self.vaultwarden_login['PrivateKey'])
        c = PKCS1_OAEP.new(RSA.importKey(private_key))
        return c.decrypt(cipher.ct)

    def encrypt_symmetric(self,
                          content: bytes,
                          organization_id: str=None,
                          enc: str=None,
                          mac: str=None,
                          ) -> None:
        iv = token_bytes(16)
        if enc is None:
            enc = self.vaultwarden_organizations[organization_id]['key'][:32]
            mac = self.vaultwarden_organizations[organization_id]['key'][32:]
        c = AES.new(enc,
                    AES.MODE_CBC,
                    iv,
                    )
        pad_len = 16 - len(content) % 16
        padding = bytes([ pad_len ] * pad_len)
        ct = c.encrypt(content + padding)
        cmac = hmac_new(mac,
                        iv + ct,
                        sha256,
                        )
        return f"2.{b64encode(iv).decode()}|{b64encode(ct).decode()}|{b64encode(cmac.digest()).decode()}"

    def encrypt_asymmetric(self,
                           plaintext: str,
                           key: str,
                           ) -> str:
        rsa_key = RSA.importKey(key)
        cipher = PKCS1_OAEP.new(rsa_key).encrypt(plaintext)
        b64_cipher = b64encode(cipher).decode()
        return f"4.{b64_cipher}"

    def get(self,
            url: str,
            ) -> None:
        with session() as req:
            resp = req.get(self.vaultwarden_url + url, headers=self._get_headers())
            assert resp.status_code == 200
            try:
                response = resp.json()
            except:
                response = resp.text
        return response

    def _post(self,
              url: str,
              data: dict,
              ) -> None:
        with session() as req:
            resp = req.post(self.vaultwarden_url + url,
                            data=data,
                            headers=self._get_headers(),
                            )
            assert resp.status_code == 200, f'unable to post to url {self.vaultwarden_url}{url} with data {data}: {resp.text}'
            try:
                response = resp.json()
            except:
                response = resp.text
        return response

    def _put(self,
             url: str,
             data: dict,
             ) -> None:
        with session() as req:
            resp = req.put(self.vaultwarden_url + url,
                           data=data,
                           headers=self._get_headers(),
                           )
            try:
                response = resp.json()
            except:
                response = resp.text
        return response

    def _get_headers(self,
                    ) -> None:
        if self.vaultwarden_login == None:
            return None
        return {'Authorization': f'Bearer {self.vaultwarden_login["access_token"]}'}

    def load_organizations(self,
                           only_default: bool=False,
                           ) -> None:
        values = self.get('/api/sync')
        enc, mac = self._get_enc_mac(self.vaultwarden_login['master_key'])
        # 'decrypt' the user_key to produce the actual keys
        cipher = cipher_string_from_str(self.vaultwarden_login['Key'])
        plaintext_userkey = self.decrypt_symmetric(cipher,
                                                   enc=enc,
                                                   mac=mac,
                                                   )
        assert len(plaintext_userkey) == 64
        self.vaultwarden_organizations = {None: {'key': plaintext_userkey, 'name': 'default', 'collections': {}}}
        if not only_default:
            for organization in values['Profile']['Organizations']:
                plaintext = self.decrypt(organization['Key'])
                self._add_organization(plaintext,
                                       organization,
                                       )
            for collection in values['Collections']:
                name = self.decrypt(collection['Name'],
                                    collection['OrganizationId'],
                                    ).decode()
                self.vaultwarden_organizations[collection['OrganizationId']]['collections'][name] = collection['Id']

    def _get_enc_mac(self,
                     master_key: str,
                     ) -> tuple:
        enc = hkdf_expand(master_key,
                          b'enc',
                          32,
                          sha256,
                          )
        mac = hkdf_expand(master_key,
                          b'mac',
                          32,
                          sha256,
                          )
        return enc, mac

    def _add_organization(self,
                          plaintext: bytes,
                          organization: dict,
                          ) -> None:
        organization_id = organization['Id']
        self.vaultwarden_organizations[organization_id] = {'name': organization['Name'], 'key': plaintext, 'collections': {}}

    def try_to_confirm(self,
                       organization_id,
                       email,
                       ) -> bool:
        # user is now in organization
        user = self.get_user_informations(organization_id,
                                          email,
                                          )

        # if account exists now, confirm it
        if user['public_key']:
            key = self.encrypt_asymmetric(self.vaultwarden_organizations[organization_id]['key'],
                                          user['public_key'],
                                          )
            data = {"key": key}
            confirmed = self._post(f'api/organizations/{organization_id}/users/{user["user_id"]}/confirm',
                                         dumps(data),
                                         )
            return user['user_id'], 'Object' not in confirmed or confirmed['Object'] != 'error'
        return user['user_id'], False

    def get_user_informations(self,
                              organization_id: str,
                              email: str,
                              ) -> None:
        users = self.get(f'/api/organizations/{organization_id}/users')
        for user in users['Data']:
            if user['Email'] == email:
                user_public_key = self.get(f'/api/users/{user["UserId"]}/public-key')
                if not user_public_key['PublicKey']:
                    public_key = None
                else:
                    public_key = b64decode(user_public_key['PublicKey'])
                return {'user_id': user['Id'],
                        'public_key': public_key,
                        }
        raise Exception(f'unknow email {email} in organization id {organization_id}')

    def create_organization(self,
                            email: str,
                            organization_name: str,
                            ) -> None:
        private_key = self.decrypt(self.vaultwarden_login['PrivateKey'])
        token = token_bytes(64)
        key = self.encrypt_asymmetric(token,
                                      private_key,
                                      )
        # defaut collection_name is organization_name
        data = {
            "key": key,
            "collectionName": self.encrypt_symmetric(organization_name.encode(),
                                                     enc=token[:32],
                                                     mac=token[32:],
                                                     ),
            "name": organization_name,
            "billingEmail": email,
            "planType": 0,
        }
        organization = self._post('api/organizations',
                                  dumps(data),
                                  )
        self.load_organizations()
        #self._add_organization(token,
        #                       organization,
        #                       )
        return organization['Id']

    def invite(self,
               organization_id: str,
               email: str,
               ) -> bool:
        data = {'emails': [email],
                'collections': [],
                'accessAll': False,
                'type': 2,
                }
        for collection_id in self.vaultwarden_organizations[organization_id]['collections'].values():
            data['collections'].append({'id': collection_id,
                                        'readOnly': True,
                                        'hidePasswords': False,
                                        })
        self._post(f'api/organizations/{organization_id}/users/invite',
                   dumps(data),
                   )

    def create_collection(self,
                          organization_id: str,
                          collection_name: str,
                          user_id: str=None,
                          ) -> None:
        data = {"groups": [],
                "name": self.encrypt_symmetric(collection_name.encode(),
                                               organization_id,
                                               ),
                }
        collection = self._post(f'api/organizations/{organization_id}/collections',
                                dumps(data),
                                )
        self.vaultwarden_organizations[organization_id]['collections'][collection_name] = collection['Id']
        if user_id:
            self.inscript_collection(organization_id,
                                     collection['Id'],
                                     user_id,
                                     )
        return collection['Id']

    def inscript_collection(self,
                            organization_id: str,
                            collection_id: str,
                            user_id: str,
                            ) -> None:
        data = [{'id': user_id,
                 'readOnly': True,
                 'hidePasswords': False,
                 }]
        self._put(f'api/organizations/{organization_id}/collections/{collection_id}/users',
                  dumps(data),
                  )

    def store_password(self,
                       organization_id: str,
                       collection_id: str,
                       name: str,
                       username: str,
                       password: str,
                       uris: list=None,
                       ) -> None:
        """create a cipher et store it in a share collection
        """
        # FIXME uris are encoded
        data = {"cipher": {
                   "type": 1,
                   "folderId": None,
                   "organizationId": organization_id,
                   "name": self.encrypt_symmetric(name.encode(),
                                                  organization_id,
                                                  ),
                   "notes": None,
                   "favorite": False,
                   "login":
                       {"response": None,
                        "uris": uris,
                        "username": self.encrypt_symmetric(username.encode(),
                                                           organization_id,
                                                           ),
                        "password": self.encrypt_symmetric(password.encode(),
                                                           organization_id,
                                                           ),
                        "passwordRevisionDate": None,
                        "totp": None,
                        }
                },
                "collectionIds": [collection_id],
        }
        self._post('api/ciphers/admin',
                   dumps(data),
                   )

    def get_password(self,
                     organization_id: str,
                     collection_name: str,
                     name: str,
                     username: str,
                     ) -> list:
        if not collection_name in self.vaultwarden_organizations[organization_id]['collections']:
            return
        collection_id = self.vaultwarden_organizations[organization_id]['collections'][collection_name]
        ciphers = self.get(f'api/ciphers/organization-details?organizationId={organization_id}')
        for cipher in ciphers['Data']:
            if collection_id in cipher['CollectionIds'] and \
                    self.decrypt(cipher['Data']['Name'], organization_id).decode() == name and \
                    self.decrypt(cipher['Data']['Username'], organization_id).decode() == username:
                return self.decrypt(cipher['Data']['Password'], organization_id).decode()