#!/usr/bin/env python

import unittest, POW, base64, sys, os, socket, time

if not os.path.isdir('working'):
   os.mkdir('working')

#--------------- Hash test case ---------------#

class HashTestCase(unittest.TestCase):
   'Hash algorithm tests'

   plainText = 'My extremely silly pass phrase!'

   def _symmetricGeneralTest(self, type, expected=None):
      digest = POW.Digest( type ) 
      digest.update( self.plainText )
      b64Text = base64.encodestring( digest.digest() )
      self.failUnlessEqual( b64Text, expected, 'Digest result incorrect' )

   def testMd2(self):
      'Generate and check MD2 hash'
      self._symmetricGeneralTest( POW.MD2_DIGEST, 'O9VUpKqYAHkCgPyAkclL8g==\n' )

   def testMd5(self):
      'Generate and check MD5 hash'
      self._symmetricGeneralTest( POW.MD5_DIGEST, 'kzb1VPPjrYNNA0gwsoKsQw==\n' )

   def testSha(self):
      'Generate and check SHA hash'
      self._symmetricGeneralTest( POW.SHA_DIGEST, 'ptkIj1ilu9oFTFbP3A6o3KuJL+Q=\n')

   def testSha1(self):
      'Generate and check SHA1 hash'
      self._symmetricGeneralTest( POW.SHA1_DIGEST, '7zk06ujVcAWhzREYzY4s4lCw4WQ=\n' )

   def testRipemd160(self):
      'Generate and check RIPEMD160 hash'
      self._symmetricGeneralTest( POW.RIPEMD160_DIGEST, 'R+ve9PdUxqr45duMhG8CBQiahkU=\n' )

#--------------- Hash test case ---------------#
#--------------- Hmac test case ---------------#

class HmacTestCase(unittest.TestCase):
   'HMAC algorithm tests'

   plainText = 'My extremely silly pass phrase!'
   password = 'Puny pass word'

   def _symmetricGeneralTest(self, type, expected=None):
      hmac = POW.Hmac( type, self.password ) 
      hmac.update( self.plainText )
      b64Text = base64.encodestring( hmac.mac() )
      self.failUnlessEqual( b64Text, expected, 'HMAC result incorrect' )

   def testHmacMd2(self):
      'Generate and check MD2 HMAC'
      self._symmetricGeneralTest( POW.MD2_DIGEST, 'UgWmfru6kM68GFn3HMmbeg==\n' )

   def testHmacMd5(self):
      'Generate and check MD5 HMAC'
      self._symmetricGeneralTest( POW.MD5_DIGEST, '+l1oP2UbL0dW7L51lw2LSg==\n' )

   def testHmacSha(self):
      'Generate and check SHA HMAC'
      self._symmetricGeneralTest( POW.SHA_DIGEST, 'xuLEZcpj96p2Uo0/Ief1zjUdJdM=\n')

   def testHmacSha1(self):
      'Generate and check SHA1 HMAC'
      self._symmetricGeneralTest( POW.SHA1_DIGEST, 'nnT7qPYMHjJ46JXQWmR/Ap0XK2E=\n' )

   def testHmacRipemd160(self):
      'Generate and check RIPEMD160 HMAC'
      self._symmetricGeneralTest( POW.RIPEMD160_DIGEST, 'AeSjVffp5FPIBBtabpD/nwVDz/s=\n' )

#--------------- Hmac test case ---------------#
#--------------- Symmetric cipher test case ---------------#

