#!/usr/bin/python
# vim: tabstop=8 expandtab shiftwidth=8 softtabstop=8

import sys
import httplib2
import urlparse
import StringIO
import bz2
import gzip
import hashlib
import subprocess
import tempfile
import os
import time
import re
import xml.etree.ElementTree as ET
import signal
import ConfigParser
#import sqlite3
import psycopg2

debug=False
verbose=True
config=None

class VersionComparator:
        # compare debian version numbers in the format
        # [epoch:]upstream_version[-debian_revision] 

        re_number = re.compile("(\d+)")
        re_notnumber = re.compile("(\D+)")
        re_epoch = re.compile("(\d+):(.+)")
        re_parts = re.compile("(.+)-(.+)")

        def compare_part(self,a,b):
                if (a==None or a=='') and (b==None or b==''):
                        return 0

                match_a = self.re_notnumber.match(a)
                match_b = self.re_notnumber.match(b)
                if match_a:
                        nondigits_a = match_a.group(1)
                else:
                        nondigits_a = ""

                if match_b:
                        nondigits_b = match_b.group(1)
                else:
                        nondigits_b = ""

                a=a[len(nondigits_a):]
                b=b[len(nondigits_b):]
                comp=self.compare_nondigits(nondigits_a,nondigits_b)
                if debug:
                        print "compare '%s' to '%s': %d" % (nondigits_a,nondigits_b,comp)
                        print "rest:",a,b
                if comp!=0:
                        return comp

                match_a = self.re_number.match(a)
                match_b = self.re_number.match(b)
                if match_a:
                        digits_a = match_a.group(1)
                else:
                        digits_a = "0"
                if match_b:
                        digits_b = match_b.group(1)
                else:
                        digits_b = "0"

                a=a[len(digits_a):]
                b=b[len(digits_b):]
                comp=cmp(int(digits_a),int(digits_b))
                if debug:
                        print "compare %s to %s: %d" % (digits_a,digits_b,comp)
                        print "rest:",a,b
                if comp!=0:
                        return comp

                return self.compare_part(a,b) 

        def compare_nondigits(self,a,b):
                i=0
                length_a=len(a)
                length_b=len(b)
                while i<length_a or i<length_b:
                        if i<length_a:
                                char_a=a[i]
                        else:
                                char_a=None

                        if i<length_b:
                                char_b=b[i]
                        else:
                                char_b=None
                        # print i,char_a,char_b
                        r=self.compare_character(char_a,char_b)
                        i=i+1
                        if r!=0:
                                return r
                if length_a>length_b:
                        return 1
                if length_a<length_b:
                        return -1
                return 0

        # sort letters before non-letters, empty string before letters,
        # ~ before everything else
        #
        def char_to_ord(self,a):
                if(a=='~'):
                        return -1
                if(a==''or a==None):
                        return 0
                if not a.isalpha():
                        return ord(a)+256
                return ord(a)

        def compare_character(self,a,b):
                x=self.char_to_ord(a)
                y=self.char_to_ord(b)
                if x<y: return -1
                if x>y: return 1
                return 0

        def separate_epoch(self,s):
                match=self.re_epoch.match(s)
                if not match:
                        return (s,0)
                epoch=match.group(1)
                return (match.group(2),int(epoch))

        def compare(self,a,b):
                a,epoch_a=self.separate_epoch(a)
                b,epoch_b=self.separate_epoch(b)

                if epoch_a>epoch_b:
                        return 1
                if epoch_a<epoch_b:
                        return -1

                upstream_a=None
                upstream_b=None
                epoch_a=None
                epoch_b=None
                revision_a=None
                revision_b=None

                match_a=self.re_parts.match(a)
                match_b=self.re_parts.match(b)

                if match_a:
                        upstream_a=match_a.group(1)
                        revision_a=match_a.group(2)
                else:
                        upstream_a=a
                        revision_a=''

                if match_b:
                        upstream_b=match_b.group(1)
                        revision_b=match_b.group(2)
                else:
                        upstream_b=b
                        revision_b=''

                comp=self.compare_part(upstream_a,upstream_b)
                if comp!=0:
                        return comp

                return self.compare_part(revision_a,revision_b)

class PuppetReportReader:
    def __init__(self,handler,path='/var/lib/puppet/reports'):
        self.puppet_report_path = path
        self.report_client_errors = handler

    def process_all_reports(self):
        path = self.puppet_report_path
        for dirname in os.listdir(path):
            self.process_client_report(path + '/' + dirname, dirname)
    
    def process_client_report(self,path, client_name):
        t = 0
        latest = None

        for filename in os.listdir(path):
            if not filename.endswith('.yaml'):
                continue
            result = os.stat(path + '/' + filename)
            if result.st_mtime > t:
                latest = filename
                t = result.st_mtime
        if latest != None:
            self.process_report_file(path + '/' + latest, client_name)

    def process_report_file(self,path,client_name):
    # implement a trivial scanner for the yaml file
    # generated by puppet

        metrics_found = False
        failure_found = False
        logs_found = False

        offset = 0
        errors = 0

        with open(path,'r') as f:
            for l in f.readlines():
                if failure_found:
                    offset = offset + 1
                    if offset == 2:
                        failure_count = l.strip().split(' ')[1]
                        self.report_client_errors(client_name, failure_count)
                        return
                if logs_found:
                    if 'ruby/sym err' in l:
                        self.report_client_errors(client_name, 1)
                        return
                        
                if metrics_found:
                    if '- failure' in l:
                        failure_found = True
                elif 'metrics:' in l:
                    logs_found = False
                    metrics_found = True
                    failure_found = False
                elif 'logs:' in l:
                    logs_found = True
                    metrics_found = False
                    failure_found = False

        # if we did not find anything, report 0 errors
        self.report_client_errors(client_name, 0)

