/*
 *  methods for encryption/security mechanisms for cryptmount
 *  $Revision: 154 $, $Date: 2007-03-31 16:55:38 +0100 (Sat, 31 Mar 2007) $
 *  Copyright 2005-2007 RW Penney
 */

/*
    This file is part of cryptmount

    cryptmount is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    As a special exemption, permission is granted to link cryptmount
    with the OpenSSL project's "OpenSSL" library and distribute
    the linked code without invoking clause 2(b) of the GNU GPL version 2.

    cryptmount is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with cryptmount; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 */

#include <config.h>

#if HAVE_DLFCN && USE_MODULES
#  include <dlfcn.h>
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "armour.h"
#include "cryptmount.h"
#include "utils.h"
#ifdef TESTING
#  include "cmtesting.h"
#endif



/*
 *  ==== OpenSSL key-management routines ====
 */


#if HAVE_OPENSSL
#  if !USE_MODULES || defined(AS_MODULE)
#    include <openssl/bio.h>
#    include <openssl/err.h>
#    include <openssl/evp.h>
#    include <openssl/objects.h>

const char ssl_saltstr[]="Salted__";

#if defined(TESTING) && defined(AS_MODULE)
cm_testinfo_t *test_ctxtptr;
#endif


static int kmssl_get_algos(const keyinfo_t *keyinfo,
            const EVP_CIPHER **cipher, const EVP_MD **digest)
    /* get SSL algorithms for encoding key */
{   const char *ciphernameP,*digestnameP;

    *cipher = NULL;
    *digest = NULL;

    ciphernameP = (keyinfo->cipheralg != NULL
                        ? keyinfo->cipheralg : SN_bf_cbc);
    digestnameP = (keyinfo->digestalg != NULL
                        ? keyinfo->digestalg : SN_md5);

    *cipher = EVP_get_cipherbyname(ciphernameP);
    if (*cipher == NULL) {
        fprintf(stderr,
            _("couldn't find OpenSSL cipher \"%s\"\n"), ciphernameP);
        return ERR_BADALGORITHM;
    }

    *digest = EVP_get_digestbyname(digestnameP);
    if (*digest == NULL) {
        fprintf(stderr,
            _("couldn't find OpenSSL digest \"%s\"\n"), digestnameP);
        return ERR_BADALGORITHM;
    }

    return ERR_NOERROR;
}


#  ifdef TESTING

static int kmssl_test_getalgos(void)
{   keyinfo_t keyinfo;
    const EVP_CIPHER *cipher=NULL;
    const EVP_MD *digest=NULL;

    CM_TEST_START("OpenSSL algorithm-identification");

    keyinfo.cipheralg = NULL;
    keyinfo.digestalg = NULL;
    CM_ASSERT_EQUAL(ERR_NOERROR,
        kmssl_get_algos(&keyinfo, &cipher, &digest));
    CM_ASSERT_DIFFERENT(NULL, cipher);
    CM_ASSERT_DIFFERENT(NULL, digest);

    keyinfo.cipheralg = "bf-ecb";
    keyinfo.digestalg = "sha1";
    CM_ASSERT_EQUAL(ERR_NOERROR,
        kmssl_get_algos(&keyinfo, &cipher, &digest));
    CM_ASSERT_EQUAL(EVP_bf_ecb(), cipher);
    CM_ASSERT_EQUAL(EVP_sha1(), digest);

    keyinfo.cipheralg = "aes-192-cbc";
    keyinfo.digestalg = "md5";
    CM_ASSERT_EQUAL(ERR_NOERROR,
        kmssl_get_algos(&keyinfo, &cipher, &digest));
    CM_ASSERT_EQUAL(EVP_aes_192_cbc(), cipher);
    CM_ASSERT_EQUAL(EVP_md5(), digest);

    CM_TEST_OK(context);
}

#  endif    /* TESTING */


static int kmssl_init_algs(void)
    /* initialize cipher + hash algorithms */
{
    OpenSSL_add_all_ciphers();
    OpenSSL_add_all_digests();

    return 0;
}


