#define CL_TARGET_OPENCL_VERSION 120
#include "ocl_boiler.h"
#include <stdio.h>
#include <stdlib.h>

void error(const char *err)
{
	fprintf(stderr, "%s\n", err);
	exit(1);
}

cl_event init_array(cl_command_queue que, cl_kernel init_kernel,
	cl_mem d_in, int nels, size_t preferred_rounding_init, int lws_arg)
{
	size_t lws[] = { lws_arg > 0 ? (size_t)lws_arg : preferred_rounding_init };
	size_t gws[] = { round_mul_up(nels, lws[0]) };

	printf("init: %u | %zu = %zu\n", nels, lws[0], gws[0]);
	cl_int err;
	cl_event ret;

	int arg = 0;
	err = clSetKernelArg(init_kernel, arg, sizeof(d_in), &d_in);
	ocl_check(err, "set init_array arg %d", arg++);
	err = clSetKernelArg(init_kernel, arg, sizeof(nels), &nels);
	ocl_check(err, "set init_array arg %d", arg++);

	err = clEnqueueNDRangeKernel(que, init_kernel, 1,
		NULL, gws, (lws_arg > 0 ? lws : NULL),
		0, NULL,  &ret);
	ocl_check(err, "enqueue init");

	return ret;
}

cl_event smooth_array(cl_command_queue que, cl_kernel smoothkernel, cl_event init_evt,
	cl_mem d_out, cl_mem d_in,
	int nels,
	size_t preferred_rounding_smooth, int lws_arg)
{
	size_t lws[] = { lws_arg > 0 ? (size_t)lws_arg : preferred_rounding_smooth };
	size_t gws[] = { round_mul_up(nels, lws[0]) };

	printf("smooth: %u | %zu = %zu\n", nels, lws[0], gws[0]);
	cl_int err;
	cl_event ret;

	int arg = 0;
	err = clSetKernelArg(smoothkernel, arg, sizeof(d_out), &d_out);
	ocl_check(err, "set smootharray arg %d", arg++);
	err = clSetKernelArg(smoothkernel, arg, sizeof(d_in), &d_in);
	ocl_check(err, "set smootharray arg %d", arg++);
	err = clSetKernelArg(smoothkernel, arg, sizeof(nels), &nels);
	ocl_check(err, "set smootharray arg %d", arg++);

	err = clEnqueueNDRangeKernel(que, smoothkernel, 1,
		NULL, gws, (lws_arg > 0 ? lws : NULL),
		1, &init_evt,  &ret);
	ocl_check(err, "enqueue sum");

	return ret;
}

void verify(const cl_int *array, int nels)
{
	for (int i = 0; i < nels; ++i) {
		int expected = i - !!(i == nels - 1);

		if (array[i] != expected) {
			fprintf(stderr, "mismatch @ %d: %d != %d\n",
				i, array[i], expected);
		}
	}
}

int main(int argc, char *argv[])
{
	if (argc < 2) error("please specify number of elements");

	int nels = atoi(argv[1]);

	if (nels <= 0) error("please specify a positive integer");

	int lws = 0;
	if (argc == 3) {
		lws = atoi(argv[2]);
	}

	cl_platform_id p = select_platform();
	cl_device_id d = select_device(p);
	cl_context ctx = create_context(p, d);
	cl_command_queue que = create_queue(ctx, d);
	cl_program prog = create_program("vecsmooth.ocl", ctx, d);

	cl_int err;
	cl_kernel init_kernel = clCreateKernel(prog, "init_kernel", &err);
	ocl_check(err, "create init_kernel");
	cl_kernel smooth_kernel = clCreateKernel(prog, "smooth_kernel", &err);
	ocl_check(err, "create smooth kernel");

	size_t memsize = nels*sizeof(cl_int);

	cl_mem d_in = clCreateBuffer(ctx, CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
		memsize, NULL, &err);
	ocl_check(err, "create d_in failed");
	cl_mem d_out = clCreateBuffer(ctx, CL_MEM_WRITE_ONLY | CL_MEM_ALLOC_HOST_PTR,
		memsize, NULL, &err);
	ocl_check(err, "create d_out failed");

	size_t preferred_rounding_init;
	size_t preferred_rounding_sum;

	err = clGetKernelWorkGroupInfo(init_kernel, d, CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
		sizeof(preferred_rounding_init), &preferred_rounding_init, NULL);
	ocl_check(err, "get preferred work-group size multiple");
	err = clGetKernelWorkGroupInfo(smooth_kernel, d, CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
		sizeof(preferred_rounding_sum), &preferred_rounding_sum, NULL);
	ocl_check(err, "get preferred work-group size multiple");

	cl_event init_evt = init_array(que, init_kernel, d_in, nels, preferred_rounding_init, lws);
	cl_event smooth_evt = smooth_array(que, smooth_kernel, init_evt, d_out, d_in, nels, preferred_rounding_sum, lws);

	cl_event map_evt, unmap_evt;

	int *h_array = clEnqueueMapBuffer(que, d_out, CL_TRUE,
		CL_MAP_READ, 0, memsize,
		1, &smooth_evt, &map_evt, &err);
	ocl_check(err, "map buffer");

	verify(h_array, nels);

	err = clEnqueueUnmapMemObject(que, d_out, h_array,
		0, NULL, &unmap_evt);
	ocl_check(err, "unmap buffer");

	err = clFinish(que);
	ocl_check(err, "finish");

	double init_runtime = runtime_ms(init_evt);
	double smooth_runtime = runtime_ms(smooth_evt);
	double map_runtime = runtime_ms(map_evt);
	double unmap_runtime = runtime_ms(unmap_evt);

	printf("init: %gms, %gGE/s, %gGB/s\n", init_runtime, 2*nels/init_runtime/1.0e6, 2*memsize/init_runtime/1.e6);
	printf("smooth: %gms, %gGE/s, %gGB/s\n", smooth_runtime, nels/smooth_runtime/1.0e6, 4*memsize/smooth_runtime/1.0e6);
	printf("map: %gms, %gGE/s, %gGB/s\n", map_runtime, nels/map_runtime/1.0e6, memsize/map_runtime/1.0e6);
	printf("unmap: %gms\n", unmap_runtime);

	clReleaseKernel(init_kernel);
	clReleaseMemObject(d_out);
	clReleaseMemObject(d_in);
	clReleaseProgram(prog);
	clReleaseCommandQueue(que);
	clReleaseContext(ctx);
}