class PackageDb:
        db=None
        package_list_id=-1
        package_list_name=None
        cursor=None

        def connect(self,dbspec):
                self.db=psycopg2.connect(dbspec)
                self.cursor=self.db.cursor()

        def commit(self):
                self.db.commit()

        def insert(self,package,version,arch,install=False):
                self.cursor.execute("INSERT INTO packages (name,version,arch,list_id,install)\
                        VALUES (%s,%s,%s,%s,%s)",
                        (package,version,arch,self.package_list_id,install))

        def update(self,package,version,arch):
                self.cursor.execute("UPDATE packages SET version=%s WHERE list_id=%s AND name=%s AND arch=%s",
                        (version,self.package_list_id,package,arch))

        def setInstallOnId(self,packageId,install):
                self.cursor.execute("UPDATE packages SET install=%s WHERE id=%s", (install,packageId))

        def find(self,package,arch):
                self.cursor.execute("SELECT id,name,version,arch FROM packages WHERE list_id=%s AND name=%s AND arch=%s" ,(self.package_list_id,package,arch))
                r=self.cursor.fetchone()
                return r
        
        def findMulti(self,package):
                self.cursor.execute("SELECT id,name,version,arch FROM packages WHERE list_id=%s AND name=%s" ,(self.package_list_id,package))
                return self.cursor.fetchall()

        def getPackages(self):
                self.cursor.execute("SELECT id,name,version,arch,install FROM packages WHERE list_id=%s ORDER BY name", (self.package_list_id,))
                return self.cursor

        def findList(self,name):
                cursor=self.db.cursor()
                cursor.execute("SELECT id,name,autoupdate FROM package_lists WHERE name=%s",(name.encode('utf-8'),))
                result=cursor.fetchone()
                cursor.close()
                return result

        def selectList(self,name):
                r=self.findList(name)
                if r==None:
                    print name,"not found"
                    sys.exit(2)

                self.package_list_id=r[0]
                self.package_list_name=r[1]
                # print "selecting package list",self.package_list_id,self.package_list_name


        def emptyList(self):
                cursor=self.db.cursor()
                cursor.execute("DELETE FROM packages WHERE list_id=%s", (self.package_list_id,))
                cursor.close()


	def copyList(self):
		cursor=self.db.cursor()
		cursor.execute("SELECT origin FROM package_lists WHERE id=%s",(self.package_list_id,))
		origin=cursor.fetchone()[0]
		if origin==None or origin==0:
			print "package list has no origin"
			return
		cursor.execute("DELETE FROM packages WHERE list_id=%s",(self.package_list_id,))
		cursor.execute("INSERT INTO packages (name,version,description,install,arch,list_id) \
SELECT name,version,description,install,arch,%s  FROM packages WHERE list_id=%s",
				(self.package_list_id,origin))
                cursor.execute("UPDATE package_lists SET approved=false WHERE id=%s",(self.package_list_id,))
                cursor.close()

        def getListName(self):
                return self.package_list_name
        
        def getListId(self):
                return self.package_list_id

        def getLists(self):
                cursor=self.db.cursor()
                cursor.execute("SELECT id,name,autoupdate FROM package_lists ORDER BY name")
                r=cursor.fetchall()
                cursor.close()
                return r

        def getListsForRepoId(self,repo_id):
                cursor=self.db.cursor()
                cursor.execute("SELECT pl.id,pl.name,pl.autoupdate FROM package_lists AS pl,\
                        repo_list_members AS rlm WHERE pl.id=rlm.list_id AND rlm.repo_id=%s ORDER BY pl.name", (repo_id,))
                r=cursor.fetchall()
                cursor.close()
                return r
                
        def getAllRepositories(self):
                cursor=self.db.cursor()
                sql="SELECT id,name,type,url,dist,arch,component,priority " \
			+ "FROM repositories WHERE autoupdate=true ORDER BY name,priority"
		cursor.execute(sql)
                r=cursor.fetchall()
                cursor.close()
                return r

        def findRepository(self,name):
                cursor=self.db.cursor()
                cursor.execute("SELECT id,name,type,url,dist,arch,component,priority \
                        FROM repositories WHERE name=%s",(name,))
                r=cursor.fetchone()
                cursor.close()          
                return r

        def createRepository(self,name,typ,url,dist,arch,component,priority):
                cursor=self.db.cursor()
                cursor.execute("INSERT INTO repositories (name,type,url,dist,arch,component,priority,autoupdate) \
                        VALUES (%s,%s,%s,%s,%s,%s,%s,true)",(name,typ,url,dist,arch,component,priority))
                cursor.close()

        def addRepository(self,repo_id):
                cursor=self.db.cursor()
                cursor.execute("INSERT INTO repo_list_members (repo_id,list_id) VALUES (%s,%s)",
                        (repo_id,self.package_list_id))
                cursor.close()

        def removeRepository(self,repo_id):
                cursor=self.db.cursor()
                cursor.execute("DELETE FROM repo_list_members WHERE repo_id=%s AND list_id=%s",
                        (repo_id,self.package_list_id))
                cursor.close()

        def getRepositoryIds(self):
                cursor=self.db.cursor()
                cursor.execute("SELECT r.id FROM repositories AS r, repo_list_members AS m \
                        WHERE r.id=m.repo_id AND m.list_id=%s;", (self.package_list_id,))
                r=cursor.fetchall()
                cursor.close()
                return r

        def getRepositories(self):
                cursor=self.db.cursor()
                cursor.execute("SELECT r.name FROM repositories AS r, repo_list_members AS m \
                        WHERE r.id=m.repo_id AND m.list_id=%s ORDER BY r.name", (self.package_list_id,))
                r=cursor.fetchall()
                cursor.close()
                return r

        def close(self):
                self.commit()
                self.cursor.close()
                self.db.close()
                
        def getProfiles(self):
                cursor=self.db.cursor()
                cursor.execute("SELECT id,name,config_extras FROM profiles ORDER BY name")
                r=cursor.fetchall()
                cursor.close()
                return r
                
        def getListsForProfile(self,profileName):
                cursor=self.db.cursor()
                cursor.execute("SELECT pm.list_id,pl.name FROM profile_members AS pm,\
                        package_lists AS pl, profiles as pr WHERE pl.id=pm.list_id AND pm.profile_id=pr.id\
			AND pr.name=%s",
                        (profileName,))
                r=cursor.fetchall()
                cursor.close()
                return r

        def updateComputerStatus(self,computerName,status):
                if self.getComputerStatus(computerName)==None:
                        self.createComputerStatus(computerName,status)
                        return
        
                cursor=self.db.cursor()
                cursor.execute("UPDATE status_count SET value=%s WHERE name=%s AND category='computer'",(status,computerName))
                cursor.close() 

        def createComputerStatus(self,computerName,status):
                cursor=self.db.cursor()
                cursor.execute("INSERT INTO status_count (name,category,value) VALUES (%s,'computer',%s)",(computerName,status))
                cursor.close() 

        def getComputerStatus(self,computerName):
                cursor=self.db.cursor()
                cursor.execute("SELECT value FROM status_count WHERE name=%s AND category='computer'",(computerName,))
                r=cursor.fetchone()
                cursor.close()
                if r:
                    if r[0]==None:
                            return 0
                    return r[0]
                else:
                    return None

        def getComputers(self):
                cursor=self.db.cursor()
                cursor.execute("SELECT computers.name,profiles.name,computers.options,computers.ip_address,computers.mac_address FROM computers,profiles \
                                WHERE computers.profile_id = profiles.id ORDER BY computers.name");
                r=cursor.fetchall()
                cursor.close()
                return r

        def getComputer(self,name):
                cursor=self.db.cursor()
                cursor.execute("SELECT computers.name,profiles.name,computers.options,computers.ip_address,computers.mac_address FROM computers,profiles \
                                WHERE computers.profile_id = profiles.id AND computers.name=%s ORDER BY computers.name",(name,));
                r=cursor.fetchone()
                cursor.close()
                return r

        def createComputer(self,name,profile_name,ip_address,mac_address):
                cursor=self.db.cursor()

                if self.getComputer(name)!=None:
                        cursor.execute("UPDATE computers  SET ip_address=%s,mac_address=%s,profile_id=(SELECT id FROM profiles WHERE name=%s) WHERE name=%s",
                        (ip_address,mac_address,profile_name,name))
                else:
                        cursor.execute("INSERT INTO computers (name,ip_address,mac_address,profile_id) \
                    VALUES (%s,%s,%s, (SELECT id FROM profiles WHERE name=%s))",
                        (name,ip_address,mac_address,profile_name))

                cursor.close()