static int kmssl_free_algs(void)
    /* remove all cipher + hash algorithms */
{
    EVP_cleanup();

    return 0;
}


static void kmssl_mk_default(keyinfo_t *keyinfo)
{
    if (keyinfo == NULL) return;

    if (keyinfo->digestalg == NULL) {
        keyinfo->digestalg = cm_strdup("md5");
    }

    if (keyinfo->cipheralg == NULL) {
        keyinfo->cipheralg = cm_strdup("bf-cbc");
    }
}


static int kmssl_is_compat(const keyinfo_t *keyinfo, FILE *fp_key)
{   char buff[32];

    if (keyinfo->format != NULL) {
        return (strcmp(keyinfo->format, "openssl") == 0);
    } else {
        if (fp_key != NULL) {
            /* check header of existing key-file: */
            fread((void*)buff, sizeof(ssl_saltstr), (size_t)1, fp_key);
            return (strncmp(buff, ssl_saltstr, sizeof(ssl_saltstr)-1) == 0);
        }
    }

    return 0;
}


static int kmssl_needs_pw(const keyinfo_t *keyinf)
{
    return 1;
}


static int kmssl_get_key(const char *ident, const keyinfo_t *keyinfo,
            unsigned char **key, int *keylen, FILE *fp_key)
    /* extract key from openssl-encrypted file */
{   enum { BUFFSZ=8192 };
    BIO *keybio=NULL,*blkenc=NULL;
    char passwd[256],buff[BUFFSZ];
    unsigned char salt[PKCS5_SALT_LEN], cryptkey[EVP_MAX_KEY_LENGTH],
                    cryptiv[EVP_MAX_IV_LENGTH];
    int i,len,lmt,eflag=ERR_NOERROR;
    const EVP_CIPHER *cipher=NULL;
    const EVP_MD *digest=NULL;

    /*
     *  this routine is strongly influenced by apps/enc.c in openssl-0.9.8
     */

    *key = NULL; *keylen = 0;

    eflag = kmssl_get_algos(keyinfo, &cipher, &digest);
    if (eflag != ERR_NOERROR) goto bail_out;


    BIO_snprintf(buff, sizeof(buff),
                _("enter password for target \"%s\": "), ident);
#ifdef TESTING
    strncpy(passwd, (test_ctxtptr->argpassword[0] != NULL ? test_ctxtptr->argpassword[0] : ""),
            sizeof(passwd));
#else
    if (EVP_read_pw_string(passwd, (int)sizeof(passwd), buff, 0) != 0) {
        fprintf(stderr, _("bad password for target \"%s\"\n"), ident);
        eflag = ERR_BADPASSWD;
        goto bail_out;
    }
#endif

    keybio = BIO_new_fp(fp_key, BIO_NOCLOSE);
    if (keybio == NULL) {
        fprintf(stderr,
                _("failed to connect to OpenSSL keyfile for target \"%s\"\n"),
                ident);
        eflag = ERR_BADFILE;
        goto bail_out;
    }

    len = sizeof(ssl_saltstr) - 1;
    i = BIO_read(keybio, buff, len);
    if (i != len || strncmp(buff,ssl_saltstr,(size_t)len) != 0) {
        fprintf(stderr, _("bad OpenSSL keyfile \"%s\"\n"), keyinfo->filename);
        eflag = ERR_BADFILE;
        goto bail_out;
    }
    i = BIO_read(keybio, (void*)salt, (int)sizeof(salt));
    EVP_BytesToKey(cipher, digest, salt,
            (unsigned char*)passwd, (int)strlen(passwd), 1, cryptkey, cryptiv);
    OPENSSL_cleanse(passwd, sizeof(passwd));

    blkenc = BIO_new(BIO_f_cipher());
    BIO_set_cipher(blkenc, cipher, cryptkey, cryptiv, 0);

    BIO_push(blkenc, keybio);

    /* read and decrypt data from keyfile: */
    for (;;) {
        lmt = (keyinfo->maxlen > 0 && (*keylen + BUFFSZ) > keyinfo->maxlen
                ? keyinfo->maxlen - *keylen : BUFFSZ);
        len = BIO_read(blkenc, (void*)buff, lmt);
        if (len <= 0) break;

        /* copy new block of data onto end of current key: */
        *key = (unsigned char*)sec_realloc((void*)*key, (size_t)(*keylen+len));
        memcpy((void*)(*key + *keylen), (const void*)buff, (size_t)len);
        *keylen += len;
    }

    if ((i = ERR_peek_last_error()) != 0) {
        fprintf(stderr, _("key-extraction failed [%x] for \"%s\"\n"),
                i, keyinfo->filename);
        eflag = ERR_BADDECRYPT;
    }

    if (blkenc != NULL) BIO_free_all(blkenc);

  bail_out:

    return eflag;
}