class SymmetricTestCase(unittest.TestCase):
   'Symmetric algorithm tests'

   password = 'Hello :)'

   plainText = '''
# Basic system aliases that MUST be present.
postmaster:	root
mailer-daemon:	postmaster

# amavis
virusalert:	root

# General redirections for pseudo accounts in /etc/passwd.
administrator:	root
daemon:		root
lp:		root
news:		root
uucp:		root
games:		root
man:		root
at:		root
postgres:	root
mdom:		root
amanda:		root
ftp:		root
wwwrun:		root
squid:		root
msql:		root
gnats:		root
nobody:		root
'''

   plainText = 'Hello World'

   def _symmetricGeneralTest(self, type):
      symmetric = POW.Symmetric( type )
      symmetric.encryptInit( self.password )
      cipherText = symmetric.update( self.plainText ) + symmetric.final()
      symmetric.decryptInit( self.password )
      decipheredText = symmetric.update( cipherText ) + symmetric.final()
      self.failUnlessEqual( self.plainText, decipheredText, 'decrypted cipher text not equal to original text' )

   def testDES_ECB(self):
      'Generate and check DES_ECB encrypted text'
      self._symmetricGeneralTest( POW.DES_ECB )

   def testDES_EDE(self):
      'Generate and check DES_EDE encrypted text'
      self._symmetricGeneralTest( POW.DES_EDE )

   def testDES_EDE3(self):
      'Generate and check DES_EDE3 encrypted text'
      self._symmetricGeneralTest( POW.DES_EDE3 )

   def testDES_CFB(self):
      'Generate and check DES_CFB encrypted text'
      self._symmetricGeneralTest( POW.DES_ECB )

   def testDES_EDE_CFB(self):
      'Generate and check DES_EDE_CFB encrypted text'
      self._symmetricGeneralTest( POW.DES_EDE_CFB )

   def testDES_EDE3_CFB(self):
      'Generate and check DES_EDE3_CFB encrypted text'
      self._symmetricGeneralTest( POW.DES_EDE3_CFB )

   def testDES_OFB(self):
      'Generate and check DES_OFB encrypted text'
      self._symmetricGeneralTest( POW.DES_OFB )

   def testDES_EDE_OFB(self):
      'Generate and check DES_EDE_OFB encrypted text'
      self._symmetricGeneralTest( POW.DES_EDE_OFB )

   def testDES_EDE3_OFB(self):
      'Generate and check DES_EDE3_OFB encrypted text'
      self._symmetricGeneralTest( POW.DES_EDE3_OFB )

   def testDES_CBC(self):
      'Generate and check DES_CBC encrypted text'
      self._symmetricGeneralTest( POW.DES_CBC )

   def testDES_EDE_CBC(self):
      'Generate and check DES_EDE_CBC encrypted text'
      self._symmetricGeneralTest( POW.DES_EDE_CBC )

   def testDES_EDE3_CBC(self):
      'Generate and check DES_EDE3_CBC encrypted text'
      self._symmetricGeneralTest( POW.DES_EDE3_CBC )

   def testDESX_CBC(self):
      'Generate and check DESX_CBC encrypted text'
      self._symmetricGeneralTest( POW.DESX_CBC )

   def testRC4(self):
      'Generate and check RC4 encrypted text'
      self._symmetricGeneralTest( POW.RC4 )

   def testRC4_40(self):
      'Generate and check RC4_40 encrypted text'
      self._symmetricGeneralTest( POW.DES_EDE3_CBC )

   def testIDEA_ECB(self):
      'Generate and check IDEA_ECB encrypted text'
      self._symmetricGeneralTest( POW.IDEA_ECB )

   def testIDEA_CFB(self):
      'Generate and check IDEA_CFB encrypted text'
      self._symmetricGeneralTest( POW.IDEA_CFB )

   def testIDEA_OFB(self):
      'Generate and check IDEA_OFB encrypted text'
      self._symmetricGeneralTest( POW.IDEA_OFB )

   def testIDEA_CBC(self):
      'Generate and check IDEA_CBC encrypted text'
      self._symmetricGeneralTest( POW.IDEA_CBC )

   def testRC2_ECB(self):
      'Generate and check RC2_ECB encrypted text'
      self._symmetricGeneralTest( POW.RC2_ECB )

   def testRC2_CBC(self):
      'Generate and check RC2_CBC encrypted text'
      self._symmetricGeneralTest( POW.RC2_CBC )

   def testRC2_40_CBC(self):
      'Generate and check RC2_40_CBC encrypted text'
      self._symmetricGeneralTest( POW.RC2_40_CBC )

   def testRC2_CFB(self):
      'Generate and check RC2_CFB encrypted text'
      self._symmetricGeneralTest( POW.RC2_CFB )

   def testRC2_OFB(self):
      'Generate and check RC2_OFB encrypted text'
      self._symmetricGeneralTest( POW.RC2_OFB )

   def testBF_ECB(self):
      'Generate and check BF_ECB encrypted text'
      self._symmetricGeneralTest( POW.BF_ECB )

   def testBF_CBC(self):
      'Generate and check BF_CBC encrypted text'
      self._symmetricGeneralTest( POW.BF_CBC )

   def testBF_CFB(self):
      'Generate and check BF_CFB encrypted text'
      self._symmetricGeneralTest( POW.BF_CFB )

   def testBF_OFB(self):
      'Generate and check BF_OFB encrypted text'
      self._symmetricGeneralTest( POW.BF_OFB )

   def testCAST5_ECB(self):
      'Generate and check CAST5_ECB encrypted text'
      self._symmetricGeneralTest( POW.CAST5_ECB )

   def testCAST5_CBC(self):
      'Generate and check CAST5_CBC encrypted text'
      self._symmetricGeneralTest( POW.CAST5_CBC )

   def testCAST5_CFB(self):
      'Generate and check CAST5_CFB encrypted text'
      self._symmetricGeneralTest( POW.CAST5_CFB )

   def testCAST5_OFB(self):
      'Generate and check CAST5_OFB encrypted text'
      self._symmetricGeneralTest( POW.CAST5_OFB )

   def testRC5_32_12_16_CBC(self):
      'Generate and check CAST5_OFB encrypted text'
      self._symmetricGeneralTest( POW.CAST5_OFB )

   def testRC5_32_12_16_CFB(self):
      'Generate and check CAST5_OFB encrypted text'
      self._symmetricGeneralTest( POW.CAST5_OFB )

   def testRC5_32_12_16_ECB(self):
      'Generate and check CAST5_OFB encrypted text'
      self._symmetricGeneralTest( POW.CAST5_OFB )

   def testRC5_32_12_16_OFB(self):
      'Generate and check CAST5_OFB encrypted text'
      self._symmetricGeneralTest( POW.CAST5_OFB )

