#######################################################################
#  This file is part of GNOWSYS: Gnowledge Networking and
#  Organizing System.
#
#  GNOWSYS is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as
#  published by the Free Software Foundation; either version 3 of
#  the License, or (at your option) any later version.
#
#  GNOWSYS is distributed in the hope that it will be useful, but
#  WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public
#  License along with GNOWSYS (gpl.txt); if not, write to the
#  Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
#  Boston, MA  02110-1301  USA59 Temple Place, Suite 330,
#
######################################################################
# Contributor: "Dinesh Joshi" <dinesh.joshi@yahoo.com>

import sys
import os
import psycopg2
import psycopg2.extensions
import psycopg2.extras

from storageSpec import *
from pgtable import *
from datatypes import *


class tblBase:
    """
    Base class for the genericTable class
    """
    def __init__( self, cur, debug_mode ):
        """
        initializes the cursor and debug mode parameters
        """
        self.cur = cur
        self.debug_mode = debug_mode

    def debug_print( self, clsname, val ):
        """
        prints some debug info on the terminal
        """
        if self.debug_mode != 0:
            print "Debug[%s]: %s " % ( clsname, val )

    def set_intersect( self, set1, set2 ):
        """
        calculates the intersection of two sets
        returns an integer
        """
        return( len( set1.intersection( set2 ) ) )
 
    def get_pkey( self, tblname, pkname ):
        """
        returns the id of the primary key last generated
        """
        query = "SELECT currval( pg_get_serial_sequence( '%s', '%s' ) );" % ( tblname, pkname )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        return rs[0][0]

    def pg_sqlize_value( self, fieldDef, value ):
        """
        Requires the fielddef to be from storageSpec class!!
        """
        self.debug_print( "Fields Def: %s" % fieldDef )
        self.debug_print( "Field val : %s" % value )

        if isinstance( value, list ):
            selValue = "ARRAY%s" % value
        else:
            selValue = "'%s'" % value
        return "%s::%s" % ( selValue, fieldDef[1] )

    def does_value_exist( self, tblname, pkcolname, colname, value ):
        """
        checks:
         1. a given table = tblname
         2. with a primary column = pkcolname
         3. in a column = colname
         4. for a value = value
         
         if not found, returns 0
         if found, returns the value

         note: potential problem when the value stored itself is 0 and does exist :P
        """
        selValue = value

        query = "SELECT %s FROM %s WHERE %s=%s GROUP BY %s;" % ( pkcolname, tblname, colname, selValue, pkcolname )
        self.debug_print( "Query: %s"  % query )
        self.cur.execute( query )
        res = self.cur.fetchall()
        self.debug_print( "Result: %s"  % res )

        if res == []:
            return 0
        else:
            return res[0][0]


class tbl_nodetype( tblBase ):
    """
    This is the class which handles the abstraction to the gbnodetype table
    """
    def __init__( self, cursor ):
        """
        initializes the cursor
        """
        tblBase.__init__( self, cursor, 0 )
        self.cur = cursor

    def getntid( self, nodetype ):
        """
        returns the nodetype id of a given nodetype
        """
        query = "SELECT ntid FROM gbnodetypes WHERE nodename='%s';" % nodetype
        self.cur.execute( query )
        rs = self.cur.fetchall()
        if len( rs ) > 0:
            return rs[0][0]
        else:
            return 0


class tbl_datatypes( tblBase ):
    """
    This is the class which handles the abstraction to the gbdatatypes table
    """
    def __init__( self, cursor ):
        """
        initializes the cursor
        """
        tblBase.__init__( self, cursor, 0 )
        self.cur = cursor

    def getdtid( self, datatype ):
        """
        returns the datatype id of the given datatype
        """
        query = "SELECT datatypeid FROM gbdatatypes WHERE datatypename='%s';" % datatype
        self.cur.execute( query )
        rs = self.cur.fetchall()
        if len( rs ) > 0:
            return rs[0][0]
        else:
            return 0


class tbl_nidinid( tblBase ):
    """
    This class abstracts the interface to the gbnidinid table
    """
    def __init__( self, cursor ):
        """
        initializes the cursor
        """
        tblBase.__init__( self, cursor, 0 )
        self.cur = cursor

    def setval( self, nid, nodetype ):
        """
        inserts ( nid, nodetype ) in the gbnidinid table

        returns 0, if the value ( nid, nodetype ) already exists
        returns the primary key id, if the value is sucessfully inserted
        """

        vt = tbl_values( self.cur, '', 'varchar' )
        vid = vt.insert( nid )

        nt = tbl_nodetype( self.cur )
        ntid = nt.getntid( nodetype )

        query = "INSERT INTO gbnidinid ( nid, inid, ntid ) VALUES ( '%s', DEFAULT, '%s' )" % ( vid, ntid )

        # Does the combo already exist? If not, then...
        if( self.isvalid_nid( nid, nodetype ) == 0 ):
            try:
                self.cur.execute( query )
                return self.get_pkey( 'gbnidinid', 'inid' )
            except StandardError, err:
                print "Error: ", err
        # It does exist! Sorry, no duplicates allowed!
        else:
            print "Sorry, ( %s, %s ) already exists!" % ( nid, nodetype )
            return 0

    def isvalid_nid( self, gbnid, gbnt ):
        """
        function is to check if the given nid,nodetype is valid or not
        """
        query = "select count( gbnidinid.nid ) from gbnidinid, datatypes_varchar, gbnodetypes where gbnidinid.nid = datatypes_varchar.vid AND gbnidinid.ntid = gbnodetypes.ntid AND datatypes_varchar.value = '%s' AND gbnodetypes.nodename='%s';" % ( gbnid, gbnt )

        print "QUERY : " + query
        self.cur.execute( query )
        rs = self.cur.fetchall()

        row = rs[0]
        valid_nid = row[0]

        return valid_nid

    def getinid( self, nid, nodetype ):
        """
        get the inid, given the ( nid, nodetype )
        """
        query = "select gbnidinid.inid from gbnidinid, datatypes_varchar, gbnodetypes where gbnidinid.nid = datatypes_varchar.vid AND gbnidinid.ntid = gbnodetypes.ntid AND datatypes_varchar.value = '%s' AND gbnodetypes.nodename='%s';" % ( nid, nodetype )
        #query = "SELECT inid FROM gbnidinid WHERE nid='%s' AND nodetype='%s'" % ( nid, nodetype )

        self.cur.execute( query )
        rs = self.cur.fetchall()

        if len( rs ) == 0:
            return 0
        else:
            return( rs[0][0] )



