#
# bootstrap.py : functions for building and using bootstrap
#
# Copyright 2010, Intel Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; version 2 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.


import os
import subprocess
import re
import glob
import xml.dom.minidom
import tarfile

try:
    import sqlite3 as sqlite
except ImportError:
    import sqlite
import _sqlitecache

try:
    from xml.etree import cElementTree
except ImportError:
    import cElementTree
xmlparse = cElementTree.parse

from mic.imgcreate.errors import *
from mic.imgcreate.fs import *
from mic.imgcreate.misc import *

def get_repo_metadata(repo, cachedir, reponame, proxies = {}, arch = None):
    makedirs(cachedir + "/" + reponame)
    print "Geting repomd.xml from repo %s..." % reponame
    url = str(repo + "/repodata/repomd.xml")
    filename = str("%s/%s/repomd.xml" % (cachedir, reponame))
    repomd = myurlgrab(url, filename, proxies)
    f = open(repomd, "r")
    content = f.read()
    f.close()
    m = re.match(".*href=\"(.*?primary.sqlite.*?)\"", content, re.M | re.S)
    if not m:
        m = re.match(".*href=\"(.*?primary.xml.*?)\"", content, re.M | re.S)
        if not m:
            print "Failed to get metadata from repo %s" % reponame
            sys.exit(-1)
        primaryxml = str(cachedir + "/%s/%s" % (reponame, os.path.basename(m.group(1))))
        primarydb = cachedir + "/%s/%s" % (reponame, ".tmp.primary.sqlite")
        primaryxmlurl = str(repo + "/" + m.group(1))
        print "Geting primary.xml from repo %s..." % reponame
        primaryxml = myurlgrab(primaryxmlurl, primaryxml, proxies)
        if primaryxml.rfind(".gz") != -1:
            subprocess.call([gunzip, "-f", primaryxml])
            primaryxml = primaryxml[0:len(primaryxml)-3]
        print "Generating primary.sqlite from primary.xml for repo %s..." % reponame
        gen_sqlitedb_from_xml(primaryxml, primarydb, arch)
    else:
        primarydb = str(cachedir + "/%s/%s" % (reponame, ".tmp.primary.sqlite"))
        primarydburl = str(repo + "/" + m.group(1))
        if primarydburl.rfind(".gz") != -1:
           primarydb = primarydb + ".gz"
        elif primarydburl.rfind(".bz2") != -1:
           primarydb = primarydb + ".bz2"
        print "Geting primary.sqlite from repo %s..." % reponame
        primarydb = myurlgrab(primarydburl, primarydb, proxies)
        if primarydb.rfind(".gz") != -1:
            subprocess.call([gunzip, "-f", primarydb])
            primarydb = primarydb[0:len(primarydb)-3]
        elif primarydb.rfind(".bz2") != -1:
            subprocess.call([bunzip2, "-f", primarydb])
            primarydb = primarydb[0:len(primarydb)-4]
    return primarydb


