/* $Id: ArkSocket.cpp,v 1.24 2003/02/04 20:18:15 zongo Exp $
** 
** Ark - Libraries, Tools & Programs for MMORPG developpements.
** Copyright (C) 1999-2000 The Contributors of the Ark Project
** Please see the file "AUTHORS" for a list of contributors
**
** This program is free software; you can redistribute it and/or modify
** it under the terms of the GNU General Public License as published by
** the Free Software Foundation; either version 2 of the License, or
** (at your option) any later version.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
** GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program; if not, write to the Free Software
** Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/


#include <Ark/ArkSocket.h>
#include <Ark/ArkSystem.h>


#ifdef HAVE_UNISTD_H
#include <unistd.h>
#endif


#include <errno.h>
#include <stdlib.h>
#include <stdio.h>

// STL list for packet list
#include <list>
#include <iostream>


#ifdef HAVE_WINSOCK

/* Ugh, should use WinSock... Why the hell can't they use
** the same header files, since this is (almost) the same
** API as the socket API on POSIX systems ??
*/
#include <winsock.h>

#define close(s)      closesocket(s)
#define read(s,d,sz)  recv(s,d,sz, 0)
#define write(s,d,sz) send(s,d,sz, 0)

#define EWOULDBLOCK WSAEWOULDBLOCK

/* Dunno why, but the size arguments to accept and getsockopt are
** int under Winsock, but uint on POSIX systems (at least under
** linux).
*/
typedef int sock_size_t;
typedef SOCKET socktype;

#else /* POSIX system */

#include <sys/socket.h>
#include <sys/time.h>

#include <netinet/tcp.h>
#include <netinet/in.h>

#include <arpa/inet.h>

#include <fcntl.h>
#include <stdio.h>
#include <netdb.h>

typedef unsigned int sock_size_t;
typedef int socktype;
#endif

namespace Ark
{

   static const int SOCKDEBUG = 1;

   class SocketP
   {
	 /*
	  * Return true if errno is a real error (ie: not a socket in
	  * non-blocking mode unable to write data).
	  */
	 bool RealError ()
	 {
	    // HIGH PRIORITY FIXME: should do some error checking...
	    return errno != 0
	       && errno != EAGAIN
	       && errno != EWOULDBLOCK;
	 }

         /*
	  * If the int given as the 2nd argument is errval (-1 is the value
	  * most socket calls return on error), check if it is a real error,
	  * and if yes mark the given socket as dead.
	  */
	 bool
	 CheckError (int i, int errval = -1)
	 {
	    if (i == errval)
	    {
	       // Check if it's EAGAIN or something. if not, kill socket.
	       if (RealError ())
	       {
		  // If so, mark this socket as dead
		  Destroy();
	       }
	      
	       // In any case, we don't want to continue looping so return 0.
	       return false;
	    }
	    
	    return true;
	 }
	 
	 // Return true if data are directly readable from the socket.
	 bool DataReady (socktype sock)
	 {
	    fd_set rfds;
	    struct timeval tv;
	    
	    FD_ZERO (&rfds);
	    FD_SET (sock, &rfds);
	    
	    tv.tv_sec = tv.tv_usec = 0;
	    select (sock+1, &rfds, NULL, NULL, &tv);
	    
	    return FD_ISSET (sock, &rfds) != 0;
	 }

      public:
	 // List of packets type
	 typedef std::list< Packet* > PacketList;
	     
	 // Input and output packet queues.
	 PacketList m_Out;
	 std::list<std::istringstream*> m_In;

	 /// If set to true, it means the a connection cannot be opened, or
	 /// that it was closed.
	 int m_Dead;
	 
      private:
	 // State machine stuff 
	 int m_rState, m_rPos;
	 int m_wState, m_wPos;
	 char m_rBuf[4];

	 /// The currently written packet buffer.
	 std::string m_wPacket;
	 
	 /// This is a pointer to the packet which is being read but hasn't
	 /// been fully loaded from network yet
	 std::string m_rPacket;

	 /// The OS socket identifier
	 socktype m_Socket;
	 struct sockaddr_in m_sin;
	 
      public:

	 /*
	  * Bind an existing socket to a new C++ structure. This is used in
	  * the Accept() function where we have to construct Ark Sockets
	  * from already existing socket identifiers;
	  */
	 SocketP (int socket, sockaddr_in sin)
	 {
	    m_sin = sin;
	    m_Socket = socket;

	    m_rState = 0;
	    m_rPos = 0;
	    m_wState = 0;
	    m_wPos = 0;
	    m_Out.clear();
	    m_In.clear();
	    m_Dead = false;
	 }

