# -*- coding: utf-8 -*-

# Author: Natalia B. Bidart <natalia.bidart@canonical.com>
# Author: Alejandro J. Cura <alecu@canonical.com>
#
# Copyright 2011 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3, as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranties of
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Tests for the oauth_headers helper function."""

import time

from twisted.application import internet, service
from twisted.internet import defer
from twisted.internet.threads import deferToThread
from twisted.trial.unittest import TestCase
from twisted.web import server, resource

from ubuntu_sso import utils
from ubuntu_sso.utils import oauth, oauth_headers, SyncTimestampChecker
from ubuntu_sso.tests import TOKEN


class FakedOAuthRequest(object):
    """Replace the OAuthRequest class."""

    params = {}

    def __init__(self):
        self.sign_request = lambda *args, **kwargs: None
        self.to_header = lambda *args, **kwargs: {}

    def from_consumer_and_token(oauth_consumer, **kwargs):
        """Fake the method storing the params for check."""
        FakedOAuthRequest.params.update(kwargs)
        return FakedOAuthRequest()
    from_consumer_and_token = staticmethod(from_consumer_and_token)


class SignWithCredentialsTestCase(TestCase):
    """Test suite for the oauth_headers method."""

    url = u'http://example.com'
    fake_timestamp_value = 1

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize this test suite."""
        yield super(SignWithCredentialsTestCase, self).setUp()
        self.timestamp_called = False

        def fake_timestamp():
            """A fake timestamp that records the call."""
            self.timestamp_called = True
            return self.fake_timestamp_value

        self.patch(utils.timestamp_checker, "get_faithful_time",
                   fake_timestamp)

    def build_header(self, url, http_method='GET'):
        """Build an Oauth header for comparison."""
        consumer = oauth.OAuthConsumer(TOKEN['consumer_key'],
                                       TOKEN['consumer_secret'])
        token = oauth.OAuthToken(TOKEN['token'],
                                 TOKEN['token_secret'])
        get_request = oauth.OAuthRequest.from_consumer_and_token
        oauth_req = get_request(oauth_consumer=consumer, token=token,
                                http_method=http_method, http_url=url)
        oauth_req.sign_request(oauth.OAuthSignatureMethod_HMAC_SHA1(),
                               consumer, token)
        return oauth_req.to_header()

    def dictify_header(self, header):
        """Convert an OAuth header into a dict."""
        result = {}
        fields = header.split(', ')
        for field in fields:
            key, value = field.split('=')
            result[key] = value.strip('"')

        return result

    def assert_header_equal(self, expected, actual):
        """Is 'expected' equals to 'actual'?"""
        expected = self.dictify_header(expected['Authorization'])
        actual = self.dictify_header(actual['Authorization'])
        for header in (expected, actual):
            header.pop('oauth_nonce')
            header.pop('oauth_timestamp')
            header.pop('oauth_signature')

        self.assertEqual(expected, actual)

    def assert_method_called(self, path, query_str='', http_method='GET'):
        """Assert that the url build by joining 'paths' was called."""
        expected = (self.url, path, query_str)
        expected = ''.join(expected).encode('utf8')
        expected = self.build_header(expected, http_method=http_method)
        actual = oauth_headers(url=self.url + path, credentials=TOKEN)
        self.assert_header_equal(expected, actual)

    def test_call(self):
        """Calling 'get' triggers an OAuth signed GET request."""
        path = u'/test/'
        self.assert_method_called(path)

    def test_quotes_path(self):
        """Calling 'get' quotes the path."""
        path = u'/test me more, sí!/'
        self.assert_method_called(path)

    def test_adds_parameters_to_oauth_request(self):
        """The query string from the path is used in the oauth request."""
        self.patch(oauth, 'OAuthRequest', FakedOAuthRequest)

        path = u'/test/something?foo=bar'
        oauth_headers(url=self.url + path, credentials=TOKEN)

        self.assertIn('parameters', FakedOAuthRequest.params)
        params = FakedOAuthRequest.params['parameters']
        del(params["oauth_timestamp"])
        self.assertEqual(params, {'foo': 'bar'})

    def test_oauth_headers_uses_timestamp_checker(self):
        """The oauth_headers function uses the timestamp_checker."""
        oauth_headers(u"http://protocultura.net", TOKEN)
        self.assertTrue(self.timestamp_called,
                        "the timestamp MUST be requested.")


