#include "config.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <ctype.h>
#include <fcntl.h>
#include <pwd.h>
#include <errno.h>
#include <sys/uio.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/time.h>
#include <netinet/in.h>

#include "sftp.h"

char *
findbasename(char *str) {
	char *base;
	base = strrchr(str, '/');
	if (base == NULL)
		base = str;
	else
		base++;
	return base;
}

static int
readn(int fd, void *buf, size_t count) {
	int n, total = 0;

	while (total < count) {
		n = read(fd, (char*)buf + total, count - total);
		if (n == 0)
			return total;
		else if (n < 0) {
			if (errno == EINTR)
				continue;
			else
				return n;
		}
		else
			total += n;
	}
	return total;
}

static int
writen(int fd, void *buf, size_t count) {
	int n, total = 0;

	while (total < count) {
		n = write(fd, (char*)buf + total, count - total);
		if (n == 0)
			return total;
		else if (n < 0) {
			if (errno == EINTR)
				continue;
			else
				return n;
		}
		else
			total += n;
	}
	return total;
}

int
send_message(int sock, message m) {
	struct iovec v[4];
	u_int32_t nlen = htonl(m.len);
	int n;

	v[0].iov_base = (void *)&m.channel;
	v[0].iov_len = 1;

	v[1].iov_base = (void *)&m.command;
	v[1].iov_len = 1;

	v[2].iov_base = (void *) &nlen;
	v[2].iov_len = 4;

	v[3].iov_base = (void *)m.data;
	v[3].iov_len = m.len;

	n = writev(sock, v, 4);
	if (n < m.len + 6) {
		if (n < 0)
			return -1;
		if (n < 6) {
			u_int8_t data[6];
			data[0] = m.channel;
			data[1] = m.command;
			memcpy(data + 2, &nlen, 4);
			if (writen(sock, data + n, 6 - n) < 6 - n)
				return -1;
			n = 0;
		}
		else
			n -= 6;
		if (writen(sock, m.data + n, m.len - n) < m.len - n)
			return -1;
	}
	return 0;

}

int
recv_message(int sock, message *m) {
	int ret;
	unsigned char buf[6];

	ret = readn(sock, buf, 6);
	if (ret < 6)
		return (-1);
	m->channel = buf[0];
	m->command = buf[1];
	memcpy(&m->len, &buf[2], 4);
	m->len = ntohl(m->len);
	
	if (m->len > 0) {
		m->data = malloc(m->len);
		ret = readn(sock, m->data, m->len);
		if (ret < m->len) {
			free(m->data);
			memset(m, 0, sizeof(message));
			return (-1);
		}
	}
	else
		m->data = NULL;
	return 0;
}

int
query_message(int sock, message *m) {
	int ret;
	struct timeval tv;
	fd_set fds;

	memset(&tv, 0, sizeof(tv));
	FD_ZERO(&fds);
	FD_SET(sock, &fds);
	ret = select (sock+1, &fds, NULL, NULL, &tv);
	if (ret == 0) {
		m->data = NULL;
		return -1;
	}
	else
		return recv_message(sock, m);
}

message
_message(u_int8_t command, void *data, u_int32_t len) {
	return _data_message(0, command, data, len);
}

message
_data_message(u_int8_t channel, u_int8_t command, void *data, u_int32_t len) {
	message m;
	m.channel = channel;
	m.command = command;
	m.len = len;
	m.data = data;
	return m;
}

sftp_channel *
new_channel() {
	sftp_channel *channel = (sftp_channel *) malloc(sizeof(sftp_channel));
	memset (channel, 0, sizeof(sftp_channel));
	return channel;
}

u_int16_t
strtomode(char *str) {
	u_int16_t mode = 0;
	if (str[0] == 'r') mode |= S_IRUSR;
	if (str[1] == 'w') mode |= S_IWUSR;
	if (str[2] == 'x') mode |= S_IXUSR;
	if (str[3] == 'r') mode |= S_IRGRP;
	if (str[4] == 'w') mode |= S_IWGRP;
	if (str[5] == 'x') mode |= S_IXGRP;
	if (str[6] == 'r') mode |= S_IROTH;
	if (str[7] == 'w') mode |= S_IWOTH;
	if (str[8] == 'x') mode |= S_IXOTH;
	return mode;
}

void
modetostr(u_int16_t mode, char *str) {
	strncpy(str, "rwxrwxrwx", 9);
	if (!(mode & S_IRUSR)) str[0] = '-';
	if (!(mode & S_IWUSR)) str[1] = '-';
	if (!(mode & S_IXUSR)) str[2] = '-';
	if (!(mode & S_IRGRP)) str[3] = '-';
	if (!(mode & S_IWGRP)) str[4] = '-';
	if (!(mode & S_IXGRP)) str[5] = '-';
	if (!(mode & S_IROTH)) str[6] = '-';
	if (!(mode & S_IWOTH)) str[7] = '-';
	if (!(mode & S_IXOTH)) str[8] = '-';
}

/* Note - this is _not_ thread safe. */
char *
tildeexpand(char *path) {
	static char newpath[256];
	struct passwd *pwd;
	char *str;

	if (path[0] != '~')
		return path;
	str = strchr(path, '/');
	if (str != NULL)
		*str++ = 0;
	if (strcmp(path, "~") == 0)
		pwd = getpwuid(getuid());
	else
		pwd = getpwnam(path + 1);
	if (pwd == NULL)
		return NULL;
	if (str == NULL)
		return pwd->pw_dir;
	if (strlen(pwd->pw_dir) + strlen(str) + 2 > sizeof(newpath))
		return NULL;
	sprintf(newpath, "%s/%s", pwd->pw_dir, str);
	return newpath;
}
