# Copyright (C) 2015  IPA Project Contributors, see COPYING for license

from __future__ import print_function
from base64 import b64encode, b64decode
from custodia.store.interface import CSStore
from jwcrypto.common import json_decode, json_encode
from ipaplatform.paths import paths
from ipapython import ipautil
from ipapython.secrets.common import iSecLdap
import ldap
import os
import shutil
import sys
import StringIO
import tempfile


class UnknownKeyName(Exception):
    pass


class DBMAPHandler(object):

    def __init__(self, config, dbmap, nickname):
        raise NotImplementedError

    def export_key(self):
        raise NotImplementedError

    def import_key(self, value):
        raise NotImplementedError


def log_error(error):
    print(error, file=sys.stderr)


def PKI_TOMCAT_password_callback():
    password = None
    with open(paths.PKI_TOMCAT_PASSWORD_CONF) as f:
        for line in f.readlines():
            key, value = line.strip().split('=')
            if key == 'internal':
                password = value
                break
    return password


def HTTPD_password_callback():
    with open(paths.ALIAS_PWDFILE_TXT) as f:
        password = f.read()
    return password


class NSSCertDB(DBMAPHandler):

    def __init__(self, config, dbmap, nickname):
        if 'type' not in dbmap or dbmap['type'] != 'NSSDB':
            raise ValueError('Invalid type "%s",'
                             ' expected "NSSDB"' % (dbmap['type'],))
        if 'path' not in dbmap:
            raise ValueError('Configuration does not provide NSSDB path')
        if 'pwcallback' not in dbmap:
            raise ValueError('Configuration does not provide Password Calback')
        self.nssdb_path = dbmap['path']
        self.nickname = nickname
        self.nssdb_password = dbmap['pwcallback']()

    def export_key(self):
        tdir = tempfile.mkdtemp(dir=paths.TMP)
        try:
            nsspwfile = os.path.join(tdir, 'nsspwfile')
            with open(nsspwfile, 'w+') as f:
                f.write(self.nssdb_password)
            pk12pwfile = os.path.join(tdir, 'pk12pwfile')
            password = b64encode(os.urandom(16))
            with open(pk12pwfile, 'w+') as f:
                f.write(password)
            pk12file = os.path.join(tdir, 'pk12file')
            ipautil.run([paths.PK12UTIL,
                         "-d", self.nssdb_path,
                         "-o", pk12file,
                         "-n", self.nickname,
                         "-k", nsspwfile,
                         "-w", pk12pwfile])
            with open(pk12file, 'r') as f:
                data = f.read()
        finally:
            shutil.rmtree(tdir)
        return json_encode({'export password': password,
                            'pkcs12 data': b64encode(data)})

    def import_key(self, value):
        v = json_decode(value)
        tdir = tempfile.mkdtemp(dir=paths.TMP)
        try:
            nsspwfile = os.path.join(tdir, 'nsspwfile')
            with open(nsspwfile, 'w+') as f:
                f.write(self.nssdb_password)
            pk12pwfile = os.path.join(tdir, 'pk12pwfile')
            with open(pk12pwfile, 'w+') as f:
                f.write(v['export password'])
            pk12file = os.path.join(tdir, 'pk12file')
            with open(pk12file, 'w+') as f:
                f.write(b64decode(v['pkcs12 data']))
            ipautil.run([paths.PK12UTIL,
                         "-d", self.nssdb_path,
                         "-i", pk12file,
                         "-n", self.nickname,
                         "-k", nsspwfile,
                         "-w", pk12pwfile])
        finally:
            shutil.rmtree(tdir)


# Exfiltrate the DM password Hash so it can be set in replica's and this
# way let a replica be install without knowing the DM password and yet
# still keep the DM password synchronized across replicas
class DMLDAP(DBMAPHandler):

    def __init__(self, config, dbmap, nickname):
        if 'type' not in dbmap or dbmap['type'] != 'DMLDAP':
            raise ValueError('Invalid type "%s",'
                             ' expected "DMLDAP"' % (dbmap['type'],))
        if nickname != 'DMHash':
            raise UnknownKeyName("Unknown Key Named '%s'" % nickname)
        self.ldap = iSecLdap(config['ldap_uri'],
                             config.get('auth_type', None))

    def export_key(self):
        conn = self.ldap.connect()
        r = conn.search_s('cn=config', ldap.SCOPE_BASE,
                          attrlist=['nsslapd-rootpw'])
        if len(r) != 1:
            raise RuntimeError('DM Hash not found!')
        return json_encode({'dmhash': r[0][1]['nsslapd-rootpw'][0]})

    def import_key(self, value):
        v = json_decode(value)
        conn = self.ldap.connect()
        mods = [(ldap.MOD_REPLACE, 'nsslapd-rootpw', str(v['dmhash']))]
        conn.modify_s('cn=config', mods)


NAME_DB_MAP = {
    'ca': {
        'type': 'NSSDB',
        'path': paths.PKI_TOMCAT_ALIAS_DIR,
        'handler': NSSCertDB,
        'pwcallback': PKI_TOMCAT_password_callback,
    },
    'ra': {
        'type': 'NSSDB',
        'path': paths.HTTPD_ALIAS_DIR,
        'handler': NSSCertDB,
        'pwcallback': HTTPD_password_callback,
    },
    'dm': {
        'type': 'DMLDAP',
        'handler': DMLDAP,
    }
}


class iSecStore(CSStore):

    def __init__(self, config=None):
        self.config = config

    def _get_handler(self, key):
        path = key.split('/', 3)
        if len(path) != 3 or path[0] != 'keys':
            raise ValueError('Invalid name')
        if path[1] not in NAME_DB_MAP:
            raise UnknownKeyName("Unknown DB named '%s'" % path[1])
        dbmap = NAME_DB_MAP[path[1]]
        return dbmap['handler'](self.config, dbmap, path[2])

    def get(self, key):
        try:
            key_handler = self._get_handler(key)
            value = key_handler.export_key()
        except Exception as e:  # pylint: disable=broad-except
            log_error('Error retrievieng key "%s": %s' % (key, str(e)))
            value = None
        return value

    def set(self, key, value, replace=False):
        try:
            key_handler = self._get_handler(key)
            key_handler.import_key(value)
        except Exception as e:  # pylint: disable=broad-except
            log_error('Error storing key "%s": %s' % (key, str(e)))

    def list(self, keyfilter=None):
        raise NotImplementedError

    def cut(self, key):
        raise NotImplementedError