def process_package(db,package,version,arch,
                update=False,import_packages=False,install=False):
        #print "%s\t%s\t%s" % (package,version,arch)
	comparator=VersionComparator()
        dbresult=db.find(package,arch)
        if dbresult!=None:
                i,n,v,a=dbresult
                if v!=version:
                        if comparator.compare(v,version)>0:
                                return
                        if update or import_packages:
                                if verbose:
                                        print n,v,"->",version
                                db.update(package,version,arch)
        else:
                if import_packages:
                        if verbose:
                                print "importing",package
                        db.insert(package,version,arch,install)

def parse_release_file(filehandle,wanted,checksum):
        ignore=True
        for l in filehandle:
                if l=='SHA256:\n':
                        ignore=False
                        continue

                if ignore or l[0]!=' ':
                        continue

                parts=l.strip().split()
                if parts[2]==wanted:
                        # print parts[0],wanted
                        if checksum==parts[0]:
                                # print " checksum matches for",wanted
				return
                        else:
                                print "ERROR: Checksum for Packages file does not match, aborting! (is %s, expecting %s)" % (checksum, parts[0])
                                sys.exit(3)
                
def line_parser(filehandle,interesting_keywords,handler):
        f=filehandle
        values={}
        
        for l in f:
                if l[0]==' ':
                        continue
                l=l.strip()
                if l=='':
                        handler(values)
                        values={}
                        continue

                index=l.find(':')
                if index<=0:
                        print "Error parsing '%s'" % (l)
                        sys.exit(9)
                        continue
                        
                keyword=l[:index]
                value=l[index+2:]
                
                if keyword in interesting_keywords:
                        values[keyword]=value
        
def parse_repo_packagelist(db,filehandle,update=True,import_all=False):
        f=filehandle
        package=None
        version=None
        arch=None
        
        for l in f:
                if l[0]==' ':
                        continue
                l=l.strip()
                if l=='':
                        process_package(db,package,version,arch,update,import_all)
                        package=None
                        version=None
                        arch=None
                        continue

                index=l.find(':')
                if index<=0:
                        print "Error parsing '%s'" % (l)
                        sys.exit(9)
                        continue
                        
                keyword=l[:index]
                value=l[index+2:]
                
        #print keyword,value
                if keyword=='Package':
                        package=value
                elif keyword=='Version':
                        version=value
                elif keyword=='Architecture':
                        arch=value

def httplib_auth(httplib,url):
        r=urlparse.urlparse(url)
        if r.username:
                httplib.add_credentials(r.username,r.password)
		netloc=r.hostname
                if r.port:
                	netloc=netloc+':'+r.port
                url=urlparse.urlunparse((r.scheme,netloc,r.path,r.params,r.query,r.fragment))
                # print " using url",url
        return url		

