#include <assert.h>
#include <math.h>
#include <stdlib.h>
#include <strings.h>

#include <Rmath.h>

#include "cgeneric.h"

#define Malloc(n_, type_) (type_ *) malloc((n_) * sizeof(type_))

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#ifndef M_SQRT2
#define M_SQRT2 1.41421356237309504880
#endif

enum {
	KERNEL_GAUSSIAN = 1,
	KERNEL_LAPLACE = 2,
	KERNEL_CAUCHY = 3
};

typedef struct {
	int kernel_id;
	double bandwidth;
	double bandwidth2;
	double lambda;
	double sigma_init;
	int gh_n;
	double *gh_nodes;
	double *gh_weights;
} MMDCache_tp;

static double sqr(double x)
{
	return x * x;
}

static const inla_cgeneric_vec_tp *find_double_vec(inla_cgeneric_data_tp *data, const char *name)
{
	for (int i = 0; i < data->n_doubles; i++) {
		if (!strcasecmp(data->doubles[i]->name, name)) {
			return data->doubles[i];
		}
	}
	return NULL;
}

static const inla_cgeneric_vec_tp *find_char_vec(inla_cgeneric_data_tp *data, const char *name)
{
	for (int i = 0; i < data->n_chars; i++) {
		if (!strcasecmp(data->chars[i]->name, name)) {
			return data->chars[i];
		}
	}
	return NULL;
}

static const inla_cgeneric_vec_tp *find_int_vec(inla_cgeneric_data_tp *data, const char *name)
{
	for (int i = 0; i < data->n_ints; i++) {
		if (!strcasecmp(data->ints[i]->name, name)) {
			return data->ints[i];
		}
	}
	return NULL;
}

static int kernel_id_from_name(const char *kernel)
{
	if (!strcasecmp(kernel, "Gaussian")) {
		return KERNEL_GAUSSIAN;
	}
	if (!strcasecmp(kernel, "Laplace")) {
		return KERNEL_LAPLACE;
	}
	if (!strcasecmp(kernel, "Cauchy")) {
		return KERNEL_CAUCHY;
	}
	return 0;
}

static void jacobi_eigen(int n, double *A, double *values, double *vectors)
{
	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			vectors[i * n + j] = (i == j ? 1.0 : 0.0);
		}
	}

	int max_iter = 50 * n * n;
	for (int iter = 0; iter < max_iter; iter++) {
		int p = 0, q = 1;
		double max_off = 0.0;

		for (int i = 0; i < n; i++) {
			for (int j = i + 1; j < n; j++) {
				double off = fabs(A[i * n + j]);
				if (off > max_off) {
					max_off = off;
					p = i;
					q = j;
				}
			}
		}

		if (max_off < 1e-14) {
			break;
		}

		double app = A[p * n + p];
		double aqq = A[q * n + q];
		double apq = A[p * n + q];
		double phi = 0.5 * atan2(2.0 * apq, aqq - app);
		double c = cos(phi);
		double s = sin(phi);

		for (int k = 0; k < n; k++) {
			if (k != p && k != q) {
				double aik = A[k * n + p];
				double akq = A[k * n + q];
				double new_kp = c * aik - s * akq;
				double new_kq = s * aik + c * akq;
				A[k * n + p] = new_kp;
				A[p * n + k] = new_kp;
				A[k * n + q] = new_kq;
				A[q * n + k] = new_kq;
			}
		}

		A[p * n + p] = c * c * app - 2.0 * s * c * apq + s * s * aqq;
		A[q * n + q] = s * s * app + 2.0 * s * c * apq + c * c * aqq;
		A[p * n + q] = 0.0;
		A[q * n + p] = 0.0;

		for (int k = 0; k < n; k++) {
			double vip = vectors[k * n + p];
			double viq = vectors[k * n + q];
			vectors[k * n + p] = c * vip - s * viq;
			vectors[k * n + q] = s * vip + c * viq;
		}
	}

	for (int i = 0; i < n; i++) {
		values[i] = A[i * n + i];
	}
}