class tbl_inidssid( tblBase ):
    """
    Access class for gbinidssid table
    """
    def __init__( self, cursor ):
        """
        get the inid, given the ( nid, nodetype )
        """
        tblBase.__init__( self, cursor, 0 )
        self.cur = cursor
    
    def setval_using_nid( self, nid, nodetype ):
        """
        get a new ssid given the ( nid, nodetype )
        """
        t = tbl_nidinid( self.cur )
        inid = t.getinid( nid, nodetype )

        nt = tbl_nodetype( self.cur )
        ntid = nt.getntid( nodetype )
        
        if inid != 0:
            query = "INSERT INTO gbinidssid ( inid, ssid, ntid ) VALUES( '%s', DEFAULT, '%s' );" % ( inid, ntid )
            self.cur.execute( query )
            return self.get_pkey( 'gbinidssid', 'ssid' )
        else:
            return inid

    def setval_using_inid( self, inid, nodetype ):
        """
        get the inid, given the ( nid, nodetype )
        """
        nt = tbl_nodetype( self.cur )
        ntid = nt.getntid( nodetype )

        query = "INSERT INTO gbinidssid ( inid, ssid, ntid ) VALUES( '%s', DEFAULT, '%s' );" % ( inid, ntid )
        print query
        self.cur.execute( query )
        return self.get_pkey( 'gbinidssid', 'ssid' )

    def get_nid_from_ssid( self, ssid ):
        query = "SELECT datatypes_varchar.value AS nid, gbnodetypes.nodename AS nodename, ssid FROM gbnidinid, gbinidssid, datatypes_varchar, gbnodetypes WHERE gbnidinid.inid = gbinidssid.inid AND gbnidinid.nid = datatypes_varchar.vid AND gbnidinid.ntid = gbnodetypes.ntid AND gbinidssid.ssid = '%s';" % ( ssid )

        print query
        self.cur.execute( query )
        res = self.cur.fetchall()
        print res
        return res[0][0]

class tbl_gbusers( tblBase ):
    """
    Access class for gbusers table
    """
    def __init__( self, cursor ):
        """
        get the inid, given the ( nid, nodetype )
        """
        tblBase.__init__( self, cursor, 0 )
        self.nodename = 'gbusers'
        self.cur = cursor

    def insert( self, dictParams ):
        """
        inserts a record in the gbusers table
        """
        t = tbl_nidinid( self.cur )
        inid = t.setval( dictParams['nid'], self.nodename )

        if inid == 0:
            print "User already exists! Sorry!"
            return 0
        else:
            # user doesn't exist, insert it
            # TODO: checks
            # does the uid of the creator exist?
                        
            # create a ssid for the user
            inidtbl = tbl_inidssid( self.cur )
            ssid = inidtbl.setval_using_inid( inid, self.nodename )

            # prepare the final dict
            dictInsert = dictParams
            dictInsert[ 'inid' ] = inid
            dictInsert[ 'ssid' ] = ssid

            # dummy values for now, must be calculated
            dictInsert['noofchangesaftercommit'] = 0
            dictInsert['noofcommits'] = 0
            dictInsert['noofchanges'] = 0
            dictInsert['history'] = [1,1,1,1]
            dictInsert['changes'] = 0
            dictInsert['fieldschanged'] = [ '','','' ]
            dictInsert['changetype'] = [ 1,1,1 ]


            # do the insertion
            gt = genericTable( self.cur, self.nodename, 1 )
            gt.insert( dictInsert )

            # inid is the uid of the user, so return it :)
            return inid

    def get_uid( self, username ):
        """
        returns the uid of the given username
        """
        t = tbl_nidinid( self.cur )
        uid = 0
        if t.isvalid_nid( username, self.nodename ) != 0:
            uid = t.getinid( username, self.nodename )
        return( uid )
        
    

def test_user_creation( cur ):
    """
    test function for user creation
    also demos how the gbuser class can be used
    """
    utbl = tbl_gbusers( cur )
    un = 'test'
    print "Inserting user name = %s" % ( un )
    uidadmin = utbl.insert( { 'nid':'admin', 'uid':0 }  )
    if uidadmin == 0:
        print "Admin already exists!"
        uidadmin = utbl.get_uid( 'admin' )
        if uidadmin != 0:
            print "Admin uid is %s" % uidadmin
        else:
            print "Error getting admin uid!"

    uid = utbl.insert( { 'nid':un, 'uid':uidadmin }  )
    if uid == 0:
        print "User %s already exists! Can't insert!" % un
        uid = utbl.get_uid( un )

    print "uid of %s is %s" % ( un, uid )