#--------------- Symmetric cipher test case ---------------#
#--------------- Asymmetric cipher test case ---------------#

class AsymmetricUtilities:

   keys = { 'client' : { 'priv' : 'working/key1Priv', 'pub' : 'working/key1Pub' },
            'server' : { 'priv' : 'working/key2Priv', 'pub' : 'working/key2Pub' },
            'ca'     : { 'priv' : 'working/key3Priv', 'pub' : 'working/key3Pub' },
            'ca2'    : { 'priv' : 'working/key4Priv', 'pub' : 'working/key4Pub' },
            'ca3'    : { 'priv' : 'working/key5Priv', 'pub' : 'working/key5Pub' },
            'server2': { 'priv' : 'working/key6Priv', 'pub' : 'working/key6Pub' }   }

   password = 'Silly password'

   def prepCiphers(self):
      for entity in self.keys.keys():
         self.makeCipher(entity)

   def unPrepCiphers(self):
      for entity in self.keys.keys():
         self.remCipher(entity)

   def getCipher(self, entry):
      privFile = open( self.keys[entry]['priv'] )
      pubFile = open( self.keys[entry]['pub'] )
      priv = POW.pemRead( POW.RSA_PRIVATE_KEY,privFile.read(),  self.password )
      pub = POW.pemRead( POW.RSA_PUBLIC_KEY, pubFile.read() )
      privFile.close()
      pubFile.close()
      return (pub, priv)
 
   def makeCipher(self, entry):
      cipher = POW.Asymmetric()
      privFile = open( self.keys[entry]['priv'], 'w' )
      pubFile = open( self.keys[entry]['pub'], 'w' )
      privFile.write( cipher.pemWrite( POW.RSA_PRIVATE_KEY, POW.DES_EDE3_CFB, self.password ) )
      pubFile.write( cipher.pemWrite( POW.RSA_PUBLIC_KEY ) )
      privFile.close()
      pubFile.close()

   def remCipher(self, entry):
      try: os.remove( self.keys[entry]['priv'] )
      except: pass
      try: os.remove( self.keys[entry]['pub'] )
      except: pass

