/*
 * Copyright 2013-2024 Pierre Ossman for Cendio AB
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include "includes.h"

#include "log.h"

#undef read
#undef write

#undef close

#undef accept
#undef bind
#undef connect
#undef gethostbyaddr
#undef gethostbyname
#undef gethostname
#undef getpeername
#undef getprotobyname
#undef getservbyname
#undef getsockopt
#undef getsockname
#undef listen
#undef setsockopt
#undef socket

#undef select

#ifdef WIN32

#include <winternl.h>
#include <ntstatus.h>

extern ssize_t win32_read_console(int fildes, void *buf, size_t nbyte);
extern int win32_filter_console_events(HANDLE console);

ssize_t
win32_read(int fildes, void *buf, size_t nbyte)
{
	int ret;

	if (isatty(fildes))
		return win32_read_console(fildes, buf, nbyte);

	ret = recv(fildes, buf, nbyte, 0);
	if (ret != SOCKET_ERROR)
		return ret;

	errno = get_win32_error();

	if ((errno != WSAENOTSOCK) && (errno != WSAEBADF) &&
	    (errno != WSAENOTCONN)) {
		return -1;
	}

	return read(fildes, buf, nbyte);
}

ssize_t
win32_write(int fildes, const void *buf, size_t nbyte)
{
	int ret;

	ret = send(fildes, buf, nbyte, 0);
	if (ret != SOCKET_ERROR)
		return ret;

	errno = get_win32_error();

	if ((errno != WSAENOTSOCK) && (errno != WSAEBADF) &&
	    (errno != WSAENOTCONN)) {
		return -1;
	}

	ssize_t wret = write(fildes, buf, nbyte);

	/* Due to a quirk of mingw_write() we get ENOSPC instead of
	   EAGAIN when we have a full pipe. */
	if ((wret == -1) && (errno == ENOSPC)) {
		errno = EAGAIN;
	}

	return wret;
}

int
win32_close(int fd)
{
	int ret;

	if (win32_is_afunix(fd))
		return win32_afunix_close(fd);

	ret = closesocket(fd);
	if (ret != SOCKET_ERROR)
		return ret;

	errno = get_win32_error();

	if ((errno != WSAENOTSOCK) && (errno != WSAEBADF)) {
		return -1;
	}

	return close(fd);
}

int
win32_accept(int socket, struct sockaddr *address, socklen_t *address_len)
{
	SOCKET ret;

	if (win32_is_afunix(socket))
		return win32_afunix_accept(socket, address, address_len);

	ret = accept(socket, address, address_len);
	if (ret == INVALID_SOCKET) {
		errno = get_win32_error();
		return -1;
	}

	return ret;
}

int
win32_bind(int socket, const struct sockaddr *address,
           socklen_t address_len)
{
	int ret;

	if (win32_is_afunix(socket))
		return win32_afunix_bind(socket, address, address_len);

	ret = bind(socket, address, address_len);
	if (ret == SOCKET_ERROR) {
		errno = get_win32_error();
		return -1;
	}

	return 0;
}

int
win32_connect(int socket, const struct sockaddr *address,
              socklen_t address_len)
{
	int ret;

	if (win32_is_afunix(socket))
		return win32_afunix_connect(socket, address, address_len);

	ret = connect(socket, address, address_len);
	if (ret == SOCKET_ERROR) {
		/*
		 * POSIX says that the proper errno for a connect()
		 * that cannot complete immediately is EINPROGRESS.
		 * Windows decides to send the standard EWOULDBLOCK.
		 */
		if (WSAGetLastError() == WSAEWOULDBLOCK)
			errno = WSAEINPROGRESS;
		else
			errno = get_win32_error();

		return -1;
	}

	return 0;
}

struct hostent*
win32_gethostbyaddr(const void *addr, socklen_t len, int type)
{
	struct hostent *ret;

	ret = gethostbyaddr(addr, len, type);
	if (ret == NULL) {
		errno = get_win32_error();
		return NULL;
	}

	return ret;
}

struct hostent*
win32_gethostbyname(const char *name)
{
	struct hostent *ret;

	ret = gethostbyname(name);
	if (ret == NULL) {
		errno = get_win32_error();
		return NULL;
	}

	return ret;
}

int
win32_gethostname(char *name, size_t len)
{
	int ret;

	ret = gethostname(name, len);
	if (ret == SOCKET_ERROR) {
		errno = get_win32_error();
		return -1;
	}

	return 0;
}

int
win32_getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
{
	int ret;

	ret = getpeername(sockfd, addr, addrlen);
	if (ret == SOCKET_ERROR) {
		errno = get_win32_error();
		return -1;
	}

	return 0;
}

