# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
# See LICENSE for details.


"""
Test cases for twisted.protocols package.
"""

from twisted.trial import unittest
from twisted.protocols import basic, wire, portforward
from twisted.internet import reactor, protocol, defer, task, error

import struct
import StringIO

class StringIOWithoutClosing(StringIO.StringIO):
    """
    A StringIO that can't be closed.
    """
    def close(self):
        """
        Do nothing.
        """

class LineTester(basic.LineReceiver):
    """
    A line receiver that parses data received and make actions on some tokens.

    @type delimiter: C{str}
    @ivar delimiter: character used between received lines.
    @type MAX_LENGTH: C{int}
    @ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called.
    @type clock: L{twisted.internet.task.Clock}
    @ivar clock: clock simulating reactor callLater. Pass it to constructor if
        you want to use the pause/rawpause functionalities.
    """

    delimiter = '\n'
    MAX_LENGTH = 64

    def __init__(self, clock=None):
        """
        If given, use a clock to make callLater calls.
        """
        self.clock = clock

    def connectionMade(self):
        """
        Create/clean data received on connection.
        """
        self.received = []

    def lineReceived(self, line):
        """
        Receive line and make some action for some tokens: pause, rawpause,
        stop, len, produce, unproduce.
        """
        self.received.append(line)
        if line == '':
            self.setRawMode()
        elif line == 'pause':
            self.pauseProducing()
            self.clock.callLater(0, self.resumeProducing)
        elif line == 'rawpause':
            self.pauseProducing()
            self.setRawMode()
            self.received.append('')
            self.clock.callLater(0, self.resumeProducing)
        elif line == 'stop':
            self.stopProducing()
        elif line[:4] == 'len ':
            self.length = int(line[4:])
        elif line.startswith('produce'):
            self.transport.registerProducer(self, False)
        elif line.startswith('unproduce'):
            self.transport.unregisterProducer()

    def rawDataReceived(self, data):
        """
        Read raw data, until the quantity specified by a previous 'len' line is
        reached.
        """
        data, rest = data[:self.length], data[self.length:]
        self.length = self.length - len(data)
        self.received[-1] = self.received[-1] + data
        if self.length == 0:
            self.setLineMode(rest)

    def lineLengthExceeded(self, line):
        """
        Adjust line mode when long lines received.
        """
        if len(line) > self.MAX_LENGTH + 1:
            self.setLineMode(line[self.MAX_LENGTH + 1:])


class LineOnlyTester(basic.LineOnlyReceiver):
    """
    A buffering line only receiver.
    """
    delimiter = '\n'
    MAX_LENGTH = 64

    def connectionMade(self):
        """
        Create/clean data received on connection.
        """
        self.received = []

    def lineReceived(self, line):
        """
        Save received data.
        """
        self.received.append(line)

class WireTestCase(unittest.TestCase):
    """
    Test wire protocols.
    """
    def testEcho(self):
        """
        Test wire.Echo protocol: send some data and check it send it back.
        """
        t = StringIOWithoutClosing()
        a = wire.Echo()
        a.makeConnection(protocol.FileWrapper(t))
        a.dataReceived("hello")
        a.dataReceived("world")
        a.dataReceived("how")
        a.dataReceived("are")
        a.dataReceived("you")
        self.failUnlessEqual(t.getvalue(), "helloworldhowareyou")

    def testWho(self):
        """
        Test wire.Who protocol.
        """
        t = StringIOWithoutClosing()
        a = wire.Who()
        a.makeConnection(protocol.FileWrapper(t))
        self.failUnlessEqual(t.getvalue(), "root\r\n")

    def testQOTD(self):
        """
        Test wire.QOTD protocol.
        """
        t = StringIOWithoutClosing()
        a = wire.QOTD()
        a.makeConnection(protocol.FileWrapper(t))
        self.failUnlessEqual(t.getvalue(),
                             "An apple a day keeps the doctor away.\r\n")

    def testDiscard(self):
        """
        Test wire.Discard protocol.
        """
        t = StringIOWithoutClosing()
        a = wire.Discard()
        a.makeConnection(protocol.FileWrapper(t))
        a.dataReceived("hello")
        a.dataReceived("world")
        a.dataReceived("how")
        a.dataReceived("are")
        a.dataReceived("you")
        self.failUnlessEqual(t.getvalue(), "")

