"""Module implementing a SQLite based repository for cfvers"""

# Copyright 2003-2005 Iustin Pop
#
# This file is part of cfvers.
#
# cfvers is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# cfvers is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with cfvers; if not, write to the Free Software Foundation,
# Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

# $Id: r_sqlite.py 218 2005-10-30 09:26:23Z iusty $


from pysqlite2 import dbapi2 as sqlite
#import sqlite
import os.path
import fnmatch

import cfvers.repository.sql
import cfvers
from cfvers.repository.sql import Param
from cfvers.main import ProgrammingException

class RSqlite(cfvers.repository.sql.RSql):
    def __init__(self, create=False, cnxargs=None, createopts=None):
        self._quote = self._quote_qmark
        self.conn = sqlite.connect(cnxargs)
        self.backend = "sqlite"
        self.cnxargs = cnxargs
        if create:
            self._create(createopts=createopts)
        cursor = self.conn.cursor()
        self._check_schema(cursor)
        cursor.close()
        self.conn.create_function("basename", 1, os.path.basename)
        self.conn.create_function("fnmatch", 2, fnmatch.filter)
        return

    def getAreas(self):
        cursor = self.conn.cursor()
        cursor.execute("select name, root, ctime, description from areas")
        c2 = self.conn.cursor()
        areas = []
        for row in cursor.fetchall():
            c2.execute("select count(*) from items where items.area = ?", (row[0],))
            nitems = int(c2.fetchone()[0])
            c2.execute("select max(revno) from revisions where area = ?", (row[0],))
            revno = c2.fetchone()[0]
            if revno is not None:
                revno = int(revno)
            areas.append(cfvers.Area(name=row[0], root=row[1], ctime=row[2],
                                     numitems=nitems, revno=revno,
                                     description=row[3]))
        cursor.close()
        c2.close()
        return areas
    
    def getArea(self, name):
        cursor = self.conn.cursor()
        self._exec(cursor, "select name, root, ctime, description from areas where name = ", Param(name))
        row = cursor.fetchone()
        if row is None:
            return None
        c2 = self.conn.cursor()
        self._exec(c2, "select count(*) from items where items.area = ", Param(name))
        nitems = int(c2.fetchone()[0])
        self._exec(c2, "select max(revno) from revisions where area = ", Param(name))
        revno = c2.fetchone()[0]
        if revno is not None:
            revno = int(revno)
        area = cfvers.Area(name=row[0], root=row[1], ctime=row[2],
                           numitems=nitems, revno=revno,
                           description=row[3])
        cursor.close()
        c2.close()
        return area

    def TimestampFromMx(mxtimestamp):
        return str(mxtimestamp)

    TimestampFromMx = staticmethod(TimestampFromMx)

    def getEntries(self, options):
        mall = getattr(options, "allentries", False)
        options.allentries = True
        elist = list(super(type(self), self).getEntries(options))
        if not mall:
            em = {}
            for e in elist:
                # FIXME maybe key is only e.item since item unique over areas
                key = (e.areaname, e.item)
                em[key] = max(em.get(key, -1), e.revno)
            nlist = []
            for e in elist:
                key = (e.areaname, e.item)
                if e.revno == em[key]:
                    nlist.append(e)
            elist = nlist
        options.allentries = mall
        return elist

    def _encode_binary(value):
        if not isinstance(value, str):
            raise ProgrammingException("Invalid type passed to _encode_binary: %s" % type(value))
        return buffer(value)

    _encode_binary = staticmethod(_encode_binary)

    def _decode_binary(value):
        return str(value)

    _decode_binary = staticmethod(_decode_binary)

    def putRevision(self, r):
        """Adds a Revision to the database and fills its revno

        It is important to note that only this function can return a
        valid new revision number. Until stored in the repository, the
        new revision number cannot be safely determined.

        """
        cursor = self.conn.cursor()
        # Lock the areas row in order that the next query performs correctly
        self._exec(cursor, "select * from areas where name = ", Param(r.area))
        self._exec(cursor, "insert into revisions (area, revno, logmsg, ctime, uid, gid, uname, gname, commiter, server) values (", Param(r.area), ", (select coalesce(max(revno), 0)+1 from revisions where area = ", Param(r.area), "), ", Param(r.logmsg), ",", Param(self.TimestampFromMx(r.ctime)), ",", Param(r.uid), ",", Param(r.gid), ",", Param(r.uname), ",", Param(r.gname), ",", Param(r.commiter), ",", Param(r.server), ")")
        oid = cursor.lastrowid
        self._exec(cursor, "select revno from revisions where oid = ", Param(oid))
        row = cursor.fetchone()
        r.revno = row[0]
        cursor.close()
        return r

    def quotestr(value):
        v = value.replace("'", "''")
        v = v.replace("\\", "\\\\")
        return "'%s'" % v

    quotestr = staticmethod(quotestr)
