/* usual C/C++ includes */
#include <stdio.h>

/* PAM image support */
#include "pamalign.h"

imgInfo img;
imgInfo ref;

const char imgname[] = "image.pam";
const char imgsave[] = "saved.pam";

/* OpenCL includes */
#define CL_USE_DEPRECATED_OPENCL_1_1_APIS
#include <CL/cl.h>

void check_ocl_error(cl_int error, const char *message) {
	if (error != CL_SUCCESS) {
		fprintf(stderr, "error %d %s\n", error,
				message);
		exit(1);
	}
}

char *read_file(const char *fname) {
	size_t fsize, readsize;
	char *buff;

	FILE *fd = fopen(fname, "rb");
	if (!fd) {
		fprintf(stderr, "%s not found\n", fname);
		return NULL;
	}

	fseek(fd, 0, SEEK_END);
	fsize = ftell(fd);

	buff = (char *)malloc(fsize);
	rewind(fd);
	readsize = fread(buff, 1, fsize, fd);
	if (fsize != readsize) {
		fprintf(stderr, "could only read %zu/%zu bytes from %s\n",
			readsize, fsize, fname);
		free(buff);
		return NULL;
	}
	buff[fsize] = '\0';

	printf("read %zu bytes from %s\n", fsize, fname);

	return buff;
}