def fetch_release_file(httplib,url,dist,arch,component):
	url=httplib_auth(httplib,url)	
        if not url.endswith('/'):
                url=url+'/'
        if dist!=None and dist!="":
                url_release=url+'dists/'+dist+'/Release'
        else:
                url_release=url+'Release'

        # fetch Release file
        resp, content_release = httplib.request(url_release)
        if resp.status >=400:
                print "Error fetching",url_release,resp.status
                sys.exit(2)

        # fetch Release.gpg
        resp, content_gpg = httplib.request(url_release+'.gpg')
        if resp.status >=400:
                print "Error fetching",url_release+'.gpg',resp.status
                sys.exit(2)

        # verify Release file with gpg

        # write content into temporary file
        # because gpg cannot handle both signature and content on
        # stdin
        tf=tempfile.NamedTemporaryFile()
        tf.write(content_release)
        tf.flush()
        devnull=open(os.path.devnull,'w')
        child=subprocess.Popen("gpg --quiet --verify - %s" % (tf.name),
                stdin=subprocess.PIPE,stdout=devnull,stderr=devnull,shell=True)
        child.communicate(content_gpg)

        tf.close()

        if child.returncode!=0:
                print "ERROR: failed to verify Release file GPG signature, aborting!"
                sys.exit(3)
                
        return content_release
        
def fetch_packages_file(httplib,package_db,url,dist,arch,component,content_release):
	url=httplib_auth(httplib,url)	
	compression='bz2'
        if not url.endswith('/'):
                url=url+'/'
        if dist!=None and dist!="":
                url_packages=url+'dists/'+dist+'/'+component+'/binary-'+arch+'/Packages.bz2'
        else:
                url_packages=url+'Packages.bz2'

        # fetch Packages file
        resp, content_packages = httplib.request(url_packages)
        if resp.status >=400:
                # print "Error fetching",url_packages,resp.status
		url_packages=url_packages[:-4]+'.gz'
                resp, content_packages = httplib.request(url_packages)
                if resp.status >=400:
                        print "Error fetching",url_packages,resp.status
                        sys.exit(2)
                compression='gz'
		

        # verify checksum

        hasch=hashlib.sha256()
        hasch.update(content_packages)
        checksum=hasch.hexdigest()
        tf=open('/tmp/tmp','w')
        tf.write(content_packages)
        tf.close()
       
        if dist!=None and dist!="":
                wanted=component+'/binary-'+arch+'/Packages.'+compression
        else:
                wanted='Packages.'+compression

        filehandle=StringIO.StringIO(content_release)
        parse_release_file(filehandle,wanted,checksum)
        filehandle.close()

        if compression=='gz':
                gziphandle=StringIO.StringIO(content_packages)
                f=gzip.GzipFile(mode='r',fileobj=gziphandle)
                content_packages=f.read()
                f.close()
        else:
                content_packages=bz2.decompress(content_packages)
        return content_packages

def make_url(url, filename):
        if '/?' in url:
                url = url.replace('/?', '/%s?' % filename)
        elif '?' in url:
                url = url.replace('?', '/%s?' % filename)
        else:
                if not url.endswith('/'):
                        url=url+'/'
                url=url+filename
        return url

def fetch_repomd_file(httplib,url,checkSignature=True):
        url = make_url(url,'repodata/repomd.xml')
        # fetch Release file
        resp, content_release = httplib.request(url)
        if resp.status >=400:
                print "Error fetching",url,resp.status
                sys.exit(2)

        if not checkSignature:
            return content_release

        # fetch signature
        resp, content_gpg = httplib.request(url+'.asc')
        if resp.status >=400:
                print "Error fetching",url+'.asc',resp.status
                sys.exit(2)

        # verify Release file with gpg

        # write content into temporary file
        # because gpg cannot handle both signature and content on
        # stdin
        tf=tempfile.NamedTemporaryFile()
        tf.write(content_release)
        tf.flush()
        devnull=open(os.path.devnull,'w')
        child=subprocess.Popen("gpg --quiet --verify - %s" % (tf.name),
                stdin=subprocess.PIPE,stdout=devnull,stderr=devnull,shell=True)
        child.communicate(content_gpg)

        tf.close()

        if child.returncode!=0:
                print "ERROR: failed to verify Release file GPG signature, aborting!"
                sys.exit(3)
                
        return content_release

def fetch_primary_xml(httplib,url,checksum,hasch):
        # fetch primary.xml file
        resp, compressed_content = httplib.request(url)
        if resp.status >=400:
                print "Error fetching",url,resp.status
                sys.exit(2)

       # verify checksum
        hasch.update(compressed_content)
        calculated_checksum=hasch.hexdigest()
        if calculated_checksum!=checksum:
                print "ERROR: invalid checksum %s for %s (should be %s)" % (calculated_checksum, url, checksum)
                sys.exit(3)

        gziphandle=StringIO.StringIO(compressed_content)
        f=gzip.GzipFile(mode='r',fileobj=gziphandle)
        primary_content=f.read()
        f.close()

        return primary_content

def find_by_attr(parent,tag,attr,value):
        elems = parent.findall(tag)
        for e in elems:
                v = e.get(attr)
                if v == value:
                        return e
        return None

def update_package_lists_from_rpmmd(repo_id,httplib,package_db,url,repo_type,
        arch,autoupdate=True,import_all=False):
      
        url=httplib_auth(httplib,url)

        checkSignature=True
        if "unsigned" in repo_type:
            checkSignature=False
 
        content_repomd=fetch_repomd_file(httplib,url,checkSignature)

        prefix='{http://linux.duke.edu/metadata/repo}'
        #print content_repomd
        root = ET.fromstring(content_repomd)
        # d=root.find(prefix+"data[@type='primary']")
        d=find_by_attr(root,prefix+'data','type','primary')
        l=d.find(prefix+'location')
        primary_url=l.get('href')
        el=find_by_attr(d,prefix+'checksum','type','sha')
        if el!=None:
            primary_checksum=el.text
            hasch=hashlib.sha1()
        else:
            el=find_by_attr(d,prefix+'checksum','type','sha256')
            primary_checksum=el.text
            hasch=hashlib.sha256()

        # print primary_url
        # print primary_checksum

        content_packages=fetch_primary_xml(httplib,make_url(url,primary_url),primary_checksum,hasch)
        
        # with open("/tmp/primary.xml","w") as f:
        #         f.write(content_packages)

        xmlroot = ET.fromstring(content_packages)

        plists=package_db.getListsForRepoId(repo_id)
	# print "package lists for repo",repo_id,plists
        for pl in plists:
                pl_id,pl_name,pl_autoupdate=pl
		# print pl_id,pl_name,pl_autoupdate
                if autoupdate and not pl_autoupdate:
                        continue
                if verbose:
                        if import_all:
                                print "  importing into",pl_name
                        else:
                                print "  updating package list:",pl_name
                package_db.selectList(pl_name)
                process_primary_xml(package_db,xmlroot,arch,autoupdate,import_all)

