#
# Copyright 2009 Canonical Ltd.
#
# Written by:
#     Gustavo Niemeyer <gustavo.niemeyer@canonical.com>
#     Sidnei da Silva <sidnei.da.silva@canonical.com>
#
# This file is part of the Image Store Proxy.
#
# This program is free software: you can redistribute it and/or modify it 
# under the terms of the GNU General Public License version 3, as published 
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but 
# WITHOUT ANY WARRANTY; without even the implied warranties of 
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR 
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along 
# with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import thread

from twisted.internet.defer import Deferred
from twisted.internet import reactor

from imagestore.lib.twistedutil import mergeDeferreds

from imagestore.lib.service import (
    Service, ThreadedService, ServiceTask, UnknownTaskError, NotStartedError,
    taskHandler, taskHandlerInThread, ServiceHub, ServiceError)
from imagestore.lib.tests import TestCase


class ServiceTaskTest(TestCase):

    def setUp(self):
        self.task = ServiceTask()

    def testGetDeferred(self):
        self.assertTrue(type(self.task.getDeferred()), Deferred)

    def testInitWithoutSuper(self):
        class MyTask(ServiceTask):
            def __init__(self):
                # No super call.  This makes the creation of tasks
                # less error prone.
                pass
        task = MyTask()
        self.assertTrue(task.getDeferred())

    def testExecuteWithValueResult(self):
        deferred = self.task.getDeferred()
        self.task.execute(lambda task, a, b: (task, a + b), "success", b="!")
        def callback((task, result)):
            self.assertEquals(task, self.task)
            self.assertEquals(result, "success!")
        deferred.addCallback(callback)
        return deferred

    def testExecuteWithFailureResult(self):
        def throw(task, error): raise error
        class MyError(Exception): pass
        deferred = self.task.getDeferred()
        self.task.execute(throw, MyError())
        def errback(failure):
            self.assertEquals(type(failure.value), MyError)
            return failure
        deferred.addErrback(errback)
        return self.assertFailure(deferred, MyError)

    def testExecuteWithDeferredResult(self):
        another_deferred = Deferred()
        reactor.callLater(0, another_deferred.callback, "success!")
        deferred = self.task.getDeferred()
        self.task.execute(lambda task: another_deferred)
        def callback(result):
            self.assertEquals(result, "success!")
        deferred.addCallback(callback)
        return deferred

    def testFromThreadExecuteWithValueResult(self):
        deferred = self.task.getDeferred()
        self.task.fromThreadExecute(reactor,
                                    lambda task, a, b: (task, a + b),
                                    "success", b="!")
        def callback((task, result)):
            self.assertEquals(task, self.task)
            self.assertEquals(result, "success!")
        deferred.addCallback(callback)
        return deferred

    def testFromThreadExecuteWithFailureResult(self):
        def throw(task, error): raise error
        class MyError(Exception): pass
        deferred = self.task.getDeferred()
        self.task.fromThreadExecute(reactor, throw, MyError())
        def errback(failure):
            self.assertEquals(type(failure.value), MyError)
            return failure
        deferred.addErrback(errback)
        return self.assertFailure(deferred, MyError)

    def testFromThreadExecuteWithDeferredResult(self):
        another_deferred = Deferred()
        reactor.callLater(0, another_deferred.callback, "success!")
        deferred = self.task.getDeferred()
        self.task.fromThreadExecute(reactor, lambda task: another_deferred)
        def callback(result):
            self.assertEquals(result, "success!")
        deferred.addCallback(callback)
        return deferred


class ServiceErrorTest(TestCase):

    def testIsException(self):
        self.assertTrue(issubclass(ServiceError, Exception))

    def testMessage(self):
        error = ServiceError("Some message")
        self.assertEquals(error.message, "Some message")