	 SocketP (const String &resource,
		  const String &proto)
	 {
	    struct servent *iservice;
	    struct protoent *itransport;
	    
	    int s, type, tproto;

	    // If there's an error creating the socket, the end of the function
	    // wont be reached, and the socket will be considered as dead.
	    m_Dead = true;
	    
	    memset (&m_sin, 0, sizeof (m_sin));
	    m_sin.sin_family = AF_INET;
	    
	    // Search for the port.
	    m_sin.sin_port = htons (atoi (resource.c_str()));

	    if (!m_sin.sin_port)
	    {
	       iservice = getservbyname(resource.c_str(), proto.c_str());
	       
	       if (!iservice)
	       {
		  Sys()->Error ("Unknow service '%s'.", resource.c_str());
		  return;
	       }
	       
	       m_sin.sin_port = iservice->s_port;
	    }
	    
	    // Find the protocol number. 
	    itransport = getprotobyname (proto.c_str());
	    
	    if (!itransport)
	    {
	       Sys()->Error ("Unknow protocol '%s'.", proto.c_str());
	       return;
	    }
	    
	    tproto = itransport->p_proto;	    
	    if (proto == "udp")
	       type = SOCK_DGRAM;
	    else
	       type = SOCK_STREAM;

	    /* Cration de la socket */
	    s = socket(AF_INET, type, tproto);
	    
	    if (s < 0)
	    { 
	       Sys()->Error ("Unable to create socket.\n");
	       return;
	    }

	    m_Socket = s;
	  
	    m_rState = 0;
	    m_rPos = 0;
	    m_wState = 0;
	    m_wPos = 0;
	    
	    m_Out.clear();
	    m_In.clear();
	    m_Dead = false;
	 }
	 
	 ~SocketP ()
	 {
	    Destroy();
	 }

	 void Destroy()
	 {
	    if (m_Socket)
	    {
	       close (m_Socket);
	       m_Socket = 0;
	    }

	    m_Dead = 1;
	    
	    std::list<Packet*>::iterator i;
	    std::list<std::istringstream*>::iterator j;
	    
	    // Unref all input packets
	    for (j = m_In.begin(); j != m_In.end();  ++j)
	       delete (*j);
	    
	    m_In.clear();
	    
	    // Unref all output packets
	    for (i = m_Out.begin(); i != m_Out.end();  ++i)
	       (*i)->Unref ();
	    
	    m_Out.clear();	   
	 }

         /*
	  * Try to connect to remote \c host (which can be either an IP
	  * address either an hostname), with the protocol/ports that
	  * were givent to the constructor.
	  */
	 bool Connect (const String &host)
	 {
	    // We should find the ip address of the host. 
	    m_sin.sin_addr.s_addr = inet_addr(host.c_str());
	    
	    if (m_sin.sin_addr.s_addr == INADDR_NONE)
	    {
	       // If host is a server name, find it! 
	       struct hostent *ihost = gethostbyname(host.c_str());
	       
	       if (!ihost)
	       {
		  Sys()->Error ("No host '%s'.\n", host.c_str());
		  return false;
	       }

	       memcpy(&m_sin.sin_addr, ihost->h_addr, ihost->h_length);
	    }

	    // Connection! 
	    if (connect (m_Socket, (struct sockaddr *)&m_sin,
			 sizeof (m_sin)) < 0)
	    {
	       Destroy ();
	       return false;
	    }   

	    return true;
	 }

	 // ------------------------------------------
	 bool Bind (const String &resource,
		    const String &proto,
		    int lqueue)
	 {
	    /* We should find the ip address of the host. */
	    m_sin.sin_addr.s_addr = INADDR_ANY;
	    
	    /* Allocation du numro de port */
	    if (bind (m_Socket, (struct sockaddr *)&m_sin, sizeof (m_sin)) < 0)
	    {
	       Sys()->Error ("Unable to bind socket '%d'.", m_Socket);
	       Destroy();
	       return false;
	    }
	    
	    if (proto != "udp" && listen(m_Socket, lqueue) < 0)
	    {

	       Sys()->Error ("Cannot listen to port '%s'.",
				   resource.c_str());
	       Destroy();
	       return false;
	    }

	    return true;
	 }