class AsymmetricTestCase(unittest.TestCase):
   'Asymmetric algorithm tests'

   plainText = 'A little text to encrypt!'

   def testPublicEncrypt(self):
      'Encrypt text using public RSA cipher, decrypt and compare'
      cipher = ciphers.getCipher('client')
      public = cipher[0]
      private = cipher[1]
      cipherText = public.publicEncrypt( self.plainText )
      deCiphered = private.privateDecrypt( cipherText )
      self.failUnlessEqual( self.plainText, deCiphered )

   def testPrivateEncrypt(self):
      'Encrypt text using private RSA cipher, decrypt and compare'
      cipher = ciphers.getCipher('client')
      public = cipher[0]
      private = cipher[1]
      cipherText = private.privateEncrypt( self.plainText )
      deCiphered = public.publicDecrypt( cipherText )
      self.failUnlessEqual( self.plainText, deCiphered )

   def testSign(self):
      'Sign text using private RSA cipher and verify'
      cipher = ciphers.getCipher('client')
      public = cipher[0]
      private = cipher[1]
      digest = POW.Digest( POW.SHA1_DIGEST )
      digest.update( self.plainText )
      signedText = private.sign( digest.digest(), POW.SHA1_DIGEST )
      self.failUnless( public.verify( signedText, digest.digest(), POW.SHA1_DIGEST ) )

#--------------- Asymmetric cipher test case ---------------#
#--------------- X509 test case ---------------#

class X509Utilities:

   certs = {   'client' : 'working/cert1',
               'server' : 'working/cert2',
               'ca'     : 'working/cert3',    
               'ca2'    : 'working/cert4',
               'ca3'    : 'working/cert5',
               'server2': 'working/cert6'    }

   clientName = ( ('C', 'GB'), ('ST', 'Hertfordshire'),
                  ('O', 'The House'), ('CN', 'Client') )

   serverName = ( ('C', 'GB'), ('ST', 'Hertfordshire'),
                  ('O', 'The House'), ('CN', 'Server') )

   caName = (  ('C', 'GB'), ('ST', 'Hertfordshire'),
               ('O', 'The House'), ('CN', 'CA') )

   ca2Name = (  ('C', 'GB'), ('ST', 'Hertfordshire'),
               ('O', 'The House'), ('CN', 'CA2') )

   ca3Name = (  ('C', 'GB'), ('ST', 'Hertfordshire'),
               ('O', 'The House'), ('CN', 'CA3') )

   server2Name = (  ('C', 'GB'), ('ST', 'Hertfordshire'),
               ('O', 'The House'), ('CN', 'server2') )

   notBefore = 1005960447
   notAfter = 1037496447

   caSerial = 0
   serverSerial = 1
   clientSerial = 2
   ca2Serial = 3
   ca3Serial = 4
   server2Serial = 5

   def __init__(self):
      self.asymUtils = AsymmetricUtilities()
      self.asymUtils.prepCiphers()

   def __del__(self):
      self.asymUtils.unPrepCiphers()

   def prepCerts(self):
      for cert in self.certs.keys():
         self.makeCert(cert)

   def unPrepCerts(self):
      for cert in self.certs.keys():
         self.remCert(cert)

   def getCert(self, entry):
      certFile = open( self.certs[entry] )
      cert = POW.pemRead( POW.X509_CERTIFICATE, certFile.read() )
      certFile.close()
      return cert
 
   def makeCert(self, entry):
      caCipher = self.asymUtils.getCipher('ca')
      ca2Cipher = self.asymUtils.getCipher('ca2')
      ca3Cipher = self.asymUtils.getCipher('ca3')
      cert = POW.X509()

      if entry == 'server':
         serverCipher = self.asymUtils.getCipher('server')
         cert.setIssuer( self.caName )
         cert.setSubject( self.serverName )
         cert.setSerial( self.serverSerial )
         cert.setNotBefore( self.notBefore )
         cert.setNotAfter( self.notAfter )
         cert.setPublicKey( serverCipher[0] )
         cert.sign( caCipher[1] )

      elif entry == 'client':
         clientCipher = self.asymUtils.getCipher('client')
         cert.setIssuer( self.caName )
         cert.setSubject( self.clientName )
         cert.setSerial( self.clientSerial )
         cert.setNotBefore( self.notBefore )
         cert.setNotAfter( self.notAfter )
         cert.setPublicKey( clientCipher[0] )
         cert.sign( caCipher[1] )

      elif entry == 'ca':
         cert.setIssuer( self.caName )
         cert.setSubject( self.caName )
         cert.setSerial( self.caSerial )
         cert.setNotBefore( self.notBefore )
         cert.setNotAfter( self.notAfter )
         cert.setPublicKey( caCipher[0] )
         cert.sign( caCipher[1] )

      elif entry == 'ca2':
         cert.setIssuer( self.caName )
         cert.setSubject( self.ca2Name )
         cert.setSerial( self.ca2Serial )
         cert.setNotBefore( self.notBefore )
         cert.setNotAfter( self.notAfter )
         cert.setPublicKey( ca2Cipher[0] )
         cert.sign( caCipher[1] )

      elif entry == 'ca3':
         cert.setIssuer( self.ca2Name )
         cert.setSubject( self.ca3Name )
         cert.setSerial( self.ca3Serial )
         cert.setNotBefore( self.notBefore )
         cert.setNotAfter( self.notAfter )
         cert.setPublicKey( ca3Cipher[0] )
         cert.sign( ca2Cipher[1] )

      elif entry == 'server2':
         server2Cipher = self.asymUtils.getCipher('server2')
         cert.setIssuer( self.ca3Name )
         cert.setSubject( self.server2Name )
         cert.setSerial( self.server2Serial )
         cert.setNotBefore( self.notBefore )
         cert.setNotAfter( self.notAfter )
         cert.setPublicKey( server2Cipher[0] )
         cert.sign( ca3Cipher[1] )

      else:
         raise Exception, 'Entry should be ca, ca2, server, server2 or client!'

      certFile = open( self.certs[entry], 'w' )
      certFile.write( cert.pemWrite() )
      certFile.close()

   def remCert(self, entry):
      try: os.remove( self.certs[entry] )
      except: pass