static int kmssl_put_key(const char *ident, const keyinfo_t *keyinfo,
            const unsigned char *key, const int keylen, FILE *fp_key)
    /* store key in openssl-encrypted file */
{   enum { BUFFSZ=8192 };
    BIO *keybio=NULL,*blkenc=NULL;
    char passwd[256],buff[BUFFSZ];
    unsigned char salt[PKCS5_SALT_LEN], cryptkey[EVP_MAX_KEY_LENGTH],
                    cryptiv[EVP_MAX_IV_LENGTH];
    int i,len,lmt,pos,eflag=ERR_NOERROR;
    const EVP_CIPHER *cipher=NULL;
    const EVP_MD *digest=NULL;

    eflag = kmssl_get_algos(keyinfo, &cipher, &digest);
    if (eflag != ERR_NOERROR) goto bail_out;

    BIO_snprintf(buff, sizeof(buff),
                _("enter new password for target \"%s\": "), ident);
#ifdef TESTING
    strncpy(passwd, (test_ctxtptr->argpassword[1] != NULL ? test_ctxtptr->argpassword[1] : ""),
            sizeof(passwd));
#else
    if (EVP_read_pw_string(passwd, (int)sizeof(passwd), buff, 1) != 0) {
        eflag = ERR_BADPASSWD;
        goto bail_out;
    }
#endif

    keybio = BIO_new_fp(fp_key, BIO_NOCLOSE);
    if (keybio == NULL) {
        fprintf(stderr, _("failed to create file handle\n"));
        eflag = ERR_BADFILE;
        goto bail_out;
    }

    len = sizeof(ssl_saltstr) - 1;
    if (BIO_write(keybio, ssl_saltstr, len) != len) {
        fprintf(stderr, _("bad keyfile \"%s\"\n"), keyinfo->filename);
        eflag = ERR_BADFILE;
        goto bail_out;
    }
    get_randkey(salt, sizeof(salt));
    BIO_write(keybio, (const void*)salt, (int)sizeof(salt));

    EVP_BytesToKey(cipher, digest, salt,
            (unsigned char*)passwd, (int)strlen(passwd), 1, cryptkey, cryptiv);
    OPENSSL_cleanse(passwd, sizeof(passwd));

    blkenc = BIO_new(BIO_f_cipher());
    BIO_set_cipher(blkenc, cipher, cryptkey, cryptiv, 1);

    BIO_push(blkenc, keybio);

    /* encrypt and write data into keyfile: */
    for (pos=0; pos<keylen; ) {
        lmt = ((pos + BUFFSZ) > keylen ? (keylen - pos) : BUFFSZ);
        len = BIO_write(blkenc, (const void*)(key + pos), lmt);
        if (len <= 0) break;

        pos += len;
    }

    if ((i = BIO_flush(blkenc)) == 0) {
        fprintf(stderr, _("key-writing failed [%d] for \"%s\"\n"),
                i, keyinfo->filename);
        eflag = ERR_BADENCRYPT;
    }

  bail_out:

    if (blkenc != NULL) BIO_free_all(blkenc);

    return eflag;
}


