/*
    This file is part of SNIProxy.
    Copyright (C) 2009, Marcelo Reina Aguilar <m6rc310@yahoo.es>.

    SNIProxy is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    SNIProxy is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with SNIProxy.  If not, see <http://www.gnu.org/licenses/>.
*/

#include <stdio.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>
#include <string.h>
#include <time.h>
#include <unistd.h>

#define QUEUE_LENGTH 8
#define FAMILY_PROTOCOL PF_INET
#define PORT 4433
#define SSL_PORT 443
#define BUFFER_LENGTH 1024

int readFragment (int);
char * readSNI ();

char serverName[BUFFER_LENGTH];
char *serverDefault="savannah.gnu.org";
unsigned char buffer[BUFFER_LENGTH];
int length;
int upSSLv3;

char * readSNI () {
  int i,j;
  time_t t;
  
  if (buffer[5]!=1) return NULL;
  printf ("Handshake type Client Hello.\n");
  printf ("Protocol Version %d.%d\n",buffer[9],buffer[10]);
  t=(buffer[11]<<24)+(buffer[12]<<16)+
    (buffer[13]<<8)+buffer[14];
  printf ("GMT Unix Time %s\n",asctime (gmtime (&t)));
  printf ("SessionID Length %d\n",buffer[43]);
  i=44+buffer[43];
  if (i>=length) return NULL;
  printf ("Numbers of Cipher Suite %d\n",(buffer[i]<<8)+buffer[i+1]);
  i+=2+(buffer[i]<<8)+buffer[i+1];
  if (i>=length) return NULL;
  printf ("Numbers of Compression Method %d\n",buffer[i]);
  i+=1+buffer[i];
  printf ("Extensions Length %d\n",(buffer[i]<<8)+buffer[i+1]);
  i+=2;
  if (i>=length) return NULL;
  for (;;) {
    j=(buffer[i]<<8)+buffer[i+1];
    printf ("Extension Type %d\n",j);
    i+=2;
    if (i>=length) return NULL;
    printf ("Extension Length %d\n",(buffer[i]<<8)+buffer[i+1]);
    if (!j) {
      i+=2;
      if (i>=length) return NULL;
      printf ("Server Name List Length %d\n",(buffer[i]<<8)+buffer[i+1]);
      i+=2;
      printf ("Server Name Type %d\n",buffer[i]);
      if (buffer[i]) return NULL;
      i+=1;
      printf ("Server Name Length %d\n",(buffer[i]<<8)+buffer[i+1]);
      i+=2;
      if (i>=length) return NULL;
      printf ("Server Name [%s]\n",buffer+i);
      return buffer+i;
    }
    i+=2+(buffer[i]<<8)+buffer[i+1];
    if (i>=length) return NULL;
  }
}

int readFragment (int s) {
  int i;

  if (recv (s,buffer,1,MSG_WAITALL)<1) {
    fprintf (stderr,"Error recv %d\n",errno);
    return -1;
  }
  if (recv (s,buffer+1,1,MSG_WAITALL)<1) {
    fprintf (stderr,"Error recv %d\n",errno);
    return -1;
  }
  if (recv (s,buffer+2,1,MSG_WAITALL)<1) {
    fprintf (stderr,"Error recv %d\n",errno);
    return -1;
  }
  if (buffer[0]&0x80 && buffer[2]==1) {
    i=3;
    length=((buffer[0]&0x7f)<<8)+buffer[1]+2;
    upSSLv3=0;
    printf ("SSLv2\n\n");
  } else {
    if (recv (s,buffer+3,1,MSG_WAITALL)<1) {
      fprintf (stderr,"Error recv %d\n",errno);
      return -1;
    }
    if (recv (s,buffer+4,1,MSG_WAITALL)<1) {
      fprintf (stderr,"Error recv %d\n",errno);
      return -1;
    }
    i=5;
    length=(buffer[3]<<8)+buffer[4]+5;
    upSSLv3=1;
    printf ("SSLv3\n\n");
  }
  for (;i<length;i++) {
    if (recv (s,buffer+i,1,MSG_WAITALL)<1) {
      fprintf (stderr,"Error recv %d\n",errno);
      return -1;
    }
  }
  return 0;
}