def test_nid_table( cur ):
    """
    demos how the tbl_nidinid class can be used
    """
    nid = "fan"
    nodetype = "gbobjecttypes"

    t = tbl_nidinid( cur )
    print "Inserting nid = %s in nodetype = %s" % ( nid, nodetype )

    inid = t.setval( nid, nodetype )
    print "Generated inid = %s" % inid

    print "testing if nid = %s is valid -> result: %s" % ( nid, t.isvalid_nid( nid, nodetype ) )
    print "testing if nid = %s is valid -> result: %s" % ( 'asdkjsd', t.isvalid_nid( 'asdkjsd', nodetype ) )

    print "getting inid for nid = %s, nodetype = %s -> result: %s" % ( nid, nodetype, t.getinid( nid, nodetype ) )

def test_inid_table( cur ):
    """
    demos how the tbl_inid class can be used
    """

    nid = "fan"
    nodetype = "gbobjecttypes"

    t = tbl_inidssid( cur )
    print "Inserting nid = %s in nodetype = %s" % ( nid, nodetype )

    ssid = t.setval_using_nid( nid, nodetype )
    print "Generated ssid = %s" % ssid

    t2 = tbl_nidinid( cur )
    inid = t2.getinid( nid, nodetype )
    print "Got inid = %s" % inid

    ssid = t.setval_using_inid( inid, nodetype )
    print "Generated ssid = %s" % ssid


def test_field_table( cur ):
    """
    demos how the tbl_field class can be used
    """

    s = storageSpec()
    flddef = s.dictTNamesFDefs['gbmetatypes']['title']

    obj = tbl_nodetype( cur )
    ntid = obj.getntid( 'gbmetatypes' )

    dtobj = tbl_datatypes( cur )
    dtid = dtobj.getdtid( 'varchar[]' )

    ft = tbl_field( cur, flddef )
    fid = ft.insert( { 'ntid':ntid, 'datatypeid':dtid, 'value':[ 'one', 'ONE' ] } )
    print "Fid generated was %s" % fid

def test_value_table( cur ):
    """
    demos how the tbl_values class can be used
    """

    vt = tbl_values( cur, '', 'int8' )

    num = 23
    vid = vt.insert( num )
    print "vid for %s is %s" % ( num, vid )

    print "inserting %s again" % num
    vid = vt.insert( num )
    print "vid for %s is %s" % ( num, vid )
    

def test_retr_ntid( cur ):
    """
    demos how the nodetype id can be retrieved from tbl_nodetype
    """

    nodes = [ "gbobjects",
              "gbrelations",
              "gbattributes",
              "gbmetatypes",
              "gbobjecttypes",
              "gbrelationtypes",
              "gbattributetypes",
              "gbusertypes",
              "gbusers"
              ]

    obj = tbl_nodetype( cur )
    for n in nodes:
        nodetype = n
        ntid = obj.getntid( nodetype )

        print "ntid for %s is %s" % ( nodetype, ntid )


def test_retr_dtid( cur ):
    """
    demos how the datatype id can be retrieved from tbl_datatypes
    """

    datatypenames = [
            "int8",
            "int8[]",
            "bit",
            "bit[]",
            "varbit",
            "varbit[]",
            "boolean",
            "boolean[]",
            "box",
            "box[]",
            "bytea",
            "bytea[]",
            "varchar",
            "varchar[]",
            "char",
            "char[]",
            "cidr",
            "cidr[]",
            "circle",
            "circle[]",
            "date",
            "date[]",
            "float8",
            "float8[]",
            "inet",
            "inet[]",
            "int4",
            "int4[]",
            "interval",
            "interval[]",
            "line",
            "line[]",
            "lseg",
            "lseg[]",
            "macaddr",
            "macaddr[]",
            "money",
            "money[]",
            "numeric",
            "numeric[]",
            "path",
            "path[]",
            "point",
            "point[]",
            "polygon",
            "polygon[]",
            "float4",
            "float4[]",
            "int2",
            "int2[]",
            "text",
            "text[]",
            "time",
            "time[]",
            "timestamptz",
            "timestamptz[]",
            "abstime",
            "abstime[]",
            "aclitem",
            "aclitem[]",
            "bpchar",
            "bpchar[]",
            "cid",
            "cid[]",
            "oid",
            "oid[]",
            "refcursor",
            "refcursor[]",
            "regclass",
            "regclass[]",
            "regoper",
            "regoper[]",
            "regoperator",
            "regoperator[]",
            "regproc",
            "regproc[]",
            "regprocedure",
            "regprocedure[]",
            "regtype",
            "regtype[]",
            "reltime",
            "reltime[]",
            "smgr",
            "tid",
            "tid[]",
            "timetz",
            "timetz[]",
            "tinterval",
            "tinterval[]",
            "unknown",
            "xid",
            "xid[]",
            "int2vector",
            "int2vector[]",
            "name",
            "name[]",
            "oidvector",
            "oidvector[]",
            "serial",
            "serial[]",
            "serial8",
            "serial8[]",
            ]


    obj = tbl_datatypes( cur )
    for n in datatypenames:
        dtid = obj.getdtid( n )

        print "dtid for %s is %s" % ( n, dtid )

