# Written by Bram Cohen
# Modified by Cameron Dale
# see LICENSE.txt for license information

# $Id: StreamCheck.py 266 2007-08-18 02:06:35Z camrdale-guest $

"""Not used.

@type logger: C{logging.Logger}
@var logger: the logger to send all log messages to for this module

"""

from cStringIO import StringIO
from binascii import b2a_hex
from socket import error as socketerror
from urllib import quote
from DebTorrent.__init__ import protocol_name, make_readable
import Connecter
import logging

logger = logging.getLogger('DebTorrent.BT1.StreamCheck')

option_pattern = chr(0)*8

# header, reserved, download id, my id, [length, message]

streamno = 0


class StreamCheck:
    def __init__(self):
        global streamno
        self.no = streamno
        streamno += 1
        self.buffer = StringIO()
        self.next_len, self.next_func = 1, self.read_header_len

    def read_header_len(self, s):
        if ord(s) != len(protocol_name):
            logger.warning(str(self.no)+' BAD HEADER LENGTH')
        return len(protocol_name), self.read_header

    def read_header(self, s):
        if s != protocol_name:
            logger.warning(str(self.no)+' BAD HEADER')
        return 8, self.read_reserved

    def read_reserved(self, s):
        return 20, self.read_download_id

    def read_download_id(self, s):
        logger.debug(str(self.no)+' download ID ' + b2a_hex(s))
        return 20, self.read_peer_id

    def read_peer_id(self, s):
        logger.debug(str(self.no)+' peer ID' + make_readable(s))
        return 4, self.read_len

    def read_len(self, s):
        l = struct.unpack('>i',s)[0]
        if l > 2 ** 23:
            logger.warning(str(self.no)+' BAD LENGTH: '+str(l)+' ('+s+')')
        return l, self.read_message

    def read_message(self, s):
        if not s:
            return 4, self.read_len
        m = s[0]
        if ord(m) > 8:
            logger.warning(str(self.no)+' BAD MESSAGE: '+str(ord(m)))
        if m == Connecter.REQUEST:
            if len(s) != 13:
                logger.warning(str(self.no)+' BAD REQUEST SIZE: '+str(len(s)))
                return 4, self.read_len
            index, begin, length = struct.unpack('>iii',s[1:])
            logger.info(str(self.no)+' Request: '+str(index)+': '+str(begin)+'-'+str(begin)+'+'+str(length))
        elif m == Connecter.CANCEL:
            if len(s) != 13:
                logger.warning(str(self.no)+' BAD CANCEL SIZE: '+str(len(s)))
                return 4, self.read_len
            index, begin, length = struct.unpack('>iii',s[1:])
            logger.info(str(self.no)+' Cancel: '+str(index)+': '+str(begin)+'-'+str(begin)+'+'+str(length))
        elif m == Connecter.PIECE:
            index, begin = struct.unpack('>ii',s[1:9])
            length = len(s)-9
            logger.info(str(self.no)+' Piece: '+str(index)+': '+str(begin)+'-'+str(begin)+'+'+str(length))
        else:
            logger.info(str(self.no)+' Message '+str(ord(m))+' (length '+str(len(s))+')')
        return 4, self.read_len

    def write(self, s):
        while True:
            i = self.next_len - self.buffer.tell()
            if i > len(s):
                self.buffer.write(s)
                return
            self.buffer.write(s[:i])
            s = s[i:]
            m = self.buffer.getvalue()
            self.buffer.reset()
            self.buffer.truncate()
            x = self.next_func(m)
            self.next_len, self.next_func = x
