/*
Copyright (C) 2015 John Tse

This file is part of Libknit.

Libknit is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Libknit is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Libknit.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "_knit.h"

static uint8_t _knit_hmac(KNIT knit, void *m, size_t m_length, uint8_t **t) {
        uint8_t * (*H)(uint8_t *, uint64_t);
        uint16_t B;
        uint16_t L;

        switch (knit->hash) {
                case KNIT_MD2:
                        H = md2;
                        B = KNIT_MD2_BLOCK_SIZE;
                        L = KNIT_MD2_DIGEST_SIZE;

                        break;
                case KNIT_MD4:
                        H = md4;
                        B = KNIT_MD4_BLOCK_SIZE;
                        L = KNIT_MD4_DIGEST_SIZE;

                        break;
                case KNIT_MD5:
                        H = md5;
                        B = KNIT_MD5_BLOCK_SIZE;
                        L = KNIT_MD5_DIGEST_SIZE;

                        break;
                default:
                        return KNIT_ERROR_NO_ALGO;
        }

        *t = hmac(H, B, L, knit->k, knit->k_length, m, m_length);

        return KNIT_ERROR_OK;
}

static void _knit_crypt_block(KNIT knit, uint8_t op, void *i, size_t i_length, uint8_t **o, size_t *o_length) {
	uint32_t block_size;
	uint32_t n_blocks;
	void (*encrypt)(uint64_t, uint64_t *);
	void (*decrypt)(uint64_t, uint64_t *);

	encrypt = NULL;
	decrypt = NULL;

	switch (knit->cipher) {
		case KNIT_BLOWFISH:
			blowfish_key_expansion(knit->k, knit->k_length);

			block_size = KNIT_BLOWFISH_BLOCK_SIZE;
			encrypt    = blowfish_encrypt;
			decrypt    = blowfish_decrypt;

			break;
		case KNIT_RC2:
			rc2_key_expansion(knit->k, knit->k_length, knit->ekb ? knit->ekb : knit->k_length * 8);

			block_size = KNIT_RC2_BLOCK_SIZE;
			encrypt    = rc2_encrypt;
			decrypt    = rc2_decrypt;

			break;
	}

	n_blocks = (i_length + block_size - 1) / block_size;

	*o = calloc(n_blocks, block_size);

	switch (knit->mode) {
		default:
		case KNIT_ECB:
			ecb(op == KNIT_ENCRYPT ? encrypt : decrypt, block_size, n_blocks, i, *o);
			break;
		case KNIT_CBC:
			if (op == KNIT_ENCRYPT)
                        	cbc_encrypt(encrypt, block_size, n_blocks, s2i(knit->iv, knit->iv_length), i, *o);
			else
				cbc_decrypt(decrypt, block_size, n_blocks, s2i(knit->iv, knit->iv_length), i, *o);
                        break;
		case KNIT_OFB:
			ofb(encrypt, block_size, n_blocks, s2i(knit->iv, knit->iv_length), i, *o);
			break;
	}

	if (o_length != NULL)
		*o_length = n_blocks * block_size;
}

uint8_t _knit_crypt(KNIT knit, uint8_t op, void *i, size_t i_length, uint8_t **o, size_t *o_length) {
	uint16_t block_size;
	uint16_t min_key_size;
	uint16_t max_key_size;
	uint8_t cipher;

	if (!knit->k_length)
		return KNIT_ERROR_NO_KEY;

	switch (knit->cipher) {
                case KNIT_BLOWFISH:
			block_size   = KNIT_BLOWFISH_BLOCK_SIZE;
			min_key_size = KNIT_BLOWFISH_MIN_KEY_SIZE;
			max_key_size = KNIT_BLOWFISH_MAX_KEY_SIZE;
			cipher       = _KNIT_BLOCK_CIPHER;

			break;
                case KNIT_RC2:
			block_size   = KNIT_RC2_BLOCK_SIZE;
			min_key_size = KNIT_RC2_MIN_KEY_SIZE;
			max_key_size = KNIT_RC2_MAX_KEY_SIZE;
			cipher       = _KNIT_BLOCK_CIPHER;

			break;
                case KNIT_RC4:
			min_key_size = KNIT_RC4_MIN_KEY_SIZE;
			max_key_size = KNIT_RC4_MAX_KEY_SIZE;
			cipher       = _KNIT_STREAM_CIPHER;

                        break;
                default:
                        return KNIT_ERROR_NO_ALGO;
        }

	if (knit->k_length < min_key_size)
		return KNIT_ERROR_MIN_KEY_SIZE;
	else if (knit->k_length > max_key_size)
		knit->k_length = max_key_size;

	if (cipher == _KNIT_BLOCK_CIPHER) {
		if (knit->mode == KNIT_CBC) {
			if (!knit->iv_length)
				return KNIT_ERROR_NO_IV;
			else if (knit->iv_length > block_size)
				knit->iv_length = block_size;
		}

		_knit_crypt_block(knit, op, i, i_length, o, o_length);
	} else if (cipher == _KNIT_STREAM_CIPHER) {
		rc4_crypt(knit->k, knit->k_length, i, i_length, o);

		if (o_length != NULL)
			*o_length = i_length;
	}

	return KNIT_ERROR_OK;
}

uint8_t _knit_encrypt(KNIT knit, void *p, size_t p_length, uint8_t **c, size_t *c_length) {
        return _knit_crypt(knit, KNIT_ENCRYPT, p, p_length, c, c_length);
}

uint8_t _knit_decrypt(KNIT knit, void *c, size_t c_length, uint8_t **p, size_t *p_length) {
        return _knit_crypt(knit, KNIT_DECRYPT, c, c_length, p, p_length);
}

uint8_t _knit_hash(KNIT knit, void *m, size_t m_length, uint8_t **d) {
        uint8_t * (*H)(uint8_t *, uint64_t);

        H = NULL;

        switch (knit->hash) {
                case KNIT_MD2:
                        H = md2;

                        break;
                case KNIT_MD4:
                        H = md4;

                        break;
                case KNIT_MD5:
                        H = md5;

                        break;
                default:
                        return KNIT_ERROR_NO_ALGO;
        }

        *d = (*H)(m, m_length);

        return KNIT_ERROR_OK;
}

uint8_t _knit_mac(KNIT knit, void *m, size_t m_length, uint8_t **t) {
        if (!knit->k_length)
                return KNIT_ERROR_NO_KEY;

        switch (knit->mac) {
                case KNIT_HMAC:
                        return _knit_hmac(knit, m, m_length, t);
        }

        return KNIT_ERROR_OK;
}