class X509TestCase(unittest.TestCase):
   'X509 tests'

   def testIssuer(self):
      'Check the issuer is correct for server cerficate'
      serverCert = certs.getCert('server')
      self.failUnlessEqual( certs.caName, serverCert.getIssuer() )

   def testSubject(self):
      'Check the subject is correct for server cerficate'
      serverCert = certs.getCert('server')
      self.failUnlessEqual( certs.serverName, serverCert.getSubject() )

   def testVersion(self):
      'Check version number is correct for server cerficate'
      serverCert = certs.getCert('server')
      self.failUnlessEqual( 1, serverCert.getSerial() )

   def testSerial(self):
      'Check serial number is correct for server cerficate'
      serverCert = certs.getCert('server')
      self.failUnlessEqual( certs.serverSerial, serverCert.getSerial() )

   def testNotBefore(self):
      'Check notBefore date is correct for server cerficate'
      serverCert = certs.getCert('server')
      self.failUnlessEqual( certs.notBefore, serverCert.getNotBefore()[0] )

   def testNotAfter(self):
      'Check notAfter date is correct for server cerficate'
      serverCert = certs.getCert('server')
      self.failUnlessEqual( certs.notAfter, serverCert.getNotAfter()[0] )

#--------------- X509 test case ---------------#
#--------------- X509 Store test case ---------------#