class LineReceiverTestCase(unittest.TestCase):
    """
    Test LineReceiver, using the C{LineTester} wrapper.
    """
    buffer = '''\
len 10

0123456789len 5

1234
len 20
foo 123

0123456789
012345678len 0
foo 5

1234567890123456789012345678901234567890123456789012345678901234567890
len 1

a'''

    output = ['len 10', '0123456789', 'len 5', '1234\n',
              'len 20', 'foo 123', '0123456789\n012345678',
              'len 0', 'foo 5', '', '67890', 'len 1', 'a']

    def testBuffer(self):
        """
        Test buffering for different packet size, checking received matches
        expected data.
        """
        for packet_size in range(1, 10):
            t = StringIOWithoutClosing()
            a = LineTester()
            a.makeConnection(protocol.FileWrapper(t))
            for i in range(len(self.buffer)/packet_size + 1):
                s = self.buffer[i*packet_size:(i+1)*packet_size]
                a.dataReceived(s)
            self.failUnlessEqual(self.output, a.received)


    pause_buf = 'twiddle1\ntwiddle2\npause\ntwiddle3\n'

    pause_output1 = ['twiddle1', 'twiddle2', 'pause']
    pause_output2 = pause_output1+['twiddle3']

    def testPausing(self):
        """
        Test pause inside data receiving. It uses fake clock to see if
        pausing/resuming work.
        """
        for packet_size in range(1, 10):
            t = StringIOWithoutClosing()
            clock = task.Clock()
            a = LineTester(clock)
            a.makeConnection(protocol.FileWrapper(t))
            for i in range(len(self.pause_buf)/packet_size + 1):
                s = self.pause_buf[i*packet_size:(i+1)*packet_size]
                a.dataReceived(s)
            self.failUnlessEqual(self.pause_output1, a.received)
            clock.advance(0)
            self.failUnlessEqual(self.pause_output2, a.received)

    rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n'

    rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '']
    rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345',
                        'twiddle3']

    def testRawPausing(self):
        """
        Test pause inside raw date receiving.
        """
        for packet_size in range(1, 10):
            t = StringIOWithoutClosing()
            clock = task.Clock()
            a = LineTester(clock)
            a.makeConnection(protocol.FileWrapper(t))
            for i in range(len(self.rawpause_buf)/packet_size + 1):
                s = self.rawpause_buf[i*packet_size:(i+1)*packet_size]
                a.dataReceived(s)
            self.failUnlessEqual(self.rawpause_output1, a.received)
            clock.advance(0)
            self.failUnlessEqual(self.rawpause_output2, a.received)

    stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n'

    stop_output = ['twiddle1', 'twiddle2', 'stop']

    def testStopProducing(self):
        """
        Test stop inside producing.
        """
        for packet_size in range(1, 10):
            t = StringIOWithoutClosing()
            a = LineTester()
            a.makeConnection(protocol.FileWrapper(t))
            for i in range(len(self.stop_buf)/packet_size + 1):
                s = self.stop_buf[i*packet_size:(i+1)*packet_size]
                a.dataReceived(s)
            self.failUnlessEqual(self.stop_output, a.received)


    def testLineReceiverAsProducer(self):
        """
        Test produce/unproduce in receiving.
        """
        a = LineTester()
        t = StringIOWithoutClosing()
        a.makeConnection(protocol.FileWrapper(t))
        a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n')
        self.assertEquals(a.received,
                          ['produce', 'hello world', 'unproduce', 'goodbye'])


class LineOnlyReceiverTestCase(unittest.TestCase):
    """
    Test line only receiveer.
    """
    buffer = """foo
    bleakness
    desolation
    plastic forks
    """

    def testBuffer(self):
        """
        Test buffering over line protocol: data received should match buffer.
        """
        t = StringIOWithoutClosing()
        a = LineOnlyTester()
        a.makeConnection(protocol.FileWrapper(t))
        for c in self.buffer:
            a.dataReceived(c)
        self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1])

    def testLineTooLong(self):
        """
        Test sending a line too long: it should close the connection.
        """
        t = StringIOWithoutClosing()
        a = LineOnlyTester()
        a.makeConnection(protocol.FileWrapper(t))
        res = a.dataReceived('x'*200)
        self.assertTrue(isinstance(res, error.ConnectionLost))


class TestMixin:

    def connectionMade(self):
        self.received = []

    def stringReceived(self, s):
        self.received.append(s)

    MAX_LENGTH = 50
    closed = 0

    def connectionLost(self, reason):
        self.closed = 1


class TestNetstring(TestMixin, basic.NetstringReceiver):
    pass


class LPTestCaseMixin:

    illegal_strings = []
    protocol = None

    def getProtocol(self):
        t = StringIOWithoutClosing()
        a = self.protocol()
        a.makeConnection(protocol.FileWrapper(t))
        return a

    def testIllegal(self):
        for s in self.illegal_strings:
            r = self.getProtocol()
            for c in s:
                r.dataReceived(c)
            self.assertEquals(r.transport.closed, 1)


