/* Nessuslib -- the Nessus Library
 * Copyright (C) 1998 - 2002 Renaud Deraison
 * SSL Support Copyright (C) 2001 Michel Arboi
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the Free
 * Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 * Network Functions
 */ 

#define EXPORTING
#include <includes.h>
#include <stdarg.h>
#include "libnessus.h"
#include "network.h"
#include "resolve.h"

#include <setjmp.h>

#ifdef HAVE_SSL
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/rand.h>
#endif

#define TIMEOUT 20

#ifndef INADDR_NONE
#define INADDR_NONE 0xffffffff
#endif



/*----------------------------------------------------------------*
 * Low-level connection management                                *
 *----------------------------------------------------------------*/
 
/* Nessus "FILE" structure */
typedef struct {
 int fd;		/* socket number, or whatever */
 int transport;	/* "transport" layer code when stream is encapsultated. 
		 * Negative transport signals a free descriptor */
 int timeout;	  /* timeout, in seconds
		   * special values: -2 for default */
 int options;			/* Misc options - see libnessus.h */
  
 int port;			 
#ifdef HAVE_SSL
  SSL_CTX* 	ssl_ctx;	/* SSL context 	*/
  SSL_METHOD* 	ssl_mt;		/* SSL method   */
  SSL* 		ssl;		/* SSL handler  */
  int		last_ssl_err;	/* Last SSL error code */
#endif
 pid_t		pid;		/* Owner - for debugging only */
#if 0
  int		last_sock_err;	/* last socket level error */
#endif
  char*		buf;		/* NULL if unbuffered */
  int		bufsz, bufcnt, bufptr;
} nessus_connection;

/* 
 * The role of this offset is:
 * 1. To detect bugs when the program tries to write to a bad fd
 * 2. See if a fd is a real socket or a "nessus descriptor". This is a
 * quick & dirty hack and should be changed!!!
 */
#define NESSUS_FD_MAX 1024
#define NESSUS_FD_OFF 1000000

static nessus_connection connections[NESSUS_FD_MAX];

/*
 * Quick & dirty patch to run Nessus from behind a picky firewall (e.g.
 * FW/1 and his 'Rule 0'): Nessus will never open more than 1 connection at
 * a time.
 * Define NESSUS_CNX_LOCK, recompile and install nessus-library, and restart nessusd
 *
 * WARNING: waiting on the lock file may be long, so increase the default
 * script timeout or some scripts may be killed.
 */
#undef NESSUS_CNX_LOCK
/*#define NESSUS_CNX_LOCK	"/tmp/NessusCnx"*/

#ifdef NESSUS_CNX_LOCK
static int	lock_cnt = 0;
static int	lock_fd = -1;
#endif

/*
 * NESSUS_STREAM(x) is TRUE if <x> is a Nessus-ified fd
 */
#define NESSUS_STREAM(x) (((x - NESSUS_FD_OFF) < NESSUS_FD_MAX) && ((x - NESSUS_FD_OFF) >=0))


/*
 * Same as perror(), but prefixes the data by our pid
 */
static int 
nessus_perror(error)
 const char* error;
{
  fprintf(stderr, "[%d] %s : %s\n", getpid(), error, strerror(errno));
  return 0;
}

/*
 * Returns a free file descriptor
 */
static int
get_connection_fd()
{
 int i;
 
 for ( i = 0; i < NESSUS_FD_MAX ; i++)
 {
  if(connections[i].transport <= 0) /* Not used */
  {
   bzero(&(connections[i]),  sizeof(connections[i]));
   connections[i].pid = getpid();
   return i + NESSUS_FD_OFF;
  }
 }
 fprintf(stderr, "[%d] %s:%d : Out of Nessus file descriptors\n", 
	 getpid(), __FILE__, __LINE__);
 errno = EMFILE;
 return -1;
}

static int
release_connection_fd(fd)
 int fd;
{
 nessus_connection *p;
 
 if(!NESSUS_STREAM(fd))
    {
     errno = EINVAL;
     return -1;
    }
    
    
 p = &(connections[fd - NESSUS_FD_OFF]);

 efree(&p->buf);

#ifdef HAVE_SSL
 if (p->ssl != NULL)
  SSL_free(p->ssl);
 if (p->ssl_ctx != NULL)
  SSL_CTX_free(p->ssl_ctx);
#endif

/* 
 * So far, fd is always a socket. If this is changed in the future, this
 * code shall be fixed
 */
if (p->fd >= 0)
 {
  if (shutdown(p->fd, 2) < 0)
    {
#if DEBUG_SSL > 1
    /*
     * It's not uncommon to see that one fail, since a lot of
     * services close the connection before we ask them to
     * (ie: http), so we don't show this error by default
     */
    nessus_perror("release_connection_fd: shutdown()");
#endif    
    }
  if (socket_close(p->fd)  < 0)
    nessus_perror("release_connection_fd: close()");
 }
 bzero(p, sizeof(*p));
 p->transport = -1; 
 return 0;
}