def gen_sqlitedb_from_xml(primaryxml, sqlitedb, target_arch = None):
    con = sqlite.connect(sqlitedb)
    con.executescript("""
        DROP table IF EXISTS packages;
        DROP table IF EXISTS provides;
        DROP table IF EXISTS requires;
    """)
    sql = "CREATE TABLE IF NOT EXISTS packages (  pkgKey INTEGER PRIMARY KEY,  pkgId TEXT,  name TEXT,  arch TEXT,  version TEXT,  epoch TEXT,  release TEXT,  summary TEXT,  description TEXT,  url TEXT,  time_file INTEGER,  time_build INTEGER,  rpm_license TEXT,  rpm_vendor TEXT,  rpm_group TEXT,  rpm_buildhost TEXT,  rpm_sourcerpm TEXT,  rpm_header_start INTEGER,  rpm_header_end INTEGER,  rpm_packager TEXT,  size_package INTEGER,  size_installed INTEGER,  size_archive INTEGER,  location_href TEXT,  location_base TEXT,  checksum_type TEXT);"
    con.execute(sql)
    
    sql = "CREATE TABLE IF NOT EXISTS provides (  name TEXT,  flags TEXT,  epoch TEXT,  version TEXT,  release TEXT,  pkgKey INTEGER );"
    con.execute(sql)
    
    sql = "CREATE TABLE IF NOT EXISTS requires (  name TEXT,  flags TEXT,  epoch TEXT,  version TEXT,  release TEXT,  pkgKey INTEGER , pre BOOLEAN DEFAULT FALSE);"
    con.execute(sql)
    
    dom = xml.dom.minidom.parse(primaryxml)
    
    pkgKey = 0
    nodelist = dom.getElementsByTagName("package")
    for node in nodelist:
        if (node.nodeType == node.ELEMENT_NODE) and (len(node.childNodes) >= 12):
            pkgKey += 1
            myprovides = []
            myrequires = []
            for subnode in node.childNodes:
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "name"):
                    pkgname = subnode.childNodes[0].data
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "arch"):
                    arch = subnode.childNodes[0].data
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "version"):
                    epoch = subnode.getAttribute("epoch")
                    version = subnode.getAttribute("ver")
                    release = subnode.getAttribute("rel")
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "checksum"):
                    pkgId = subnode.childNodes[0].data
                    checksum_type = subnode.getAttribute("type")
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "summary"):
                    summary = subnode.childNodes[0].data
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "description"):
                    if subnode.childNodes:
                        description = subnode.childNodes[0].data
                    else:
                        description = summary
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "packager"):
                    if len(subnode.childNodes) != 0:
                        rpm_packager = subnode.childNodes[0].data
                    else:
                        rpm_packager = ""
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "url"):
                    if len(subnode.childNodes) != 0:
                        url = subnode.childNodes[0].data
                    else:
                        url = ""
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "time"):
                    time_file = subnode.getAttribute("file")
                    time_build = subnode.getAttribute("build")
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "size"):
                    size_package = subnode.getAttribute("package")
                    size_installed = subnode.getAttribute("installed")
                    size_archive = subnode.getAttribute("archive")
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "location"):
                    location_href = subnode.getAttribute("href")
                    location_base = subnode.getAttribute("base")
                if (subnode.nodeType == subnode.ELEMENT_NODE) and (subnode.nodeName == "format"):
                    for subsubnode in subnode.childNodes:
                        if (subsubnode.nodeType == subsubnode.ELEMENT_NODE) and (subsubnode.nodeName == "rpm:license"):
                            rpm_license = subsubnode.childNodes[0].data
                        if (subsubnode.nodeType == subsubnode.ELEMENT_NODE) and (subsubnode.nodeName == "rpm:vendor"):
                            rpm_vendor = subsubnode.childNodes[0].data
                        if (subsubnode.nodeType == subsubnode.ELEMENT_NODE) and (subsubnode.nodeName == "rpm:group"):
                            rpm_group = subsubnode.childNodes[0].data
                        if (subsubnode.nodeType == subsubnode.ELEMENT_NODE) and (subsubnode.nodeName == "rpm:buildhost"):
                            rpm_buildhost = subsubnode.childNodes[0].data
                        if (subsubnode.nodeType == subsubnode.ELEMENT_NODE) and (subsubnode.nodeName == "rpm:sourcerpm"):
                            if (len(subsubnode.childNodes) != 0):
                                rpm_sourcerpm = subsubnode.childNodes[0].data
                            else:
                                rpm_sourcerpm = ""
                        if (subsubnode.nodeType == subsubnode.ELEMENT_NODE) and (subsubnode.nodeName == "rpm:header-range"):
                            rpm_header_start = subsubnode.getAttribute("start")
                            rpm_header_end = subsubnode.getAttribute("end")
                        if (subsubnode.nodeType == subsubnode.ELEMENT_NODE) and (subsubnode.nodeName == "rpm:provides"):
                            for tmpnpode in subsubnode.childNodes:
                                if (tmpnpode.nodeType == tmpnpode.ELEMENT_NODE) and (tmpnpode.nodeName == "rpm:entry"):
                                    myprovides.append((tmpnpode.getAttribute("name"), tmpnpode.getAttribute("flag"), tmpnpode.getAttribute("epoch"), tmpnpode.getAttribute("ver"), tmpnpode.getAttribute("rel"), pkgKey))
                        if (subsubnode.nodeType == subsubnode.ELEMENT_NODE) and (subsubnode.nodeName == "rpm:requires"):
                            for tmpnpode in subsubnode.childNodes:
                                if (tmpnpode.nodeType == tmpnpode.ELEMENT_NODE) and (tmpnpode.nodeName == "rpm:entry"):
                                    myrequires.append((tmpnpode.getAttribute("name"), tmpnpode.getAttribute("flag"), tmpnpode.getAttribute("epoch"), tmpnpode.getAttribute("ver"), tmpnpode.getAttribute("rel"), pkgKey, tmpnpode.getAttribute("pre")))
            if not target_arch and arch.startswith("arm"):
                continue
            if target_arch != None and arch != target_arch and arch != "noarch":
		continue
            if arch == "src":
                continue
            sql = "INSERT INTO packages VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"
            con.execute(sql, (pkgKey, pkgId,  pkgname, arch, version, epoch, release, summary, description, url, time_file, time_build, rpm_license, rpm_vendor, rpm_group, rpm_buildhost, rpm_sourcerpm, rpm_header_start, rpm_header_end, rpm_packager, size_package, size_installed, size_archive, location_href, location_base, checksum_type))
            sql = "INSERT INTO provides values (?,?,?,?,?,?)";
            con.executemany(sql, myprovides)
            sql = "INSERT INTO requires values (?,?,?,?,?,?,?)";
            con.executemany(sql, myrequires)
        
    con.commit()

