# Schedwi
# Copyright (C) 2011, 2012 Herve Quatremain
#
# This file is part of Schedwi.
#
# Schedwi is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Schedwi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


"""Host functions."""

import socket
import re

import sqlalchemy.orm.session
from sqlalchemy import or_
from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound

from tables.hosts import hosts


def parse_host(name):
    """Parse a hostname.

    @param name:
                hostnames to parse. The given name can contain a port (after
                a `:'. IPv6 addresses must be enclosed in square brakets.
    @return:
                the tuple (hostname, port) or (None, None) in case of error.
                In case of IPv6, the address is returned in `hostname' with
                the square brackets stripped.
    """
    reg = re.compile(r'''^ # start of string
        (?:([^@])+@)? # 1:
        (?:\[([0-9a-fA-F:]+)\]| # 2: IPv6 addr
        ([^\[\]:]+)) # 3: IPv4 addr or reg-name
        (?::(\w+))? # 4: optional port
        $''', re.VERBOSE)  # end of string
    try:
        host = reg.match(name)
    except:
        return (None, None)
    if host is None:
        return (None, None)
    # IPv4
    if host.group(3):
        return (host.group(3), host.group(4))
    # IPv6
    return (host.group(2), host.group(4))


def is_known_port(port):
    """Tell if the given port is a valid one (a number or a known name).

    @param port:
            port number or service name.
    @return:
            true if the given port is valid or None, False otherwise.
   """
    if port is None:
        return True
    try:
        int(port)
    except:
        try:
            socket.getservbyname(port)
        except:
            return False
    return True


def name2host_list(sql_session, name):
    """Convert a host name pattern to a list host database objects.

    @param sql_session:
                SQLAlchemy session (it can be an opened session)
    @param name:
                hostnames to look for. Can contain wildcards `*' and `?' and
                port. IPv6 addresses must be enclosed in square brakets.
    @return:    list of L{tables.hosts.hosts} database objects.
    @raise sqlalchemy.orm.exc.NoResultFound:
                host not found.
    """
    host, port = parse_host(name)
    if host is None:
        host = name
        port = None
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        session = sql_session.open_session()
    else:
        session = sql_session
    h = host.replace("*", "%").replace("?", "_")
    h = h.decode('utf-8')
    q = session.query(hosts).filter(or_(hosts.hostname.like(h),
                                        hosts.hostname.like("%s.%%" % h)))
    if port is not None:
        try:
            p = int(port)
        except:
            try:
                pnum = socket.getservbyname(port)
            except:
                q = q.filter(hosts.portnum == port)
            else:
                q = q.filter(or_(hosts.portnum == port, hosts.portnum == pnum))
        else:
            try:
                pname = socket.getservbyport(p)
            except:
                q = q.filter(hosts.portnum == port)
            else:
                q = q.filter(or_(hosts.portnum == port,
                                 hosts.portnum == pname))
    host_list = q.order_by(hosts.hostname).all()
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        sql_session.close_session(session)

    return host_list


def name2host(sql_session, name):
    """Convert a host name to a host database object.

    @param sql_session:
                SQLAlchemy session (it can be an opened session)
    @param name:
                hostname to look for. Can contain wildcards `*' and `?' and
                port. IPv6 addresses must be enclosed in square brakets.
    @return:    the database L{tables.hosts.hosts} object.
    @raise sqlalchemy.orm.exc.NoResultFound:
                host not found.
    @raise sqlalchemy.orm.exc.MultipleResultsFound:
                more than one host match the given name.
    """
    host_list = name2host_list(sql_session, name)
    if not host_list:
        raise NoResultFound
    if len(host_list) > 1:
        raise MultipleResultsFound
    return host_list[0]


def get_host_by_id(sql_session, host_id):
    """Retrieve a host from the database by its ID.

    @param sql_session:
                SQLAlchemy session (it can be an opened session)
    @param host_id:
                the host ID to look for in the database.
    @return:    the host database object.
    @raise sqlalchemy.orm.exc.NoResultFound:
                the given host ID is not in the database.
    """
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        session = sql_session.open_session()
    else:
        session = sql_session
    try:
        h = session.query(hosts).filter(hosts.id == host_id).one()
    except:
        if not isinstance(sql_session, sqlalchemy.orm.session.Session):
            sql_session.close_session(session)
        raise
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        sql_session.close_session(session)
    return h
