#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <openssl/md5.h>
#include <netdb.h>
#include <stdio.h>
#include <fcntl.h>
#include <unistd.h> /* close */
#include <zlib.h>
#include <signal.h>
#include "mcast.h"

#undef max
#define max(h,i) ( (h) > (i) ? (h) : (i))

#define MAXCLIENTS 256

int g_debug=0;

/* signal handler */
int g_sd=0;
int g_server_socket=0;
static void signalHandler(int nr) {
  logit("Retrive signal: %d",nr);
  if(g_server_socket!=0) close(g_server_socket);
  if(g_sd!=0) close(g_sd);
  exit(0);
}

/* set non blocking fd */
int set_nonblock_fd(int fd)
{
  int flags;

  flags = fcntl(fd, F_GETFL);
  if (flags == -1)
    return -1;

  flags |= O_NONBLOCK;
  flags = fcntl(fd, F_SETFL, flags);
  return flags;
}

/* create a tcp server socket */
int create_tcp_serversocket(int port)
{
  int s;
  struct sockaddr_in adr_srvr;

  /* socket address */  
  memset(&adr_srvr,0,sizeof(adr_srvr));
  adr_srvr.sin_family = AF_INET;
  adr_srvr.sin_port = htons((u_short)port);
  adr_srvr.sin_addr.s_addr = INADDR_ANY;

  /* socket */
  if( (s = socket(AF_INET,SOCK_STREAM,0)) < 0 ) {
    logit("line 67: cannot create server socket on port %d",port);
    return -1;
  }

  /* bind the server address */
  if( bind(s, (struct sockaddr *)&adr_srvr, sizeof(adr_srvr)) < 0 ) {
    logit("line 73: cannot bind the server socket on port %d",port);
    return -1;
  }

  return s;
}

int sync_request(int client_socket[], int nsockets, long value, int type) {

  int i=0;
  SynPacket packet;
  packet.header=HEADER;
  packet.type=type; 
  packet.value=value;

  for(i=0;i<nsockets;i++) {
#ifdef DEBUG
    logit("<-- %d %d",i,value);
#endif
    write(client_socket[i], &packet, sizeof(SynPacket));
  }
 
  return 0;
}

int sync_response(int sockets[], int nsockets, long seqnum, int numClients) {

  int i, j, n, maxfdpl, done;
  fd_set rfds;
  struct timeval tv;
  //char buf[BUFSIZ];
  SynPacket packet;

  done=0;
  for(j=0;j<4;j++) {

    FD_ZERO(&rfds);
    for(i=0;i<nsockets;i++) {
      FD_SET(sockets[i],&rfds);
      maxfdpl = max(maxfdpl, sockets[i]);
    }

    /* 0.5 sec */
    tv.tv_sec = 0;
    tv.tv_usec = 500000;
    if(select(maxfdpl+1,&rfds,NULL,NULL,&tv)<0) { 
      logit("sync_response: select error"); 
      return done; 
    }

    /* check */
    for(i=0;i<nsockets;i++) {
      if(FD_ISSET(sockets[i],&rfds)) {

        memset(&packet,0x00,sizeof(SynPacket));
        while( (n = read(sockets[i], &packet, sizeof(SynPacket))) > 0 ) {

		  if(packet.header==HEADER && 
		     packet.type==TYPE_RETURN_CODE &&
		     packet.value==seqnum) { 
#ifdef DEBUG
			   if(g_debug) logit("--> %d %d %d",i,packet.value,j);
#endif
			   done++;
               //FD_CLR(sockets[i],&rfds);
          }
		  else {
#ifdef DEBUG
		    if(g_debug) logit("--> %d %d %d (X)",i,packet.value,j);
#endif
          }
        } // END while
      } /* FD_ISSET */
    } 

    if(done==numClients) return done;

  } /* end for */
  return done;
}

