/*
*	rp,dynlib,examp.c
*
*	Example of user-defined 3d penalty function.
*	Please do not distribute this.
*	You are free to modify as long as you change the file name.
*	And please keep my name attached.
*
*	To compile this as a dynamic library on Solaris, use
*		cc -c rp,dynlib,examp.c
*		ld -G -o example.so.1 rp,dynlib,examp.o
*	You may need to use
*		setenv LD_LIBRARY_PATH
*	to include the directory where you store the dynamic library.
*
*	As a sanity check, type "nm example.so.1" and you will see
*	rp_init, rp_grad, etc. defined as symbols therein
*
*	Copyright Jan 1998,	Jeff Fessler	University of Michigan
*/
#include <stdio.h>
#include <math.h>
#include "rp,dynlib,proto.h"

#ifndef Sqr
#define	Sqr(x)	((x)*(x))
#endif

#define Rd_quad(d,e)	d = (e);
#define Rg_quad(d)	(d)


typedef unsigned char byte;

	static double betaxy=-1, betaz=0;
	static int nx_global=-1, ny_global=-1, nz_global=-1;

/*
*	rp_init()
*	Initialize any static local variables,
*	possibly using parameters extracted from :userinfo
*	In this example, userinfo is log2betaxy,log2betaz
*/
int rp_init(
const void *DoNotTouch1,
const void *DoNotTouch2,
const int nx,
const int ny,
const int nz,
const int chat)
{
	const char *userinfo = *((const char **) DoNotTouch1);
	/* although it says do not touch, you may do this for now :-) */
	(void) DoNotTouch1;
	(void) DoNotTouch2;

	nx_global = nx;
	ny_global = ny;
	nz_global = nz;

	if (chat)
		printf("userinfo = '%s'\n", userinfo);

	/*
	*	extract betas from userinfo string
	*/
	if (2 != sscanf(userinfo, "%lg,%lg", &betaxy, &betaz)) {
		printf("Fail: no log2betaxy,log2betaz in '%s'", userinfo);
		return 0;
	}
	if (chat)
		printf("log2betaxy=%g log2betaz=%g'\n", betaxy, betaz);

	betaxy = pow(2., betaxy);
	betaz = pow(2., betaz);

	return 1;
}


/*
*	rp_free()
*	Free any local memory that was allocated
*/
int rp_free(const void *DoNotTouch)
{
	(void) DoNotTouch;
	return 1;
}


/*
*	r3_penalty_quad_1st_scalar()
*	standard 1st-order 3d penalty \sum_j \sum_{k \in N_j} (xj - xk)^2 / 2
*	where N_j = {j-1, j-nx, j-nxy} (left, above, previous slice)
*/
static double r3_penalty_quad_1st_scalar(
const float *image,		/* [nx,ny,nz]	*/
const int nx, const int ny, const int nz)
{
	int ix, iy, iz;
	const int nxy = nx * ny;
	double penal = 0.;
	printf("in r3_penalty_quad_1st_scalar %g %g %d %d %d\n", betaxy, betaz, nx, ny, nz);

	for (iz=0; iz < nz; ++iz) {
		const float *ppi = image + iz * nxy;
		double sum;

		sum = 0;
		for (iy=0; iy < ny; ++iy) {
			double xsum = 0;
			for (ix=1; ix < nx; ++ix) {
				const float *pi = ppi + ix + iy * nx;
				xsum += Sqr(*pi - pi[-1]);
			}
			sum += xsum / 2;
		}
		penal += betaxy * sum;

		sum = 0;
		for (iy=1; iy < ny; ++iy) {
			double xsum = 0;
			for (ix=0; ix < nx; ++ix) {
				const float *pi = ppi + ix + iy * nx;
				xsum += Sqr(*pi - pi[-nx]);
			}
			sum += xsum / 2;
		}
		penal += betaxy * sum;
	}

	for (iz=1; iz < nz; ++iz) {
		const float *ppi = image + iz * nxy;
		double sum = 0.;
		for (iy=0; iy < ny; ++iy) {
			double xsum = 0;
			for (ix=0; ix < nx; ++ix) {
				const float *pi = ppi + ix + iy * nx;
				xsum += Sqr(*pi - pi[-nxy]);
			}
			sum += xsum / 2;
		}
		penal += betaz * sum;
	}

	return penal;
}


