# -*- coding: utf-8 -*-
# Moovida - Home multimedia server
# Copyright (C) 2006-2009 Fluendo Embedded S.L. (www.fluendo.com).
# All rights reserved.
#
# This file is available under one of two license agreements.
#
# This file is licensed under the GPL version 3.
# See "LICENSE.GPL" in the root of this distribution including a special
# exception to use Moovida with Fluendo's plugins.
#
# The GPL part of Moovida is also available under a commercial licensing
# agreement from Fluendo.
# See "LICENSE.Moovida" in the root directory of this distribution package
# for details on that license.
#
# Authors: Alessandro Decina <alessandro@fluendo.com>

import os
import pkg_resources
import platform
import sys
from tempfile import mktemp

from twisted.internet import defer, reactor
from twisted.internet.protocol import ProcessProtocol
from twisted.python import log
from twisted.internet.error import CannotListenError, ProcessExitedAlready

from elisa.plugins.amp.protocol import MasterFactory, newCookie

class StartError(Exception):
    pass

class Slave(object):
    process_transport = None
    process_protocol = None
    
    # the amp protocol (note that there's not amp transport as all the slaves
    # connect to the same port/socket)
    amp = None

    def __init__(self, cookie):
        self.cookie = cookie

    def kill(self, force=False):
        """ Send KILL signal to the process and close the file
        descriptors if force is set to True (in the case where
        connection between master and slave died due to a ping
        timeout).
        """
        try:
            self.process.signalProcess('KILL')
        except ProcessExitedAlready:
            pass
        else:
            if force:
                self.process.loseConnection()

class SlaveProcessProtocol(ProcessProtocol):
    def __init__(self, master, slave_cookie):
        # ProcessProtocol has no __init__, how annoying
        self.master = master
        self.slave_cookie = slave_cookie

    def outReceived(self, data):
        self.info(data + '(%s stdout)' % self)

    def errReceived(self, data):
        self.warning(data + '(%s stderr)' % self)

    def processEnded(self, reason):
        self.info('slave %s process dead: %s' % (self, reason))
        self.master._slaveDead(self.slave_cookie, process_dead=True)

    def __str__(self):
        return 'Slave-%s' % self.slave_cookie

    def debug(self, *args, **kw):
        log.msg(*args, **kw)

    info = debug

    def warning(self, *args, **kw):
        log.err(*args, **kw)