class tbl_values:
    """
    this class handles all value tables
    """
    def __init__( self, cur, tblName, datatype ):
        """
        initializes the cursor, tablename and the datatype
        """
        tmp = gnowsysDatatypes( cur )
        self.tblName = tmp.getDataTypeTableName( datatype )

        self.cur = cur
        self.datatype = datatype
        self.debug_mode = 1

    def debug_print( self, val ):
        """
        function to print debug info to the terminal
        """
        if self.debug_mode != 0:
            print "Debug[%s]: %s " % ( self.tblName, val )

    def process_str_list( self, strList ):
        """
        function to print debug info to the terminal
        """
        if isinstance( strList, list ):
            i=0
            tmpList = []
            for s in strList:
                # handle both types of strings, unicode as well as ascii
                if isinstance( s, str ) or isinstance( s, unicode ):
                    tmps = s.replace( "'", "''" )
                    #print tmps
                    tmpList.append( tmps )
                else:
                    tmpList.append( s )
                i=i+1
            return self.make_array_into_str( tmpList )
        else:
            return strList

    def make_array_into_str( self, arr ):
        """
        formats an array so that it can be inserted in postgres
        """

        myStr = "[%s]"
        processedElements = []
        for e in arr:
            if isinstance( e, list ):
                arrStr = "ARRAY%s" % self.make_array_into_str( e )
                processedElements.append( arrStr )
            else:
                processedElements.append( "'%s'" % e )

        insideStr = ", " . join( [ "%s" % e for e in processedElements ] )
        return myStr % insideStr
        

    def process_value( self, value ):
        """
        sanitize the given value before inserting into the database
        """

        if isinstance( value, list ):
            return "ARRAY%s::%s" % ( self.process_str_list( value ), self.datatype )
        else:
            # handle both types of strings, unicode as well as ascii
            if isinstance( value, str ) or isinstance( value, unicode ):
                value = value.replace( "'", "''" )

            return "'%s'::%s" % ( value, self.datatype )

    def get_pkey( self ):
        """
        returns the last generated primary key
        """
        query = "SELECT currval( pg_get_serial_sequence( '%s', 'vid' ) );" % ( self.tblName )
        self.debug_print( "QUERY: %s" % query )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        return rs[0][0]

    def exists( self, value ):
        """
        checks if the given value exists

        returns the value if it exists
        returns 0 if the value DOESNT exist
        """
        query = "SELECT vid FROM %s WHERE value=%s" % ( self.tblName, self.process_value( value ) )
        self.debug_print( "QUERY: %s" % query )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        
        if( len( rs ) > 0 ):
            return rs[0][0]
        else:
            return 0

    def insert( self, value ):
        """
        inserts the given value in the value table

        logic:
        check if value exists
        if not, add and return vid
        else, return vid
        """
        vid = self.exists( value )
        if( vid == 0 ):
            # value doesn't exist, insert it
            self.debug_print( "Value %s doesn't exist, inserting" % value )

            query = "INSERT INTO %s ( vid, value ) VALUES ( DEFAULT, %s )" % ( self.tblName, self.process_value( value ) )
            self.debug_print( "INSERT QUERY: %s " % query )
            self.cur.execute( query )
            vid = self.get_pkey()
            self.debug_print( "Generated vid is : %s" % vid )
            # get the just inserted value
            return vid
        else:
            self.debug_print( "Value %s already exists, vid is %s" % ( value, vid ) )
            return vid


class tbl_field:
    """
    handles the field tables
    """
    def __init__( self, cur, fldDef ):
        """
        initializes the cursor and field definition
        """
        self.tblName = "field_%s" % fldDef[0]
        self.datatype = fldDef[1]
        self.valuetable = tbl_values( cur, '', self.datatype )
        self.cur = cur
        self.debug_mode = 1

    def debug_print( self, val ):
        """
        function to print debug info to the terminal
        """
        if self.debug_mode != 0:
            print "Debug[%s]: %s " % ( self.tblName, val )

    def get_pkey( self ):
        """
        returns the last generated primary key
        """
        query = "SELECT currval( pg_get_serial_sequence( '%s', 'fid' ) );" % ( self.tblName )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        return rs[0][0]

    def exists( self, value ):
        """
        checks if the given value exists

        returns the value if it exists
        returns 0 ( zero ) if the value DOESNT exist
        """
        vid = self.valuetable.exists( value )
        if vid != 0:
            query = "SELECT fid FROM %s WHERE vid=%s" % ( self.tblName, vid )
            self.debug_print( "QUERY: %s" % query )
            self.cur.execute( query )
            
            rs = self.cur.fetchall()
        
            if( len( rs ) > 0 ):
                return rs[0][0]
            else:
                return 0
        else:
            return 0


    def insert( self, dictValues ):
        """
        inserts a record in the field table. It requires the following POPULATED dictionary:
        dictValues { 'ntid': '', 'datatypeid': '', 'value': '' }

        logic:
        check if value exists
        if not, add and return fid
        else, return existing fid
        """
        value = dictValues[ 'value' ]

        vid = self.valuetable.exists( value )
        ntid = dictValues[ 'ntid' ]
        dtid = dictValues[ 'datatypeid' ]

        if( vid == 0 ):
            # value doesn't exist, insert it
            self.debug_print( "Value %s doesn't exist, inserting" % value )
            vid = self.valuetable.insert( value )

            self.debug_print( "Generated vid is: %s " % vid )
            query = "INSERT INTO %s ( fid, ntid, vid, datatypeid ) VALUES( DEFAULT, '%s', '%s', '%s' )" % ( self.tblName, ntid, vid, dtid )
            self.debug_print( "QUERY: %s" % query )
            self.cur.execute( query )
            
            # get the just inserted value
            fid = self.get_pkey()
            self.debug_print( "Generated fid is: %s" % fid )
            return fid
        else:
            fid = self.exists( value )
            if fid != 0:
                return fid
            else:
                query = "INSERT INTO %s ( fid, ntid, vid, datatypeid ) VALUES( DEFAULT, '%s', '%s', '%s' )" % ( self.tblName, ntid, vid, dtid )
                self.debug_print( "QUERY: %s" % query )
                self.cur.execute( query )

                # get the just inserted value
                fid = self.get_pkey()
                self.debug_print( "Generated fid is: %s" % fid )
                return fid
                