/* ******** Compatibility function ******** */

ExtFunc int
nessus_register_connection(s, ssl)
     int	s;
#ifdef HAVE_SSL
     SSL	*ssl;
#else
     void	*ssl;
#endif
{
  int			fd;
  nessus_connection	*p;

  if((fd = get_connection_fd()) < 0)
    return -1;
  p = connections + (fd - NESSUS_FD_OFF);
#ifdef HAVE_SSL 
  p->ssl_ctx = NULL;
  p->ssl_mt = NULL;		/* shall be freed elsewhere */
  p->ssl = ssl;			/* will be freed on close */
#endif  
  p->timeout = TIMEOUT;		/* default value */
  p->port = 0;			/* just used for debug */
  p->fd = s;
  p->transport = (ssl != NULL) ? NESSUS_ENCAPS_SSLv23 : NESSUS_ENCAPS_IP;
  return fd;
}

ExtFunc int
nessus_deregister_connection(fd)
 int fd;
{
 nessus_connection * p;
 if(!NESSUS_STREAM(fd))
 {
  errno = EINVAL;
  return -1;
 }
 
 p = connections +  (fd - NESSUS_FD_OFF);
 bzero(p, sizeof(*p));
 p->transport = -1; 
 return 0;
}

/*----------------------------------------------------------------*
 * High-level connection management                               *
 *----------------------------------------------------------------*/

static int __port_closed;

static int unblock_socket(int soc)
{
  int	flags =  fcntl(soc, F_GETFL, 0);
  if (flags < 0)
{
      nessus_perror("fcntl(F_GETFL)");
      return -1;
    }
  if (fcntl(soc, F_SETFL, O_NONBLOCK | flags) < 0)
    {
      nessus_perror("fcntl(F_SETFL,O_NONBLOCK)");
      return -1;
    }
  return 0;
}

static int block_socket(int soc)
{
  int	flags =  fcntl(soc, F_GETFL, 0);
  if (flags < 0)
    {
      nessus_perror("fcntl(F_GETFL)");
      return -1;
    }
  if (fcntl(soc, F_SETFL, (~O_NONBLOCK) & flags) < 0)
    {
      nessus_perror("fcntl(F_SETFL,~O_NONBLOCK)");
      return -1;
    }
  return 0;
}

/*
 * Initialize the SSL library (error strings and algorithms) and try
 * to set the pseudo random generator to something less silly than the
 * default value: 1 according to SVID 3, BSD 4.3, ISO 9899 :-(
 */

#ifdef HAVE_SSL
/* Adapted from stunnel source code */
ExtFunc
void sslerror2(txt, err)
     char	*txt;
     int	err;
{
  char string[120];

  ERR_error_string(err, string);
  fprintf(stderr, "[%d] %s: %s\n", getpid(), txt, string);
}

void
sslerror(txt)
     char	*txt;
{
  sslerror2(txt, ERR_get_error());
}
#endif

ExtFunc int
nessus_SSL_init(path)
     char	*path;		/* entropy pool file name */
{
#ifdef HAVE_SSL
  SSL_library_init();
  SSL_load_error_strings();

#ifdef HAVE_RAND_STATUS
  if (RAND_status() == 1)
    {
    /* The random generator is happy with its current entropy pool */
    return 0;
   }
#endif


  /*
   * Init the random generator
   *
   * OpenSSL provides nice functions for this job.
   * OpenSSL also ensures that each thread uses a different seed.
   * So this function should be called *before* forking.
   * Cf. http://www.openssl.org/docs/crypto/RAND_add.html#
   *
   * On systems that have /dev/urandom, SSL uses it transparently to seed 
   * its PRNG
   */

 
#if 0
  RAND_screen();	/* Only available under MSWin */
#endif

#ifdef EGD_PATH
  /*
   * We have the entropy gathering daemon.
   * However, OpenSSL automatically query it if it is not in some odd place
   */
  if(RAND_egd(EGD_PATH) > 0)
	  return 0;
#endif

   if (path != NULL)
    {
    (void) RAND_load_file(path, -1);
    RAND_write_file(path);
    }
   else
   {
    /*
     * Try with the default path
     */
    char path[1024];
    if(RAND_file_name(path, sizeof(path) - 1) == 0)
	    return -1;
    path[sizeof(path) - 1] = '\0';
    if(RAND_load_file(path, -1) < 0)
	    return -1;
    RAND_write_file(path);	
    return 0;
   } 
#endif
   return -1;
}