class X509StoreTestCase(unittest.TestCase):
   'X509 Store tests'

   def testVerify(self):
      'Verify server\'s certificate againtst CA certificate'
      caCert = certs.getCert('ca')
      serverCert = certs.getCert('server')

      store = POW.X509Store()
      store.addTrust( caCert )
      self.failUnless( store.verify( serverCert ) )

   def testVerifyChain(self):
      'Verify chain of certificate againtst CA certificate'
      caCert = certs.getCert('ca')
      ca2Cert = certs.getCert('ca2')
      ca3Cert = certs.getCert('ca3')
      server2Cert = certs.getCert('server2')

      store = POW.X509Store()
      store.addTrust( caCert )
      self.failUnless( store.verifyChain( server2Cert, [ca3Cert, ca2Cert ])  )


#--------------- X509 Store test case ---------------#
#--------------- X509 Revoked test case ---------------#

class X509RevokedTestCase(unittest.TestCase):
   'X509 Store tests'

   serial = 7
   revokedOn = 1005960447

   def testRevoked(self):
      'Create X509 revocation and check values are correct'
      rev = POW.X509Revoked( self.serial, self.revokedOn )
      self.failUnlessEqual( rev.getDate()[0], self.revokedOn )
      self.failUnlessEqual( rev.getSerial(), self.serial )

#--------------- X509 Revoked test case ---------------#
#--------------- X509 CRL test case ---------------#

class X509CrlTestCase(unittest.TestCase):
   'X509 CRL tests'

   revocationData = (   ( 1, 1005960447 ),
                        ( 2, 1005960448 ),
                        ( 3, 1005960449 ),
                        ( 4, 1005960450 ),
                        ( 5, 1005960451 )    )

   thisUpdate = 1005960447
   nextUpdate = 1037496447

   version = 2

   def setUp(self):
      self.ca = certs.getCert('ca')
      self.caCipher = ciphers.getCipher('ca')

      revocation = []
      for rev in self.revocationData:
         revocation.append( POW.X509Revoked( rev[0], rev[1] ) )

      self.crl = POW.X509Crl()
      self.crl.setVersion( self.version )
      self.crl.setIssuer( self.ca.getIssuer() )
      self.crl.setThisUpdate( self.thisUpdate )
      self.crl.setNextUpdate( self.nextUpdate )
      self.crl.setRevoked( revocation )
      self.crl.sign( self.caCipher[1] )

   def tearDown(self):
      del self.ca
      del self.caCipher
      del self.crl

   def testVersion(self):
      'Create CRL and check version number is correct'
      self.failUnlessEqual( self.version, self.crl.getVersion() )

   def testIssuer(self):
      'Create CRL and check issuer name is correct'
      self.failUnlessEqual( self.ca.getIssuer(), self.crl.getIssuer() )

   def testThisUpdate(self):
      'Create CRL and check thisUpdate is correct'
      self.failUnlessEqual( self.thisUpdate, self.crl.getThisUpdate()[0] )

   def testNextUpdate(self):
      'Create CRL and check nextUpdate is correct'
      self.failUnlessEqual( self.nextUpdate, self.crl.getNextUpdate()[0] )

   def testRevoked(self):
      'Create CRL and check list of revoked objects is correct'
      revokedCerts = self.crl.getRevoked()
      for i in range( len(revokedCerts) ):
         revocation = revokedCerts[i]
         serial = revocation.getSerial()
         date = revocation.getDate()[0]
         index = serial - 1
         self.failUnlessEqual( self.revocationData[index][0], serial )
         self.failUnlessEqual( self.revocationData[index][1], date )

#--------------- X509 CRL test case ---------------#
#--------------- SSL test case ---------------#

serverPort = 7777
clientMsg = 'Message from client to server...'
serverMsg = 'Message from server to client...'

def serverCertKey():
   cert = certs.getCert('server')
   key = ciphers.getCipher('server')[1]
   return cert, key

def clientCertKey():
   cert = certs.getCert('client')
   key = ciphers.getCipher('client')[1]
   return cert, key