#             self.debug_print( "Value %s already exists, fid is %s" % ( value, vid ) )
#             return vid
    
        

class genericTable:
    """
    This is the mother / father of all classes :) IT can perform insertion in the gnowsys
    tables which is a very complicated task ;)
    """
    def __init__( self, cur, tblName, debugmode=0 ):
        """
        Initializes the following:
        1. cursor
        2. table name
        3. debug mode
        4. GNOWSYS storage information ( pulls it from storageSpec.py )
        5. GNOWSYS datatypes information( pulls it from datatypes.py )
        6. List of table definitions
        7. Dictionary of table definitions
        8. Classifies the field tables and regular tables
        """
        self.cur = cur
        self.tblName = tblName
        self.debug_mode = debugmode
        self.storageInfo = storageSpec()
        self.datatypeInfo = self.storageInfo.dtTables
        self.tblDef = self.storageInfo.dictTableNamesAndDefs[ self.tblName ]
        self.dictTblDef = self.storageInfo.dictTNamesFDefs[ self.tblName ]    # dict: { fldname: flddef }

        self.lstRegFlds = []
        self.lstFldTbls = []
        self.lstRegFldNames = []
        self.lstFldTblNames = []
        self.lstFields  = []

        for v in self.tblDef:
            self.lstFields = v[0]
            if v[3] != "":
                self.lstFldTbls.append( v )
                self.lstFldTblNames.append( v[0] )
                self.debug_print( "Field table   : %s" % v )
            else:
                self.lstRegFlds.append( v )
                self.lstRegFldNames.append( v[0] )
                self.debug_print( "Regular field : %s" % v )

    def debug_print( self, val ):
        """
        function to print debug info to the terminal
        """

        if self.debug_mode != 0:
            print "Debug[%s]: %s " % ( self.tblName, val )

    def process_str_list( self, strList ):
        """
        takes a list and returns a sanitized list of strings to insert / search in postgres database
        """
        if isinstance( strList, list ):
            i=0
            tmpList = []
            for s in strList:
                if isinstance( s, str ):
                    tmps = s.replace( "'", "''" )
                    print tmps
                    tmpList.append( tmps )
                else:
                    tmpList.append( s )
                i=i+1
            return self.make_array_into_str( tmpList )
        else:
            return strList

    def make_array_into_str( self, arr ):
        """
        takes a sanitized list of strings and returns a string which is sanitized to insert / search in postgres database
        """
        myStr = "[%s]"
        insideStr = ", " . join( [ "'%s'" % e for e in arr ] )
        return myStr % insideStr
        

    def process_value( self, value, datatype ):
        """
        takes in a value, datatype and processes so that it can be inserted in postgres
        returns a string
        """
        if isinstance( value, list ):
            return "ARRAY%s::%s" % ( self.process_str_list( value ), datatype )
        else:
            if isinstance( value, str ):
                value = value.replace( "'", "''" )

            return "'%s'::%s" % ( value, datatype )

    def make_string_safe( self, value ):
        """
        sanitize the given string
        """
        if isinstance( value, str ):
            return value.replace( "'", "''" )

        return value
        
    def insert( self, dictFlds ):
        """
        makes a insert in a gnowsys table such as gbobjects
        
        logic:
        1. classifies the passed dictionary fields into
            a) regular fields
            b) field table fields
        2. for each passed regular field, process it so that the
           value is sanitized and ready for insertion in postgres
        3. for each passed field table fields, use tbl_field to insert
           the value in the field table and collect its generated fid
        4. prepare a SQL query to make record in the actual snapshot table
           using the fids generated and the sanitized values
        """
        nt = tbl_nodetype( self.cur )
        self.ntid = nt.getntid( self.tblName )
        
        setRegFlds = set( self.lstRegFldNames )
        setFldTbls = set( self.lstFldTblNames )
        setPassedFlds = set( dictFlds.keys() )

        setPassedRegFlds = setRegFlds.intersection( setPassedFlds )
        setPassedFldTbls = setFldTbls.intersection( setPassedFlds )
    
        self.debug_print( "Passed Regular Fields: %s" % setPassedRegFlds )
        self.debug_print( "Passed Table Fields: %s" % setPassedFldTbls )

        lstPassedRegFlds = list( setPassedRegFlds )
        lstPassedFldTbls = list( setPassedFldTbls )

        lstFids = []
        for f in lstPassedRegFlds:
            lstFids.append( self.process_value( dictFlds[ f ], self.dictTblDef[ f ][1] ) )
            #lstFids.append( dictFlds[ f ] )


        print self.tblDef
        for f in lstPassedFldTbls:
            ft = tbl_field( self.cur, self.dictTblDef[ f ] )
            dtt = tbl_datatypes( self.cur )
            dtid = dtt.getdtid( self.dictTblDef[ f ][1] )
            fid = ft.insert( { 'ntid' : self.ntid, 'datatypeid' : dtid, 'value' : dictFlds[ f ] } )
            lstFids.append( fid )
        strRegFlds = ", " . join( lstPassedRegFlds )
        strFldTbls = ", " . join( lstPassedFldTbls )

        strFlds = strRegFlds
        if strFldTbls != "":
            strFlds = strFlds + ', ' + strFldTbls

        print strFlds
        
        strVals = ", " . join( "%s" % v for v in lstFids )
        print strVals
        
        strQuery = "INSERT INTO %s ( %s ) VALUES ( %s );" % ( self.tblName, strFlds, strVals )
        self.debug_print( strQuery )
        self.cur.execute( strQuery )

    def getAllBySSIDCols( self, viewName, colNames, lstSSID, nodeType ):
        """
        Get all SSIDs but only the columns in the list provided
        """
        lstColNames = colNames
        lstColNames.append( 'ssid' )

        strSelCols = ", " . join( lstColNames )
        strLstSSID = ", " . join( "%s" % s for s in lstSSID )

        selectQuery = 'SELECT %s FROM %s WHERE ssid in ( %s ) ORDER BY ssid DESC;' % ( strSelCols, viewName, strLstSSID )
        print selectQuery

        self.cur.execute( selectQuery )
        res = self.cur.fetchall()

        resDict = {}
        for r in res:
            i=0
            tmpDict = {}
            for c in r:
                tmpDict[ lstColNames[i] ] = c 
                i=i+1

            resDict[ tmpDict['ssid'] ] = tmpDict

        return resDict

    def getAllIdsFromTableCols( self, idCol, tblName, colNames, lstids ):
        """
        Get all specified ids from specified table and only the columns in the list provided
        """
        lstColNames = colNames
        lstColNames.append( idCol )

        strSelCols = ", " . join( lstColNames )
        strLstID = ", " . join( "%s" % s for s in lstids )

        selectQuery = 'SELECT %s FROM %s WHERE %s in ( %s ) ORDER BY %s DESC;' % ( strSelCols, tblName, idCol, strLstID, idCol )
        print selectQuery

        self.cur.execute( selectQuery )
        res = self.cur.fetchall()

        resDict = {}
        for r in res:
            i=0
            tmpDict = {}
            for c in r:
                tmpDict[ lstColNames[i] ] = c 
                i=i+1

            resDict[ tmpDict[ idCol ] ] = tmpDict

        return resDict

    def getAllBySSID( self, viewName, lstSSID, nodeType ):
        """
        Get all entries whose SSIDs are in lstSSID
        
        returns a dictionary of dictionaries something like this:
        { 
           ssid1 : { col1: 'value1', col2: 'value2' },
           ssid2 : { col1: 'value1', col2: 'value2' },
           ssid3 : { col1: 'value1', col2: 'value2' },
        }
        """
        query = "SELECT column_name FROM information_schema.columns WHERE table_name='%s';" % ( viewName )
        #print query

        self.cur.execute( query )
        cols = self.cur.fetchall()
        #print "Columns: %s" % cols

        lstColNames = []
        for fld in cols:
            lstColNames.append( fld[0] )

        #print "Column names: %s " % lstColNames

        strSelCols = ", " . join( lstColNames )
        strLstSSID = ", " . join( "%s" % s for s in lstSSID )

        selectQuery = 'SELECT %s FROM %s WHERE ssid in ( %s ) ORDER BY ssid DESC;' % ( strSelCols, viewName, strLstSSID )
        #print selectQuery

        self.cur.execute( selectQuery )
        res = self.cur.fetchall()
        #print "result: %s" % res

        resDict = {}
        for r in res:
            i=0
            tmpDict = {}
            for c in r:
                tmpDict[ lstColNames[i] ] = c 
                i=i+1

            #print tmpDict
            resDict[ tmpDict['ssid'] ] = tmpDict

        #print 'result dict: %s' % resDict
        return resDict

    def getLatestSSIDFromNid( self, lstNids ):
        """
        Get all entries whose nids are in lstNids
        
        returns a dictionary of dictionaries something like this:
        { 
           nid1 : { col1: 'value1', col2: 'value2' },
           nid2 : { col1: 'value1', col2: 'value2' },
           nid3 : { col1: 'value1', col2: 'value2' },
        }
        """

        strNids = ", " . join( [ "'%s'" % self.make_string_safe( n ) for n in lstNids ] )
        query = "SELECT MAX( ssid ) AS ssid, nid FROM view_nidinidssid WHERE nid IN ( %s ) GROUP BY nid;" % strNids
        self.cur.execute( query )
        res = self.cur.fetchall()
        resDict = {}

        for t in res:
            resDict[ t[1] ] = t[0]

        print "nid to ssid: %s" % resDict
        return resDict

    def get_pkey( self, pkcolname ):
        """
        Get the last primary key id generated from pkcolname column in the table
        """
        query = "SELECT currval( pg_get_serial_sequence( '%s', '%s' ) );" % ( self.tblName, pkcolname )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        return rs[0][0]

    def getViewQuery( self, viewname ):
        """
        Returns a SQL query to generate a view
        """
        numRegFlds = len( self.lstRegFldNames )
        strSelRegFlds = 'nidvals.value AS nid, loginidvals.value AS login, '
        count = 1
        for f in self.lstRegFldNames:
            tmpStr = "%s.%s" % ( self.tblName, f )
            strSelRegFlds = strSelRegFlds + tmpStr

            if count < numRegFlds:
                strSelRegFlds = strSelRegFlds + ", "
            count = count + 1

        strSelFldTbls = ""
        count = 1
        numFldTbls = len( self.lstFldTblNames )
        for f in self.lstFldTblNames:
            tmpStr = "%s.%s AS %s_fid" % ( self.tblName, f, f )
            strSelFldTbls = strSelFldTbls + tmpStr

            if count < numFldTbls:
                strSelFldTbls = strSelFldTbls + ", "
            count = count + 1

        lstExtraFields = [ 'nid' ]

        strExtraFlds = "LEFT OUTER JOIN gbnidinid ON ( %s.inid = gbnidinid.inid ) LEFT OUTER JOIN datatypes_varchar nidvals ON ( nidvals.vid = gbnidinid.nid ) LEFT OUTER JOIN gbnidinid value_loginid ON ( %s.uid = value_loginid.inid ) LEFT OUTER JOIN datatypes_varchar loginidvals ON ( loginidvals.vid = value_loginid.nid )" % ( self.tblName, self.tblName )

        strSelFlds = strSelRegFlds

        if strSelFldTbls != "":
            strSelFlds = strSelFlds + ", " + strSelFldTbls

        clause = self.tblName
        clause = clause + " " + strExtraFlds

        joinLst = []

        for f in self.lstFldTblNames:
            fldTblName = "field_" + f
            fldFidName = fldTblName + ".fid"
            tblFidName = self.tblName + "." + f 

            valFldName = fldTblName + ".vid"

            fldDef = self.storageInfo.getFieldDef( self.tblName, f )
            dttName = self.datatypeInfo.getDataTypeTableName( fldDef[1] )
            dttNameAlias = "value_%s" % f
            dtVidName = dttNameAlias + ".vid"
            joinLst.append( dttNameAlias + ".value AS %s" % fldDef[0] )

            tmpONClauseVT = "( %s = %s )" % ( dtVidName, valFldName )
            tmpONClauseFLD = "( %s = %s )" % ( fldFidName, tblFidName )

            clause = clause + " LEFT OUTER JOIN %s ON %s " % ( fldTblName, tmpONClauseFLD )
            clause = clause + " LEFT OUTER JOIN %s %s ON %s " % ( dttName, dttNameAlias, tmpONClauseVT )

        selFlds = ", " . join( joinLst )
        strSelFlds = strSelFlds + ", " + selFlds
        mainquery = "SELECT %s FROM %s;" % ( strSelFlds, clause )
        viewquery = "CREATE OR REPLACE VIEW %s AS %s" % ( viewname, mainquery )

        return viewquery


    def get( self ):
        """
        Gets all values from the snapshot table independent of the view
        returns almost the same information, except the login of the user
        """
        strSelRegFlds = ", " . join( self.lstRegFldNames )
        strSelFldTbls = ", " . join( self.lstFldTblNames )

        strSelFlds = strSelRegFlds

        if strSelFldTbls != "":
            strSelFlds = strSelFlds + ", " + strSelFldTbls

        clause = self.tblName
        joinLst = self.lstRegFldNames

        for f in self.lstFldTblNames:
            fldTblName = "field_" + f
            fldFidName = fldTblName + ".fid"
            tblFidName = self.tblName + "." + f 

            valFldName = fldTblName + ".vid"

            fldDef = self.storageInfo.getFieldDef( self.tblName, f )
            dttName = self.datatypeInfo.getDataTypeTableName( fldDef[1] )
            dttNameAlias = "value_%s" % f
            dtVidName = dttNameAlias + ".vid"
            joinLst.append( dttNameAlias + ".value" )

            tmpONClauseVT = "( %s = %s )" % ( dtVidName, valFldName )
            tmpONClauseFLD = "( %s = %s )" % ( fldFidName, tblFidName )

            clause = clause + " LEFT OUTER JOIN %s ON %s " % ( fldTblName, tmpONClauseFLD )
            clause = clause + " LEFT OUTER JOIN %s %s ON %s " % ( dttName, dttNameAlias, tmpONClauseVT )

        selFlds = ", " . join( joinLst )
        mainquery = "SELECT %s FROM %s;" % ( selFlds, clause )

        self.cur.execute( mainquery )
        res = self.cur.fetchall()

        resDict = []
        for r in res:
            i=0
            tmpDict = {}
            for c in r:
                tmpDict[ joinLst[i] ] = c 
                i=i+1
            resDict.append( tmpDict )

        return resDict

    def getFromView( self, viewname, ssid=0, nid='' ):
        """
        Gets records from the view
        Returns list of dictionaries like the following:

        [
           nid1 : { col1: 'value1', col2: 'value2' },
           nid2 : { col1: 'value1', col2: 'value2' },
           nid3 : { col1: 'value1', col2: 'value2' },
        ]
        """
        lstClause = []
        if ssid != 0:
            lstClause.append( ( 'ssid=%s' % ssid ) )

        if nid != '':
            lstClause.append( ( "nid='%s'" % nid ) )

        strClause = " AND " . join( lstClause )

        print strClause

        query = "SELECT column_name FROM information_schema.columns WHERE table_name='%s';" % ( viewname )
        print query

        self.cur.execute( query )
        cols = self.cur.fetchall()

        lstColNames = []
        for fld in cols:
            lstColNames.append( fld[0] )

        strSelCols = ", " . join( lstColNames )

        if strClause != '':
            strClause = "WHERE " + strClause

        selectQuery = 'SELECT %s FROM %s %s ORDER BY ssid DESC;' % ( strSelCols, viewname, strClause )
        print selectQuery

        self.cur.execute( selectQuery )
        res = self.cur.fetchall()
        resDict = []
        for r in res:
            i=0
            tmpDict = {}
            for c in r:
                tmpDict[ lstColNames[i] ] = c 
                i=i+1
            resDict.append( tmpDict )

        #self.showRS2( resDict )
        return resDict

    def showRS2( self, res ):
        """
        shows a result set which contains dictionaries
        """
        for r in res:
            for k,v in r.items():
                print "%s : %s" % ( k, v )
            print

    def showRS( self, res ):
        """
        shows a result set which contains only lists
        """
        for r in res:
            for c in r:
                print c,
            print