struct protoent*
win32_getprotobyname(const char *name)
{
	struct protoent *ret;

	ret = getprotobyname(name);
	if (ret == NULL) {
		errno = get_win32_error();
		return NULL;
	}

	return ret;
}

struct servent*
win32_getservbyname(const char *name, const char *proto)
{
	struct servent *ret;

	ret = getservbyname(name, proto);
	if (ret == NULL) {
		errno = get_win32_error();
		return NULL;
	}

	return ret;
}

int
win32_getsockopt(int sockfd, int level, int optname,
                 void *optval, socklen_t *optlen)
{
	int ret;

	ret = getsockopt(sockfd, level, optname, optval, optlen);
	if (ret == SOCKET_ERROR) {
		errno = get_win32_error();
		return -1;
	}

	return 0;
}

int
win32_getsockname(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
{
	int ret;

	ret = getsockname(sockfd, addr, addrlen);
	if (ret == SOCKET_ERROR) {
		errno = get_win32_error();
		return -1;
	}

	return 0;
}

int
win32_listen(int sockfd, int backlog)
{
	int ret;

	if (win32_is_afunix(sockfd))
		return win32_afunix_listen(sockfd, backlog);

	ret = listen(sockfd, backlog);
	if (ret == SOCKET_ERROR) {
		errno = get_win32_error();
		return -1;
	}

	return 0;
}

int
win32_setsockopt(int sockfd, int level, int optname,
                 const void *optval, socklen_t optlen)
{
	int ret;

	ret = setsockopt(sockfd, level, optname, optval, optlen);
	if (ret == SOCKET_ERROR) {
		errno = get_win32_error();
		return -1;
	}

	return 0;
}

int
win32_socket(int domain, int type, int protocol)
{
	SOCKET ret;

	if (domain == AF_UNIX)
		return win32_afunix_socket(domain, type, protocol);

	ret = socket(domain, type, protocol);
	if (ret == INVALID_SOCKET) {
		errno = get_win32_error();
		return -1;
	}

	return ret;
}

typedef struct fd_events {
	SOCKET fd;
	long events;
} fd_events;

typedef struct fd_event_set {
	u_int fd_count;
	fd_events fd_array[128];
} fd_event_set;

static void
build_fd_event(fd_event_set *events, fd_set *readfds, fd_set *writefds,
               fd_set *exceptfds)
{
	u_int i, j;
	fd_events *ev;

	events->fd_count = 0;

	if (readfds != NULL) {
		for (i = 0;i < readfds->fd_count;i++) {
			ev = &events->fd_array[events->fd_count++];
			ev->fd = readfds->fd_array[i];
			ev->events = FD_READ | FD_ACCEPT | FD_CLOSE;
		}
	}

	if (writefds != NULL) {
		for (i = 0;i < writefds->fd_count;i++) {
			for (j = 0;j < events->fd_count;j++) {
				ev = &events->fd_array[j];
				if (ev->fd == writefds->fd_array[i])
					break;
			}

			if (j == events->fd_count) {
				ev = &events->fd_array[events->fd_count++];
				ev->fd = writefds->fd_array[i];
				ev->events = 0;
			}

			ev->events |= FD_WRITE | FD_CONNECT;
		}
	}
}

static int
win32_wait(fd_set *readfds, fd_set *writefds,
           fd_set *exceptfds, struct timeval *timeout)
{
	int ret;
	u_int i;

	struct timeval no_timeout = {0, 0};

	fd_event_set events;
	WSAEVENT socket_event;

	int has_console, has_fifo;

	HANDLE handles[2];
	DWORD dwtimeout, pipe_timeout;

	has_console = 0;
	has_fifo = 0;

	/* Start by assembling the three lists into one */
	build_fd_event(&events, readfds, writefds, exceptfds);

	/*
	 * Now we need to check for pre-existing states, which only
	 * the good ol' select() can do.
	 */
	if (readfds != NULL) {
		for (i = 0;i < readfds->fd_count;i++) {
			if (isatty(readfds->fd_array[i]))
				has_console = 1;
			else if (isafifo(readfds->fd_array[i]))
				has_fifo = 1;
			else
				continue;
			FD_CLR(readfds->fd_array[i], readfds);
			i--;
		}
	}
	if (writefds != NULL) {
		for (i = 0;i < writefds->fd_count;i++) {
			if (isatty(writefds->fd_array[i]))
				; /* checked in win32_select() */
			else if (isafifo(writefds->fd_array[i]))
				has_fifo = 1;
			else
				continue;
			FD_CLR(writefds->fd_array[i], writefds);
			i--;
		}
	}
	if (exceptfds != NULL) {
		for (i = 0;i < exceptfds->fd_count;i++) {
			if (isatty(exceptfds->fd_array[i]) ||
			    isafifo(exceptfds->fd_array[i])) {
				FD_CLR(exceptfds->fd_array[i], exceptfds);
				i--;
			}
		}
	}

	ret = select(0, readfds, writefds, exceptfds, &no_timeout);
	if (ret > 0)
		return ret;
	if (ret == SOCKET_ERROR) {
		errno = get_win32_error();
		return -1;
	}

	/* Don't need these anymore */
	if (readfds != NULL)
		FD_ZERO(readfds);
	if (writefds != NULL)
		FD_ZERO(writefds);
	if (exceptfds != NULL)
		FD_ZERO(exceptfds);

	/* Connect an event to all sockets */
	socket_event = WSACreateEvent();

	for (i = 0;i < events.fd_count;i++) {
		if (isatty(events.fd_array[i].fd) ||
		    isafifo(events.fd_array[i].fd))
			continue;

		ret = WSAEventSelect(events.fd_array[i].fd, socket_event,
		                     events.fd_array[i].events);
		if (ret == SOCKET_ERROR) {
			errno = get_win32_error();
			ret = -1;
			goto end;
		}
	}

	handles[0] = socket_event;
	handles[1] = GetStdHandle(STD_INPUT_HANDLE);

	if (timeout == NULL)
		dwtimeout = INFINITE;
	else
		dwtimeout = timeout->tv_sec * 1000 + timeout->tv_usec / 1000;

	pipe_timeout = 100;

	/*
	 * It is impossible to get events for pipes, so we have to resort
	 * to polling if we have one in the list. We start by polling
	 * rapidly and gradually back off to at least a 1 second delay.
	 * This is done to quickly respond to interactive periods, yet
	 * still sleep properly during idle times.
	 */
	do {
		DWORD iter_timeout;

		iter_timeout = dwtimeout;

		if (has_fifo) {
			for (i = 0;i < events.fd_count;i++) {
				HANDLE handle;

				if (!isafifo(events.fd_array[i].fd))
					continue;

				handle = (HANDLE)_get_osfhandle(events.fd_array[i].fd);

				if (events.fd_array[i].events & FD_READ) {
					DWORD avail;

					if (!PeekNamedPipe(handle, NULL, 0, NULL, &avail, NULL)) {
						errno = get_win32_error();
						ret = -1;
						goto end;
					}

					if (avail > 0)
						FD_SET(events.fd_array[i].fd, readfds);
				}

				if (events.fd_array[i].events & FD_WRITE) {
					IO_STATUS_BLOCK iosb;
					FILE_PIPE_LOCAL_INFORMATION fpli;
					NTSTATUS ret;

					/*
					 * See comments in win32-misc.c about
					 * this call.
					 */
					ret = NtQueryInformationFile(handle, &iosb,
							             &fpli, sizeof(fpli),
							             FilePipeLocalInformation);
					if (ret != STATUS_SUCCESS) {
						SetLastError(ERROR_INVALID_PARAMETER);
						errno = get_win32_error();
						ret = -1;
						goto end;
					}

					if (fpli.WriteQuotaAvailable > 0)
						FD_SET(events.fd_array[i].fd, writefds);
				}
			}

			ret = 0;
			if (readfds != NULL)
				ret += readfds->fd_count;
			if (writefds != NULL)
				ret += writefds->fd_count;
			if (ret > 0)
				goto end;

			if (iter_timeout > pipe_timeout)
				iter_timeout = pipe_timeout;
			if (dwtimeout != INFINITE)
				dwtimeout -= iter_timeout;
			if (pipe_timeout < 1000)
				pipe_timeout *= 2;
		}

		/* Now do the actual wait...*/
		ret = WaitForMultipleObjects(has_console ? 2 : 1, handles,
		                             FALSE, iter_timeout);
	} while (has_fifo && (ret == WAIT_TIMEOUT) && (dwtimeout > 0));

	if (ret == WAIT_FAILED) {
		errno = get_win32_error();
		ret = -1;
		goto end;
	}

	if (ret == WAIT_TIMEOUT) {
		ret = 0;
		goto end;
	}

	/* We got something on stdin... */
	if (ret == (WAIT_OBJECT_0 + 1)) {
		/*
		 * stdin will signal on a lot of non-data events, so
		 * we need to filter out that crud.
		 */
		ret = win32_filter_console_events(GetStdHandle(STD_INPUT_HANDLE));
		if (ret <= 0)
			goto end;

		for (i = 0;i < events.fd_count;i++) {
			if (!isatty(events.fd_array[i].fd))
				continue;

			if (!(events.fd_array[i].events & FD_READ))
				continue;

			FD_SET(events.fd_array[i].fd, readfds);
		}

		ret = readfds->fd_count;
		goto end;
	}

	/* We got something on a socket. Figure out what... */
	for (i = 0;i < events.fd_count;i++) {
		WSANETWORKEVENTS net_events;

		if (isatty(events.fd_array[i].fd) ||
		    isafifo(events.fd_array[i].fd))
			continue;

		ret = WSAEnumNetworkEvents(events.fd_array[i].fd,
		                           socket_event, &net_events);
		if (ret == SOCKET_ERROR) {
			errno = get_win32_error();
			ret = -1;
			goto end;
		}

		if (readfds != NULL) {
			if (net_events.lNetworkEvents & events.fd_array[i].events &
			    (FD_READ | FD_ACCEPT | FD_CLOSE))
				FD_SET(events.fd_array[i].fd, readfds);
		}

		if (writefds != NULL) {
			if (net_events.lNetworkEvents & events.fd_array[i].events &
			    (FD_WRITE | FD_CONNECT))
				FD_SET(events.fd_array[i].fd, writefds);
		}
	}

	/* Sum up the number of signalled sockets */
	ret = 0;
	if (readfds != NULL)
		ret += readfds->fd_count;
	if (writefds != NULL)
		ret += writefds->fd_count;
	if (exceptfds != NULL)
		ret += exceptfds->fd_count;

end:
	/*
	 * Undocumented misfeature of the week: having an event
	 * connected to a socket makes accept() return WSEANOTSOCK.
	 * Make sure we clear the event from all sockets before
	 * returning.
	 */
	for (i = 0;i < events.fd_count;i++) {
		int local_ret;

		if (isatty(events.fd_array[i].fd) ||
		    isafifo(events.fd_array[i].fd))
			continue;

		local_ret = WSAEventSelect(events.fd_array[i].fd, NULL, 0);
		if (local_ret == SOCKET_ERROR) {
			error("WSAEventSelect(%d, NULL, 0): %s",
			      events.fd_array[i].fd,
			      win32_strerror(get_win32_error()));
		}
	}

	CloseHandle(socket_event);

	return ret;
}


int
win32_select(int nfds, fd_set *readfds, fd_set *writefds,
             fd_set *exceptfds, struct timeval *timeout)
{
	int ret, only_sockets;
	u_int i;

	/*
	 * We have no good way of checking if a console handle is writeable,
	 * so we assume they always are and let the blocking handle it
	 */
	if (writefds != NULL) {
		for (i = 0;i < writefds->fd_count;i++) {
			if (isatty(writefds->fd_array[i]))
				break;
		}

		if (i != writefds->fd_count) {
			fd_set newwr;

			FD_ZERO(&newwr);

			for (;i < writefds->fd_count;i++) {
				if (isatty(writefds->fd_array[i]))
					FD_SET(writefds->fd_array[i], &newwr);
			}

			if (readfds != NULL)
				FD_ZERO(readfds);
			if (writefds != NULL)
				FD_ZERO(writefds);
			if (exceptfds != NULL)
				FD_ZERO(exceptfds);

			memcpy(writefds, &newwr, sizeof(fd_set));

			return writefds->fd_count;
		}
	}

	/*
	 * Windows only allows sockets to select(), so if we have
	 * anything else we have to fall back to more complicated
	 * checks.
	 */
	only_sockets = 1;
	if (readfds != NULL) {
		for (i = 0;i < readfds->fd_count;i++) {
			if (isatty(readfds->fd_array[i]) ||
			    isafifo(readfds->fd_array[i])) {
				only_sockets = 0;
				break;
			}
		}
	}
	if ((writefds != NULL) && only_sockets) {
		for (i = 0;i < writefds->fd_count;i++) {
			if (isatty(writefds->fd_array[i]) ||
			    isafifo(writefds->fd_array[i])) {
				only_sockets = 0;
				break;
			}
		}
	}

	if (only_sockets) {
		ret = select(nfds, readfds, writefds, exceptfds, timeout);
		if (ret == SOCKET_ERROR)
			errno = get_win32_error();

		return ret;
	}

	return win32_wait(readfds, writefds, exceptfds, timeout);
}

int
win32_pselect(int nfds, fd_set *readfds, fd_set *writefds,
              fd_set *exceptfds, const struct timespec *timeout,
              const sigset_t *mask)
{
	struct timeval tv, *tvp = NULL;

	if (timeout != NULL) {
		tv.tv_sec = timeout->tv_sec;
		tv.tv_usec = timeout->tv_nsec / 1000;
		tvp = &tv;
	}

	/*
	 * Windows doesn't really have useful signals, so let's just ignore
	 * everything durint the select call
	 */
	return win32_select(nfds, readfds, writefds, exceptfds, tvp);
}

#endif /* WIN32 */