class ServiceTest(TestCase):

    class MyTask(ServiceTask): pass

    ServiceClass = Service

    def createService(self, serviceClass):
        return serviceClass()

    def setUp(self):
        self.service = self.createService(self.ServiceClass)
        self.task = self.MyTask()

    def testInitializeHandlersOnCreation(self):
        called = []
        class MyService(self.ServiceClass):
            def initializeHandlers(self):
                called.append(True)
        self.createService(MyService)
        self.assertEquals(called, [True])

    def testAddUnhandledTask(self):
        self.assertRaises(UnknownTaskError,
                          self.service.addTask, self.task)

    def testAddKnownTask(self):
        self.service.addHandler(self.MyTask, lambda: None)
        deferred = self.service.addTask(self.task)
        self.assertEquals(deferred, self.task.getDeferred())

    def testStartAndStopTwice(self):
        self.service.start()
        self.service.stop()
        self.assertRaises(NotStartedError, self.service.stop)

    def testStopReturnsDeferred(self):
        self.service.start()
        deferred = self.service.stop()
        self.assertTrue(isinstance(deferred, Deferred))
        return deferred

    def testRunTask(self):
        called = []
        def handler(task, a, b):
            called.append(True)
            self.assertEquals(task, self.task)
            self.assertEquals(a, 1)
            self.assertEquals(b, 2)
        self.service.addHandler(self.MyTask, handler, 1, b=2)
        deferred1 = self.service.addTask(self.task)
        self.service.start()
        deferred2 = self.service.stop()
        return mergeDeferreds([deferred1, deferred2])

    def testRunTask(self):
        called = []
        def handler(task, a, b):
            called.append(True)
            self.assertEquals(task, self.task)
            self.assertEquals(a, 1)
            self.assertEquals(b, 2)
        self.service.addHandler(self.MyTask, handler, 1, b=2)
        deferred1 = self.service.addTask(self.task)
        self.service.start()
        deferred2 = self.service.stop()
        # Gather results so that we see any errors.
        return mergeDeferreds([deferred1, deferred2])

    def testRunThreeTasks(self):
        counter = [0]
        def handler(task):
            counter[0] += 1
        self.service.addHandler(self.MyTask, handler)
        deferred1 = self.service.addTask(self.MyTask())
        deferred2 = self.service.addTask(self.MyTask())
        deferred3 = self.service.addTask(self.MyTask())
        self.service.start()
        deferred4 = self.service.stop()
        # Gather results so that we see any errors.
        deferred = mergeDeferreds([deferred1, deferred2, deferred3, deferred4])
        def callback(result):
            self.assertEquals(counter, [3])
        deferred.addCallback(callback)
        return deferred

    def testHandlerDecorator(self):
        class MyOtherTask(ServiceTask): pass

        called = []
        class MyService(self.ServiceClass):
            @taskHandler(self.MyTask)
            def handler1(self, task):
                called.append(("handler1", self, task))
            @taskHandler(MyOtherTask)
            def handler2(self, task):
                called.append(("handler2", self, task))

        task1 = self.MyTask()
        task2 = MyOtherTask()
        service = self.createService(MyService)
        service.addTask(task1)
        service.addTask(task2)
        service.start()

        def callback(result):
            self.assertEquals(called,
                              [("handler1", service, task1),
                               ("handler2", service, task2)])

        return service.stop().addCallback(callback)


