/* main.c - main program for the fast-dm project
 *
 * Copyright (C) 2006  Jochen Voss, Andreas Voss.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301 USA.
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <time.h>
#include <assert.h>

#include "fast-dm.h"


/**********************************************************************
 * computation of the KS test statistics (simultaneously for many z)
 */

struct KS_stat {
	int  N;
	double *T;		/* array of T-values (length N+1) */
	double  zmin, zmax;	/* z-values for T[0] and T[N] */
};

static struct KS_stat *
new_KS_stat (int N, double zmin, double zmax)
{
	struct KS_stat *result;

	result = xnew(struct KS_stat, 1);
	result->N = N;
	result->T = xnew(double, N+1);
	result->zmin = zmin;
	result->zmax = zmax;

	return  result;
}

static void
delete_KS_stat (struct KS_stat *result)
{
	xfree (result->T);
	xfree (result);
}

static double
KS_get_z (const struct KS_stat *result, double z)
/* Use linear interpolation to get the T-value for a given z.  */
{
	int  N = result->N;
	double  dz, ratio;
	int  step_before;

	dz = (result->zmax - result->zmin) / N;
	step_before = (z - result->zmin) / dz;
	if (step_before >= N)  step_before = N-1;
	ratio = (z - (result->zmin + dz*step_before)) / dz;
	return  (1-ratio) * result->T[step_before]
		+ ratio * result->T[step_before+1];
}

static struct KS_stat *
KS_stat_get (const struct samples *samples, const double *para)
/* Compute the KS test statistic (simultaneously for many z)
 *
 * The computed test statistic is a measure for the difference between
 * the empirical distribution given by 'samples' and the theoretical
 * distribution described by the parameter set 'para'.
 *
 * The function returns a 'struct KS_stat' which the caller must free
 * after use.
 */
{
	struct F_calculator *fc;
	struct KS_stat *result;
	double  p_mid, dp;
	double *T;
	int	i, j, N;

	dp = 1.0/(samples->plus_used+samples->minus_used);
	p_mid = samples->minus_used*dp;

	fc = F_new (para);
	N = F_get_N (fc);
	result = new_KS_stat (N, F_get_z (fc, 0), F_get_z (fc, N));

	T = result->T;
	for (i=0; i<=N; ++i)  T[i] = 0;

	F_start (fc, b_upper);
	for (j=0; j<samples->plus_used; ++j) {
		const double *F = F_get_F (fc, samples->plus_data[j]);
		for (i=0; i<=N; ++i) {
			double p_theo = F[i];
			double dist;

			dist = fabs (p_mid+j*dp-p_theo);
			if (dist > T[i])  T[i] = dist;
			dist = fabs (p_mid+(j+1)*dp-p_theo);
			if (dist > T[i])  T[i] = dist;
		}
	}

	F_start (fc, b_lower);
	for (j=0; j<samples->minus_used; ++j) {
		const double *F = F_get_F (fc, samples->minus_data[j]);
		for (i=0; i<=N; ++i) {
			double p_theo = F[i];
			double dist;

			dist = fabs (p_mid-j*dp-p_theo);
			if (dist > T[i])  T[i] = dist;
			dist = fabs (p_mid-(j+1)*dp-p_theo);
			if (dist > T[i])  T[i] = dist;
		}
	}

	F_delete (fc);
	return  result;
}

/**********************************************************************
 * Minimise the KS statistic
 */

static double
check_bounds(const double *para)
/* Check whether the parameters 'para' are valid.
 *
 * In case of invalid parameters, a value >1 is returned, the
 * magnitude gives the 'badness' of the violation.  If the parameter
 * set is valid, 0 is returned.
 */
{
	double  penalty = 0;
	int  bad = 0;

	if (para[p_sz] < 0) {
		bad = 1;
		penalty += -para[p_sz];
	}
	if (para[p_st0] < 0) {
		bad = 1;
		penalty += -para[p_st0];
	}
	if (para[p_sv] < 0) {
		bad = 1;
		penalty += -para[p_sv];
	}

	if (para[p_a] < para[p_sz]) {
		bad = 1;
		penalty += para[p_sz] - para[p_a];
	}
	if (para[p_t0] < 0.5*para[p_st0]) {
		bad = 1;
		penalty += 0.5*para[p_st0] - para[p_t0];
	}

	/* avoid problems caused by rounding errors */
	return  bad ? 1+penalty : 0.0;
}