def process_primary_xml(package_db,root,wanted_arch,autoupdate,import_all):
        prefix='{http://linux.duke.edu/metadata/common}'
        r=root.findall(prefix+"package")
        for e in r:
                if e.get("type")!="rpm":
                        continue
                name=e.findtext(prefix+'name')
                arch=e.findtext(prefix+'arch')
                if arch=='src':
                        continue
                if arch!=wanted_arch and arch!='noarch':
                        continue
                v=e.find(prefix+'version')
                epoch=v.get('epoch')
                version=v.get('ver')
                release=v.get('rel')
                combined_version=''
                if epoch!='0':
                        combined_version=epoch+':'
                combined_version=combined_version+version+'-'+release
                # print name,combined_version,arch
                process_package(package_db,name,combined_version,arch,autoupdate,import_all)

def update_package_lists_from_repo(repo_id,httplib,package_db,url,dist,arch,component,
                autoupdate=True,import_all=False):
        content_release=fetch_release_file(httplib,url,dist,arch,component)
        content_packages=fetch_packages_file(httplib,package_db,url,dist,arch,component,content_release)

        plists=package_db.getListsForRepoId(repo_id)
	# print "package lists for repo",repo_id,plists
        for pl in plists:
                pl_id,pl_name,pl_autoupdate=pl
		# print pl_id,pl_name,pl_autoupdate
                if autoupdate and not pl_autoupdate:
                        continue

                if verbose:
                    if import_all:
                        print "  importing into",pl_name
                    else:
                        print "  updating package list:",pl_name

                package_db.selectList(pl_name)
                filehandle=StringIO.StringIO(content_packages)
                parse_repo_packagelist(package_db,filehandle,autoupdate,import_all)
                filehandle.close()

def create_package_list_xml(db,path,priority=0,t='deb'):
        f=open(path,"w")
        f.write("<PackageList type='%s' priority='%d'>\n" %(t,priority))
        f.write("\t<Group name='packages-%s'>\n" % (db.getListName()))

        packages=db.getPackages()
        for id,name,version,arch,install in packages:
                if install:
                        f.write("\t\t<Package name='%s' version='%s' type='%s'/>\n" % (name,version,t))

        f.write("\t</Group>\n")
        f.write("</PackageList>\n")
        f.close()

def create_package_list_puppet(db,path,multiarch=False):
	f=open(path,"w")
	f.write("class packagemanager::%s {\n" % (db.getListName().lower(),))

        packages=db.getPackages()
	default_arch="i386"
        archs={}
        for id,name,version,arch,install in packages:
                if archs.has_key(arch):
                        archs[arch]=archs[arch]+1
                else:
                        archs[arch]=1
        max=0
        for a,count in archs.items():
                if count>max and a!='all':
                        default_arch=a
                        max=count

        packages=db.getPackages()
        for id,name,version,arch,install in packages:
                if install:
			if multiarch==False or default_arch==arch or arch=='all' or arch=='any':
                        	f.write("\tpackage { \"%s\": name => \"%s\", ensure => \"%s\" }\n" % (name,name,version))
			else:
                        	f.write("\tpackage { \"%s\": name => \"%s:%s\", ensure => \"%s\" }\n" % (name,name,arch,version))
	f.write("}\n")

def create_package_bundle_xml(db,path,priority=0):
        f=open(path,"w")
        f.write("<Bundle name='packages-%s'>\n" %(db.getListName()))
        f.write("\t<Action name='apt-get-update'/>\n")
        packages=db.getPackages()
        for id,name,version,arch,install in packages:
                if install:
                        f.write("\t<Package name='%s'/>\n" % (name))
        f.write("</Bundle>\n")
        f.close()
        
def create_autoprofiles_xml(db,path):
        f=open(path,"w")
        f.write("<Groups>\n")
        
        profiles=db.getProfiles()
        for pr in profiles:
                pr_id,pr_name,pr_configgroups=pr
                print pr_name
                f.write("\t<Group name='%s' profile='true' public='true'>\n" % (pr_name))
                if pr_configgroups!=None and pr_configgroups!='':
                        for configgroup in pr_configgroups.split(','):
                                f.write("\t\t<Group name='%s'/>\n" % (configgroup.strip()))
                packetlists=db.getListsForProfile(pr_name)
                for pl in packetlists:
                        pl_id,pl_name=pl
                        f.write("\t\t<Group name='packages-%s'/>\n" % pl_name)
                f.write("\t</Group>\n")
                
                for pl in packetlists:
                        pl_id,pl_name=pl
                        f.write("\t<Group name='packages-%s'>\n" % pl_name)
                        f.write("\t\t<Bundle name='packages-%s'/>\n" % pl_name)
                        f.write("\t</Group>\n")
                f.write("\n")               
 
        f.write("</Groups>\n")
        f.close()

def create_autoprofiles_puppet(db,path):
        profiles=db.getProfiles()
        for pr in profiles:
                pr_id,pr_name,pr_configgroups=pr
		f=open(path+'/'+pr_name+'.pp','w')
		f.write("# file created automatically, do not edit!\n")
		f.write("class %s {\n" % (pr_name))
                if pr_configgroups!=None and pr_configgroups!='':
                        for configgroup in pr_configgroups.split(','):
                                f.write("  include \"%s\"\n" % (configgroup.strip()))
		f.write("\n#\n# package lists\n")
                f.write("  include \"packagemanager\"\n")
                packetlists=db.getListsForProfile(pr_name)
                for pl in packetlists:
                        pl_id,pl_name=pl
                        f.write("  include \"packagemanager::%s\"\n" % pl_name)
		f.write("}\n")
		f.close()                