class RootResource(resource.Resource):
    """A root resource that logs the number of calls."""

    isLeaf = True

    def __init__(self, *args, **kwargs):
        """Initialize this fake instance."""
        resource.Resource.__init__(self, *args, **kwargs)
        self.count = 0
        self.request_headers = []

    # pylint: disable=C0103
    def render_HEAD(self, request):
        """Increase the counter on each render."""
        self.count += 1
        self.request_headers.append(request.requestHeaders)
        return ""


class MockWebServer(object):
    """A mock webserver for testing."""

    def __init__(self):
        """Start up this instance."""
        # pylint: disable=E1101
        self.root = RootResource()
        site = server.Site(self.root)
        application = service.Application('web')
        self.service_collection = service.IServiceCollection(application)
        self.tcpserver = internet.TCPServer(0, site)
        self.tcpserver.setServiceParent(self.service_collection)
        self.service_collection.startService()

    def get_url(self):
        """Build the url for this mock server."""
        # pylint: disable=W0212
        port_num = self.tcpserver._port.getHost().port
        return "http://localhost:%d/" % port_num

    def stop(self):
        """Shut it down."""
        # pylint: disable=E1101
        self.service_collection.stopService()


class FakedError(Exception):
    """A mock, test, sample, and fake exception."""


class TimestampCheckerTestCase(TestCase):
    """Tests for the timestamp checker."""

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize a fake webserver."""
        yield super(TimestampCheckerTestCase, self).setUp()
        self.ws = MockWebServer()
        self.addCleanup(self.ws.stop)
        self.patch(SyncTimestampChecker, "SERVER_URL", self.ws.get_url())

    @defer.inlineCallbacks
    def test_returned_value_is_int(self):
        """The returned value is an integer."""
        checker = SyncTimestampChecker()
        timestamp = yield deferToThread(checker.get_faithful_time)
        self.assertEqual(type(timestamp), int)

    @defer.inlineCallbacks
    def test_first_call_does_head(self):
        """The first call gets the clock from our web."""
        checker = SyncTimestampChecker()
        yield deferToThread(checker.get_faithful_time)
        self.assertEqual(self.ws.root.count, 1)

    @defer.inlineCallbacks
    def test_second_call_is_cached(self):
        """For the second call, the time is cached."""
        checker = SyncTimestampChecker()
        yield deferToThread(checker.get_faithful_time)
        yield deferToThread(checker.get_faithful_time)
        self.assertEqual(self.ws.root.count, 1)

    @defer.inlineCallbacks
    def test_after_timeout_cache_expires(self):
        """After some time, the cache expires."""
        fake_timestamp = 1
        self.patch(time, "time", lambda: fake_timestamp)
        checker = SyncTimestampChecker()
        yield deferToThread(checker.get_faithful_time)
        fake_timestamp += SyncTimestampChecker.CHECKING_INTERVAL
        yield deferToThread(checker.get_faithful_time)
        self.assertEqual(self.ws.root.count, 2)

    @defer.inlineCallbacks
    def test_server_date_sends_nocache_headers(self):
        """Getting the server date sends the no-cache headers."""
        checker = SyncTimestampChecker()
        yield deferToThread(checker.get_server_time)
        assert len(self.ws.root.request_headers) == 1
        headers = self.ws.root.request_headers[0]
        result = headers.getRawHeaders("Cache-Control")
        self.assertEqual(result, ["no-cache"])

    @defer.inlineCallbacks
    def test_server_error_means_skew_not_updated(self):
        """When server can't be reached, the skew is not updated."""
        fake_timestamp = 1
        self.patch(time, "time", lambda: fake_timestamp)
        checker = SyncTimestampChecker()

        def failing_get_server_time():
            """Let's fail while retrieving the server time."""
            raise FakedError()

        self.patch(checker, "get_server_time", failing_get_server_time)
        yield deferToThread(checker.get_faithful_time)
        self.assertEqual(checker.skew, 0)
        self.assertEqual(checker.next_check,
                         fake_timestamp + SyncTimestampChecker.ERROR_INTERVAL)