static double
binom(int a, int b)
/* Compute the binomial coefficient a over b. */
{
	double  result = 1;
	int  i;

	if (2*b > a)  b = a-b;
	for (i=1; i<=b; ++i) {
		result *= a-i+1;
		result /= i;
	}
	return  result;
}

static double
T_to_p(double T, int n)
/* Convert KS T-values into standard p-values.
 *
 * Return (1 - G(t)^2) where G(t) is given by (4), p.431 in Conover. */
{
	double  sum = 0;
	int  i;
	int  max = n - n*T;
	
	for (i=0; i<=max; ++i) {
		double  a, b, c;
		a = binom(n,i);
		b = pow(1-T-i/(double)n, (n-i));
		c = pow(T+i/(double)n, (i-1));
		sum += a*b*c;
	}
	sum *= T;
	return  (2 - sum)*sum;
}

static void
find_best_log_p (struct KS_stat *const*result, const struct dataset *ds,
		 const int *z_used, double *dist_ret, double *ret_z)
/* Combine the T-values from 'result' into a common p-value.
 *
 * This computes p-values as the product of the probabilities for all
 * Ts from result and optimises over 'z'.  This takes the different
 * experimental conditions into account.
 *
 * On entry 'z_ret' must point to an array of size 'ds->z->used'.
 *
 * The distance is returned in '*dist_ret'.  If the parameters are
 * invalid, a value >1 is returned.  Otherwise one minus the p-value
 * is returned.  The z-parameter values corresponding to this
 * (minimal) distance are returned in '*z_ret'.
 */
{
	double  total_log;
	double  penalty;
	int  k;

	total_log = 0;
	penalty = 0;
	for (k=0; k<ds->z->used; k++) {
		double  dz;
		double  best_logp, best_z;
		double  zmin = result[0]->zmin;
		double  zmax = result[0]->zmax;
		int  i, j, n;

		for (j=1; j<ds->samples_used; ++j) {
			if (z_used[j]!=k) continue;
			if (zmin < result[j]->zmin) zmin = result[j]->zmin;
			if (zmax > result[j]->zmax) zmax = result[j]->zmax;
		}
		if (zmax < zmin) {
			/* no common z is available: abort the
			 * computation and return a penalty value >1.  */
			ret_z[k] = (zmax+zmin)/2.0;
			penalty = 1 + (zmin-zmax);
			break;
		}

		/* Use sub-sampling of the interval zmin..zmax
		 * to find the best z.  */
		n = (zmax - zmin) / 0.0001 + 1.5;
		dz = (zmax - zmin) / n;
		best_logp = - DBL_MAX;
		for (i=0; i<=n; ++i) {
			double  z = zmin + i*dz;
			double  logp = 0;
			
			for (j=0; j<ds->samples_used; ++j) {
				double	T;
				int  N;
				
				if (z_used[j] != k)  continue;
				if (result[j]->N > 0) {
					T = KS_get_z (result[j], z);
				} else { /* no point in interpolating */
					T = result[j]->T[0];
				}
				N = ds->samples[j]->plus_used
					+ ds->samples[j]->minus_used;
				logp += log(T_to_p(T,N));
			}
			if (logp > best_logp) {
				best_logp = logp;
				best_z = z;
			}
		}
		total_log += best_logp;
		ret_z[k] = best_z;
	}

	if (penalty == 0)
		*dist_ret = 1 - exp(total_log);
	else
		*dist_ret = penalty;
}

static double
find_fixed_log_p (struct KS_stat *const*result, const struct dataset *ds,
		  double z)