	 // ------------------------------------------
	 bool SetBlocking (bool blocking)
	 {
	    int i;
	    sock_size_t j;
	    
#ifdef HAVE_FCNTL
	    // Set the filedes to nonblock. Print an error message if we
	    // can't and abort.
	    int f = fcntl(m_Socket, F_GETFL);
	    
	    if (f == -1)
	    {
	       Sys()->Fatal ("Can't make %i nonblocking.", m_Socket);
	       return false;
	    }
	    
	    if (fcntl (m_Socket, F_SETFL, f|O_NONBLOCK) == -1)
	    {
	       Sys()->Fatal ("No cant make %i nonblocking dude.",
				   m_Socket);
	       return false;
	    }
#elif defined HAVE_WINSOCK
	    /* Dont have fcntl, so use ioctls */
	    
	    unsigned long buffer = 1;
	    ioctlsocket (m_Socket, FIONBIO, &buffer);
#endif
	    
	    // These flags are not crucial for the code to work but if I
	    // can't set these flags then latency will skyrocket :
	    // TCP_NODELAY disables the Nagle algorithm. Nagle buffers packets
	    // for up to .2 second (!) and that is BAAAAD.
	    
	    i = -1;
	    j = sizeof(i);
	    
	    char *data = (char *) &i;
	    setsockopt (m_Socket, IPPROTO_TCP, TCP_NODELAY, data, sizeof(i));
	    getsockopt (m_Socket, IPPROTO_TCP, TCP_NODELAY, data, &j);
	    
	    if (!i)
	    {
	       Sys()->Warning
		  ("Could not set TCP_NODELAY for %i.", m_Socket);
	       return false;
	    }

	    return true;
	 }

	 bool WaitForData (int ms)
	 {
	    fd_set rfds;
	    int n = 5;
	    
	    do
	    {
	       /* Wait for data. */
	       FD_ZERO (&rfds);
	       FD_SET (m_Socket, &rfds);
	       
	       if (ms)
	       {
		  struct timeval tv;
		  tv.tv_sec = ms/1000;
		  tv.tv_usec = (ms-tv.tv_sec*1000)*1000;

		  select (m_Socket + 1, &rfds, NULL, NULL, &tv);
	       }
	       else
	       {
		  select (1, &rfds, NULL, NULL, NULL);
	       }
	       
	       if (FD_ISSET (m_Socket, &rfds))
	       {
		  return true;
	       }
	    } while (--n);
	    
	    return false;
	 }

	 Socket *Accept (bool wait)
	 {
	    struct sockaddr_in sin;
	    sock_size_t lsin = sizeof (sin);
	    int ssock;
	    fd_set rfds;
	    struct timeval tv;
	    
	    if (m_Dead)
	       return NULL;
	    
	    FD_ZERO (&rfds);
	    FD_SET (m_Socket, &rfds);
	    
	    tv.tv_sec = 0;
	    tv.tv_usec = 0;
	    
	    select (m_Socket + 1, &rfds, NULL, NULL, &tv);
	    
	    if (wait || FD_ISSET (m_Socket, &rfds))
	       ssock = accept(m_Socket, (struct sockaddr *)&sin, &lsin);
	    else
	       return NULL;
	    
	    // A connection has been opened. Allocate a new C++ wrapper for it.
	    if (ssock)
	    {
	       Socket *csocket = new Socket ();
	       
	       assert (csocket != NULL);
	       
	       csocket->m_Priv = new SocketP (ssock, sin);
	       csocket->SetBlocking (false);	 
	       return csocket;
	    }
	    
	    return NULL;
	 }
	 