def createViews( cur ):
    """
    function to create views
    """
    tables = { 
        'gbattributes':'a', 
        'gbattributetypes':'at', 
        'gbobjects':'o', 
        'gbobjecttypes':'ot', 
        'gbrelations':'r', 
        'gbrelationtypes':'rt',
        'gbusertypes':'ut',
        'gbusers':'u',
        'gbmetatypes':'mt',
#         'gbnidinid':'ni',
#         'gbinidssid':'is'
        }

    nidinidssidview = """CREATE OR REPLACE VIEW view_nidinidssid AS SELECT dtnid.value AS nid, gbnidinid.ntid, gbnodetypes.nodename AS nodename, gbnidinid.inid, gbinidssid.ssid FROM gbnidinid LEFT OUTER JOIN gbinidssid ON gbnidinid.inid = gbinidssid.inid LEFT OUTER JOIN datatypes_varchar AS dtnid ON gbnidinid.nid = dtnid.vid LEFT OUTER JOIN gbnodetypes ON gbnidinid.ntid = gbnodetypes.ntid;"""

    buffer = []
    for nt, vn in tables.items():
        viewname1 = "view_%s" % vn
        viewname2 = "djview_%s" % vn
        gt = genericTable( cur, nt, 0 )
        query = gt.getViewQuery( viewname1 )
        buffer.append( query )
        query = gt.getViewQuery( viewname2 )
        buffer.append( query )

    buffer.append( nidinidssidview )
    return buffer