def load_config():
        global config

        config = ConfigParser.ConfigParser()
        try:
                config.read('/etc/packagemanager.conf')
        except:
                pass
 
def get_config(section,name):
        global config

        if config==None:
                load_config()

        try:
                return config.get(section,name)
        except:
                return None

def create_nodes_puppet(db,path):
        systems=db.getComputers()
        node_inherit=get_config('puppet','node_inherit')
        if node_inherit!=None:
                node_inherit="inherits "+node_inherit
        else:
                node_inherit=""

        f=open(path,'w')
	f.write("# file created automatically, do not edit!\n")
        for s in systems:
                name,profile,options,ip,mac=s
                f.write("node \"%s\" %s {\n" % (name,node_inherit))
                if options is not None and options != '':
                        for o in options.split(','):
                                clss=o.strip()
                                if re.match('^[-A-Za-z0-9_]+$',clss): 
                                        f.write("  include %s\n" % clss)
                # f.write("  %s\n" % options)
                f.write("  include %s\n" % profile)
                f.write("}\n");
        f.close() 
#
# Reads a PackageList XML file and imports the packages into
# the current package list. If the package is not yet in the
# package list, the install flag is set, otherwise it is not
# changed
#
def import_packagelist_xml(db,path,defaultarch,priority=0):
        tree=ET.parse(path)
        root=tree.getroot()
        for package in root.iter('Package'):
                name=package.get('name')
                version=package.get('version')
                process_package(db,name,version,defaultarch,True,True,True)

#
# Reads a PackageList XML file, extracts the package names from the XML
# tree and sets the install flag on the packages with the same name
# in the current package list
#
def install_packagelist_xml(db,path):
        tree=ET.parse(path)
        root=tree.getroot()
        for package in root.iter('Package'):
                name=package.get('name')
                results=package_db.findMulti(name)
                # print name,results
                if results==None or results==[]:
                        print "package",name,"not found in package list, skipping"
                        continue
                if len(results)>1:
                        print "multiple packages",name,"found in package list, skipping"
                        continue
                p_id=results[0][0]
                package_db.setInstallOnId(p_id,True)
                
#
# Reads the output of "rpm -qa --qf '%{Name} %{Version}-%{Release} %{Arch}\n'"
# or "dpkg-query --show -f '${Package} ${Version} ${Architecture}\n'"
# and sets the install flag on the corresponding packages
# in the current package list
#
def install_packagelist_csv(db,input_file):
	for l in input_file.readlines():
		parts=l.split()
		name=parts[0]
                arch=None
		if ':' in name:
	                name,arch=name.split(':')
                else:
                        arch=parts[2] 

                result=package_db.find(name,arch)
                # print name,result
                if result==None or result==[]:
                        print "package",name,"not found in package list, skipping"
                        continue
                p_id=result[0]
                package_db.setInstallOnId(p_id,True)

def parse_csv(line):
        parts=line.split()
        name=parts[0]
        version=parts[1]
        arch=parts[2]
        if ':' in name:
                name,arch=name.split(':')

        return (name,version,arch)
#
# Reads the output of "rpm -qa --qf '%{Name} %{Version}-%{Release} %{Arch}\n'"
# and imports all packages into the current package list
#
def import_packagelist_csv(db,input_file):
	for l in input_file.readlines():
                name,version,arch = parse_csv(l)
                process_package(db,name,version,arch,True,True,True)

# Reads a list of installed packages
# and try to find each package in any of the package lists
# of a given profile.
# Print out any packages which were not found.
def print_extra_packages(db,input_file,profile_name):
        all_packages = []
        lists = db.getListsForProfile(profile_name)

        for list_id,list_name in lists:
                package_db.selectList(list_name)
                for id,name,version,arch,install in package_db.getPackages():
                        all_packages.append((name,arch))

        for l in input_file.readlines():
                name,version,arch = parse_csv(l)
                if not (name,arch) in all_packages:
                        print name,version,arch

def bcfg2_import_reports(input_file,handler):
        skip_first = True
        for l in input_file.readlines():
                if skip_first == True:
                        skip_first = False
                        continue
                fields = l.strip().split()
                name = fields[0]
                status = int(fields[1])
                #print name, status
                handler(name, status) 

def bcfg2_import_clients(file_path,db):
        tree = ET.parse(file_path)
        root = tree.getroot()
        for client in root:
                # print client.tag, client.attrib['address']
                db.createComputer(client.attrib['name'],client.attrib['profile'],
                    client.attrib['address'],'')

def check_repo(package_db,r,import_packages=False):
                r_id,r_name,r_type,r_url,r_dist,r_arch,r_component,r_priority=r

                if verbose:
                    print "Checking repository:",r_name,r_dist,r_component
               
                if r_type=='apt': 
                    update_package_lists_from_repo(r_id,h,package_db,
                        r_url,r_dist,r_arch,r_component,
                        not import_packages,import_packages)
                elif r_type.startswith('rpm') or r_type.startswith('yum'):
                    update_package_lists_from_rpmmd(r_id,h,package_db,
                        r_url,r_type,r_arch,not import_packages,import_packages)

def check_args(command,descriptions):
        if len(sys.argv)-2 < len(descriptions):
                print "Usage: %s %s %s" % (sys.argv[0],command, " ".join([ '<'+n[0]+'>' for n in descriptions]))
                print "\tParameters:"
                for n in descriptions:
                    print "\t%-20s:\t%s" % ('<'+n[0]+'>',n[1])
                sys.exit(1)

def is_cmd(commands,c1,c2):
        if not c2 in commands:
            commands.append(c2)
        return c1==c2