class ThreadedServiceTest(ServiceTest):

    ServiceClass = ThreadedService

    def createService(self, serviceClass):
        return serviceClass(reactor)

    def testInitializeThreadOnCreation(self):
        threadIdent = []
        class MyService(self.ServiceClass):
            def initializeThread(self):
                threadIdent.append(thread.get_ident())
        service = self.createService(MyService)
        service.addHandler(self.MyTask, lambda task: None)
        deferred = service.addTask(self.task)
        self.assertEquals(threadIdent, [])
        service.start()
        while not deferred.called:
            # Spin a bit waiting for the deferred to be called.  When it's
            # called, it means that the task has been run, so thread
            # initialization must have taken place.
            reactor.runUntilCurrent()
        try:
            self.assertEquals(len(threadIdent), 1)
            self.assertNotEquals(threadIdent, [thread.get_ident()])
        finally:
            stopDeferred = service.stop()
        return stopDeferred

    def testDeinitializeThreadBeforeStopping(self):
        threadIdent = []
        class MyService(self.ServiceClass):
            def destructThread(self):
                threadIdent.append(thread.get_ident())
        service = self.createService(MyService)
        service.addHandler(self.MyTask, lambda task: None)
        deferred = service.addTask(self.task)
        self.assertEquals(threadIdent, [])
        service.start()
        while not deferred.called:
            # Spin a bit waiting for the deferred to be called.  When it's
            # called, it means that the task has been run, but destructThread
            # has not.
            reactor.runUntilCurrent()
        try:
            self.assertEquals(len(threadIdent), 0)
        finally:
            stopDeferred = service.stop()
        def callback(result):
            self.assertEquals(len(threadIdent), 1)
            self.assertNotEquals(threadIdent, [thread.get_ident()])
        return stopDeferred.addCallback(callback)

    def testThreadedHandlerDecorator(self):
        class MyOtherTask(ServiceTask): pass

        called = []
        class MyService(self.ServiceClass):
            @taskHandlerInThread(self.MyTask)
            def handler1(self, task):
                called.append(("handler1", self, task))
            @taskHandlerInThread(MyOtherTask)
            def handler2(self, task):
                called.append(("handler2", self, task))

        task1 = self.MyTask()
        task2 = MyOtherTask()
        service = self.createService(MyService)
        service.addTask(task1)
        service.addTask(task2)
        service.start()

        def callback(result):
            self.assertEquals(called,
                              [("handler1", service, task1),
                               ("handler2", service, task2)])

        return service.stop().addCallback(callback)


class ServiceHubTest(TestCase):

    class MyTask1(ServiceTask): pass
    class MyTask2(ServiceTask): pass
    class MyTask3(ServiceTask): pass
    class MyService1(Service): handled = []
    class MyService2(Service): handled = []
    class MyService3(Service): handled = []

    def setUp(self):
        self.hub = ServiceHub()
        self.service1 = self.MyService1()
        self.service2 = self.MyService2()
        self.task1 = self.MyTask1()
        self.task2 = self.MyTask2()
        self.task3 = self.MyTask3()

        self.service1.addHandler(self.MyTask1,
                                 lambda task: self.service1.handled.append(task))
        self.service2.addHandler(self.MyTask2,
                                 lambda task: self.service2.handled.append(task))

    def startServices(self):
        self.service1.start()
        self.service2.start()

    def stopServices(self):
        return mergeDeferreds([self.service1.stop(),
                              self.service2.stop()])

    def testAddServicesAndDispatch(self):
        self.hub.addService(self.service1)
        self.hub.addService(self.service2)

        self.startServices()

        deferred1 = self.hub.addTask(self.task1)
        deferred2 = self.hub.addTask(self.task2)

        while not (deferred1.called and deferred2.called):
            reactor.runUntilCurrent()

        self.assertEquals(self.service1.handled, [self.task1])
        self.assertEquals(self.service2.handled, [self.task2])

        return self.stopServices()

    def testRaiseErrorOnUnknownTask(self):
        self.hub.addService(self.service1)
        self.hub.addService(self.service2)

        self.assertRaises(UnknownTaskError, self.hub.addTask, self.task3)

    def testStartAndStopServices(self):
        called = []
        class FakeService(object):
            def __init__(self, id):
                self.id = id
            def start(self):
                called.append("start-%d" % self.id)
            def stop(self):
                called.append("stop-%d" % self.id)
                return "result-%d" % self.id
        service1 = FakeService(1)
        service2 = FakeService(2)

        self.hub.addService(service1)
        self.hub.addService(service2)

        self.hub.start()
        self.assertEquals(called, ["start-1", "start-2"])
        result = self.hub.stop()
        self.assertEquals(called, ["start-1", "start-2", "stop-1", "stop-2"])
        self.assertEquals(result, ["result-1", "result-2"])

    def testGetService(self):
        service1 = self.MyService1()
        service2 = self.MyService2()
        self.hub.addService(service1)
        self.hub.addService(service2)
        self.assertEquals(self.hub.getService(self.MyService1), service1)
        self.assertEquals(self.hub.getService(self.MyService2), service2)
        self.assertRaises(ServiceError, self.hub.getService, self.MyService3)