def test_nested_array_insertion( cur ):
    """
    tests nested array insertion
    """
    vt = tbl_values( cur, '', 'int8[]' )

    num = [ [1,2], [2,3], [3,5], [7,8] ]

    vid = vt.insert( num )
    print "vid for %s is %s" % ( num, vid )

    print "inserting %s again" % num
    vid = vt.insert( num )
    print "vid for %s is %s" % ( num, vid ) 
   

def test_new_functions( cur ):
    """
    tests some of the new functions
    """
    gt = genericTable( cur, 'gbobjects', 1 )
    lstNids = [ 'admin', 'Bookshelf', 'CourseMaterial' ]
    print gt.getLatestSSIDFromNid( lstNids )

    viewName = 'djview_o'
    lstSSID = [ 307, 308, 309 ]
    nodeType = 'gbobjects'
    print gt.getAllBySSID( viewName, lstSSID, nodeType )

    idCol = 'vid'
    tblName = 'datatypes_varchar'
    colNames = [ 'value' ]
    lstids = [ 3, 4, 5, 6 ]
    print gt.getAllIdsFromTableCols( idCol, tblName, colNames, lstids )

    viewName = 'djview_o'
    colNames = [ 'title', 'description' ]
    nodeType = 'gbobjects'    
    print gt.getAllBySSIDCols( viewName, colNames, lstSSID, nodeType )