/*
*	rp_penalty()
*/
double rp_penalty(
const float *image,             /* [nx,ny,nz] image volume      */
const void *DoNotTouch,
const byte *mask,      /* [nx,ny,nz] binary support mask */
const int nx,                   /* image dimensions */
const int ny,
const int nz)
{
	(void) DoNotTouch;
	printf("in rp_penalty %g %g %d %d %d\n", betaxy, betaz, nx, ny, nz);
	(void) mask;

	return r3_penalty_quad_1st_scalar(image, nx, ny, nz);
}


/*
*	r3_grad_quad_1st_scalar()
*	Gradient of roughness penalty: \gradient R(x)
*	For quadratic penalties, this is simply R * image
*/
static void r3_grad_quad_1st_scalar(
float	*rx,		/* [nx,ny,nz]	*/
const float *image,	/* [nx,ny,nz]	*/
const byte *mask,	/* [nx,ny,nz]	*/
const int nx, const int ny, const int nz)
{
	int ix, iy, iz;
	const int xmin = 1;	/* !Ruseedge(rp); */
	const int ymin = 1;	/* !Ruseedge(rp); */
	const int xmax = nx-xmin;
	const int ymax = ny-ymin;
	const int nxy = nx * ny;

	for (iz=0; iz < nz; ++iz) {
		float		*ppr = rx	+ iz * nxy;
		const float	*ppi = image	+ iz * nxy;
		const byte	*ppm = mask	+ iz * nxy;
	    for (iy=ymin; iy < ymax; ++iy) {
		for (ix=xmin; ix < xmax; ++ix) {
			int ip = ix + iy * nx;
			float		*pr = ppr + ip;
			const float	*pi = ppi + ip;
			const byte	*pm = ppm + ip;
			if (!mask || *pm) {
				double d, sumg;
				Rd_quad(d, *pi-pi[-1])	sumg  = Rg_quad(d);
				Rd_quad(d, *pi-pi[ 1])	sumg += Rg_quad(d);
				Rd_quad(d, *pi-pi[-nx])	sumg += Rg_quad(d);
				Rd_quad(d, *pi-pi[ nx])	sumg += Rg_quad(d);
			
				*pr = betaxy * sumg;
				if (iz > 0) {
					Rd_quad(d, *pi-pi[-nxy])
					*pr += betaz * Rg_quad(d);
				}
				if (iz < nz-1) {
					Rd_quad(d, *pi-pi[nxy])
					*pr += betaz * Rg_quad(d);
				}
			}
			else
				*pr = 0;
		}
	    }
	}
}


/*
*	rp_grad()
*/
int rp_grad(
float	*rx,		/* [nx,ny,nz] output gradient vector */
const float *image,	/* [nx,ny,nz] input current image */
const void *DoNotTouch,
const byte *mask,	/* [nx,ny,nz] binary support mask */
const int nx,		/* image dimensions */
const int ny,
const int nz)
{
	(void) DoNotTouch;
	printf("in rp_grad\n");
	r3_grad_quad_1st_scalar(rx, image, mask, nx, ny, nz);
	return 1;
}


/*
*	rp_newton1()
*/
extern double rp_newton1(
const double der1,		/* d1 */
const double ner2,		/* n2 */
const float *image,		/* [nx,ny,nz] image */
const int ix,			/* which pixel */
const int iy,
const int iz,
const void *DoNotTouch,
const double depierro)		/* depierro factor */
{
	const int nx = nx_global;
	const int ny = ny_global;
	const int nz = nz_global;
	const int nxy = nx*ny;
	const int ip = ix + iy * nx + iz * nx * ny;
	double d, sumx=0, sumz=0;
	const float *pi	= image+ip;
	const double	bxy	= betaxy;
	const double	bz	= betaz;
	const double	denom	= ner2 + depierro * (4 * bxy +
					bz * ((iz > 0) + (iz < nz-1)) );
	/* error: the "4" is not quite correct if mask goes to image edge. */
	(void) DoNotTouch;

	if (ix)		{ Rd_quad(d,*pi-pi[-1])		sumx  = Rg_quad(d); }
	if (ix<nx-1)	{ Rd_quad(d,*pi-pi[ 1])		sumx += Rg_quad(d); }
	if (iy)		{ Rd_quad(d,*pi-pi[-nx])	sumx += Rg_quad(d); }
	if (iy<ny-1)	{ Rd_quad(d,*pi-pi[ nx])	sumx += Rg_quad(d); }
	if (iz)		{ Rd_quad(d,*pi-pi[-nxy])	sumz  = Rg_quad(d); }
	if (iz<nz-1)	{ Rd_quad(d,*pi-pi[ nxy])	sumz += Rg_quad(d); }

	return (der1 - bxy * sumx - bz * sumz) / denom;
}

#undef Rd_quad
#undef Rg_quad
#undef Sqr
