# bzr-dbus: dbus support for bzr/bzrlib.
# Copyright (C) 2007 Canonical Limited.
#   Author: Robert Collins.
# 
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; version 2 of the License.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301 USA
# 

"""Tests for the dbus hooks."""

from bzrlib.branch import Branch
from bzrlib.smart.server import SmartTCPServer

from bzrlib.plugins.dbus import activity, hook
from bzrlib.plugins.dbus.tests.test_activity import TestCaseWithDBus


class TestHooksAreSet(TestCaseWithDBus):

    def test_set_rh_installed(self):
        """Loading the plugin should have installed its hooks."""
        # check by looking in self._preserved hooks.
        # the set_rh Branch hook to detect branch changes.
        self.assertTrue(hook.on_set_rh in
            self._preserved_hooks[Branch]['set_rh'])
        # the server_started and server_stopped smart server hooks
        # to detect url maps for servers.
        self.assertTrue(hook.on_server_start in
            self._preserved_hooks[SmartTCPServer]['server_started'])
        self.assertTrue(hook.on_server_stop in
            self._preserved_hooks[SmartTCPServer]['server_stopped'])

    def test_install_hooks(self):
        """dbus.hook.install_hooks() should install hooks."""
        hook.install_hooks()
        # check the branch hooks.
        self.assertTrue(hook.on_set_rh in Branch.hooks['set_rh'])
        # check the SmartServer hooks.
        self.assertTrue(hook.on_server_start in SmartTCPServer.hooks['server_started'])
        self.assertTrue(hook.on_server_stop  in SmartTCPServer.hooks['server_stopped'])

    def test_on_set_rh_hook(self):
        """The on_set_rh hook should hand off the branch to advertise it."""
        # change the global b.p.dbus.activity.Activity to instrument
        # on_set_rh.
        calls = []
        class SampleActivity(object):

            def advertise_branch(self, branch):
                calls.append(('advertise_branch', branch))

        # prevent api skew: check we can use the API SampleActivity presents.
        activity.Activity(bus=self.bus).advertise_branch(self.make_branch('.'))
        # now test the hook
        original_class = activity.Activity
        try:
            activity.Activity = SampleActivity
            hook.on_set_rh('branch', 'history')
        finally:
            activity.Activity = original_class
        self.assertEqual([('advertise_branch', 'branch')], calls)

    def test_on_server_start_hook(self):
        """The on_server_start hook should add a URL mapping for the server."""
        # change the global b.p.dbus.activity.Activity to instrument
        # on_server_start.
        calls = []
        class SampleActivity(object):

            def add_url_map(self, source_prefix, target_prefix):
                calls.append(('add_url_map', source_prefix, target_prefix))

        # prevent api skew: check we can use the API SampleActivity presents.
        activity.Activity(bus=self.bus).add_url_map('foo/', 'bar/')
        # now test the hook
        original_class = activity.Activity
        try:
            activity.Activity = SampleActivity
            hook.on_server_start(['source'], 'target')
        finally:
            activity.Activity = original_class
        self.assertEqual([('add_url_map', 'source', 'target')], calls)

    def test_on_server_stop_hook(self):
        """The on_server_stop hook should add a URL mapping for the server."""
        # change the global b.p.dbus.activity.Activity to instrument
        # on_server_stop.
        calls = []
        class SampleActivity(object):

            def remove_url_map(self, source_prefix, target_prefix):
                calls.append(('remove_url_map', source_prefix, target_prefix))

        # prevent api skew: check we can use the API SampleActivity presents.
        activity.Activity(bus=self.bus).remove_url_map('foo/', 'bar/')
        # now test the hook
        original_class = activity.Activity
        try:
            activity.Activity = SampleActivity
            hook.on_server_stop(['source'], 'target')
        finally:
            activity.Activity = original_class
        self.assertEqual([('remove_url_map', 'source', 'target')], calls)