/* Combine the T-values from 'result' into a common p-value.
 *
 * 'z' gives the fixed z-value.
 *
 * If the parameters are invalid, a value >1 is returned.  Otherwise
 * one minus the p-value is returned.
 */
{
	double  logp;
	double  zmin = result[0]->zmin;
	double  zmax = result[0]->zmax;
	int   j;

	for (j=1; j<ds->samples_used; ++j) {
		if (zmin < result[j]->zmin) zmin = result[j]->zmin;
		if (zmax > result[j]->zmax) zmax = result[j]->zmax;
	}
	if (z < zmin)  return  1 + (zmin-z);
	if (z > zmax)  return  1 + (z-zmax);

	logp = 0;
	for (j=0; j<ds->samples_used; ++j) {
		double	T;
		int  N;		
		if (result[j]->N > 0) {
			T = KS_get_z (result[j], z);
		} else { /* no point in interpolating */
			T = result[j]->T[0];
		}
		N = ds->samples[j]->plus_used + ds->samples[j]->minus_used;
		logp += log(T_to_p(T, N));
	}
	return  1 - exp(logp);
}

static void
badness (const struct dataset *ds, const double *x,
	 double *dist_ret, double *z_ret)
/* Get the 'distance' between theoretical and target distribution.
 *
 * The target distribution is described by the dataset 'ds', the
 * theoretical distribution is described by the parameters 'x'.  The
 * correnspondence between entries of 'x' and the modele parameters is
 * encoded in the 'ds->cmds' field.
 *
 * If the parameter 'z' is being optimised, 'z_ret' must on entry
 * point to an array of size 'ds->z->used'.  If the parameter 'z' is
 * fixed, the value 'z_ret' must be 'NULL'.
 *
 * The distance between empirical and theoretically predicted
 * distribution function is returned in '*dist_ret'.  If the
 * parameters are invalid, a value >1 is returned.  Otherwise one
 * minus the p-value is returned.  If 'z_ret' is non-null, the
 * z-parameter values corresponding to this (minimal) distance are
 * returned in '*z_ret'.
 */
{
	struct KS_stat **result;
	int *z_used;
	double  para[p_count], para_z, penalty;
	int  i, z_idx;

	result = xnew(struct KS_stat *, ds->samples_used);
	for (i=0; i<ds->samples_used; ++i) result[i] = NULL;
	z_used = xnew(int, ds->samples_used);

	penalty = 0;
	for (i=0; i<ds->cmds_used; ++i) {
		int  arg1 = ds->cmds[i].arg1;
		int  arg2 = ds->cmds[i].arg2;
		switch (ds->cmds[i].cmd) {
		case c_copy_param:
			if (arg1 >= 0) {
				para[arg1] = x[arg2];
			} else {
				assert(z_ret);
				z_idx = arg2;
			}
			break;
		case c_copy_const:
			if (arg1 >= 0) {
				para[arg1] = ds->consts[arg2];
			} else {
				assert(! z_ret);
				para_z = ds->consts[arg2];
			}
			break;
		case c_run:
			penalty = check_bounds (para);
			if (penalty > 0) break;
			result[arg1] = KS_stat_get (ds->samples[arg1], para);
			z_used[arg1] = z_idx;
			break;
		}
		if (penalty>0) break;
	}

	if (penalty>0) {
		assert (penalty > 1);
		*dist_ret = penalty;
	} else if (z_ret) {
		find_best_log_p(result, ds, z_used, dist_ret, z_ret);
	} else {
		*dist_ret = find_fixed_log_p(result, ds, para_z*para[p_a]);
	}

	xfree(z_used);
	for (i=0; i<ds->samples_used; ++i) {
		if (result[i]) delete_KS_stat (result[i]);
	}
	xfree(result);
}

static double
minimiser (const double *x, void *data)
/* Wrapper to call 'badness' from inside the 'simplex' function.  */
{
	double  dist, *z;
	struct dataset *ds = data;

	if (ds->z->used > 0)
		z = xnew(double, ds->z->used);
	else
		z = NULL;
	badness (ds, x, &dist, z);
	xfree(z);
	return  dist;
}