int main (int argc,char *argv[]) {
  int s,sc,ss,c,i,l;
  struct sockaddr_in a,ac,ae;
  l=sizeof ac;
  time_t t;
  char *sn;
  pid_t pid;
 
  a.sin_family=FAMILY_PROTOCOL;
  a.sin_port=htons (PORT);
  a.sin_addr.s_addr=INADDR_ANY;
  memset (&(a.sin_zero),0,(size_t) 8);
 
  if ((s=socket (FAMILY_PROTOCOL,SOCK_STREAM,0))==-1) {
    fprintf (stderr,"Error socket %d\n",errno);
    return -1;
  }
  if ((bind (s,(struct sockaddr *) &a,sizeof a))==-1) {
    fprintf (stderr,"Error bind %d\n",errno);
    return -1;
  }
  if ((listen (s,QUEUE_LENGTH))==-1) {
    fprintf (stderr,"Error listen %d\n",errno);
    return -1;
  }
  
  for (;;) {
    if ((sc=accept (s,(struct sockaddr *) &ac,&l))==-1) {
      fprintf (stderr,"Error accept %d\n",errno);
      return -1;
    }
    pid=fork ();
    if (pid>0) continue;
    if (pid==-1)
      wait (NULL);
    printf ("Connect %s:%d\n",inet_ntoa (ac.sin_addr),ntohs (ac.sin_port));
    readFragment (sc);
    sn=readSNI ();
    if (sn!=NULL) strncpy (serverName,sn,BUFFER_LENGTH);
    else strncpy (serverName,serverDefault,BUFFER_LENGTH);
    printf ("Server %s:%d\n",serverName,SSL_PORT);
    if ((ss=socket (FAMILY_PROTOCOL,SOCK_STREAM,0))==-1) {
      fprintf (stderr,"Error socket %d\n",errno);
      return -1;
    }
    ae.sin_family=FAMILY_PROTOCOL;
    ae.sin_port=htons (SSL_PORT);
    ae.sin_addr=*((struct in_addr *)gethostbyname (serverName)->h_addr);
    memset (&(ae.sin_zero),0,(size_t) 8);
    if ((connect (ss,(struct sockaddr *) &ae,sizeof ae))==-1) {
      fprintf (stderr,"Error socket %d\n",errno);
      return -1;
    }
    printf ("Connect %s:%d\n",inet_ntoa (ae.sin_addr),ntohs (ae.sin_port));
    for (i=0;i<length;i++) {
      if ((send (ss,buffer+i,1,0))==-1) {
        fprintf (stderr,"Error send %d\n",errno);
        return -1;
      }
    }
    for (;;) {
      i=recv (ss,&c,1,MSG_DONTWAIT);
      if (i==1) {
        if ((send (sc,&c,1,0))==-1) {
          fprintf (stderr,"Error send %d\n",errno);
          break;
        }
      } else if (i==!EAGAIN) {
        fprintf (stderr,"Error send %d\n",errno);
        break;
      }
      i=recv (sc,&c,1,MSG_DONTWAIT);
      if (i==1) {
        if ((send (ss,&c,1,0))==-1) {
          fprintf (stderr,"Error send %d\n",errno);
          break;
        }
      } else if (i==!EAGAIN) {
        fprintf (stderr,"Error send %d\n",errno);
        break;
      }
    }
    if (shutdown (sc,SHUT_RDWR)==-1) {
      fprintf (stderr,"Error close %d\n",errno);
      return -1;
    }
    if (shutdown (ss,SHUT_RDWR)==-1) {
      fprintf (stderr,"Error close %d\n",errno);
      return -1;
    }
    printf ("Disconnect %s:%d\n",inet_ntoa (ac.sin_addr),ntohs (ac.sin_port));
  }
  if (shutdown (s,SHUT_RDWR)==-1) {
    fprintf (stderr,"Error close %d\n",errno);
    return -1;
  }
}