class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin):

    strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]

    illegal_strings = [
        '9999999999999999999999', 'abc', '4:abcde',
        '51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]

    protocol = TestNetstring

    def testBuffer(self):
        for packet_size in range(1, 10):
            t = StringIOWithoutClosing()
            a = TestNetstring()
            a.MAX_LENGTH = 699
            a.makeConnection(protocol.FileWrapper(t))
            for s in self.strings:
                a.sendString(s)
            out = t.getvalue()
            for i in range(len(out)/packet_size + 1):
                s = out[i*packet_size:(i+1)*packet_size]
                if s:
                    a.dataReceived(s)
            self.assertEquals(a.received, self.strings)


class TestInt32(TestMixin, basic.Int32StringReceiver):
    MAX_LENGTH = 50


class Int32TestCase(unittest.TestCase, LPTestCaseMixin):

    protocol = TestInt32
    strings = ["a", "b" * 16]
    illegal_strings = ["\x10\x00\x00\x00aaaaaa"]
    partial_strings = ["\x00\x00\x00", "hello there", ""]

    def testPartial(self):
        for s in self.partial_strings:
            r = self.getProtocol()
            r.MAX_LENGTH = 99999999
            for c in s:
                r.dataReceived(c)
            self.assertEquals(r.received, [])

    def testReceive(self):
        r = self.getProtocol()
        for s in self.strings:
            for c in struct.pack("!i",len(s))+s:
                r.dataReceived(c)
        self.assertEquals(r.received, self.strings)


class OnlyProducerTransport(object):
    # Transport which isn't really a transport, just looks like one to
    # someone not looking very hard.

    paused = False
    disconnecting = False

    def __init__(self):
        self.data = []

    def pauseProducing(self):
        self.paused = True

    def resumeProducing(self):
        self.paused = False

    def write(self, bytes):
        self.data.append(bytes)


class ConsumingProtocol(basic.LineReceiver):
    # Protocol that really, really doesn't want any more bytes.

    def lineReceived(self, line):
        self.transport.write(line)
        self.pauseProducing()


class ProducerTestCase(unittest.TestCase):
    def testPauseResume(self):
        p = ConsumingProtocol()
        t = OnlyProducerTransport()
        p.makeConnection(t)

        p.dataReceived('hello, ')
        self.failIf(t.data)
        self.failIf(t.paused)
        self.failIf(p.paused)

        p.dataReceived('world\r\n')

        self.assertEquals(t.data, ['hello, world'])
        self.failUnless(t.paused)
        self.failUnless(p.paused)

        p.resumeProducing()

        self.failIf(t.paused)
        self.failIf(p.paused)

        p.dataReceived('hello\r\nworld\r\n')

        self.assertEquals(t.data, ['hello, world', 'hello'])
        self.failUnless(t.paused)
        self.failUnless(p.paused)

        p.resumeProducing()
        p.dataReceived('goodbye\r\n')

        self.assertEquals(t.data, ['hello, world', 'hello', 'world'])
        self.failUnless(t.paused)
        self.failUnless(p.paused)

        p.resumeProducing()

        self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
        self.failUnless(t.paused)
        self.failUnless(p.paused)

        p.resumeProducing()

        self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
        self.failIf(t.paused)
        self.failIf(p.paused)


class Portforwarding(unittest.TestCase):
    """
    Test port forwarding.
    """
    def setUp(self):
        self.serverProtocol = wire.Echo()
        self.clientProtocol = protocol.Protocol()
        self.openPorts = []

    def tearDown(self):
        try:
            self.clientProtocol.transport.loseConnection()
        except:
            pass
        try:
            self.serverProtocol.transport.loseConnection()
        except:
            pass
        return defer.gatherResults(
            [defer.maybeDeferred(p.stopListening) for p in self.openPorts])

    def testPortforward(self):
        """
        Test port forwarding through Echo protocol.
        """
        realServerFactory = protocol.ServerFactory()
        realServerFactory.protocol = lambda: self.serverProtocol
        realServerPort = reactor.listenTCP(0, realServerFactory,
                                           interface='127.0.0.1')
        self.openPorts.append(realServerPort)

        proxyServerFactory = portforward.ProxyFactory('127.0.0.1',
                                realServerPort.getHost().port)
        proxyServerPort = reactor.listenTCP(0, proxyServerFactory,
                                            interface='127.0.0.1')
        self.openPorts.append(proxyServerPort)

        nBytes = 1000
        received = []
        d = defer.Deferred()
        def testDataReceived(data):
            received.extend(data)
            if len(received) >= nBytes:
                self.assertEquals(''.join(received), 'x' * nBytes)
                d.callback(None)
        self.clientProtocol.dataReceived = testDataReceived

        def testConnectionMade():
            self.clientProtocol.transport.write('x' * nBytes)
        self.clientProtocol.connectionMade = testConnectionMade

        clientFactory = protocol.ClientFactory()
        clientFactory.protocol = lambda: self.clientProtocol

        reactor.connectTCP(
            '127.0.0.1', proxyServerPort.getHost().port, clientFactory)

        return d