def get_my_all_deps_in_order(pkg, con):
    class Node:
        def __init__(self, data=None):
            self.data = data
            self.parent = None
            self.childs = []
            self.root = self
            self.level = 0
        def setroot(self, node):
            self.root = node
        def getroot(self):
            return self.root
        def setnode(self, data):
            self.data = data
        def addChild(self, node):
            self.childs.append(node)
            node.setroot(self.root)
            node.level = self.level + 1
        def removeChild(self, node):
            self.childs.remove(node)
        def findNode(self, data):
            if self.data == data:
                return self
            elif len(self.childs) == 0:
                return None
            else:
               for child in self.childs:
                   tmpnode = child.findNode(data)
                   if tmpnode:
                       return tmpnode
               return None
        def lastordertranverse(self):
            tmplist = []
            if len(self.childs) == 0:
                tmplist.append({"name":self.data, "isleaf":True})
            else:
                for child in self.childs:
                    tmplist = tmplist + child.lastordertranverse()
                tmplist.append({"name":self.data, "isleaf":False})
            return tmplist
        def firstordertranverse(self):
            tmplist = []
            if len(self.childs) != 0:
                tmplist.append({"name":self.data, "isleaf":False})
                for child in self.childs:
                    tmplist = tmplist + child.firstordertranverse()
            else:
                tmplist.append({"name":self.data, "isleaf":True})
            return tmplist

        def getlevels(self, levels = {}):
            levels[self.data] = self.level
            if len(self.childs) != 0:
                for child in self.childs:
                    child.getlevels(levels)

    def getmydepends(pkgname, con = con):
        myrequires = []
        mydepends = []
        for row in con.execute("select requires.name from packages inner join requires where packages.name = \"" + pkgname + "\" and packages.pkgKey = requires.pkgKey and packages.arch <> \"src\" group by requires.name"):
            myrequires.append(row[0])
        for i in range(len(myrequires)):
            for row in con.execute("select packages.name from provides, packages where provides.name = \"" + myrequires[i] + "\" and packages.pkgKey = provides.pkgKey group by packages.name"):
                if row[0] not in mydepends:
                    mydepends.append(row[0])

        if pkgname in mydepends:
            mydepends.remove(pkgname)

        return mydepends

    def recursivegetmydependstree(pkgname, mydepstree):
        if (mydepstree.getroot().findNode(pkgname)):
            return
        mydepstree.addChild(Node(pkgname))
        mydeps = getmydepends(pkgname)
        for pkga in mydeps:
            recursivegetmydependstree(pkga, mydepstree.findNode(pkgname))
        return

    def getalldependstree(pkgname, mydepstree):
        mydeps = getmydepends(pkgname)
        for pkga in mydeps:
            recursivegetmydependstree(pkga, mydepstree.findNode(pkgname))
        return

    dependstree = Node(pkg)
    getalldependstree(pkg, dependstree)
    childpkgs = dependstree.lastordertranverse()
    if pkg in childpkgs:
        childpkgs.remove(pkg)

    return childpkgs

def is_ubuntu():
    found = False
    for release in glob.glob("/etc/*-release"):
        fd = open(release, "r")
        content = fd.read()
        for line in content.split("\n"):
            if line.find("Ubuntu") != -1:
                found = True
                break
        fd.close()
        if found:
            break

    return found

def install_rpm(pkgpath, pkgname, rootpath, force = False):
    rpm = "/bin/rpm"
    if not os.path.exists(rpm):
        rpm = "/usr/bin/rpm"
        if not os.path.exists(rpm):
            raise CreatorError("Failed to run 'rpm'.")
    
    rootpath = os.path.abspath(os.path.expanduser(rootpath))
    print "Installing %s..." % pkgname
    dev_null = os.open("/dev/null", os.O_WRONLY)
    try:
        argv = [rpm, "-i", "--ignorearch", "--ignoreos", "--nodigest", "--nosignature", "--root=" + rootpath, pkgpath]
        if force:
            argv.extend(["--force", "--nodeps"])
        if is_ubuntu():
            argv.append("--force-debian")
        subprocess.call(argv, stdout = dev_null, stderr = dev_null)
    finally:
        os.close(dev_null)