static void *kmssl_md_prepare(void)
{   EVP_MD_CTX *mdcontext;

    mdcontext = (EVP_MD_CTX*)malloc(sizeof(EVP_MD_CTX));
    EVP_MD_CTX_init(mdcontext);
    EVP_DigestInit(mdcontext, EVP_sha1());

    return (void*)mdcontext;
}


static void kmssl_md_block(void *state, unsigned char *buff, size_t len)
{   EVP_MD_CTX *mdcontext=(EVP_MD_CTX*)state;

    EVP_DigestUpdate(mdcontext, buff, (unsigned)len);
}


static void kmssl_md_final(void *state, unsigned char **mdval, size_t *mdlen)
{   EVP_MD_CTX *mdcontext=(EVP_MD_CTX*)state;
    unsigned umdlen;

    *mdval = (unsigned char*)malloc((size_t)(EVP_MAX_MD_SIZE * sizeof(unsigned char)));
    EVP_DigestFinal(mdcontext, *mdval, &umdlen);
    *mdlen = umdlen;
}


static void kmssl_md_release(void *state)
{   EVP_MD_CTX *mdcontext=(EVP_MD_CTX*)state;

    EVP_MD_CTX_cleanup(mdcontext);
    free((void*)mdcontext);
}


#  ifdef TESTING

static int kmssl_test_hash(void)
{   void *mdcontext;
    unsigned char *mdval=NULL;
    size_t mdlen, i;
    unsigned q;
    const char *str="random\n";
    const char *hash="5d64b71392b1e00a3ad893db02d381d58262c2d6";

    CM_TEST_START("OpenSSL hashing");

    mdcontext = kmssl_md_prepare();
    CM_ASSERT_DIFFERENT(NULL, mdcontext);
    kmssl_md_block(mdcontext, (unsigned char*)str, strlen(str));
    kmssl_md_final(mdcontext, &mdval, &mdlen);
    CM_ASSERT_DIFFERENT(NULL, mdval);
    CM_ASSERT_EQUAL(strlen(hash)/2, mdlen);
    for (i=0; i<mdlen; ++i) {
        sscanf(hash+2*i, "%2x", &q);
        CM_ASSERT_EQUAL(q, (unsigned)mdval[i]);
    }

    kmssl_md_release(mdcontext);

    CM_TEST_OK();
}

static void kmssl_testctxt(cm_testinfo_t *context)
{
    test_ctxtptr = context;
}

static int kmssl_runtests(void)
{
    kmssl_init_algs();
    kmssl_test_hash();
    kmssl_test_getalgos();
    kmssl_free_algs();

    return 0;
}

#  endif    /* TESTING */


keymanager_t keymgr_ssl = {
    "openssl", 0,       kmssl_init_algs, kmssl_free_algs,
                        kmssl_mk_default, kmssl_is_compat, kmssl_needs_pw,
                        kmssl_get_key, kmssl_put_key,
                        kmssl_md_prepare, kmssl_md_block,
                        kmssl_md_final, kmssl_md_release,
    NULL
#ifdef TESTING
    , kmssl_testctxt, kmssl_runtests
#endif
};

#  endif    /* !USE_MODULES || defined(AS_MODULE) */
#endif  /* HAVE_OPENSSL */


#ifndef AS_MODULE

#if defined(TESTING)
#  define MOD_PATH "./cm-ssl.so"
#else
#  define MOD_PATH CM_MODULE_DIR "/cm-ssl.so"
#endif

keymanager_t *kmssl_gethandle()
{
#if HAVE_OPENSSL
#  if USE_MODULES
    KM_GETHANDLE(MOD_PATH, "keymgr_ssl");
#  else
    return &keymgr_ssl;
#  endif
#else
    return NULL;
#endif
}

#endif  /* !AS_MODULE */

/*
 *  (C)Copyright 2005-2007, RW Penney
 */