ExtFunc int
nessus_get_socket_from_connection(fd)
     int	fd;
{
  nessus_connection	*fp;

  if (!NESSUS_STREAM(fd))
    {
      fprintf(stderr,
	      "[%d] nessus_get_socket_from_connection: bad fd <%d>\n", getpid(), fd);
      fflush(stderr);
      return fd;
    }
  fp = connections + (fd - NESSUS_FD_OFF);
  if(fp->transport <= 0)
    {
      fprintf(stderr, "nessus_get_socket_from_connection: fd <%d> is closed\n", fd);
      return -1;
    }
  return fp->fd;
}


#ifdef HAVE_SSL

int
nessus_SSL_password_cb(buf, size, rwflag, userdata)
     char *buf;
     int size;
     int rwflag;
     void *userdata;
{
  if (userdata != NULL)
    {
      buf[size - 1] = '\0';
      strncpy(buf, userdata, size - 1);
      return strlen(buf);
    }
  else
    {
      *buf = '\0';
      return 0;
    }
}


ExtFunc void
nessus_install_passwd_cb(ssl_ctx, pass)
     SSL_CTX	*ssl_ctx;
     char	*pass;
{
  SSL_CTX_set_default_passwd_cb(ssl_ctx, nessus_SSL_password_cb);
  SSL_CTX_set_default_passwd_cb_userdata(ssl_ctx, pass);
}



ExtFunc SSL*
stream_get_ssl(fd)
     int	fd;
{
 nessus_connection * fp;
 if(!NESSUS_STREAM(fd))
    {
     errno = EINVAL;
     return NULL;
    }
  fp = &(connections[fd - NESSUS_FD_OFF]);
  if (fp->transport <= 0)
    return NULL;
  else
    return fp->ssl;
}

#endif

ExtFunc int
stream_set_timeout(fd, timeout)
 int fd;
 int timeout;
{
 int old;
 nessus_connection * fp;
 if(!NESSUS_STREAM(fd))
    {
     errno = EINVAL;
     return 0;
    }
  fp = &(connections[fd - NESSUS_FD_OFF]);
  old = fp->timeout;
  fp->timeout = timeout;
  return old;
}

ExtFunc int
stream_set_options(fd, reset_opt, set_opt)
     int	fd, reset_opt, set_opt;
{
 nessus_connection * fp;
 if(!NESSUS_STREAM(fd))
    {
     errno = EINVAL;
     return -1;
    }
  fp = &(connections[fd - NESSUS_FD_OFF]);
  fp->options &= ~reset_opt;
  fp->options |= set_opt;
  return 0;
}