def get_pkg_url(pkg, con):
    rows = con.execute("select location_href,location_base from packages where name =\"%s\"" % pkg)
    for row in rows:
        return row[0]
    return None

def build_bootstrap(repourl, cachedir, reponame, bootstrapdir, proxies = {}, arch = None):
    sqlitedb = "%s/%s/.tmp.primary.sqlite" % (cachedir, reponame)
    g = URLGrabber()
    con = sqlite.connect(sqlitedb)
    installed_pkgs = []

    cachedir = os.path.abspath(os.path.expanduser(cachedir))
    bootstrapdir = os.path.abspath(os.path.expanduser(bootstrapdir))
    pkgspath = "%s/%s/packages" % (cachedir, reponame)
    makedirs(pkgspath)
    makedirs(bootstrapdir)
    makedirs(bootstrapdir + "/var/lib/rpm")

    if arch and arch.startswith("arm"):
        setup_qemu_emulator(bootstrapdir, "arm")

    def download_and_install(pkg, force = False):
        if pkg in installed_pkgs:
            return

        local_href = get_pkg_url(pkg, con)
        if not local_href:
            return
        pkgurl = str(repourl + "/" + get_pkg_url(pkg, con))
        pkgfn = str(pkgspath + "/" + os.path.basename(pkgurl))
        if not os.path.exists(pkgfn):
            if pkgurl.startswith("file://"):
                 pkgfn = pkgurl.replace("file://", "")
            else:
                 print "Downloading %s..." % pkg
                 try:
                     pkgfn = g.urlgrab(url = pkgurl, filename = pkgfn, proxies = proxies)
                 except URLGrabError, e:
                     raise CreatorError("URLGrabber error: %s: %s" % (e, pkgurl))
                 except:
                     raise CreatorError("URLGrabber error: %s" % pkgurl)
        install_rpm(pkgfn, pkg, bootstrapdir, force)
        installed_pkgs.append(pkg)

    # First install a mini base
    pkglist = ["bash", "setup", "filesystem", "basesystem", "nss-softokn-freebl", "cpio", "info", "glibc-common", "libgcc", "glibc"]
    for pkg in pkglist:
        download_and_install(pkg, True)

    pkglist = ["glibc", "passwd", "pam-modules-cracklib", "meego-release", "moblin-release", "fastinit", "nss", "genisoimage", "bzip2", "gzip", "cpio", "perl", "syslinux-extlinux", "mic2"]
    pkglist.extend(["isomd5sum", "wget"])
    for pkg in pkglist:
        for node in get_my_all_deps_in_order(pkg, con):
            #download_and_install(node["name"], node["isleaf"])
            """ We must force it to install because our dependencies tree isn't complete """
            download_and_install(node["name"], True)
        download_and_install(pkg, False)

    """ Remove rpmdb because they are non-sense and will result in some errors message """
    for rpmdbfile in glob.glob(bootstrapdir + "/var/lib/rpm/__db.*"):
        os.unlink(rpmdbfile)

    """ Create a emtpy file /tmp/SampleMedia.tar to avoid long download time """
    fd = open(bootstrapdir + "/tmp/SampleMedia.tar", "w")
    fd.close()

def isbootstrap(rootdir):
    ret = False
    if (os.path.exists(rootdir + "/etc/meego-release") or os.path.exists(rootdir + "/etc/moblin-release")) \
       and os.path.exists(rootdir + "/etc/inittab") \
       and os.path.exists(rootdir + "/etc/rc.sysinit"):
        ret = True

    return ret

def has_package_in_repo(con, pkg):
    for row in con.execute("select * from packages where name = \"%s\"" % pkg):
        return True
    return False

def is_mainrepo(cachedir, reponame):
    repomd = "%s/%s/repomd.xml" % (cachedir, reponame)
    f = open(repomd, "r")
    content = f.read()
    f.close()
    m = re.match(".*href=\"(.*?comps.xml.*?)\"", content, re.M | re.S)
    if not m:
        return False
    sqlitedb = "%s/%s/.tmp.primary.sqlite" % (cachedir, reponame)
    con = sqlite.connect(sqlitedb)
    if has_package_in_repo(con, "yum") \
       and has_package_in_repo(con, "mic2") \
       and has_package_in_repo(con, "rpm") \
       and has_package_in_repo(con, "fastinit"):
        return True

    return False