static void build_gauss_hermite_rule(int n, double **nodes_out, double **weights_out)
{
	assert(n > 0);

	double *nodes = Malloc(n, double);
	double *weights = Malloc(n, double);
	assert(nodes != NULL);
	assert(weights != NULL);

	if (n == 1) {
		nodes[0] = 0.0;
		weights[0] = sqrt(M_PI);
		*nodes_out = nodes;
		*weights_out = weights;
		return;
	}

	double *A = Malloc((size_t) n * (size_t) n, double);
	double *values = Malloc(n, double);
	double *vectors = Malloc((size_t) n * (size_t) n, double);
	assert(A != NULL);
	assert(values != NULL);
	assert(vectors != NULL);

	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			A[i * n + j] = 0.0;
		}
	}
	for (int i = 0; i < n - 1; i++) {
		double off = sqrt((i + 1.0) / 2.0);
		A[i * n + (i + 1)] = off;
		A[(i + 1) * n + i] = off;
	}

	jacobi_eigen(n, A, values, vectors);

	for (int i = 0; i < n; i++) {
		nodes[i] = values[i];
		weights[i] = sqrt(M_PI) * vectors[i] * vectors[i];
	}

	for (int i = 0; i < n - 1; i++) {
		int min_idx = i;
		for (int j = i + 1; j < n; j++) {
			if (nodes[j] < nodes[min_idx]) {
				min_idx = j;
			}
		}
		if (min_idx != i) {
			double tmp_node = nodes[i];
			double tmp_weight = weights[i];
			nodes[i] = nodes[min_idx];
			weights[i] = weights[min_idx];
			nodes[min_idx] = tmp_node;
			weights[min_idx] = tmp_weight;
		}
	}

	free(A);
	free(values);
	free(vectors);

	*nodes_out = nodes;
	*weights_out = weights;
}

static MMDCache_tp *get_cache(inla_cgeneric_data_tp *data)
{
	if (!(data->cache)) {
#ifdef _OPENMP
#pragma omp critical (MMDCLOG_LIK_CACHE)
#endif
		if (!(data->cache)) {
			const inla_cgeneric_vec_tp *kernel_vec = find_char_vec(data, "kernel");
			const inla_cgeneric_vec_tp *bandwidth_vec = find_double_vec(data, "bdwth");
			const inla_cgeneric_vec_tp *lambda_vec = find_double_vec(data, "lambda");
			const inla_cgeneric_vec_tp *sigma_init_vec = find_double_vec(data, "sigma_init");
			const inla_cgeneric_vec_tp *quadrature_order_vec = find_int_vec(data, "quadrature_order");

			assert(kernel_vec != NULL);
			assert(bandwidth_vec != NULL && bandwidth_vec->len >= 1);
			assert(lambda_vec != NULL && lambda_vec->len >= 1);
			assert(sigma_init_vec != NULL && sigma_init_vec->len >= 1);
			assert(quadrature_order_vec != NULL && quadrature_order_vec->len >= 1);

			MMDCache_tp *cache = Malloc(1, MMDCache_tp);
			assert(cache != NULL);

			cache->kernel_id = kernel_id_from_name(kernel_vec->chars);
			cache->bandwidth = bandwidth_vec->doubles[0];
			cache->bandwidth2 = sqr(cache->bandwidth);
			cache->lambda = lambda_vec->doubles[0];
			cache->sigma_init = sigma_init_vec->doubles[0];
			cache->gh_n = quadrature_order_vec->ints[0];
			cache->gh_nodes = NULL;
			cache->gh_weights = NULL;

			assert(cache->kernel_id != 0);
			assert(cache->bandwidth > 0.0);
			assert(cache->lambda > 0.0);
			assert(cache->sigma_init > 0.0);
			assert(cache->gh_n > 0);
			if (cache->kernel_id == KERNEL_CAUCHY) {
				build_gauss_hermite_rule(cache->gh_n, &cache->gh_nodes, &cache->gh_weights);
			}

			*((MMDCache_tp **) (&data->cache)) = cache;
		}
	}
	return *((MMDCache_tp **) (&data->cache));
}

static double gaussian_m_scaled(double y, double mu, double sigma, double bandwidth2)
{
	double new_var = bandwidth2 + 2.0 * sqr(sigma);
	return exp(-sqr(y - mu) / new_var) / sqrt(new_var);
}

static double gaussian_c_scaled(double sigma, double bandwidth2)
{
	return 1.0 / sqrt(bandwidth2 + 4.0 * sqr(sigma));
}

static double exp_log_pnorm(double log_scale, double z)
{
	double log_tail = pnorm(z, 0.0, 1.0, 1, 1);
	if (!isfinite(log_tail)) {
		return 0.0;
	}
	return exp(log_scale + log_tail);
}

static double laplace_m(double y, double mu, double sigma, double bandwidth)
{
	double delta = mu - y;
	double a = 1.0 / bandwidth;

	if (sigma <= 1e-12) {
		return exp(-fabs(delta) * a);
	}

	double ss = sqr(sigma);
	double log_common = 0.5 * a * a * ss;
	double t1 = exp_log_pnorm(log_common - a * delta, (delta - a * ss) / sigma);
	double t2 = exp_log_pnorm(log_common + a * delta, (-delta - a * ss) / sigma);
	return t1 + t2;
}