static int 
read_stream_connection_unbuffered(fd, buf0, min_len, max_len)
 int fd;
 void* buf0;
 int min_len, max_len;
{
  int			ret, realfd, trp, t, err;
  int			total = 0, flag = 0, timeout = TIMEOUT, waitall = 0;
  unsigned char		* buf = (unsigned char*)buf0;
  nessus_connection	*fp = NULL;
  fd_set		fdr, fdw;
  struct timeval	tv;
  time_t		now, then;

  int		 	select_status;
 

  if (NESSUS_STREAM(fd))
    {
      fp = &(connections[fd - NESSUS_FD_OFF]);
      trp = fp->transport;
      realfd = fp->fd;
      if (fp->timeout != -2)
	timeout = fp->timeout;
    }
  else
    {
      trp = NESSUS_ENCAPS_IP;
      if(fd < 0 || fd > 1024)
	      	{
			errno = EBADF;
			return -1;
		}
      realfd = fd;
    }

#ifndef INCR_TIMEOUT
# define INCR_TIMEOUT	1
#endif

#ifdef MSG_WAITALL
  if (min_len == max_len || timeout <= 0)
    waitall = MSG_WAITALL;
#endif

  if(trp == NESSUS_ENCAPS_IP)
    {
      for (t = 0; total < max_len && (timeout <= 0 || t < timeout); )
	{
	  tv.tv_sec = INCR_TIMEOUT; /* Not timeout! */
	  tv.tv_usec = 0;
	  FD_ZERO(&fdr);
	  FD_SET(realfd, &fdr);
	  if(select(realfd + 1, &fdr, NULL, NULL, timeout > 0 ? &tv : NULL) <= 0)
	    {
	      t += INCR_TIMEOUT;
	      /* Try to be smart */
	      if (total > 0 && flag) 
		return total;
	      else if (total >= min_len)
		flag ++;
	    }
	  else
	    {
	      errno = 0;
	      ret = recv(realfd, buf + total, max_len - total, waitall);
	      if (ret < 0)
		if (errno != EINTR)
		  return total;
		else
		  ret = 0;
	      else if (ret == 0) /* EOF */
		return total;
	      /*ret > 0*/
	      total += ret; 
	      if (min_len > 0 && total >= min_len)
		return total;
	      flag = 0;
	    }
	}
      return total;
    }

  switch(trp)
    {
      /* NESSUS_ENCAPS_IP was treated before with the non-Nessus fd */
#ifdef HAVE_SSL
    case NESSUS_ENCAPS_SSLv2:
    case NESSUS_ENCAPS_SSLv23:
    case NESSUS_ENCAPS_SSLv3:
    case NESSUS_ENCAPS_TLSv1:
# if DEBUG_SSL > 0
      if (getpid() != fp->pid)
	{
	  fprintf(stderr, "PID %d tries to use a SSL connection established by PID %d\n", getpid(), fp->pid);
	  errno = EINVAL;
	  return -1;
	}
# endif

      FD_ZERO(&fdr); FD_ZERO(&fdw);
      FD_SET(realfd, &fdr); FD_SET(realfd, &fdw); 
      now = then = time(NULL);
      for (t = 0; timeout <= 0 || t < timeout; t = now - then )
	{
          now = time(NULL);
	  tv.tv_sec = INCR_TIMEOUT; tv.tv_usec = 0;
	  select_status = select ( realfd + 1, &fdr, &fdw, NULL, &tv );
          if ( select_status == 0 )
          {
      	   FD_ZERO(&fdr); FD_ZERO(&fdw);
      	   FD_SET(realfd, &fdr); FD_SET(realfd, &fdw); 
          }
	  else
	  if ( select_status > 0 )
	    {
	  ret = SSL_read(fp->ssl, buf + total, max_len - total);
	  if (ret > 0)
		{
	          total += ret;
		  FD_SET(realfd, &fdr);
		  FD_SET(realfd, &fdw); 
		}

	  if (total >= max_len)
	    return total;
	      if (ret <= 0)
		{
		  err = SSL_get_error(fp->ssl, ret);
		  FD_ZERO(&fdr); 
		  FD_ZERO(&fdw);
		  switch (err)
	   {
		    case SSL_ERROR_WANT_READ:
#if DEBUG_SSL > 2
		      fprintf(stderr, "SSL_read[%d]: SSL_ERROR_WANT_READ\n", getpid());
#endif
		      FD_SET(realfd, &fdr);
		      break;
		    case SSL_ERROR_WANT_WRITE:
#if DEBUG_SSL > 2
		      fprintf(stderr, "SSL_Connect[%d]: SSL_ERROR_WANT_WRITE\n", getpid());
#endif
		      FD_SET(realfd, &fdr);
		      FD_SET(realfd, &fdw);
		      break;

		    case SSL_ERROR_ZERO_RETURN:
#if DEBUG_SSL > 2
		      fprintf(stderr, "SSL_Connect[%d]: SSL_ERROR_ZERO_RETURN\n", getpid());
#endif
		      return total;

		    default:
#if DEBUG_SSL > 0
		      sslerror2("SSL_read", err);
#endif
		      return total;
		    }
		}
	    }

	    if (min_len <= 0)
	      {
		/* Be smart */
		if (total > 0 && flag)
		  return total;
		else
		  flag ++;
	      }
	  else if (total >= min_len)
		return total;
	}
      return total;
#endif
    default :
      if (fp->transport != -1 || fp->fd != 0)
	fprintf(stderr, "Severe bug! Unhandled transport layer %d (fd=%d)\n",
		fp->transport, fd);
      else
	fprintf(stderr, "read_stream_connection_unbuffered: fd=%d is closed\n", fd);
      errno = EINVAL;
      return -1;
    }
  /*NOTREACHED*/
}

ExtFunc int 
read_stream_connection_min(fd, buf0, min_len, max_len)
 int fd;
 void* buf0;
 int min_len, max_len;
{
  nessus_connection	*fp;

  if (NESSUS_STREAM(fd))
    {
      fp = &(connections[fd - NESSUS_FD_OFF]);
      if (fp->buf != NULL)
	{
	  int		l1, l2;

	  if (max_len == 1) min_len = 1; /* avoid "magic read" later */
	  l2 = max_len > fp->bufcnt ? fp->bufcnt : max_len;
	  if (l2 > 0)
	    {
	      memcpy(buf0, fp->buf + fp->bufptr, l2);
	      fp->bufcnt -= l2;
	      if (fp->bufcnt == 0)
		{
		  fp->bufptr = 0;
		  fp->buf[0] = '\0'; /* debug */
		}
	      else
		fp->bufptr += l2;
	      if (l2 >= min_len || l2 >= max_len)
		return l2;
	      max_len -= l2;
	      min_len -= l2;
	    }
	  if (min_len > fp->bufsz)
	    {
	      l1 = read_stream_connection_unbuffered(fd, (char*)buf0 + l2,
						     min_len, max_len);
	      if (l1 > 0)
		return l1 + l2;
	      else
		return l2;
	    }
	  /* Fill buffer */
	  l1 = read_stream_connection_unbuffered(fd, fp->buf, min_len, fp->bufsz);
	  if (l1 <= 0)
	    return l2;
	  
	  fp->bufcnt = l1;
	  l1 = max_len > fp->bufcnt ? fp->bufcnt : max_len;
	  memcpy((char*)buf0 + l2, fp->buf + fp->bufptr, l1);
	  fp->bufcnt -= l1;
	  if (fp->bufcnt == 0)
	    fp->bufptr = 0;
	  else
	    fp->bufptr += l1;
	  return l1 + l2;
	}
    }
  return read_stream_connection_unbuffered(fd, buf0, min_len, max_len);
}

