# Schedwi
# Copyright (C) 2011-2013 Herve Quatremain
#
# This file is part of Schedwi.
#
# Schedwi is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Schedwi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


"""SSL functions (abstraction layer)."""

import time
import datetime
import calendar
try:
    from gnutls.crypto import X509Certificate
    import gnutls.errors
except ImportError:
    from M2Crypto import X509
    WITH_GNUTLS = False
else:
    WITH_GNUTLS = True

import babel.dates
import locale_utils

STRING, FILE = 0, 1


class schedwiX509Error(Exception):
    pass


if WITH_GNUTLS:
    class schedwiX509:
        def __init__(self, str_cert, type_crt=STRING):
            """Raise IOError or schedwiX509Error"""
            try:
                if type_crt == STRING:
                    self.cert = X509Certificate(str_cert)
                else:
                    self.cert = X509Certificate(open(str_cert).read())
            except gnutls.errors.GNUTLSError as e:
                    raise schedwiX509Error(e)

        def get_subject(self):
            return self.cert.subject

        def get_list_dns(self):
            return self.cert.alternative_names.dns

        def get_dns(self):
            alt = self.cert.alternative_names
            if alt and alt.dns:
                return ", ".join(alt.dns)
            return ''

        def get_not_before(self):
            dt = datetime.datetime.fromtimestamp(self.cert.activation_time)
            return babel.dates.format_datetime(dt, format="medium",
                            locale=locale_utils.get_locale()).encode('utf-8')

        def get_not_after(self):
            dt = datetime.datetime.fromtimestamp(self.cert.expiration_time)
            return babel.dates.format_datetime(dt, format="medium",
                            locale=locale_utils.get_locale()).encode('utf-8')

        def get_CN(self):
            return self.cert.subject.CN

        def not_yet_active(self):
            return time.time() < self.cert.activation_time

        def has_expired(self):
            return time.time() >= self.cert.expiration_time

        def matches_hostname(self, hostname):
            """Test if the provided hostname matches the certificate."""
            if not hostname:
                return True
            h = hostname.strip().lower()
            if h == self.get_CN().strip().lower():
                return True
            for d in self.get_list_dns():
                if h == d.strip().lower():
                    return True
else:
    class schedwiX509:
        def __init__(self, str_cert, type_crt=STRING):
            """Raise IOError or schedwiX509Error"""
            try:
                if type_crt == STRING:
                    self.cert = X509.load_cert_string(str_cert)
                else:
                    self.cert = X509.load_cert(str_cert)
            except X509.X509Error as e:
                raise schedwiX509Error(e)

        def get_subject(self):
            return ','.join(str(self.cert.get_subject()).split('/')[1:])

        def get_list_dns(self):
            dns = list()
            try:
                ext = self.cert.get_ext('subjectAltName')
            except:
                return dns
            if ext:
                for e in ext.get_value().split(', '):
                    n, v = e.split(':', 1)
                    if n == "DNS" or 'address' in n.lower():
                        dns.append(v)
            return dns

        def get_dns(self):
            return ", ".join(self.get_list_dns())

        def get_not_before(self):
            dt = datetime.datetime.strptime(str(self.cert.get_not_before()),
                                            "%b %d %H:%M:%S %Y %Z")
            return babel.dates.format_datetime(dt, format="medium",
                            locale=locale_utils.get_locale()).encode('utf-8')

        def get_not_after(self):
            dt = datetime.datetime.strptime(str(self.cert.get_not_after()),
                                            "%b %d %H:%M:%S %Y %Z")
            return babel.dates.format_datetime(dt, format="medium",
                            locale=locale_utils.get_locale()).encode('utf-8')

        def get_CN(self):
            return self.cert.get_subject().CN

        def not_yet_active(self):
            before_time = self.cert.get_not_before()
            before_tuple = time.strptime(str(before_time),
                                         "%b %d %H:%M:%S %Y %Z")
            starts = datetime.timedelta(seconds=calendar.timegm(before_tuple))
            now = datetime.timedelta(seconds=time.time())
            # Cert is not yet valid
            time_delta = now - starts
            if time_delta.days < 0:
                return True
            return False

        def has_expired(self):
            after_time = self.cert.get_not_after()
            after_tuple = time.strptime(str(after_time),
                                        "%b %d %H:%M:%S %Y %Z")
            expires = datetime.timedelta(seconds=calendar.timegm(after_tuple))
            now = datetime.timedelta(seconds=time.time())
            time_delta = expires - now
            # cert has expired
            if time_delta.days < 0:
                return True
            return False

        def matches_hostname(self, hostname):
            """Test if the provided hostname matches the certificate."""
            if not hostname:
                return True
            h = hostname.strip().lower()
            if h == self.get_CN().strip().lower():
                return True
            for d in self.get_list_dns():
                if h == d.strip().lower():
                    return True
