#!/usr/bin/python

# Copyright (C) 2006 by Tapsell-Ferrier Limited

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program; see the file COPYING.  If not, write to the
# Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
# Boston, MA 02110-1301 USA


"""Module to make libxslt as easy as possible.
"""

import os.path
import StringIO
import libxml2
import libxslt
import logging
import types
import sys

import pdb
import traceback

# Global stylesheet cache
STYLESHEETS = {}


class xsltError(Exception):
    """Something went wrong with the xslt process"""
    def __init__(self, value):
        self.value = value
    def __str__(self):
        return str(self.value)


def preload(stylesheet_xml):
    """let's the user preload the xslt xml file"""
    __load_xslt(stylesheet_xml)
    return None


def xslt(stylesheet_xml, out, src_doc):
    """Push the doc or the filename through the stylesheet into out."""

    logger = logging.getLogger("xslt")

    # Setup the error handler
    ## FIXME we should accept an error handler passed into this func.
    def error_handler(ctx, str):
        newstr = str.rstrip()
        if newstr != '':
            logger.error(newstr)
        return None

    libxslt.registerErrorHandler(error_handler, "")
    script = __load_xslt(stylesheet_xml)
    __transform(script, out, src_doc)
    return None
    

def __load_xslt(stylesheet_xml):
    """loader implementation."""

    logger = logging.getLogger("__load_xslt")
    global STYLESHEETS
    script = None

    try:
        script = STYLESHEETS[stylesheet_xml]
    except KeyError:
        # The script doesn't exist in the cache so it must be new
        try:
            if os.path.exists(stylesheet_xml):
                logger.info("stylesheet file is: %s" % (stylesheet_xml))
                # FIXME: should record the file against the XSLT object here
                script = libxslt.parseStylesheetFile(stylesheet_xml)
                if not script:
                    raise xsltError("some problem with parsing the stylesheet")
                
                script.filename = stylesheet_xml
                script.filename_cached_date = os.path.getmtime(stylesheet_xml)
            else:
                xsl_doc = libxml2.readMemory(stylesheet_xml, len(stylesheet_xml), "file:///-", "UTF-8", 0)
                script = libxslt.parseStylesheetDoc(xsl_doc)

            # Try and dynamically load functions - this is still a hack
            register_modules_by_xpath(script.doc())
        except Exception, e:
            raise xsltError("with %s we got %s with tb %s" % (stylesheet_xml, str(e), traceback.extract_tb(sys.exc_info()[2])))
    else:
        # Has the script been touched since we cached it?
        try:
            filename = script.filename
            if os.path.getmtime(filename) > script.filename_cached_date:
                script = libxslt.parseStylesheetFile(filename)
                script.filename_cached_date = os.path.getmtime(filename)
                # Try and dynamically load functions - this is still a hack
                register_modules_by_xpath(script.doc())
        except AttributeError:
            pass

    # Record the stylesheet in the cache and return it
    STYLESHEETS[stylesheet_xml] = script
    
    return script


def __transform(script, out, src_doc):
    """do the transform with the script"""

    # Get a doc for the input
    if isinstance(src_doc, file):
        doc = libxml2.readFd(src_doc.fileno(), "file:///", "utf-8", 0)
    elif isinstance(src_doc, libxml2.xmlDoc):
        doc = src_doc
    elif isinstance(src_doc, str):
        doc = libxml2.readMemory(src_doc, len(src_doc), "file:///", "utf-8", 0)
    elif isinstance(src_doc, types.FunctionType):
        doc = __function_output_to_dom__(src_doc)
    else:
        raise xsltError(src_doc)

    # Transform it.
    result = script.applyStylesheet(doc, {})

    # Handle the result
    if isinstance(out, libxml2.xmlNode):
        root = result.getRootElement()
        out.addChild(root)
        root.setTreeDoc(out.get_doc())
    elif isinstance(out, file):
        script.saveResultToFile(out, result)
    else:
        str_val = script.saveResultToString(result)
        print >>out, str_val
        
    return None


def __function_output_to_dom__(func):
    """Call the function with a stream to collect XML which will then be parsed."""
    try:
        buffer = StringIO.StringIO()
        func(buffer)
        str_src = buffer.getvalue()
        buffer.close()
        dom = libxml2.readMemory(str_src, len(str_src), "file:///-", "UTF-8", 0)
        return dom
    except:
        raise xsltError(func)
    return None




## XSL type mapping

def py_xslfn_typemap(value, node = None):
    """Map the Python value to some libxml2 value.

    Adds the value to the supplied node if it's not 'None'.
    """

    if isinstance(value, dict):
        if node == None:
            node = libxml2.newDoc("1.0")

        # Render the mappings as XML
        mappings = node.newChild(None, "mappings", None)
        for key, data in value.iteritems():
            pair = mappings.newChild(None, "mapping", None)
            py_xslfn_typemap(key, pair.newChild(None, "key", None))
            py_xslfn_typemap(data, pair.newChild(None, "value", None))

        return mappings

    elif hasattr(value, "__iter__") \
             or isinstance(value, tuple) \
             or isinstance(value, list):
        if node == None:
            node = libxml2.newDoc("1.0")

        # Now we need a list of items
        items = node.newChild(None, "items", None)
        for val in value:
            py_xslfn_typemap(val, items.newChild(None, "item", None))

        return items

    elif isinstance(value, bool) \
             or isinstance(value, int) \
             or isinstance(value, float): ### how do you test for all numeric types?
        # If we have a node then add the data there
        if node != None:
            node.addContent(str(value))
        # otherwise simply return it.
        else:
            return str(value)

    # Base case is that everything's a string
    else:
        if node != None:
            node.addContent(str(value))
        else:
            return str(value)

    return None