ExtFunc int 
read_stream_connection(fd, buf0, len)
 int fd;
 void* buf0;
 int len;
{
 return read_stream_connection_min(fd, buf0, -1, len);
}

static int
write_stream_connection4(fd, buf0, n, i_opt) 
 int fd;
 void * buf0;
 int n;
 int	i_opt;
{
  int			err, ret, count;
 unsigned char* buf = (unsigned char*)buf0;
 nessus_connection * fp;
  fd_set		fdr, fdw;
  struct timeval	tv;
  int e;

 if(!NESSUS_STREAM(fd))
   {
#if DEBUG_SSL > 0
     fprintf(stderr, "write_stream_connection: fd <%d> invalid\n", fd);
# if 0
     abort();
# endif
#endif
     errno = EINVAL;
     return -1;
    }

 fp = &(connections[fd - NESSUS_FD_OFF]);
 
#if DEBUG_SSL > 8
 fprintf(stderr, "> write_stream_connection(%d, 0x%x, %d, 0x%x) \tE=%d 0=0x%x\n",
	 fd, buf, n, i_opt, fp->transport, fp->options);
#endif

 switch(fp->transport)
 {
  case NESSUS_ENCAPS_IP:
   for(count = 0; count < n;)
   {
     ret = send(fp->fd, buf + count, n - count, i_opt);

    if(ret <= 0)
      break;
     
     count += ret;
    }
    break;

#ifdef HAVE_SSL
  case NESSUS_ENCAPS_SSLv2:
  case NESSUS_ENCAPS_SSLv23:
  case NESSUS_ENCAPS_SSLv3:
  case NESSUS_ENCAPS_TLSv1:
      FD_ZERO(&fdr); FD_ZERO(&fdw); 
      FD_SET(fp->fd, & fdr); FD_SET(fp->fd, & fdw);

      /* i_opt ignored for SSL */
    for(count = 0; count < n;)
    { 
     ret = SSL_write(fp->ssl, buf + count, n - count);
	  if (ret > 0)
	    count += ret;
	  else
	    {
	      fp->last_ssl_err = err = SSL_get_error(fp->ssl, ret);
	      FD_ZERO(&fdw); FD_ZERO(&fdr); 
	      if (err == SSL_ERROR_WANT_WRITE)
		{
		  FD_SET(fp->fd, &fdw);
#if DEBUG_SSL > 2
		  fprintf(stderr, "SSL_write[%d]: SSL_ERROR_WANT_WRITE\n", getpid());
#endif    
     }
	      else if (err == SSL_ERROR_WANT_READ)
		{
#if DEBUG_SSL > 2
		  fprintf(stderr, "SSL_write[%d]: SSL_ERROR_WANT_READ\n", getpid());
#endif
		  FD_SET(fp->fd, &fdr);
		}
	      else
     { 
#if DEBUG_SSL > 0
		  sslerror2("SSL_write", err);
#endif      
  	break;
     }
	      if (fp->timeout >= 0)
		tv.tv_sec = fp->timeout;
     else 
		tv.tv_sec = TIMEOUT;

	      tv.tv_usec = 0;
 	      do {
 	      errno = 0;
	      e = select(fp->fd+1, &fdr, &fdw, NULL, &tv);
 	      } while ( e < 0 && errno == EINTR );

	    if ( e <= 0 )
		{
#if DEBUG_SSL > 0
		  nessus_perror("select");
#endif
		  break;
		}
	    }
     }
    break;
#endif
   default:
     if (fp->transport != -1 || fp->fd != 0)
       fprintf(stderr, "Severe bug! Unhandled transport layer %d (fd=%d)\n",
	       fp->transport, fd);
     else
       fprintf(stderr, "read_stream_connection_unbuffered: fd=%d is closed\n", fd);
     errno =EINVAL;
     return -1;
  }
  
  
  if(count == 0 && n > 0)
   return -1;
  else 
   return count;
}

ExtFunc int
write_stream_connection(fd, buf0, n) 
 int fd;
 void * buf0;
 int n;
{
  return write_stream_connection4(fd, buf0, n, 0);
}