if __name__=='__main__':

        if len(sys.argv)>1:
                cmd=sys.argv[1]
        else:
                cmd=None

	if os.environ.has_key('http_proxy'):
		proxy_url=os.environ['http_proxy']
		r=urlparse.urlparse(proxy_url)
        	proxy=r.hostname
        	proxy_port=r.port
		proxy_info = httplib2.ProxyInfo(httplib2.socks.PROXY_TYPE_HTTP_NO_TUNNEL, proxy,proxy_port)
        	h=httplib2.Http("/tmp/update-packages.cache",proxy_info=proxy_info)
	else:
        	h=httplib2.Http("/tmp/update-packages.cache")

        package_db=PackageDb()
        package_db.connect("dbname=packagemanager user=packagemanager")

        import_packages=False

        commands = [ 'autoupdate', 'autoimport', 'cron' ]

        if cmd=='autoupdate' or cmd=='autoimport' or cmd=='cron':
                if cmd=='autoimport':
                        import_packages=True

		if cmd=='cron':
                        verbose=False 
			wait=hash(os.uname()[2]) % 60
			time.sleep(wait*60)
                        # protect from hanging http connections
                        signal.alarm(60*60)

                repos=package_db.getAllRepositories()
                for r in repos:
                    check_repo(package_db,r)

        elif is_cmd(commands,cmd,'importrepo'):
                check_args(cmd, [ ("name","Name of repository to import") ])
                name=sys.argv[2]
                r=package_db.findRepository(name)
                if r==None:
                    print "Repository %s not found" % name
                    sys.exit(1)
                check_repo(package_db,r,import_packages=True)

        elif is_cmd(commands,cmd,'testrpmmd'):
                repo_id=int(sys.argv[2])
                arch=sys.argv[3]
                repo_url=sys.argv[4]
                update_package_lists_from_rpmmd(repo_id,h,package_db,repo_url,arch,True,False)

        elif is_cmd(commands,cmd,'exportxml'):
                check_args(cmd,[ ("directory","directory to write exported files to") ])
                package_lists=package_db.getLists()
		directory=sys.argv[2]
                for pl in package_lists:
                        pl_id,pl_name,pl_autoupdate=pl
                        package_db.selectList(pl_name)
                        create_package_list_xml(package_db,directory+'/Pkgmgr/'+pl_name+'.xml')
                        create_package_bundle_xml(package_db,directory+'/Bundler/packages-'+pl_name+'.xml')

	elif is_cmd(commands,cmd,'exportpuppet'):
                check_args(cmd,[ ("directory","directory to write exported files to") ])
		package_lists=package_db.getLists()
		directory=sys.argv[2]
		for pl in package_lists:
                        pl_id,pl_name,pl_autoupdate=pl
                        package_db.selectList(pl_name)
			create_package_list_puppet(package_db,directory+'/'+pl_name+'.pp')

	elif is_cmd(commands,cmd,'listlists'):
                package_lists=package_db.getLists()
		print "Name                                     Autoupdate  Repos"
                for pl in package_lists:
                        pl_id,pl_name,pl_autoupdate=pl
                        package_db.selectList(pl_name)
                       	print "%-40s %-5s       [" % (pl_name,pl_autoupdate),\
                                " ".join( [ r[0] for r in package_db.getRepositories() ] ),"]"

       	elif is_cmd(commands,cmd,'copylist'):
                check_args(cmd,[ ("list","name of the package list to copy from origin") ])
		list_name=sys.argv[2]
 		package_db.selectList(list_name)
		package_db.copyList()
	
        elif is_cmd(commands,cmd,'autoprofiles'):
                check_args(cmd,[ ("directory","directory to write exported files to") ])
                path=sys.argv[2]
                create_autoprofiles_xml(package_db,path)

	elif is_cmd(commands,cmd,'puppetprofiles'):
                check_args(cmd,[ ("directory","directory to write exported files to") ])
		path=sys.argv[2]
		create_autoprofiles_puppet(package_db,path)

        elif is_cmd(commands,cmd,'puppetnodes'):
                check_args(cmd,[ ("file","file name to write node declarationst to") ])
                path=sys.argv[2]
                create_nodes_puppet(package_db,path)

        elif is_cmd(commands,cmd,'xmlimport'):
                check_args(cmd,[ ("file","file name to import"),
                        ("arch","architecture for imported packages"),
                        ("list","name of package list to import") ])
                xmlfile=sys.argv[2]
                arch=sys.argv[3]
                package_list=sys.argv[4]
                package_db.selectList(package_list)
                import_packagelist_xml(package_db,xmlfile,arch)
	
	elif is_cmd(commands,cmd,'csvimport'):
		if len(sys.argv)<3:
			print "Argument expected: <package list>"
			print "Reads lines from standard input in the following format:"
			print '  <package name> <version> <architecture>'
                        print 'To create a file in this format:'
                        print "  rpm -qa --qf '%{Name} %{Version}-%{Release} %{Arch}\\n'"
                        print "or"
                        print "  dpkg-query --show -f '${Package} ${Version} ${Architecture}\\n'"
                        sys.exit(1)

		csvfile=sys.stdin
		package_list=sys.argv[2]
		package_db.selectList(package_list)
                if len(sys.argv)==4 and sys.argv[3]=='merge':
                        pass
                else:
                        package_db.emptyList()
		import_packagelist_csv(package_db,csvfile)

        elif is_cmd(commands, cmd, 'clearlist'):
                check_args(cmd,[ ("list","name of package list to clear") ])
                listname=sys.argv[2]
                package_db.selectList(listname)
                package_db.emptyList()

	elif is_cmd(commands,cmd,'sshimport'):
		if len(sys.argv)<5:
			print "Argument expected: <package list> <hostname> rpm|dpkg [merge]"
                        sys.exit(1)

		package_list=sys.argv[2]
                hostname=sys.argv[3]
                typ=sys.argv[4]
                merge=False
                if len(sys.argv)==5:
                        merge=True

		package_db.selectList(package_list)
                if merge:
                        package_db.emptyList()
		import_packagelist_csv(package_db,csvfile)

        elif is_cmd(commands,cmd,'xmlinstall'):
                check_args(cmd,[ ("file","file name to import"),
                            ("arch","architecture for packages"),
                            ("list","name of package list to import into") ])
                xmlfile=sys.argv[2]
                arch=sys.argv[3]
                package_list=sys.argv[4]
                package_db.selectList(package_list)
                install_packagelist_xml(package_db,xmlfile)

	elif is_cmd(commands,cmd,'csvinstall'):
                check_args(cmd, [
                            ("list","name of package list to import into from stdin, expects format <name> <version> <arch>") ])
		dpkgfile=sys.stdin
		package_list=sys.argv[2]
		package_db.selectList(package_list)
		install_packagelist_csv(package_db,dpkgfile)

        elif is_cmd(commands,cmd,'addrepo'):
                check_args(cmd, [
                            ("list","name of packagelist"),
                            ("repo","name of repository to add to packagelist") ])
                listname=sys.argv[2]
                reponame=sys.argv[3]
                package_db.selectList(listname)
                repo_id=package_db.findRepository(reponame)[0]
                # print "repo_id of",reponame,"is",repo_id
                package_db.addRepository(repo_id)

        elif is_cmd(commands,cmd,'createrepo'):
                check_args(cmd, [
                            ("name","name of repository to create"),
                            ("type","repository type (yum or apt)"),
                            ("url","repository url"),
                            ("dist","distribution (apt only, e.g. squeeze, can be empty)"),
                            ("component","component (apt only, e.g. universe, can be empty)"),
                            ("arch","architecture for packages"),
                            ("priority","priority for update checks") ])
                        
                name=sys.argv[2]
                typ=sys.argv[3]
                url=sys.argv[4]
                dist=sys.argv[5]
                component=sys.argv[6]
                arch=sys.argv[7]
                priority=sys.argv[8]
                package_db.createRepository(name,typ,url,dist,arch,component,priority) 

	elif is_cmd(commands,cmd,'listrepos'):
                repos=package_db.getAllRepositories()
                for r in repos:
                        r_id,r_name,r_type,r_url,r_dist,r_arch,r_component,r_priority=r
                        print "%-40s %-20s %-8s %-16s %d" % (r_name,r_dist,r_arch,r_component,r_priority)

        elif is_cmd(commands,cmd,'listprofiles'):
		for id,name,extras in package_db.getProfiles():
			print "%-20s" % name,"[", " ".join([ p for i,p in package_db.getListsForProfile(name) ]),"]"

        elif is_cmd(commands,cmd,'listsystems'):
                for name,profile,options,ip,mac in package_db.getComputers():
                        if mac == None:
                                mac = ''
                        print "%-20s %-20s %-16s %-16s" %(name,profile,ip,mac)

        elif is_cmd(commands,cmd,'extrapackages'):
                check_args(cmd, [ ("profile","name of profile") ])
                profilename=sys.argv[2]
                print_extra_packages(package_db,sys.stdin,profilename)

        elif is_cmd(commands,cmd,'listpackages') or is_cmd(commands,cmd,'listuninstalled'):
                check_args(cmd, [("list","name of packagelist")])
                listname=sys.argv[2]
                package_db.selectList(listname)
                print "%-40s %-20s %-10s %-5s" % ("Name","Version","Arch","Install?")
                skipflag=(cmd == 'listuninstalled')
                for id,name,version,arch,install in package_db.getPackages():
                        if install==skipflag:
                                continue
                        print "%-40s %-20s %-10s %-5s" % (name,version,arch,install)

        elif is_cmd(commands,cmd,'install') or is_cmd(commands,cmd,'uninstall'):
                check_args(cmd, [("list","name of packagelist"),
                            ("package","name of package to mark as 'install'"),
                            ("arch","architecture of package")])
                listname=sys.argv[2]
                packagename=sys.argv[3]
                arch=sys.argv[4]
                package_db.selectList(listname)
		result=package_db.find(packagename,arch)
		if result==None:
			print "Package",packagename,"not found"
			sys.exit(2)
                p_id,p_name,p_version,p_arch=result
                if cmd=='install':
                        package_db.setInstallOnId(p_id,True)
                else:
                        package_db.setInstallOnId(p_id,False)

        elif is_cmd(commands,cmd,'showsystem'):
                check_args(cmd, [("name","system name to show")])
                name=sys.argv[2]
                res = package_db.getComputer(name)
                if res is None:
                        print name,"not found"
                        sys.exit(2)

                name,profile,options,ip,mac = res
                print name,profile,options,ip,mac

        elif is_cmd(commands,cmd,'showSystemStatus'):
                check_args(cmd, [("name","system name to show")])
                name=sys.argv[2]
                print name,package_db.getComputerStatus(name)

        elif is_cmd(commands,cmd,'setSystemStatus'):
                check_args(cmd, [("name","system name"),("status","status to set")])
                name=sys.argv[2]
                status=int(sys.argv[3])
                package_db.updateComputerStatus(name,status)

        elif is_cmd(commands,cmd,'puppetReports'):
                handler = lambda name,status: package_db.updateComputerStatus(name,int(status)) 
                reportReader=PuppetReportReader(handler)
                reportReader.process_all_reports()

        elif is_cmd(commands,cmd,'bcfg2Reports'):
                # /usr/sbin/bcfg2-reports -a --fields=bad | /usr/sbin/packagemanager bcfg2Reports
                handler = lambda name,status: package_db.updateComputerStatus(name,int(status))
                bcfg2_import_reports(sys.stdin,handler)
 
        elif is_cmd(commands,cmd,'bcfg2Clients'):
                bcfg2_import_clients('/var/lib/bcfg2/Metadata/clients.xml',package_db)
 
        else:
                print "Usage: %s <command> <arguments>" % (sys.argv[0])
                print "\twhere <command> is one of:\n"
                for c in sorted(commands):
                        print "\t",c
                print
                sys.exit(1)

        package_db.close()