static void
initialise_parameters (const struct dataset *ds, double *x, double *eps)
{
	double  def_x [p_count], def_eps [p_count];
	int  i;

	def_x[p_a] = 1;  def_eps[p_a] = 0.5;
	def_x[p_v] = 0;  def_eps[p_v] = 1;
	def_x[p_t0] = 0.3;  def_eps[p_t0] = 0.5;
	def_x[p_sz] = 0.2;  def_eps[p_sz] = 0.2;
	def_x[p_sv] = 0.2;  def_eps[p_sv] = 0.2;
	def_x[p_st0] = 0.2;  def_eps[p_st0] = 0.2;

	for (i=ds->cmds_used-1; i>=0; --i) {
		int  arg1 = ds->cmds[i].arg1;
		int  arg2 = ds->cmds[i].arg2;
		switch (ds->cmds[i].cmd) {
		case c_copy_param:
			if (arg1 < 0)  break;
			x[arg2] = def_x[arg1];
			eps[arg2] = def_eps[arg1];
			break;
		case c_run:
			EZ_par (ds->samples[arg1],
				def_x+p_a, def_x+p_v, def_x+p_t0);
			break;
		default:
			break;
		}
	}
}

int
main (int argc, char **argv)
{
	const char *ex_name = "experiment.ctl";
	struct experiment *ex;
	struct dataset *ds;
	int  dsn, N;
	clock_t	 start, stop, start_total, stop_total;
	double	cpu_time_used;

	if (argc == 2) {
		ex_name = argv[1];
	} else if (argc > 2) {
		fprintf (stderr, "too many arguments\n");
		exit (1);
	}
	ex = new_experiment (ex_name);
	if (! ex) {
		fprintf (stderr,
			 "failed to load experiment control file \"%s\"\n",
			 ex_name);
		exit (1);
	}
	experiment_print (ex);

	N=0;
	start_total = clock ();
	for (dsn=0; dsn<1000; ++dsn) {
		double  *x, *eps, dist, *z;
		int  i;

		ds = experiment_get_dataset (ex, dsn);
		if (! ds) continue;

		N++;
		dataset_print (ds);

		start = clock ();
		x = xnew(double, ds->param->used);
		eps = xnew(double, ds->param->used);
		initialise_parameters (ds, x, eps);
		dist = simplex (ds->param->used, x, eps, 0.05, ds, minimiser);
		printf ("  ... p = %g\n", 1-dist);
		dist = simplex (ds->param->used, x, eps, 0.01, ds, minimiser);
		printf ("  ... p = %g\n", 1-dist);
		dist = simplex (ds->param->used, x, eps, 0.001, ds, minimiser);
		printf ("  ... p = %g\n", 1-dist);
		stop = clock ();
		cpu_time_used = (double)(stop - start) / CLOCKS_PER_SEC;

		if (ds->z->used > 0)
			z = xnew(double, ds->z->used);
		else
			z = NULL;
		badness (ds, x, &dist, z);
		for (i=0; i<ds->z->used; ++i) {
			printf ("  -> %s = %f\n", ds->z->entry[i], z[i]);
		}

		for (i=0; i<ds->param->used; ++i) {
			printf ("  -> %s = %f\n", ds->param->entry[i], x[i]);
		}

		if (ds->logname) {
			FILE *fd = fopen (ds->logname, "w");
			for (i=0; i<ds->z->used; ++i)
				fprintf (fd, "%s = %f\n",
					 ds->z->entry[i], z[i]);

			for (i=0; i<ds->param->used; ++i)
				fprintf (fd, "%s = %f\n",
					 ds->param->entry[i], x[i]);

			fprintf (fd, "p = %f\n", 1-dist);
			fprintf (fd, "time = %f\n", cpu_time_used);
			fclose (fd);
		}

		experiment_save(ex, ds, x, dsn+1, z, 1-dist, cpu_time_used);

		xfree(x);
		xfree(z);
		xfree(eps);

		delete_dataset(ds);
	}

	stop_total = clock ();
	cpu_time_used = (double)(stop_total - start_total) / CLOCKS_PER_SEC;
	printf ("%d dataset%s processed, total CPU time used: %fs\n",
		N, N==1?"":"s", cpu_time_used);

	delete_experiment (ex);

	return  0;
}

/*
 * Local Variables:
 * c-file-style: "linux"
 * End:
 */
