# Schedwi
# Copyright (C) 2013 Herve Quatremain
#
# This file is part of Schedwi.
#
# Schedwi is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Schedwi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Cluster functions."""

import sqlalchemy.orm.session
from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound

from tables.clusters import clusters
from tables.host_clusters import host_clusters
from tables.hosts import hosts


def name2cluster_list(sql_session, name):
    """Convert a cluster name pattern to a list cluster database objects.

    @param sql_session:
                SQLAlchemy session (it can be an opened session)
    @param name:
                cluster names to look for. Can contain wildcards `*' and `?'.
    @return:    list of L{tables.clusters.clusters} database objects.
    """
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        session = sql_session.open_session()
    else:
        session = sql_session
    c = name.replace("*", "%").replace("?", "_")
    c = c.decode('utf-8')
    q = session.query(clusters).filter(clusters.name.like(c))
    cluster_list = q.order_by(clusters.name).all()
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        sql_session.close_session(session)
    return cluster_list


def name2cluster(sql_session, name):
    """Convert a cluster name to cluster database object.

    @param sql_session:
                SQLAlchemy session (it can be an opened session)
    @param name:
                cluster names to look for. Can contain wildcards `*' and `?'.
    @return:    the L{tables.clusters.clusters} database object.
    @raise sqlalchemy.orm.exc.NoResultFound:
                cluster not found.
    @raise sqlalchemy.orm.exc.MultipleResultsFound:
                more than one cluster match the given name.
    """
    cluster_list = name2cluster_list(sql_session, name)
    if not cluster_list:
        raise NoResultFound
    if len(cluster_list) > 1:
        raise MultipleResultsFound
    return cluster_list[0]


def num_host_in_cluster(sql_session, cluster_id):
    """Return the number of hosts in the given cluster.

    @param sql_session:
                SQLAlchemy session (it can be an opened session)
    @param cluster_id:
                cluster database ID.
    @return:
                the number of hosts associated with the cluster
    @raise sqlalchemy.orm.exc.NoResultFound:
                cluster not found.
    """
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        session = sql_session.open_session()
    else:
        session = sql_session
    q = session.query(host_clusters)
    try:
        num_hosts = q.filter(host_clusters.cluster_id == cluster_id).count()
    except:
        if not isinstance(sql_session, sqlalchemy.orm.session.Session):
            sql_session.close_session(session)
        raise
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        sql_session.close_session(session)
    return num_hosts


def get_hosts_in_cluster(sql_session, cluster_id):
    """Return the hosts in the given cluster.

    @param sql_session:
                SQLAlchemy session (it can be an opened session)
    @param cluster_id:
                cluster database ID.
    @return:
                the list of L{tables.hosts.hosts} objects.
    """
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        session = sql_session.open_session()
    else:
        session = sql_session
    q = session.query(hosts).filter(hosts.id == host_clusters.host_id)
    q = q.filter(host_clusters.cluster_id == cluster_id)
    host_list = q.order_by(hosts.hostname).all()
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        sql_session.close_session(session)
    return host_list


def get_cluster_by_id(sql_session, cluster_id):
    """Retrieve a cluster from the database by its ID.

    @param sql_session:
                SQLAlchemy session (it can be an opened session)
    @param cluster_id:
                the cluster ID to look for in the database.
    @return:    the L{tables.clusters.clusters} database object.
    @raise sqlalchemy.orm.exc.NoResultFound:
                the given cluster ID is not in the database.
    """
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        session = sql_session.open_session()
    else:
        session = sql_session
    try:
        c = session.query(clusters).filter(clusters.id == cluster_id).one()
    except:
        if not isinstance(sql_session, sqlalchemy.orm.session.Session):
            sql_session.close_session(session)
        raise
    if not isinstance(sql_session, sqlalchemy.orm.session.Session):
        sql_session.close_session(session)
    return c
