"""Module implementing a SQL based repository for cfvers"""

# Copyright 2003 Iustin Pop
#
# This file is part of cfvers.
#
# cfvers is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# cfvers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with cfvers; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import base64
import quopri
import os
import re
import bz2
import zlib

BACKEND_SQLITE   = 1
BACKEND_POSTGRES = 2
BACKEND_GADFLY   = 3
_AVBACK = []
_BACKENDS = [(BACKEND_SQLITE, "sqlite"),
             (BACKEND_POSTGRES, "pgdb"),
             (BACKEND_GADFLY, "gadfly"),
             ]
for (code, name) in _BACKENDS:
    try:
        exec("import %s" % name)
    except ImportError:
        pass
    else:
        _AVBACK.append(code)

__all__ = ["BACKEND_SQLITE", "BACKEND_POSTGRES", "BACKEND_GADFLY",
           "Repository",
           ]

from main import *

class Repository(object):
    def __init__(self, create=False, cnxargs=None):
        if type(self) is Repository:
            raise TypeError, "You can't instantiate a Repository!"
        return

    def _parse(cnxargs):
        m = re.match('^([^:]+):(.*)', cnxargs)
        if m is not None:
            bk = m.group(1)
            data = m.group(2)
            if bk == 'sqlite':
                return RepoSqlite, data
            elif bk == 'postgres':
                return RepoPostgres, data
            elif bk == 'gadfly':
                return RepoGadfly, data
            else:
                raise ValueError, "Unknown repository type %s" % bk
            return
        else:
            raise ValueError, "Unknown repository path %s" % cnxargs
        return
    
    _parse = staticmethod(_parse)

    def open(cnxargs=None, create=False):
        rclass, rdata = Repository._parse(cnxargs)
        return rclass(cnxargs=rdata, create=create)

    open = staticmethod(open)

    def _create(self, doarea=True):
        cursor = self.conn.cursor()
        self._init_schema(cursor)
        if doarea:
            self._init_data(cursor)
        self.conn.commit()
        return
    
    def _init_schema(self, cursor):
        cursor.execute("CREATE TABLE areas ( " \
                       "id INTEGER PRIMARY KEY, " \
                       "server TEXT, " \
                       "name TEXT, " \
                       "root TEXT, " \
                       "revno INTEGER, " \
                       "ctime TIMESTAMP, " \
                       "description TEXT)")
        cursor.execute("CREATE UNIQUE INDEX areas_sn_idx ON " \
                       "areas (server, name)")
        cursor.execute("CREATE TABLE arearevs (" \
                       "area INTEGER, " \
                       "revno INTEGER, " \
                       "logmsg TEXT, " \
                       "ctime TIMESTAMP, " \
                       "uid INTEGER, " \
                       "gid INTEGER, " \
                       "commiter TEXT)")
        cursor.execute("CREATE UNIQUE INDEX arearevs_ar_idx ON " \
                       "arearevs (area, revno)")
        cursor.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, area " \
                       "INTEGER, name TEXT, ctime TIMESTAMP)")
        cursor.execute("CREATE UNIQUE INDEX items_an_idx ON " \
                       "items (area, name)")
        cursor.execute("CREATE TABLE revisions (item INTEGER, " \
                       "revno INTEGER, filename TEXT, filetype INTEGER, " \
                       "filecontents TEXT, mode INTEGER, mtime INTEGER, " \
                       "atime INTEGER, uid INTEGER, gid INTEGER, " \
                       "rdev INTEGER, encoding TEXT)")
        cursor.execute("CREATE UNIQUE INDEX revisions_ir_idx ON " \
                       "revisions (item, revno)")
        return
    
    def _init_data(self, cursor):
        a = Area(server=os.uname()[1], name="default", description="Default area", root="/")
        self.addArea(a)
        return

    def close(self):
        self.conn.close()
        return

    def commit(self):
        self.conn.commit()
        return

    def rollback(self):
        self.conn.rollback()
        return
        
    def getItem(self, id):
        cursor = self.conn.cursor()
        cursor.execute("select id, area, name, ctime from items where id = %s", (id,))
        row = cursor.fetchone()
        if row is None:
            return None
        i = Item(id=row[0], area=self.getArea(row[1]), name=row[2], ctime=row[3])
        return i

    def getItemByName(self, area, name):
        cursor = self.conn.cursor()
        cursor.execute("select id, area, name, ctime from items where area = %s and name = %s", (area, name))
        row = cursor.fetchone()
        if row is None:
            return None
        i = Item(row[0], self.getArea(row[1]), row[2], row[3])
        return i

    def addItem(self, item):
        self.conn.cursor().execute("insert into items (area, name, ctime) values (%s, %s, %s)",
                                   (item.area.id, item.name, item.ctime))

    def updItem(self, item):
        pass

    def addEntry(self, entry):
        cursor = self.conn.cursor()
        payload, encoding = self._encode_payload(entry.filecontents)
        cursor.execute("insert into revisions (item, revno, filename, filetype, filecontents, mode, mtime, atime, uid, gid, rdev, encoding) values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
                       (entry.item, entry.revno, entry.filename,
                        entry.filetype, payload, entry.mode,
                        entry.mtime, entry.atime, entry.uid, entry.gid,
                        entry.rdev, encoding)
                       )

    def items(self, area=None):
        cursor = self.conn.cursor()
        if area is None:
            alla = self.areas()
            mareas = {}
            for i in alla:
                mareas[i.id] = i
            cursor.execute("select id, area, name, ctime from items")
            items = [Item(row[0], mareas[row[1]], row[2], row[3]) for row in cursor.fetchall()]
        else:
            cursor.execute("select id, area, name, ctime from items where area = %s", (area.id,))
            items = [Item(row[0], area, row[2], row[3]) for row in cursor.fetchall()]
        return items
        
    def areas(self):
        cursor = self.conn.cursor()
        if self.backend == BACKEND_POSTGRES:
            cursor.execute("select id, server, name, root, ctime, revno, (select count(*) from items where items.area = areas.id) as nitems from areas")
            areas = [Area(id=row[0], server=row[1], name=row[2], root=row[3], ctime=row[4], revno=row[5], numitems = row[6]) for row in cursor.fetchall()]
        elif self.backend == BACKEND_SQLITE:
            cursor.execute("select id, server, name, root, ctime, revno from areas")
            c2 = self.conn.cursor()
            areas = []
            for row in cursor.fetchall():
                c2.execute("select count(*) from items where items.area = %s", (row[0],))
                areas.append(Area(id=row[0], server=row[1], name=row[2], root=row[3], ctime=row[4], revno=row[5], numitems = int(c2.fetchone()[0])))
        else:
            raise ValueError, "Backend not handled"
        return areas
    
    def getEntry(self, item, revno):
        cursor = self.conn.cursor()
        if revno is None:
            cursor.execute("select item, revno, filename, filecontents, filetype, mode, mtime, atime, uid, gid, rdev, encoding from revisions where item = %s order by revno desc limit 1", (item.id,))
        else:
            cursor.execute("select item, revno, filename, filecontents, filetype, mode, mtime, atime, uid, gid, rdev, encoding from revisions where item = %s and revno <= %s order by revno desc limit 1", (item.id, revno))
        row = cursor.fetchone()
        if row is None:
            return None
        rev = RevEntry()
        (rev.item, rev.revno, rev.filename, payload, rev.filetype, rev.mode, rev.mtime, rev.atime, rev.uid, rev.gid, rev.rdev, encoding) = row
        rev.filecontents = self._decode_payload(payload, encoding)
        return rev
    
    def _encode_payload(self, payload):
        encoding = ""
        ndata = bz2.compress(payload)
        if len(ndata) < len(payload):
            payload = ndata
            encoding = "bzip2:%s" % encoding
        if payload.find("\0") != -1:
            # payload contains embedded nulls
            b64 = base64.encodestring(payload)
            qp = quopri.encodestring(payload, quotetabs=1)
            if len(b64) < len(qp): # the file is mostly binary
                encoding = "base64:%s" % encoding
                payload = b64
            else:
                encoding = "quoted-printable:%s" % encoding
                payload = qp
        return payload, encoding
        
    def _decode_payload(self, payload, encoding):
        for enc in encoding.split(":"):
            if enc is None or enc == "":
                break
            elif enc == "base64":
                payload = base64.decodestring(payload)
            elif enc == "quoted-printable":
                payload = quopri.decodestring(payload)
            elif enc == "bzip2":
                payload = bz2.decompress(payload)
            elif enc == "gzip":
                payload = zlib.decompress(payload)
            else:
                raise ValueError, "Unknown encoding '%s'!" % enc
        return payload
        
    def getRevList(self, item):
        revs = []
        cursor = self.conn.cursor()
        cursor.execute("select item, revno, filename, filetype, mode, mtime, atime, uid, gid, rdev from revisions where item = %s order by revno desc", (item.id,))
        for row in cursor.fetchall():
            rev = RevEntry()
            (rev.item, rev.revno, rev.filename, rev.filetype, rev.mode, rev.mtime, rev.atime, rev.uid, rev.gid, rev.rdev) = row
            revs.append(rev)
        return revs

    def getRevNumbers(self, item):
        cursor = self.conn.cursor()
        cursor.execute("select revno from revisions where item = %s order by revno", (item.id,))
        return [x[0] for x in cursor.fetchall()]

    def addArea(self, a):
        self.conn.cursor().execute("insert into areas (server, name, root, revno, ctime) values (%s, %s, %s, %s, %s)",
                                   (a.server, a.name, a.root, a.revno, a.ctime))

    def updArea(self, a):
        #self.conn.cursor().execute("insert into areas (server, name, root, revno) values (%s, %s, %s, 0)", ("localhost", "Default Area", "/"))
        self.conn.cursor().execute("update areas set server = %s, name = %s, description = %s, revno = %s where id = %s",
                                   (a.server, a.name, a.description, a.revno, a.id))

    def getArea(self, id):
        cursor = self.conn.cursor()
        cursor.execute("select id, server, name, root, ctime, revno, (select count(*) from items where items.area = %s) as nitems from areas where id = %s", (id,id))
        row = cursor.fetchone()
        if row is None:
            return None
        a = Area(id=row[0], server=row[1], name=row[2], root=row[3], ctime=row[4], revno=row[5], numitems=row[6])
        return a

    def getAreaByName(self, server, name):
        cursor = self.conn.cursor()
        cursor.execute("select id from areas where server = %s and name = %s", (server, name))
        row = cursor.fetchone()
        if row is None:
            return None
        else:
            return self.getArea(row[0])
        
    def getAreaRevs(self, id):
        cursor = self.conn.cursor()
        cursor.execute("select area, revno, logmsg, ctime, uid, gid, commiter from arearevs where area = %s order by revno desc", (id,))
        ars = []
        for row in cursor.fetchall():
            r = AreaRev()
            (r.area, r.revno, r.logmsg, r.ctime, r.uid, r.gid, r.commiter) = row
            ars.append(r)
        return ars

    def getAreaRevItems(self, ar):
        c2 = self.conn.cursor()
        c2.execute("select items.id from items, revisions where items.area = %s and revisions.revno = %s and revisions.item = items.id",
                   (ar.area, ar.revno))
        return [row[0] for row in c2.fetchall()]

    def putAreaRev(self, ar):
        cursor = self.conn.cursor()
        cursor.execute("insert into arearevs (area, revno, logmsg, ctime, uid, gid, commiter) values (%s, %s, %s, %s, %s, %s, %s)",
                       (ar.area, ar.revno, ar.logmsg, ar.ctime, ar.uid, ar.gid, ar.commiter))
        return
    
    def numAreas(self):
        cursor = self.conn.cursor()
        cursor.execute("select count(1) from areas")
        row = cursor.fetchone()
        return int(row[0])