class SimpleSslServer:

   def __init__(self, test):
      cert, key = serverCertKey()
      ssl = POW.Ssl( POW.SSLV23_SERVER_METHOD )
      ssl.useCertificate(cert)
      ssl.useKey(key)

      sock = socket.socket( socket.AF_INET, socket.SOCK_STREAM )
      sock.bind( ('', serverPort) )
      sock.listen(1)
      conn, addr = sock.accept()
      sock.shutdown(0)
      sock.close()
      ssl.setFd( conn.fileno() )
      ssl.accept()

      msg = ssl.read()
      ssl.write(serverMsg)
      
      while 1:
         try: ssl.shutdown(); break
         except: time.sleep(1)

      conn.shutdown(0)
      conn.close()     
      test.failUnlessEqual( clientMsg, msg, 'client/server communication failiure' )

class ValidatingSslServer:

   def __init__(self, test):
      cert, key = serverCertKey()
      ssl = POW.Ssl( POW.SSLV23_SERVER_METHOD )
      ssl.useCertificate(cert)
      ssl.useKey(key)
      ssl.setVerifyMode( POW.SSL_VERIFY_PEER )

      store = POW.X509Store()
      store.addTrust( certs.getCert('ca') )

      sock = socket.socket( socket.AF_INET, socket.SOCK_STREAM )
      sock.bind( ('', serverPort) )
      sock.listen(1)
      conn, addr = sock.accept()
      sock.shutdown(0)
      sock.close()
      ssl.setFd( conn.fileno() )
      ssl.accept()

      clientCert = ssl.peerCertificate()

      msg = ssl.read()
      ssl.write(serverMsg)
      
      while 1:
         try: ssl.shutdown(); break
         except: time.sleep(1)

      conn.shutdown(0)
      conn.close()     
      test.failUnless( store.verify( clientCert ), 'client certificate failed verification' )

class SslClient:

   def __init__(self):
      cert, key = clientCertKey()
      ssl = POW.Ssl( POW.SSLV23_CLIENT_METHOD )
      ssl.useCertificate(cert)
      ssl.useKey(key)
      sock = socket.socket( socket.AF_INET, socket.SOCK_STREAM )
      sock.connect( ('', serverPort) )
      ssl.setFd( sock.fileno() )
      ssl.connect()

      ssl.write(clientMsg)
      ssl.read()
      
      while 1:
         try: ssl.shutdown(); break
         except: time.sleep(1)

      sock.shutdown(0)
      sock.close()     
      os._exit(0)

class SslTestCase(unittest.TestCase):
   'SSL tests'

   def testSimple(self):
      '''Test client/server communication over SSL'''
      pid = os.fork()
      if pid:
         time.sleep(1)
         SimpleSslServer(self)
      else:
         time.sleep(3)
         SslClient()

   def testClientValidation(self):
      '''Request and validate client certificate'''
      pid = os.fork()
      if pid:
         time.sleep(1)
         ValidatingSslServer(self)
      else:
         time.sleep(3)
         SslClient()


#--------------- SSL test case ---------------#
#--------------- Test suite generators ---------------#

def hashSuite():
   suite = unittest.TestSuite()
   suite.addTest( HashTestCase('testMd2') )
   suite.addTest( HashTestCase('testMd5') )
   suite.addTest( HashTestCase('testSha') )
   suite.addTest( HashTestCase('testSha1') )
   suite.addTest( HashTestCase('testRipemd160') )
   return suite

def hmacSuite():
   suite = unittest.TestSuite()
   suite.addTest( HmacTestCase('testHmacMd2') )
   suite.addTest( HmacTestCase('testHmacMd5') )
   suite.addTest( HmacTestCase('testHmacSha') )
   suite.addTest( HmacTestCase('testHmacSha1') )
   suite.addTest( HmacTestCase('testHmacRipemd160') )
   return suite