ExtFunc int
nsend (fd, data, length, i_opt)
 int fd;
 void * data;
 int length, i_opt;
{
  int		n = 0;

 if(NESSUS_STREAM(fd))
 {
  if(connections[fd - NESSUS_FD_OFF].fd < 0)
   fprintf(stderr, "Nessus file descriptor %d closed ?!\n", fd);
  else 
    return write_stream_connection4(fd, data, length, i_opt);
 }


 /* Trying OS's send() */
   block_socket(fd);	
   do
 {
       struct timeval tv = {0,5};
       fd_set wr;
       int e;
       
       FD_ZERO(&wr);
       FD_SET(fd, &wr);
       
       errno = 0;
       e  = select(fd + 1, NULL, &wr, NULL, &tv);
       if ( e > 0 )
       	 n = os_send(fd, data, length, i_opt);
       else if ( e < 0 && errno == EINTR ) continue;
       else break;
     }
   while (n <= 0 && errno == EINTR);
   if (n < 0)
     fprintf(stderr, "[%d] nsend():send %s\n", getpid(), strerror(errno));
   return n;
 }
 
ExtFunc int
nrecv (fd, data, length, i_opt)
 int fd;
 void * data;
 int length, i_opt;
{
  int e;
 if(NESSUS_STREAM(fd))
 {
  if(connections[fd - NESSUS_FD_OFF].fd < 0)
   {
   fprintf(stderr, "Nessus file descriptor %d closed ?!\n", fd);
   return -1;
   }
  else 
    return read_stream_connection(fd, data, length);
  }

 block_socket(fd);
 do {
	e = recv(fd, data, length, i_opt);
 } while ( e < 0 && errno == EINTR );
 return e;
}
 

ExtFunc int
close_stream_connection(fd)
 int fd;
{
  if(!NESSUS_STREAM(fd))
   {
    if ( fd < 0 || fd > 1024 )
    {
	   errno = EINVAL;
	   return -1;
    }
   shutdown(fd, 2);
   return socket_close(fd);
   }
  else
   return release_connection_fd(fd);
}



static int
open_socket(struct sockaddr_in *paddr, 
	    int port, int type, int protocol, int timeout)
{
  fd_set		fd_w;
  struct timeval	to;
  int			soc, x;
  int			opt;
  unsigned int opt_sz;

  __port_closed = 0;

  if ((soc = socket(AF_INET, type, protocol)) < 0)
    {
      nessus_perror("socket");
      return -1;
    }

  if (timeout == -2)
    timeout = TIMEOUT;

  if (timeout > 0)
    if (unblock_socket(soc) < 0)
      {
	close(soc);
	return -1;
      }

  
  if (connect(soc, (struct sockaddr*) paddr, sizeof(*paddr)) < 0)
    {
#if debug_SSL > 2
      nessus_perror("connect");
#endif
again:
      switch (errno)
	{
	case EINPROGRESS:
	case EAGAIN:
	  FD_ZERO(&fd_w);
	  FD_SET(soc, &fd_w);
	  to.tv_sec = timeout;
	  to.tv_usec = 0;
	  x = select(soc + 1, NULL, &fd_w, NULL, &to);
	  if (x == 0)
	    {
#if debug_SSL > 2
	      nessus_perror("connect->select: timeout");
#endif
	      socket_close(soc);
	      errno = ETIMEDOUT;
	      return -1;
	    }
	  else if (x < 0)
	    {
	      if ( errno == EINTR )
               {
 		 errno = EAGAIN;
		 goto again;
	       }
	      nessus_perror("select");
	      socket_close(soc);
	      return -1;
            }
 
	  opt = 0; opt_sz = sizeof(opt);
	  if (getsockopt(soc, SOL_SOCKET, SO_ERROR, &opt, &opt_sz) < 0)
	    {
	      nessus_perror("getsockopt");
	      socket_close(soc);
	      return -1;
	    }
	  
	  if (opt == 0)
	    break;
#if DEBUG_SSL > 2
	  errno = opt;
	  nessus_perror("SO_ERROR");
#endif
	  /* no break; go on */	  
	default:
	  __port_closed = 1;
	  socket_close(soc);
	  return  -1;
	}
    }
  block_socket(soc);
  return soc;
}


ExtFunc 
int open_sock_opt_hn(hostname, port, type, protocol, timeout)
 const char * hostname; 
 unsigned int port; 
 int type;
 int protocol;
 int timeout;
{
 struct sockaddr_in addr;
  
  bzero((void*)&addr, sizeof(addr));
  addr.sin_family=AF_INET;
  addr.sin_port=htons((unsigned short)port);
  addr.sin_addr = nn_resolve(hostname);
  if (addr.sin_addr.s_addr == INADDR_NONE || addr.sin_addr.s_addr == 0)
    {
      fprintf(stderr, "open_sock_opt_hn: invalid socket address\n");
      return  -1;
    }
   
  return open_socket(&addr, port, type, protocol, timeout);
}