         /* State machine algorithm
	  * =======================
	  *
	  * Depending on rstate the following action is performed :
	  *     - rstate = 0: go to next state.
	  *     - rstate = 1: read next packet size.
	  *     - rstate = 2: read packet data. if everything has been read,
	  *                   read next packet (if there is one) by going to
	  *                   state 0.
	  *
	  * Remark about this code: it has been taken from a library called
	  * GTC (the GameToolChest), distributed under the GNU GPL, and
	  * available at gtc.seul.org
	  * The author is Sbastien Loisel.
	  */
	 bool ProcessRead (int first)
	 {
	    if (m_Dead)
	       return false;

	    int i, left, k;

	    switch(m_rState)
	    {
	       // Initial state. Falls into the next state.
	       case 0:
		  m_rState = 1;
		  m_rPos = 0;
		  m_rPacket.resize(0);

		  // Read packet length
	       case 1:
		  // This is how many bytes we yet have to read
		  left = 4-m_rPos;
		  
		  if (m_rPos == 0 && !DataReady (m_Socket))
		     return 0;

		  i = read (m_Socket, m_rBuf, left);
		  if (SOCKDEBUG)
		     std::cerr << m_Socket << " read " << i << " (" << left
			       << ")\n";
		  
		  // Read error
		  if (!CheckError (i, -1))
		  {
		     if (SOCKDEBUG)
			perror("Read error");

		     return false;
		  }
		  
		  // Readcount is 0! buffer empty or socket dead?
		  if (i == 0)
		  {
		     // Socket dead.
		     if (first)
			Destroy();
		     
		     // In any case stop looping
		     return false;
		  }
		  
		  // Wow we've read all the bytes.
		  if (i == left)
		  {
		     // Convert into an int.
		     i = *(int*)m_rBuf;
		     
		     // Zero sized packets are ignored
		     if(i == 0)
		     {
			m_rState = 0;
			return true;
		     }
		     
		     // Allocate the packet and prepare for reading
		     m_rPacket.resize(i);
		     m_rState = 2;
		     m_rPos = 0;

		     if (SOCKDEBUG)
			std::cerr << "read packet size " << i << "\n";
		     
		     // Loop again
		     return true;
		  }
		  
		  m_rPos += i;
		  
		  // We haven't read all the bytes. Stop looping.
		  return false;
		  break;
		  
		  // Read packet itself
	       case 2:
		  left = m_rPacket.size() - m_rPos;
		  assert (left >= 0);

		  errno = 0;
		  i = read (m_Socket, &m_rPacket[m_rPos], left);
		  
		  if (!CheckError (i, -1))
		     return false;
		  
		  // Read no bytes
		  if(i == 0)
		  {
		     if (first)
			Destroy ();
		     
		     // Stop looping
		     return false;
		  }

		  if (SOCKDEBUG)
		  {
		     std::cerr << "received "<< i << "\n";
		     //int w = std::cerr.width(2); 
		     //int f = std::cerr.setf(ios::hex);
		     
		     for (k = 0;k < i; k++)
			std::cerr  << int(m_rPacket[m_rPos + k] & 0xff) << " ";

		     //std::cerr.width(w);
		     //std::cerr.setf(f);
		  
		     fprintf(stderr, "\n");
		  }

		  // Read all bytes
		  if (i == left)
		  {
		     if (SOCKDEBUG)
			puts ("received packet");
		     
		     // Push packet on queue, go to state 0
		     m_In.push_back (new std::istringstream (m_rPacket));
		     m_rPacket = "";
		     m_rState = 0;
		     
		     return true;
		  }
		  
		  m_rPos += i;

		  /* The read stopped short. don't loop again. */
		  return false;
		  break;
		  
	       default:
		  // if we get into this state there's something really wrong 
		  Sys()->Fatal ("illegal rstate %i (fd is %i)\n",
                          m_rState, m_Socket);
	    }
	    
	    return false;
	 }

	 /* State machine algorithm
	  * =======================
	  *
	  * Depending on wstate the following action is performed :
	  *     - wstate = 0: go to next state.
	  *     - wstate = 1: write current packet size.
	  *     - wstate = 2: write packet data. if no data, go to next packet
	  *                   (if there is one) and set wstate to 0.
	  *
	  */
	 bool ProcessWrite ()
	 {
	    int i, j;
	    int size = 0;
	    char buf[4];
	    
	    if (m_Dead) return false;

	    // check if the outbound queue is empty. if it is, there's
	    // nothing to do.
	    if (m_Out.empty())
	    {
	       m_wState = m_wPos = 0;
	       m_wPacket = "";
	       return false;
	    }
	    else if (m_wPacket.empty())
	    {
	       m_wPacket = m_Out.front()->m_Stream.str();
	       size = m_wPacket.size();

	       if (size == 0)
	       {
		  m_Out.front()->Unref();
		  m_Out.pop_front();
		  m_wState = m_wPos = 0;
		  m_wPacket = "";
		  return true;
	       }
	    }

	    size = m_wPacket.size();
	    
	    // now the state machine
	    switch (m_wState)
	    {
	       // initialize, fall into the next state.
	       case 0:
		  m_wPos = 0;
		  m_wState = 1;
		  
		  // send packet size
	       case 1:
		  // FIXME: convert packet size into network friendly format.
		  *(int*)buf  = size;
	
		  if (SOCKDEBUG)
		  {
		      std::cerr << "writing packet (size: " << size << ")\n";
		  }
	  
		  errno = 0;

		  // send packet size
		  j = 4 - m_wPos;
		  i = write(m_Socket, &buf[m_wPos], j);
		  
		  if (!CheckError (i, -1))
		     return false;
		  
		  // Have we written all bytes?
		  if(i == j)
		  {
		     // Yes, go to next state (continue looping)
		     m_wState = 2;
		     m_wPos = 0;
		     return true;
		  }
		  
		  // We haven't written all the bytes. Buffer must be full.
		  m_wPos += j;
		  return false;
		  
		  // Send packet
	       case 2:
		  errno = 0;

		  // Bytes to send
		  j = (size - m_wPos);
		  i = write (m_Socket, &m_wPacket[m_wPos], j);
		  
		  if (!CheckError (i, -1))
		     return false;
		  
		  if (SOCKDEBUG && i != 0)
		  {
		     int k;
		     std::cerr << "written " << i << ": "; 
		     
		     for(k = 0; k < i; k++)
			std::cerr << (m_wPacket[m_wPos+k] & 0xff) << " ";
		     std::cerr << "\n";
		  }
		  
		  /* Are we done writing? */
		  if (i == j)
		  {
		     m_Out.front()->Unref();
		     m_Out.pop_front();
		     m_wState = 0;
		     m_wPacket = "";
		     
		     if (SOCKDEBUG)
			puts("written packet");
		     
		     /* and loop again */
		     return true;
		  }
		  
		  // We are not done writing, buffer must be full,
		  // stop looping 
		  m_wPos += i;
		  return false;

	       default:
		  // If we get into this state there's something really
		  // wrong (who in the audience said "a bug" ??)
		  Sys()->Fatal ("illegal wstate %i (fd is %i)\n",
				      m_wState, m_Socket);
	    }
	    
	    return false;
	 }
   };