# Stores allocations per context
PER_CONTEXTS_ALLOC = {}                   

def py_xslfn_glue(fname, ctx, *args):
    """Glue function to make ordinary python code accessible to XSLT
    
It has to turn ordinary values into Xpath objects

Importantly it has to turn objects like iterators, lists, tuples and
dictionaries into elements in a DOM
"""
    # First generate the result by calling the user's function
    result = fname(*args)

    # This is how to call the function...
    xml_value = py_xslfn_typemap(result, None)

    # Now check if we need to store a ptr for freeing the object later...
    ## FIXME: should py_xsl_typemap do this?
    global PER_CONTEXTS_ALLOC
    if isinstance(xml_value, libxml2.xmlNode):
        # Can't return DOMs - only elements
        xml_value = xml_value.doc.getRootElement()
        # And they have to be unlinked
        xml_value.unlinkNode()

        # Save it for later
        try:
            lst = PER_CONTEXTS_ALLOC[ctx]
        except KeyError:
            lst = []
            PER_CONTEXTS_ALLOC[ctx] = lst

        if xml_value not in lst:
            lst.append(xml_value)

        return [xml_value]

    else:
        return xml_value



# 'Dynamic' function loading

import imp
import re
import string

# List of dynamic function language namespaces
DYNAMIC_FUNCTION = {
    "http://www.tapsellferrier.co.uk/xslt-dynamic-function-languages/python": "python"
    }

def register_modules_by_xpath(xslt_doc):
    """load functions declared in the stylesheet"""

    logger = logging.getLogger("register_modules_by_xpath")

    ctx = xslt_doc.xpathNewContext()
    namespaces_to_search_for = {}
    root = xslt_doc.getRootElement()
    nsll = root.nsDefs()
    while nsll != None:
        ns_str = nsll.get_content()
        ns_prefix = nsll.get_name()
        try:
            namespace_tag = DYNAMIC_FUNCTION[ns_str]
        except KeyError:
            pass
        else:
            # Save it away for finding the references later
            namespaces_to_search_for[ns_prefix] = ns_str
            # And register it with xpath
            try:
                ctx.xpathRegisterNs(ns_prefix, ns_str)
            except libxml2.libxmlError, e:
                logger.error("an error while registering %s" % (namespace_tag))

        # Go round again
        nsll = nsll.next

    # Find all the references to each namespace
    for tag,uri in namespaces_to_search_for.items():
        query = "//*/@*[starts-with(., '" + tag + ":')]"
        function_calls = ctx.xpathEval(query)
        for func_call in function_calls:
            call_str = func_call.get_content()
            try:
                # The regex used here is very conservative and restrictive....
                # ... again, this would be solved by real dynamic loading.
                m = re.match(tag + R":([a-zA-Z_][a-zA-Z0-9_.]*)\((.*)\)", call_str)
                fq_name = m.group(1)
            except:
                # Probably an re error
                logger.debug("failed to match the dynamic loading regexp for: " + call_str)
                pass
            else:
                try:
                    parts = fq_name.split(".")
                    module_name = string.join(parts[:-1], ".")
                except:
                    pass
                else:
                    # Not sure if this is the right thing to do...
                    m = imp.find_module(module_name)
                    lm = imp.load_module(module_name, *m)
                    for name, value in lm.__dict__.iteritems():
                        if name == parts[-1]:


                            
                            # Damn python and it's crappy lexical scope...
                            def doit(name, callable):
                                logger.debug("registering %s under %s" % (fq_name, uri))
                                libxslt.registerExtModuleFunction(name, uri,
                                                                  lambda ctx, *str: py_xslfn_glue(callable, ctx, *str))

                            # ... pass these in so they can be closed in the function scope.
                            doit(fq_name, value)
                
    ctx.xpathFreeContext()
    return None


# Simple xsltproc application
import getopt

def commandline():
    """Basic xsltproc interface"""
    try:
        opts, args = getopt.getopt(sys.argv[1:], "", [ "loadpath=" ])
    except getopt.GetoptError, e:
        print "xslt.py argument failure: %s" % (str(e))
        sys.exit(2)
    else:
        # First handle the option bits
        for o, a in opts:
            if o == "--loadpath":
                sys.path = sys.path + a.split(":")

        # Now the arguments
        if args[0]:
            fname = args[0]
            if os.path.exists(fname) and os.path.isfile(fname):
                src = "-"
                if len(args) > 1:
                    if args[1]:
                        src = args[1]

                if src == "-":
                    xslt(fname, sys.stdout, sys.stdin)
                else:
                    xslt(fname, sys.stdout, open(src))
    return None


if __name__ == "__main__":
    logging.basicConfig()
    logging.getLogger().setLevel(logging.DEBUG)
    commandline()

# End