class Master(object):
    serverFactory = MasterFactory
    # make this customizable for logging and testing
    slaveProcessProtocolFactory = SlaveProcessProtocol
    socket_prefix = 'elisa-master-'

    # FIXME: ping_period must always be > ping_timeout
    ping_period = 3
    ping_timeout = 2

    def __init__(self, address=None, slave_runner=None):
        super(Master, self).__init__()
        if address is None:
            if platform.system() == 'Windows':
                address = 'tcp:'
            else:
                address = 'unix:'
        self._address = address
        self._real_address = None

        if slave_runner is None:
            slave_runner = 'elisa.plugins.amp.slave.runner'
        self._slave_runner = slave_runner

        self._slave_script = \
            pkg_resources.resource_filename('elisa.plugins.amp', 'slave.py')
        
        
        self._start_defer = None
        self._stop_defer = None
        self._started = False
        self._slaves_num = 0
        self._slaves = {}
        self._connected_slaves = []
        self._cookie = newCookie()
        self._spawned = 0

    def _listenTCP(self):
        tokens = self._address.split(':', 3)
        if len(tokens) > 2:
            address, port_number = tokens[1:]
        else:
            address, port_number = tokens[1], 0
        port_number = int(port_number)
        port = reactor.listenTCP(port=port_number,
                factory=self.serverFactory(self), interface=address)
    
        return ('tcp', address, str(port.getHost().port)), port

    def _listenUNIX(self):
        tokens = self._address.split(':', 2)
        address = None
        if len(tokens) == 2:
            address = tokens[1]

        if not address:
            address = mktemp(prefix=self.socket_prefix, suffix='.socket')

        port = reactor.listenUNIX(address=address,
                factory=self.serverFactory(self))

        return ('unix', address), port

    def start(self):
        try:
            if self._address.startswith('tcp'):
                address, port = self._listenTCP()
            else:
                address, port = self._listenUNIX()
        except CannotListenError, e:
            return defer.failed(e)

        self._real_address = address
        self._port = port

    def stop(self):
        for slave in self._slaves.values():
            slave.kill(force=True)
        return defer.maybeDeferred(self._port.stopListening)

    def _startSpawnTimeout(self, slave, spawn_timeout):
        slave.timeout_call = reactor.callLater(spawn_timeout,
                self._spawnTimeout, slave)

    def _spawnTimeout(self, slave):
        slave.kill()

    def _startSlave(self, spawn_timeout):
        cookie = newCookie()
        slave = Slave(cookie)
        # FIXME: detect when running under py2exe
        # the command line used to run the slaves
        if hasattr(sys, 'frozen'):
            executable = os.path.join(os.path.dirname(sys.executable),
                    'deps', 'bin', 'moovida_fork.exe')
        else:
            executable = sys.executable

        args = [executable, self._slave_script, self._slave_runner,
                ":".join(self._real_address), 'Slave-%s' % cookie]

        if platform.system() == 'Windows':
            args.insert(1, '-u')

        self._slaves[slave.cookie] = slave
        
        slave.process_protocol = self.slaveProcessProtocolFactory(self,
                slave.cookie)

        if platform.system() == 'Windows':
            # FIXME: reuse only the environment that are accessible
            # (eg, the ones not containing lower-case accentuated
            # characters). This is done to workaround modifications
            # done to os.environ by one of elisa's runtime
            # dependencies. This hack could not be put in the elisa
            # launcher because the dark-magic modifications of
            # os.environ happen after the launcher has started
            # elisa. Needs further investigation...
            env = {}
            for key in os.environ.keys():
                try:
                    value = os.environ[key]
                except KeyError:
                    continue
                env[key] = value

            # override os.environ because it's accessed from
            # twisted.internet._dumbwin32proc
            os.environ = os._Environ(env)
        else:
            env = dict(os.environ)

        path = os.path.dirname(sys.modules['elisa'].__path__[0])
        env['PYTHONPATH'] = os.pathsep.join([path] + sys.path)
        # *boom*
        slave.process = reactor.spawnProcess(slave.process_protocol,
                executable, args, env)
    
        self._spawned += 1
        self._startSpawnTimeout(slave, spawn_timeout)
    
    def startSlaves(self, num, spawn_timeout):
        self._slaves_num = num

        assert self._start_defer is None
        self._start_defer = defer.Deferred()

        for i in xrange(num):
            self._startSlave(spawn_timeout)

        return self._start_defer

    def _stopSlave(self, slave):
        slave.amp.transport.loseConnection()
        if slave.timeout_call.active():
            slave.timeout_call.cancel()

    def stopSlaves(self):
        self._slaves_num = 0

        for slave in self._slaves.itervalues():
            self._stopSlave(slave)
        
        assert self._stop_defer is None

        if self._spawned == 0:
            return defer.succeed(self)

        self._stop_defer = defer.Deferred()
        return self._stop_defer

    def _slaveStarted(self, cookie, protocol):
        slave = self._slaves[cookie]
        slave.amp = protocol
        slave.timeout_call.cancel()
        self._connected_slaves.append(slave)
        
        self.slaveStarted(slave)
        
        if len(self._connected_slaves) == self._slaves_num and self._start_defer:
            dfr, self._start_defer = self._start_defer, None
            dfr.callback(self)

    def _slaveDead(self, cookie, process_dead=False):
        if process_dead:
            self._spawned -= 1

        try:
            slave = self._slaves.pop(cookie)
        except KeyError:
            # MasterProtocol.connectionLost called _slaveDead before
            pass
        else:
            if process_dead:
                # if the slave was connected it should have been cleaned by
                # MasterProtocol.connectionLost
                #assert slave.amp is None
                # this turns out to be not so true on windows
                pass

            else:
                slave.kill(force=True)
                
            if slave.amp:
                self._connected_slaves.remove(slave)
                self.slaveDead(slave)

            if slave.timeout_call.active():
                slave.timeout_call.cancel()

        if not self._spawned:
            if self._stop_defer:
                dfr, self._stop_defer = self._stop_defer, None
                dfr.callback(self)
            elif self._start_defer:
                dfr, self._start_defer = self._start_defer, None
                dfr.errback(StartError())

    def slaveStarted(self, slave):
        pass

    def slaveDead(self, slave):
        pass