int send_data(int msocket, struct sockaddr_in servAddr,
              int csockets[], int ncsockets, 
              long seqnum, char buf[], int n)  {

  DataPacket packet;
  unsigned char md[17], mdhex[33];
  int rc, numClients, retries;

  numClients = ncsockets;
  retries = ncsockets / 3 + 1;

  memset(md,0x00,17); memset(mdhex,0x00, 33);
  memset(&packet,0x00,sizeof(DataPacket));
  packet.header.header = HEADER;
  packet.header.type = TYPE_SEQUENCE_NUMBER;
  packet.header.value = seqnum;
  packet.ndata = n;
  memcpy(packet.data, buf, n);
#ifdef CHECKSUM
  if( MD5(buf,n,md)!=NULL ) {
    memcpy(packet.checksum, md, 16);
  #ifdef DEBUG
    md5hex(md,mdhex);
  #endif
  } else {
    memset(packet.checksum, 0x00, 16);
  }
#endif

#ifdef DEBUG
  if(g_debug) logit("seqnum: %d data: %d md5hex: %s", seqnum, n, mdhex);
#endif

do_send_data:
   rc = sendto(msocket,&packet,sizeof(DataPacket),0, 
               (struct sockaddr *)&servAddr, sizeof(servAddr));

   if (rc<0) { return rc; }

   /* sync response */
   numClients = numClients - sync_response(csockets, ncsockets, seqnum, numClients);
   if(numClients!=0) {
#ifdef DEBUG
     logit("resent data: numClients %d",numClients);
     g_debug=1;
#endif
    /*
    if(ttl<31) {
      ttl+=1;
      if(setsockopt(sd,IPPROTO_IP,IP_MULTICAST_TTL, &ttl,sizeof(ttl))<0) {
        logit("%s:%s> cannot set ttl = %d",argv[2],argv[3],ttl);
      } else { 
        logit("%s:%s> set ttl = %d",argv[2],argv[3],ttl);
      }
    }
    */
	if(--retries<0) return -1;
    goto do_send_data; 
  }
#ifdef DEBUG
  g_debug=0;
#endif

  return n;
}