ExtFunc
int open_sock_tcp_hn(hostname, port)
 const char * hostname;
 unsigned int port;
{
  return open_sock_opt_hn(hostname, port, SOCK_STREAM, IPPROTO_TCP, TIMEOUT);
}





/* This function reads a text from the socket stream into the
   argument buffer, always appending a '\0' byte.  The return
   value is the number of bytes read, without the trailing '\0'.
 */


ExtFunc
int recv_line(soc, buf, bufsiz)
 int soc;
 char * buf;
 size_t bufsiz;
{
  int n, ret = 0;
  
  /*
   * Dirty SSL hack
   */
  if(NESSUS_STREAM(soc))
  {
   buf[0] = '\0';
   
   do
   {
    n = read_stream_connection_min (soc, buf + ret, 1, 1);
    switch (n)
    {
     case -1 :
       if(ret == 0)
        return -1;
       else 
        return ret;
       break;
     
     case 0:
       return ret;
       break;
      
      default :
      	ret ++;
    }
   }
   while (buf [ret-1] != '\0' && buf [ret-1] != '\n' && ret < bufsiz) ;
   
   if(ret > 0 )
   {
   if (buf[ret - 1] != '\0')
	{
	if ( ret < bufsiz ) 
		buf[ ret ] = '\0';
	else 
		buf [ bufsiz - 1 ] = '\0';
	}
   }
   return ret;  
  }
  else
  {
   fd_set rd;
   struct timeval tv;
   
   do
   {
      int e;
 again:
      errno = 0;
      FD_ZERO(&rd);
      FD_SET(soc, &rd);
      tv.tv_sec = 5;
      tv.tv_usec = 0;
      e = select(soc+1, &rd, NULL, NULL, &tv); 
      if( e < 0 && errno == EINTR) goto again;
      if( e > 0 )
      {
       n = recv(soc, buf + ret, 1, 0);
       switch(n)
       {
        case -1 :
	 if ( errno == EINTR ) continue;
	 if(ret == 0)
	  return -1;
	 else
	  return ret;
	 break;  
       case 0 :
         return ret;
       	 break;
       default:
         ret ++;	
       }
      } 
      else break;
      tv.tv_sec = 1;
      tv.tv_usec = 0;
    } while(buf[ret -1 ] != '\0' && buf[ret -1 ] != '\n' && ret < bufsiz);
    
    if(ret > 0)
    {
    if(buf[ret - 1] != '\0')
      {
	if ( ret < bufsiz )
	      	buf[ret] = '\0';
	else
		buf[bufsiz - 1] = '\0';
      }
    }
  }
  return ret;
} 

ExtFunc int
socket_close(soc)
int soc;
{
#if defined NESSUS_CNX_LOCK
  if (lock_cnt > 0)
    if (-- lock_cnt == 0)
      {
	if (flock(lock_fd, LOCK_UN) < 0)
	  nessus_perror(NESSUS_CNX_LOCK);
	if (close(lock_fd) < 0)
	  nessus_perror(NESSUS_CNX_LOCK);
	lock_fd = -1;
      }
#endif  
  return close(soc);
}

/*
 * auth_printf()
 *
 * Writes data to the global socket of the thread
 */
ExtFunc void 
auth_printf(struct arglist * globals, char * data, ...)
{
  va_list param;
  char buffer[65535];
  
  bzero(buffer, sizeof(buffer));

  va_start(param, data);
  vsnprintf(buffer, sizeof(buffer) - 1, data, param);
  
  va_end(param);
  auth_send(globals, buffer);
}                    


ExtFunc void
auth_send(struct arglist * globals, char * data)
{
 int soc = (int)arg_get_value(globals, "global_socket");
 int confirm = (int)arg_get_value(globals, "confirm");
 int n = 0;
 int length;
 int sent = 0;

 if(soc < 0)
  return;

#ifndef NESSUSNT
 signal(SIGPIPE, _exit);
#endif
 length = strlen(data);
 while(sent < length)
 {
 n = nsend(soc, data+sent, length-sent, 0);
 if(n < 0)
 {
  if((errno == ENOMEM)
#ifdef ENOBUFS  
   ||(errno==ENOBUFS)
#endif   
   )
   n = 0;
  else
   {
   nessus_perror("nsend");
   goto out;
   }
 }
 else sent+=n;
 }
 
 if(confirm)
 {
  /*
   * If confirm is set, then we are a son
   * trying to report some message to our busy
   * father. So we wait until he told us he
   * took care of it
   */
  char n;
  read_stream_connection_min(soc, &n, 1, 1);
 }
out:
#ifndef NESSUSNT
  signal(SIGPIPE, SIG_IGN);
#else
 ;
#endif
}

/*
 * auth_gets()
 *
 * Reads data from the global socket of the thread
 */