int main(int argc, char **argv) {

	int pam_status;

	cl_uint platformCount = 0;
	cl_platform_id *platforms;

	clGetPlatformIDs(0, NULL, &platformCount);

	platforms = (cl_platform_id*)
		malloc(platformCount*sizeof(*platforms));

	clGetPlatformIDs(platformCount, platforms, NULL);

	cl_uint plat = 0;
	if (argc > 1) {
		plat = atoi(argv[1]);
		if (plat >= platformCount) {
			fprintf(stderr, "WAT? (%u > %u)\n",
				plat, platformCount);
			plat = 0;
		}
	}

#define MAX_NAME_LEN 1024
	char name[MAX_NAME_LEN];
	size_t name_size;
	clGetPlatformInfo(platforms[plat], CL_PLATFORM_NAME,
			MAX_NAME_LEN, name, &name_size);
	if (name_size >= MAX_NAME_LEN) {
		fprintf(stderr, "Looong name! %zu\n", name_size);
		name[MAX_NAME_LEN-1] = '\0';
	}
	printf("Platform %u: %s\n", plat, name);

	cl_uint deviceCount = 0;
	cl_device_id dev;
	clGetDeviceIDs(platforms[plat], CL_DEVICE_TYPE_ALL,
			1, &dev, &deviceCount);
	clGetDeviceInfo(dev, CL_DEVICE_NAME,
			MAX_NAME_LEN, name, &name_size);
	if (name_size >= MAX_NAME_LEN) {
		fprintf(stderr, "Looong name! %zu\n", name_size);
		name[MAX_NAME_LEN-1] = '\0';
	}
	printf("Device: %s\n", name);

	cl_context_properties props[] = {
		CL_CONTEXT_PLATFORM,
		(cl_context_properties)platforms[plat], 0
	};

	cl_int err;
	cl_context ctx = clCreateContext(props, 1, &dev,
			NULL, NULL, &err);
	check_ocl_error(err, "creating context");

	cl_command_queue coda = clCreateCommandQueue(
		ctx, dev, CL_QUEUE_PROFILING_ENABLE, &err);
	check_ocl_error(err, "creating queue");

	/* Try loading the image */
	if (pam_status = load_pam(imgname, &img)) {
		return pam_status;
	}

	printf("%u-channels image %s loaded, %ux%ux%u\n",
		img.channels, imgname, img.width, img.height, img.depth);

	cl_mem dSrc, dDst;

	cl_image_format fmt;
	fmt.image_channel_order = CL_RGBA;
	fmt.image_channel_data_type = CL_UNSIGNED_INT16;

	dSrc = clCreateImage2D(ctx,
			CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
			&fmt,
			img.width, img.height, 0,
			img.data, &err);
	check_ocl_error(err, "creating src buffer");

	dDst = clCreateImage2D(ctx,
			CL_MEM_WRITE_ONLY,
			&fmt,
			img.height, img.width,
			0, NULL, &err);
	check_ocl_error(err, "creating dst buffer");

	char *codice = read_file("transpose_kernels.ocl");

	if (codice == NULL)
		exit(1);

	cl_program prog = clCreateProgramWithSource(ctx,
		1, (const char **)&codice, NULL, &err);
	check_ocl_error(err, "creating program");

	err = clBuildProgram(prog, 1, &dev, NULL,
			NULL, NULL);

	if (err == CL_BUILD_PROGRAM_FAILURE) {
		// build failed! get the build log and print it
		size_t logSize = 0;
		char *log;
		err = clGetProgramBuildInfo(prog, dev,
				CL_PROGRAM_BUILD_LOG, 0, NULL, &logSize);
		check_ocl_error(err, "getting program build info size");
		log = (char *)malloc(logSize);
		err = clGetProgramBuildInfo(prog, dev,
				CL_PROGRAM_BUILD_LOG, logSize, log, NULL);
		check_ocl_error(err, "getting program build info");
		fputs(log, stderr);
		exit(1);
	} else
		check_ocl_error(err, "building program");

	cl_kernel transpose = clCreateKernel(prog,
			"transposeTex", &err);
	check_ocl_error(err, "creating kernel");

	size_t local_size[2] = { 16, 16};
	if (argc > 2) {
		local_size[0] = local_size[1] = atoi(argv[2]);
	}

	size_t work_size[2] = {
		((img.height + local_size[0] - 1)/local_size[0])*
			local_size[0],
		((img.width + local_size[1] - 1)/local_size[1])*
			local_size[1]
	};

	printf("work size: %zu x %zu\n", work_size[0], work_size[1]);

	err = clSetKernelArg(transpose, 0, sizeof(dDst), &dDst);
	check_ocl_error(err, "setting param 0");
	err = clSetKernelArg(transpose, 1, sizeof(dSrc), &dSrc);
	check_ocl_error(err, "setting param 1");
	err = clSetKernelArg(transpose, 2, sizeof(img.width), &img.width);
	check_ocl_error(err, "setting param 2");
	err = clSetKernelArg(transpose, 3, sizeof(img.height), &img.height);
	check_ocl_error(err, "setting param 3");

	cl_event k_evt;
	err = clEnqueueNDRangeKernel(coda, transpose,
			2, NULL, work_size, local_size,
			0, NULL, &k_evt);

	check_ocl_error(err, "kernel launch");

	cl_ulong startTime, endTime;
	err = clFinish(coda);
	check_ocl_error(err, "queue finish");

	clGetEventProfilingInfo(k_evt, CL_PROFILING_COMMAND_START,
			sizeof(cl_ulong), &startTime, NULL);
	clGetEventProfilingInfo(k_evt, CL_PROFILING_COMMAND_END,
			sizeof(cl_ulong), &endTime, NULL);

	clGetKernelInfo(transpose, CL_KERNEL_FUNCTION_NAME,
			MAX_NAME_LEN, name, &name_size);
	if (name_size >= MAX_NAME_LEN) {
		fprintf(stderr, "Looong name! %zu\n", name_size);
		name[MAX_NAME_LEN-1] = '\0';
	}
	printf("Kernel '%s' runtime: %gms\n", name,
		double(endTime - startTime)/1000000);

	size_t origin[3] = {0,0,0};
	size_t region[3] = {img.height, img.width, 1};
	err = clEnqueueReadImage(coda, dDst, CL_TRUE,
			origin, region,
			0, 0, img.data,
			0, NULL, NULL);
	check_ocl_error(err, "reading image to host");

	uint tmp = img.height;
	img.height = img.width;
	img.width = tmp;

	/* Try saving the image */
	if (pam_status = save_pam(imgsave, &img)) {
		return pam_status;
	}
	printf("%u-channels image %s saved, %ux%ux%u\n",
		img.channels, imgsave, img.width, img.height, img.depth);

}