static double laplace_c(double sigma, double bandwidth)
{
	if (sigma <= 1e-12) {
		return 1.0;
	}

	double a = 1.0 / bandwidth;
	return exp(log(2.0) + a * a * sqr(sigma) + pnorm(-M_SQRT2 * a * sigma, 0.0, 1.0, 1, 1));
}

static double cauchy_kernel(double scaled_diff)
{
	return 1.0 / (2.0 + sqr(scaled_diff));
}

static double cauchy_expect_standard_normal(const MMDCache_tp *cache, double mu, double sigma)
{
	double acc = 0.0;

	for (int j = 0; j < cache->gh_n; j++) {
		double z = M_SQRT2 * cache->gh_nodes[j];
		double value = cauchy_kernel((mu + sigma * z) / cache->bandwidth);
		acc += cache->gh_weights[j] * value;
	}

	return acc / sqrt(M_PI);
}

static double cauchy_m(const MMDCache_tp *cache, double y, double mu, double sigma)
{
	if (sigma <= 1e-12) {
		return cauchy_kernel((mu - y) / cache->bandwidth);
	}
	return cauchy_expect_standard_normal(cache, mu - y, sigma);
}

static double cauchy_c(const MMDCache_tp *cache, double sigma)
{
	if (sigma <= 1e-12) {
		return 1.0;
	}
	return cauchy_expect_standard_normal(cache, 0.0, M_SQRT2 * sigma);
}

static double pseudo_loglik(const MMDCache_tp *cache, double y, double mu, double sigma)
{
	double m;
	double c;

	switch (cache->kernel_id) {
	case KERNEL_GAUSSIAN:
		m = gaussian_m_scaled(y, mu, sigma, cache->bandwidth2);
		c = gaussian_c_scaled(sigma, cache->bandwidth2);
		break;
	case KERNEL_LAPLACE:
		m = laplace_m(y, mu, sigma, cache->bandwidth);
		c = laplace_c(sigma, cache->bandwidth);
		break;
	case KERNEL_CAUCHY:
		m = cauchy_m(cache, y, mu, sigma);
		c = cauchy_c(cache, sigma);
		break;
	default:
		assert(0 == 1);
		return NAN;
	}

	return 2.0 * m - c;
}

double *inla_cloglike_mmd_theta_tilde_gaussian(inla_cloglike_cmd_tp cmd, double *theta,
					       inla_cgeneric_data_tp *data, int ny, double *y, int nx, double *x, double *result)
{
	double *ret = NULL;
	MMDCache_tp *cache = NULL;
	double eta = 0.0;
	double sigma = 0.0;

	switch (cmd) {
	case INLA_CLOGLIKE_INITIAL:
	{
		cache = get_cache(data);
		ret = Malloc(2, double);
		assert(ret != NULL);
		ret[0] = 1.0;
		ret[1] = log(cache->sigma_init);
	}
		break;

	case INLA_CLOGLIKE_LOG_PRIOR:
	{
		cache = get_cache(data);
		eta = (theta ? theta[0] : log(cache->sigma_init));
		sigma = exp(eta);
		ret = Malloc(1, double);
		assert(ret != NULL);
		ret[0] = log(cache->lambda) - cache->lambda * sigma + eta;
	}
		break;

	case INLA_CLOGLIKE_LOGLIKE:
	{
		assert(ny >= 1);
		cache = get_cache(data);
		eta = (theta ? theta[0] : log(cache->sigma_init));
		sigma = exp(eta);
#ifdef _OPENMP
#pragma omp simd
#endif
		for (int i = 0; i < nx; i++) {
			result[i] = pseudo_loglik(cache, y[0], x[i], sigma);
		}
	}
		break;

	case INLA_CLOGLIKE_CDF:
	{
#ifdef _OPENMP
#pragma omp simd
#endif
		for (int i = 0; i < nx; i++) {
			result[i] = 0.5;
		}
	}
		break;

	case INLA_CLOGLIKE_QUIT:
	{
		if (data && data->cache) {
			MMDCache_tp *cache = *((MMDCache_tp **) (&data->cache));
			if (cache->gh_nodes) {
				free(cache->gh_nodes);
			}
			if (cache->gh_weights) {
				free(cache->gh_weights);
			}
			free(cache);
			*((void **) (&data->cache)) = NULL;
		}
	}
		break;
	}

	return ret;
}