if __name__ == "__main__":
    dictConn = { 
        'dbname':'self', 
        'username':'akula', 
        'password':'akula', 
        'host':'localhost'
        }

    conn = psycopg2.connect( "dbname=%(dbname)s user=%(username)s password=%(password)s host=%(host)s" % dictConn )
    cur = conn.cursor( cursor_factory = psycopg2.extras.DictCursor )

    test_new_functions( cur )
    sys.exit()

    test_nested_array_insertion( cur )

    nid = "fan"
    nt = "gbobjecttypes"

    nidtbl = tbl_nidinid( cur )
    inid = nidtbl.setval( nid, nt )
    if inid == 0:
        inid = nidtbl.getinid( nid, nt )

    inidtbl = tbl_inidssid( cur )
    ssid = inidtbl.setval_using_nid( nid, nt )

    testgbmetatypes = {
            'status':"'st'atu's'text13'",
            'content':"t''''''ext''sd'asd'asd'11",
            'inid': inid, 
            'subtypes':[15, 17, 12],
            'ssid': ssid, 
            'noofcommits':'12',
            'subtypeof':[13, 16, 16],
            'changetype':[1,1,1,1],
            'title':["'te'xt11", "'te'xt118"],
            'uri':'text13',
            'relations':[19, 17, 18],
            'noofchangesaftercommit':'10',
            'instances':[18, 19, 11],
            'noofchanges':'17',
            'description':'text15',
            'attributes':[19, 11, 12],
            'relationtypes':[13, 16, 17],
            'history':[16, 10, 13],
            'fieldschanged':["te'xt1'114", "t'e'x't'14"],
            'attributetypes':[15, 20, 14],
            'uid':'10',
            'structure': [ [ 1,2 ], [3,4 ], [5,6], [7,8] ]
    }


    gt = genericTable( cur, nt, 1 )
    gt.insert( testgbmetatypes )
    
    conn.commit()
    conn.close()