def symmetricSuite():
   suite = unittest.TestSuite()
   if 'DES_ECB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_ECB') )
   if 'DES_EDE' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_EDE') )
   if 'DES_EDE3' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_EDE3') )
   if 'DES_CFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_CFB') )
   if 'DES_EDE_CFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_EDE_CFB') )
   if 'DES_EDE3_CFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_EDE3_CFB') )
   if 'DES_OFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_OFB') )
   if 'DES_EDE_OFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_EDE_OFB') )
   if 'DES_EDE3_OFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_EDE3_OFB') )
   if 'DES_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_CBC') )
   if 'DES_EDE_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_EDE_CBC') )
   if 'DES_EDE3_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDES_EDE3_CBC') )
   if 'DESX_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testDESX_CBC') )
   if 'RC4' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC4') )
   if 'RC4_40' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC4_40') )
   if 'IDEA_ECB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testIDEA_ECB') )
   if 'IDEA_CFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testIDEA_CFB') )
   if 'IDEA_OFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testIDEA_OFB') )
   if 'IDEA_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testIDEA_CBC') )
   if 'RC2_ECB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC2_ECB') )
   if 'RC2_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC2_CBC') )
   if 'RC2_40_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC2_40_CBC') )
   if 'RC2_CFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC2_CFB') )
   if 'RC2_OFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC2_OFB') )
   if 'BF_ECB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testBF_ECB') )
   if 'BF_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testBF_CBC') )
   if 'BF_CFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testBF_CFB') )
   if 'BF_OFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testBF_OFB') )
   if 'CAST5_ECB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testCAST5_ECB') )
   if 'CAST5_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testCAST5_CBC') )
   if 'CAST5_CFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testCAST5_CFB') )
   if 'CAST5_OFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testCAST5_OFB') )
   if 'RC5_32_12_16_CBC' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC5_32_12_16_CBC') )
   if 'RC5_32_12_16_CFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC5_32_12_16_CFB') )
   if 'RC6_32_12_16_ECB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC5_32_12_16_ECB') )
   if 'RC5_32_12_16_OFB' in POW.__dict__.keys():
      suite.addTest( SymmetricTestCase('testRC5_32_12_16_OFB') )
   return suite

def asymmetricSuite():
   suite = unittest.TestSuite()
   suite.addTest( AsymmetricTestCase('testPublicEncrypt') )
   suite.addTest( AsymmetricTestCase('testPrivateEncrypt') )
   suite.addTest( AsymmetricTestCase('testSign') )
   return suite

def x509Suite():
   suite = unittest.TestSuite()
   suite.addTest( X509TestCase('testIssuer') )
   suite.addTest( X509TestCase('testSubject') )
   suite.addTest( X509TestCase('testVersion') )
   suite.addTest( X509TestCase('testSerial') )
   suite.addTest( X509TestCase('testNotBefore') )
   suite.addTest( X509TestCase('testNotAfter') )
   return suite

def x509StoreSuite():
   suite = unittest.TestSuite()
   suite.addTest( X509StoreTestCase('testVerify') )
   suite.addTest( X509StoreTestCase('testVerifyChain') )
   return suite

def x509RevokedSuite():
   suite = unittest.TestSuite()
   suite.addTest( X509RevokedTestCase('testRevoked') )
   return suite

def x509CrlSuite():
   suite = unittest.TestSuite()
   suite.addTest( X509CrlTestCase('testVersion') )
   suite.addTest( X509CrlTestCase('testIssuer') )
   suite.addTest( X509CrlTestCase('testThisUpdate') )
   suite.addTest( X509CrlTestCase('testNextUpdate') )
   suite.addTest( X509CrlTestCase('testRevoked') )
   return suite

def sslSuite():
   suite = unittest.TestSuite()
   suite.addTest( SslTestCase('testSimple') )
   suite.addTest( SslTestCase('testClientValidation') )
   return suite

#--------------- Test suite generators ---------------#
#--------------- main ---------------#

if __name__ == '__main__':
   print '\n\tGenerating RSA keys and certificates to use for testing...\n'

   certs = X509Utilities()
   ciphers = certs.asymUtils
   certs.prepCerts()

   runner = unittest.TextTestRunner( sys.stderr, 1, 2)
   runner.run( hashSuite() )
   runner.run( hmacSuite() )
   runner.run( symmetricSuite() )
   runner.run( asymmetricSuite() )
   runner.run( x509Suite() )
   runner.run( x509StoreSuite() )
   runner.run( x509RevokedSuite() )
   runner.run( x509CrlSuite() )
   runner.run( sslSuite() )

   certs.unPrepCerts()
   ciphers.unPrepCiphers()

#--------------- main ---------------#
