# Schedwi
# Copyright (C) 2011-2013 Herve Quatremain
#
# This file is part of Schedwi.
#
# Schedwi is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Schedwi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


"""Module to manage SQLAlchemy sessions."""

from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine


class SqlSession(object):

    _dialects = {
        'mssql': ['pyodbc', 'mxodbc', 'pymssql', 'zxjdbc', 'adodbapi'],
        'pgsql': ['psycopg2', 'pypostgresql', 'pg8000', 'zxjdbc'],
        'mysql': ['mysqldb', 'oursql', 'pymysql', 'mysqlconnector',
                  'gaerdbms', 'pyodbc', 'zxjdbc'],
        'sqlite': ['pysqlite']}

    def __init__(self,
                 drivername='sqlite3',
                 user='schedwi',
                 password=None,
                 hostname='localhost',
                 dbname='schedwidb',
                 dbdir='.'):
        drivername = drivername.lower()
        connect_args = dict()
        urls = list()
        if drivername == 'sqlite3' or drivername == 'sqlite':
            connect_args = {'timeout': 15}
            urls.append('sqlite:///' + dbdir + '/' + dbname)
            for dialect in self._dialects['sqlite']:
                urls.append('sqlite+' + dialect + ':///' +
                            dbdir + '/' + dbname)
        else:
            if password:
                path = user + ':' + password + '@' + hostname + '/' + dbname
            else:
                path = user + '@' + hostname + '/' + dbname
            if drivername == 'pgsql' or drivername == 'postgresql':
                urls.append('postgresql://' + path)
                for dialect in self._dialects['pgsql']:
                    urls.append('postgresql+' + dialect + '://' + path)
            elif (drivername == 'mssql' or drivername == 'msql' or
                  drivername == 'freetds'):
                urls.append('mssql://' + path)
                for dialect in self._dialects['mssql']:
                    urls.append('mssql+' + dialect + '://' + path)
            elif drivername == 'mysql':
                urls.append('mysql://' + path)
                for dialect in self._dialects['mysql']:
                    urls.append('mysql+' + dialect + '://' + path)
            else:
                urls.append(drivername + '://' + path)
        success = False
        for url in urls:
            try:
#               self.engine = create_engine(url, echo=True)
                self.engine = create_engine(url, connect_args=connect_args)
            except Exception as error:
                pass
            else:
                success = True
                break
        if not success:
            raise error
        self.Session = sessionmaker(bind=self.engine)

    def open_session(self):
        return self.Session()

    def close_session(self, session):
        """Commit."""
        session.commit()

    def cancel_session(self, session):
        """Rollback."""
        session.rollback()