class RepoSqlite(Repository):
    def __init__(self, create=False, cnxargs=None):
        if not BACKEND_SQLITE in _AVBACK:
            raise ValueError, "SQLite backend not available - check your python installation"
        self.conn = sqlite.connect(cnxargs)
        self.backend = BACKEND_SQLITE
        self.cnxargs = cnxargs
        if create:
            self._create()
        return

    
class RepoGadfly(Repository):
    def __init__(self, create=False, cnxargs=None):
        if not BACKEND_GADFLY in _AVBACK:
            raise ValueError, "Gadfly backend not available - check your python installation"
        self.backend = BACKEND_GADFLY
        self.cnxargs = cnxargs
        if create:
            self.conn = gadfly.gadfly()
            self.conn.startup("gadfly", self.cnxargs)
            self._create()
        else:
            self.conn = gadfly.gadfly("gadfly", self.cnxargs)
        return


class RepoPostgres(Repository):
    def __init__(self, create=False, cnxargs=None):
        if not BACKEND_POSTGRES in _AVBACK:
            raise ValueError, "Postgres backend not available - check your python installation"
        self.conn = pgdb.connect(cnxargs)
        self.backend = BACKEND_POSTGRES
        self.cnxargs = cnxargs
        if create:
            self._create()
        return

    def _init_schema(self, cursor):
        Repository._init_schema(self, cursor)
        cursor.execute("CREATE SEQUENCE areas_id_seq")
        cursor.execute("ALTER TABLE areas ALTER COLUMN id SET DEFAULT nextval('areas_id_seq')")
        cursor.execute("CREATE SEQUENCE items_id_seq")
        cursor.execute("ALTER TABLE items ALTER COLUMN id SET DEFAULT nextval('items_id_seq')")
        cursor.execute("ALTER TABLE items ADD constraint items_area_fk FOREIGN KEY (area) REFERENCES areas(id)")
        cursor.execute("ALTER TABLE revisions ADD CONSTRAINT revisions_item_fk FOREIGN KEY (item) REFERENCES items(id)")
        cursor.execute("ALTER TABLE arearevs ADD CONSTRAINT arearevs_area_fk FOREIGN KEY (area) REFERENCES areas(id)")        
        return