def check_depends(imgfmt):
    if imgfmt == "vmdk":
        if not os.path.exists('/usr/bin/qemu-img'):
            raise CreatorError("Needed file: /usr/bin/qemu-img not found, please check your installation.")
    if imgfmt == "vdi":
        if not os.path.exists('/usr/bin/VBoxManage'):
            raise CreatorError("Needed file: /usr/bin/VBoxManage not found, please check your installation..")

def save_imginfo(outimage, infofile):
    print "Saving image info..."
    fd = open(infofile, "w")
    for file in outimage:
        fd.write(file + "\n")
    fd.close()

def get_imgfile(infofile):
    print "Getting image info..."
    fd = open(infofile, "r")
    content = fd.read()
    fd.close()
    os.unlink(infofile)
    for file in content.split("\n"):
        if file.endswith(".raw"):
            return file
    return None

def write_image_vmx(imgfile):
    vmdkcfg_file = imgfile[0:-4] + "vmx"
    vmdkcfg = open(vmdkcfg_file, "w")
    vmx = """#!/usr/bin/vmware
.encoding = "UTF-8"
displayName = "MeeGo 1.0"
guestOS = "linux"

memsize = "512"
"""
    vmx += "ide0:0.fileName = \"" + "%s\"" % (os.path.basename(imgfile))
    vmx += """

# DEFAULT SETTINGS UNDER THIS LINE
config.version = "8"
virtualHW.version = "4"

MemAllowAutoScaleDown = "FALSE"
MemTrimRate = "-1"

uuid.location = "56 4d 8a 28 88 e5 86 1f-7e ed 8f 25 45 7d f8 e4"
uuid.bios = "56 4d 8a 28 88 e5 86 1f-7e ed 8f 25 45 7d f8 e4"

uuid.action = "create"

ethernet0.present = "TRUE"
ethernet0.connectionType = "nat"
ethernet0.addressType = "generated"
ethernet0.generatedAddress = "00:0c:29:7d:f8:e4"
ethernet0.generatedAddressOffset = "0"

usb.present = "TRUE"
ehci.present = "TRUE"
sound.present = "TRUE"
sound.autodetect = "TRUE"

scsi0.present = "FALSE"
floppy0.present = "FALSE"
ide0:0.present = "TRUE"
ide0:0.deviceType = "disk"
ide0:1.present = "FALSE"

virtualHW.productCompatibility = "hosted"
tools.upgrade.policy = "manual"

tools.syncTime = "FALSE"

ide0:0.redo = ""
"""
    vmdkcfg.write(vmx)
    vmdkcfg.close()
    return vmdkcfg_file

def convert_to(imgfile, imgfmt):
    outimage = []
    dst = imgfile[0:-3] + imgfmt
    outimage.append(dst)
    print "converting %s image to %s" % (imgfile, dst)
    if imgfmt == "vmdk":
        rc = subprocess.call(["/usr/bin/qemu-img", "convert",
                                      "-f", "raw", imgfile,
                                      "-O", imgfmt, dst])
    elif imgfmt == "vdi":
        rc = subprocess.call(["/usr/bin/VBoxManage", "convertfromraw",
                                      imgfile, dst,
                                      "--format", "VDI"])
                
    if rc != 0:
        raise CreatorError("Unable to convert to %s" % imgfmt)
    else:
        if imgfmt == "vmdk":
            vmxfile = write_image_vmx(dst)
            outimage.append(vmxfile)
        print "convert successfully"
        try:
            os.unlink(imgfile[0:-8] + ".xml")
            os.unlink(imgfile)
        except OSError:
            pass
    return outimage

def package_image(imagefiles, image_format, destdir, package):
    if not package or package == "none":
        return None

    for file in imagefiles:
        if file.endswith(".vmdk"):
            name = os.path.basename(file)[0:-9]
            break
        if file.endswith(".vdi"):
            name = os.path.basename(file)[0:-8]
            break
    destdir = os.path.abspath(os.path.expanduser(destdir))
    (pkg, comp) = os.path.splitext(package)
    if comp:
        comp=comp.lstrip(".")

    if pkg == "tar":
        if comp:
            dst = "%s/%s-%s.tar.%s" % (destdir, name, image_format, comp)
        else:
            dst = "%s/%s-%s.tar" % (destdir, name, image_format)
        print "creating %s" % dst
        tar = tarfile.open(dst, "w:" + comp)
        for file in imagefiles:
            print "adding %s to %s" % (file, dst)
            tar.add(file, arcname=os.path.join("%s-%s" % (name, image_format), os.path.basename(file)))
        tar.close()
        return dst

def print_imginfo(outimage):
    print "Your new image can be found here:"
    for file in outimage:
        print os.path.abspath(file)