ExtFunc char * 
auth_gets(globals, buf, bufsiz)
     struct arglist * globals;
     char * buf;
     size_t bufsiz;
{
  int soc = (int)arg_get_value(globals, "global_socket");
  int n;
  /* bzero(buf, bufsiz); */
  n = recv_line(soc, buf, bufsiz);
  if(n <= 0)
	  return NULL;
  
  return(buf);
}


/*
 * Select() routines
 */
 
ExtFunc int
stream_zero(set)
 fd_set * set;
{ 
 FD_ZERO(set);
 return 0;
}

ExtFunc int
stream_set(fd, set)
 int fd;
 fd_set * set;
{
 int soc = nessus_get_socket_from_connection(fd);
 if(soc >= 0)
  FD_SET(soc, set);
 return soc;
}

ExtFunc int
stream_isset(fd, set)
 int fd;
 fd_set * set;
{
 return FD_ISSET(nessus_get_socket_from_connection(fd), set);
}

ExtFunc int
fd_is_stream(fd)
     int	fd;
{
  return NESSUS_STREAM(fd);	/* Should probably be smarter... */
}


ExtFunc int 
stream_get_buffer_sz ( int fd )
{
  nessus_connection	*p;
  if (! NESSUS_STREAM(fd))
    return -1;
  p = &(connections[fd - NESSUS_FD_OFF]);
  return p->bufsz;
}


ExtFunc int
stream_set_buffer(fd, sz)
     int	fd, sz;
{
  nessus_connection	*p;
  char			*b;

  if (! NESSUS_STREAM(fd))
    return -1;

  p = &(connections[fd - NESSUS_FD_OFF]);
  if (sz < p->bufcnt)
      return -1;		/* Do not want to lose data */

  if (sz == 0)
    {
      efree(&p->buf);
      p->bufsz = 0;
      return 0;
    }
  else if (p->buf == 0)
    {
      p->buf = malloc(sz);
      if (p->buf == NULL)
	return -1;
      p->bufsz = sz;
      p->bufptr = 0;
      p->bufcnt = 0;
      return 0;
    }
  else
    {
      if (p->bufcnt > 0)
	{
	  memmove(p->buf, p->buf + p->bufptr, p->bufcnt);
	  p->bufptr = 0;
	}
      b = realloc(p->buf, sz);
      if (b == NULL)
	return -1;
      p->bufsz = sz;
      return 0;
    }
  /*NOTREACHED*/
}



/*------------------------------------------------------------------*/


int os_send(int soc, void * buf, int len, int opt )
{
 char * buf0 = (char*)buf;
 int e, n;
 for ( n = 0 ; n < len ; ) 
 {
  errno = 0;
  e = send(soc, buf0 + n , len -  n, opt);
  if ( e < 0 && errno == EINTR ) continue; 
  else if ( e <= 0 ) return -1;
  else n += e;
 }
 return n;
}

int os_recv(int soc, void * buf, int len, int opt )
{
 char * buf0 = (char*)buf;
 int e, n;
 for ( n = 0 ; n < len ; ) 
 {
  errno = 0;
  e = recv(soc, buf0 + n , len -  n, opt);
  if ( e < 0 && errno == EINTR ) continue; 
  else if ( e <= 0 ) return -1;
  else n += e;
 }
 return n;
}



int
get_and_clear_stream_socket_errno(int fd)
{
  nessus_connection	*fp = NULL;
  int		err;
  unsigned  opt_sz = sizeof(err);

  if (NESSUS_STREAM(fd))
    {
      fp = connections + (fd - NESSUS_FD_OFF);
      if (fp->fd < 0)
	{
	  fprintf(stderr,
	    "[%d] get_and_clear_last_socket_error: closed Nessus fd <%d>\n", getpid(), fd);
	  errno = EINVAL;
	  return -1;
	}
    }
#if DEBUG_SSL > 0
  else
    fprintf(stderr,
	    "[%d] get_and_clear_last_socket_error: not a Nessus fd <%d>\n", getpid(), fd);
#endif

#if 0
  if (fp->last_sock_err != 0)
    {
      err = fp->last_sock_err;
      fp->last_sock_err = 0;
      return err;
    }
#endif
  if (getsockopt(fp != NULL ? fp->fd : fd, 
		 SOL_SOCKET, SO_ERROR, &err, &opt_sz) < 0)
    {
      nessus_perror("getsockopt");
      return -1;
    }
  else
    return err;
}

ExtFunc int stream_pending(int fd)
{
  nessus_connection * fp;
 if ( ! NESSUS_STREAM(fd) )
 {
  errno = EINVAL;
  return -1;
 }
 fp = &(connections[fd - NESSUS_FD_OFF]);

 if ( fp->bufcnt )
        return fp->bufcnt;
#ifdef HAVE_SSL
 else if ( fp->transport != NESSUS_ENCAPS_IP )
        return SSL_pending(fp->ssl);
#endif

 return 0;
}