   // Create a new client socket. 
   Socket::Socket (const String &host,
		   const String &resource,
		   const String &proto)
   {  
      m_Priv = new SocketP (resource, proto.c_str());
      m_Priv->Connect (host);
      
      if (IsDead())
      {
	 Sys()->Error
	    ("Unable to connect to '%s:%s' with protocol '%s'.",
	     host.c_str(), resource.c_str(), proto.c_str());
	 return;
      }
   }
   
   // Create a server (listening) socket.
   Socket::Socket (const String &resource,
		   const String &proto, int lqueue)
   {
      m_Priv = new SocketP (resource, proto);
      m_Priv->Bind (resource, proto, lqueue);
   }
   
   // Close a socket.
   Socket::~Socket ()
   {
      Close ();
   }
   
   void
   Socket::Close ()
   {
      if (m_Priv)
	 delete m_Priv;
   }
   
   // Is this socket dead ??
   bool Socket::IsDead ()
   {
      if (m_Priv == NULL || m_Priv->m_Dead)
	 return true;
      
      return false;
   }
   
   void Socket::SetBlocking (bool blocking)
   {
      if (IsDead())
	 return;
      
      m_Priv->SetBlocking (blocking);
   }
   
   
   // If a client wants to connect, this function will create a new
   // socket for it.
   Socket *Socket::Accept (bool wait)
   {
      if (IsDead ())
	 return NULL;

      return m_Priv->Accept (wait);
   }

   // Wait that the socket can receive data, using the select() call
   bool Socket::WaitForData  (int ms)
   {
      if (IsDead())
	 return false;

      return m_Priv->WaitForData (ms);
   }
   
   // Updates the socket input buffer. 
   bool Socket::ReceiveData ()
   {
      if (IsDead())
	 return false;
      
      if (m_Priv->ProcessRead (1))
      {
	 while (m_Priv->ProcessRead (0))
	    ;
      }

      return true;
   }



   /* Send all datas in the output buffer.
    * You should call this after SendMessage, if you
    * want your message to be really sent.
    */
   bool Socket::SendData ()
   {
      if (IsDead())
	 return false;
      
      while (m_Priv->ProcessWrite ())
	 ;
      
      return true;
   }
   
   /* Returns the oldest message received but not evaluated (can be NULL).
    * You should destroy the message after you have used it.
    */
   std::istringstream *Socket::GetPacket ()
   {
      if (IsDead())
	 return NULL;

      std::istringstream *res;
      if (!m_Priv->m_In.empty())
      {
	 res = m_Priv->m_In.front();
	 m_Priv->m_In.pop_front ();
	 
	 return res;
      }
      
      return 0;
   }
   
   /* Put message on the list of messages to be send. It is referenced,
    * so you should call Packet::Unref after calling SendPacket(),
    * except if you want to send the packet to another socket or if you
    * want to keep it. You should *NEVER* destroy it (because in this case
    * the reference counter wouldnt be checked, while the socket may still
    * be using your packet).
    */
   bool Socket::SendPacket (Packet  *packet)
   {
      if (IsDead())
	 return false;

      assert (packet);
      
      packet->Ref ();
      m_Priv->m_Out.push_back (packet);

      return true;
   }

/* namespace Ark */
}
