/*
 * Copyright (C) 2011 Martin Willi
 *
 * Copyright (C) secunet Security Networks AG
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 */

#include "whitelist_control.h"

#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
#include <errno.h>

#include <daemon.h>
#include <collections/linked_list.h>
#include <processing/jobs/callback_job.h>

#include "whitelist_msg.h"

typedef struct private_whitelist_control_t private_whitelist_control_t;

/**
 * Private data of an whitelist_control_t object.
 */
struct private_whitelist_control_t {

	/**
	 * Public whitelist_control_t interface.
	 */
	whitelist_control_t public;

	/**
	 * Whitelist
	 */
	whitelist_listener_t *listener;

	/**
	 * Whitelist stream service
	 */
	stream_service_t *service;
};

/*
 * List whitelist entries using a read-copy
 */
static void list(private_whitelist_control_t *this,
				 stream_t *stream, identification_t *id)
{
	identification_t *current;
	enumerator_t *enumerator;
	linked_list_t *list;
	whitelist_msg_t msg = {
		.type = htonl(WHITELIST_LIST),
	};

	list = linked_list_create();
	enumerator = this->listener->create_enumerator(this->listener);
	while (enumerator->enumerate(enumerator, &current))
	{
		if (current->matches(current, id))
		{
			list->insert_last(list, current->clone(current));
		}
	}
	enumerator->destroy(enumerator);

	while (list->remove_first(list, (void**)&current) == SUCCESS)
	{
		snprintf(msg.id, sizeof(msg.id), "%Y", current);
		current->destroy(current);
		if (!stream->write_all(stream, &msg, sizeof(msg)))
		{
			DBG1(DBG_CFG, "listing whitelist failed: %s", strerror(errno));
			break;
		}
	}
	list->destroy_offset(list, offsetof(identification_t, destroy));

	msg.type = htonl(WHITELIST_END);
	memset(msg.id, 0, sizeof(msg.id));
	stream->write_all(stream, &msg, sizeof(msg));
}

/**
 * Information about a client connection.
 */
typedef struct {
	private_whitelist_control_t *this;
	whitelist_msg_t msg;
	size_t read;
} whitelist_conn_t;

/**
 * Information needed for async disconnect job.
 */
typedef struct {
	whitelist_conn_t *conn;
	stream_t *stream;
} disconnect_data_t;

/**
 * Asynchronous callback to disconnect client
 */
CALLBACK(disconnect_async, job_requeue_t,
	disconnect_data_t *data)
{
	data->stream->destroy(data->stream);
	free(data->conn);
	return JOB_REQUEUE_NONE;
}

/**
 * Disconnect a connected client
 */
static void disconnect(whitelist_conn_t *conn, stream_t *stream)
{
	disconnect_data_t *data;

	INIT(data,
		.conn = conn,
		.stream = stream,
	);
	lib->processor->queue_job(lib->processor,
			(job_t*)callback_job_create(disconnect_async, data, free, NULL));
}

/**
 * Dispatch a received message
 */
CALLBACK(on_read, bool,
	whitelist_conn_t *conn, stream_t *stream)
{
	private_whitelist_control_t *this = conn->this;
	identification_t *id;
	ssize_t len;

	while (TRUE)
	{
		while (conn->read < sizeof(conn->msg))
		{
			len = stream->read(stream, (char*)&conn->msg + conn->read,
							   sizeof(conn->msg) - conn->read, FALSE);
			if (len <= 0)
			{
				if (errno == EWOULDBLOCK)
				{
					return TRUE;
				}
				if (len != 0)
				{
					DBG1(DBG_CFG, "whitelist socket error: %s", strerror(errno));
				}
				disconnect(conn, stream);
				return FALSE;
			}
			conn->read += len;
		}

		conn->msg.id[sizeof(conn->msg.id) - 1] = 0;
		id = identification_create_from_string(conn->msg.id);
		switch (ntohl(conn->msg.type))
		{
			case WHITELIST_ADD:
				this->listener->add(this->listener, id);
				break;
			case WHITELIST_REMOVE:
				this->listener->remove(this->listener, id);
				break;
			case WHITELIST_LIST:
				list(this, stream, id);
				break;
			case WHITELIST_FLUSH:
				this->listener->flush(this->listener, id);
				break;
			case WHITELIST_ENABLE:
				this->listener->set_active(this->listener, TRUE);
				break;
			case WHITELIST_DISABLE:
				this->listener->set_active(this->listener, FALSE);
				break;
			default:
				DBG1(DBG_CFG, "received unknown whitelist command");
				break;
		}
		id->destroy(id);
		conn->read = 0;
	}

	return TRUE;
}

CALLBACK(on_accept, bool,
	private_whitelist_control_t *this, stream_t *stream)
{
	whitelist_conn_t *conn;

	INIT(conn,
		.this = this,
	);
	stream->on_read(stream, on_read, conn);
	return TRUE;
}

METHOD(whitelist_control_t, destroy, void,
	private_whitelist_control_t *this)
{
	this->service->destroy(this->service);
	free(this);
}

/**
 * See header
 */
whitelist_control_t *whitelist_control_create(whitelist_listener_t *listener)
{
	private_whitelist_control_t *this;
	char *uri;

	INIT(this,
		.public = {
			.destroy = _destroy,
		},
		.listener = listener,
	);

	uri = lib->settings->get_str(lib->settings,
				"%s.plugins.whitelist.socket", "unix://" WHITELIST_SOCKET,
				lib->ns);
	this->service = lib->streams->create_service(lib->streams, uri, 10);
	if (!this->service)
	{
		DBG1(DBG_CFG, "creating whitelist socket failed");
		free(this);
		return NULL;
	}

	this->service->on_accept(this->service, (stream_service_cb_t)on_accept,
							 this, JOB_PRIO_CRITICAL, 0);

	return &this->public;
}