int main(int argc, char *argv[]) {

  int sd, fd, i, j, n, sock_opt;
  char loop;
  struct stat file_info;
  unsigned char ttl = 1;
  struct sockaddr_in servAddr; // the multicast address

  int server_socket, *client_socket, nclients, len;
  struct sockaddr_in adr_clnt;
  fd_set rfds;

  struct hostent *h;
  char buf[BUFSIZ];
  long volumeSize, seqnum=1;

  if(argc<5) {
    printf("usage %s <# clients> <mgroup> <port> <image file> ...\n",argv[0]);
    exit(1);
  }
 
  h = gethostbyname(argv[2]);
  if(h==NULL) {
    printf("%s : unknown host '%s'\n",argv[0],argv[2]);
    exit(1);
  }

  /* check if dest address is multicast */
  servAddr.sin_family = h->h_addrtype;
  memcpy((char *) &servAddr.sin_addr.s_addr, h->h_addr_list[0],h->h_length);
  if(!IN_MULTICAST(ntohl(servAddr.sin_addr.s_addr))) {
    printf("%s : given address '%s' is not multicast \n",argv[0],
	      inet_ntoa(servAddr.sin_addr));
    exit(1);
  }

  /* create socket for multicast */
  sd = socket(AF_INET,SOCK_DGRAM,0);
  if (sd<0) {
    printf("%s : cannot open socket (data)\n",argv[0]);
    exit(1);
  }
 
  /* bind the port */
  servAddr.sin_port = htons(atoi(argv[3]));
  if(bind(sd,(struct sockaddr *) &servAddr,sizeof(servAddr))<0) {
    printf("%s : bind error\n", argv[0]);
    exit(1);
  }

  loop = 0;
  if(setsockopt(sd, IPPROTO_IP, IP_MULTICAST_LOOP, &loop, sizeof(char))<0) {
    printf("%s : cannot disable the loopback on the socket",argv[0]);
  }

  if(setsockopt(sd,IPPROTO_IP,IP_MULTICAST_TTL, &ttl,sizeof(ttl))<0) {
    printf("%s : cannot set ttl = %d\n",argv[0],ttl);
    exit(1);
  }

  printf("%s : sending data on multicast group '%s' (%s)\n",argv[0],
	 h->h_name,inet_ntoa(*(struct in_addr *) h->h_addr_list[0]));

  g_sd=sd;

  /* sync server & client */
  server_socket = create_tcp_serversocket(atoi(argv[3])+1);
  if (server_socket<0) {
    printf("%s : cannot open socket (sync)\n",argv[0]);
    close(sd);
    exit(1);
  }
  sock_opt=1;
  if(setsockopt(server_socket, SOL_SOCKET, SO_REUSEADDR, 
                (void *)&sock_opt, sizeof(sock_opt))<0) {
    printf("%s : cannot setsockopt(SO_REUSEADDR)");
  }
  g_server_socket=server_socket;

  if( listen(server_socket,10) < 0 ) {
    printf("%s : cannot listen to the socket (sync)\n",argv[0]);
    close(server_socket);
    close(sd);
    exit(1);
  }

  nclients = atoi(argv[1]);
  /* Blake, 2004/04/07, nclients = 0 .. */
  if(nclients==0) {

    nclients=MAXCLIENTS;
    client_socket = (int*)malloc(nclients * sizeof(int));
    for(i=0;i<nclients;i++) {
      len = sizeof(adr_clnt);
      client_socket[i] = accept(server_socket, (struct sockaddr *)&adr_clnt,&len);
      if(client_socket[i]<0) {
        logit("%s:%s> client_socket error",argv[2],argv[3]);
        i--;
      } else {
        /* check if this connect is used to tell me how many nclients */
        struct timeval tv;
        fd_set rfds;

        tv.tv_sec = 1; tv.tv_usec = 500000;
        FD_ZERO(&rfds); FD_SET(client_socket[i],&rfds); 
        if(select(client_socket[i]+1,&rfds,NULL,NULL,&tv)>=0 && 
           FD_ISSET(client_socket[i],&rfds)) {
          n = read(client_socket[i],buf,BUFSIZ);
          if(n>0) { 
            buf[n]='\0'; nclients = atoi(buf); 
            /* more clients ask for connection */
            if(nclients<i) nclients = i;
            logit("%s:%s> nclients is %d, i=%d",argv[2],argv[3],nclients,i);
            close(client_socket[i]);
            /* no clients are waiting.. abort!! */
            if(nclients==0) {
              close(sd);
              for(j=0;j<i;j++) { close(client_socket[j]); }
              free(client_socket);
              close(server_socket);
              printf("%s:%s> 0 clients are waiting.. abort!!",argv[2],argv[3]);
              exit(0);
            }
            i--;
          }
        } else {
          if(set_nonblock_fd(client_socket[i])<0) 
            logit("%s:%s> set nonblocking fd on %s failed",argv[2],argv[3],inet_ntoa(adr_clnt.sin_addr));
          printf("%s:%s> client %s connected\n",argv[2],argv[3],inet_ntoa(adr_clnt.sin_addr));
        }
      }
    }

  } else {
    
    client_socket = (int*)malloc(nclients * sizeof(int));
    for(i=0;i<nclients;i++) {
      len = sizeof(adr_clnt);
      client_socket[i] = accept(server_socket, (struct sockaddr *)&adr_clnt,&len);
      if(client_socket[i]<0) {
        logit("%s:%s> client_socket error",argv[2],argv[3]);
        i--;
      } else {
        if(set_nonblock_fd(client_socket[i])<0) 
          logit("%s:%s> set nonblocking fd on %s failed",argv[2],argv[3],inet_ntoa(adr_clnt.sin_addr));
        printf("%s:%s> client %s connected\n",argv[2],argv[3],inet_ntoa(adr_clnt.sin_addr));
      }
    }

  }

  /* signal handler */
  signal(SIGKILL, signalHandler);

  /* send data */
  for(i=4;i<argc;i++) {

    logit("%s:%s> start sending data %d",argv[2],argv[3],i);
    fd = open(argv[i],O_RDONLY);
    if(fd<0) { 
      logit("%s:%s> cannot open the file '%s'",argv[2], argv[3], argv[i]);
      continue; 
    }
    fstat(fd, &file_info); 
    volumeSize = file_info.st_size;  
    logit("%s:%s> volumeSize: %d",argv[2], argv[3], volumeSize);

#if 0
    /* skip partimage CVolumeHeader (512 bytes) */
    if(i>4) {
      pid_t pid;
	  int pipefd[2];

      pipe(pipefd);
      pid = fork();
	  if(pid!=0) {
	    /* parent process */
	    close(fd);
	    close(pipefd[1]);
	    fd = pipefd[0];
	  } else {
	    /* child process */
	    gzFile gzRF, gzWF;
	  
	    close(pipefd[0]);
        gzRF = gzdopen(fd,"rb");
        gzWF = gzdopen(pipefd[1],"wb");
		/* skip 512 bytes */
        gzseek(gzRF, 512L, SEEK_SET);	  
	    while( (n=gzread(gzRF, buf, BUFSIZ)) > 0) {
          gzwrite(gzWF, buf, BUFSIZ);
        }
        gzclose(gzRF);
	    gzclose(gzWF);
	  
	    exit(0);
	  }
    } // EOF i>4
#endif

    while( (n=read(fd, buf, BUFSIZ)) > 0 ) {

      if( send_data(sd,servAddr,client_socket,nclients,seqnum,buf,n) < 0 ) {
        logit("%s:%s> cannot send data %d",argv[2],argv[3],i);
        goto do_exit;
      }
      seqnum++;

    } /* end while */
    close(fd);

  }/* end for */

  /* End Of Multicast Stream */
  buf[0]='E'; buf[1]='O'; buf[2]='M'; buf[3]='S';
  if( send_data(sd,servAddr,client_socket,nclients,seqnum,buf,4) < 0 ) {
    logit("%s:%s> cannot send data 'EOMS'",argv[2],argv[3]);
  }

do_exit:
  /* close socket and exit */
  close(sd);
  for(i=0;i<nclients;i++) { close(client_socket[i]); }
  free(client_socket);
  close(server_socket);
 
  exit(0);
}
