#define CL_TARGET_OPENCL_VERSION 120
#include "ocl_boiler.h"
#include <stdio.h>
#include <stdlib.h>

void error(const char *err)
{
	fprintf(stderr, "%s\n", err);
	exit(1);
}

cl_event init_array(cl_command_queue que, cl_kernel init_kernel,
	cl_mem d_in, int nels, size_t preferred_rounding_init, int lws_arg)
{
	size_t lws[] = { lws_arg > 0 ? (size_t)lws_arg : preferred_rounding_init };
	size_t gws[] = { round_mul_up(nels, lws[0]) };

	printf("init: %u | %zu = %zu\n", nels, lws[0], gws[0]);
	cl_int err;
	cl_event ret;

	int arg = 0;
	err = clSetKernelArg(init_kernel, arg, sizeof(d_in), &d_in);
	ocl_check(err, "set init_array arg %d", arg++);
	err = clSetKernelArg(init_kernel, arg, sizeof(nels), &nels);
	ocl_check(err, "set init_array arg %d", arg++);

	err = clEnqueueNDRangeKernel(que, init_kernel, 1,
		NULL, gws, (lws_arg > 0 ? lws : NULL),
		0, NULL,  &ret);
	ocl_check(err, "enqueue init");

	return ret;
}

cl_event reduce(cl_command_queue que, cl_kernel reduce_kernel, cl_event init_evt,
	cl_mem d_out, cl_mem d_in, int nels,
	size_t preferred_rounding_reduce, int lws_arg)
{
	size_t lws[] = { lws_arg > 0 ? (size_t)lws_arg : preferred_rounding_reduce };
	size_t gws[] = { round_mul_up(nels, lws[0]) };

	printf("reduce: %u | %zu = %zu\n", nels, lws[0], gws[0]);
	cl_int err;
	cl_event ret;

	int arg = 0;
	err = clSetKernelArg(reduce_kernel, arg, sizeof(d_out), &d_out);
	ocl_check(err, "set reduce_array arg %d", arg++);
	err = clSetKernelArg(reduce_kernel, arg, sizeof(d_in), &d_in);
	ocl_check(err, "set reduce_array arg %d", arg++);
	err = clSetKernelArg(reduce_kernel, arg, sizeof(nels), &nels);
	ocl_check(err, "set reduce_array arg %d", arg++);

	err = clEnqueueNDRangeKernel(que, reduce_kernel, 1,
		NULL, gws, (lws_arg > 0 ? lws : NULL),
		1, &init_evt,  &ret);
	ocl_check(err, "enqueue reduce");

	return ret;
}


void verify(const cl_int sum, int nels)
{
	int expected = (nels - 1)*(nels/2);
	if (expected != sum)
		fprintf(stderr, "mismatch: %d != %d\n", sum, expected);
}

int main(int argc, char *argv[])
{
	if (argc < 2) error("please specify number of elements");

	int nels = atoi(argv[1]);

	if (nels <= 0) error("please specify a positive integer");
	if (nels & (nels - 1)) error("please specify a power of 2");

	int log4 = 1;
	while ((1 << 2*log4) < nels) {
		log4++;
	}
	printf("%u => %u (%u)\n", nels, log4, (1 << 2*log4));
	if ((1 << 2*log4) != nels) error("nels must be a power of 4");

	int lws = 0;
	if (argc == 3) {
		lws = atoi(argv[2]);
	}

	cl_platform_id p = select_platform();
	cl_device_id d = select_device(p);
	cl_context ctx = create_context(p, d);
	cl_command_queue que = create_queue(ctx, d);
	cl_program prog = create_program("reduce.ocl", ctx, d);

	cl_int err;
	cl_kernel init_kernel = clCreateKernel(prog, "init_kernel", &err);
	ocl_check(err, "create init_kernel");
	cl_kernel reduce_kernel = clCreateKernel(prog, "reduce_v4", &err);
	ocl_check(err, "create reduce_kernel");

	size_t memsize = nels*sizeof(cl_int);

	cl_mem d_in = clCreateBuffer(ctx, CL_MEM_READ_WRITE, memsize, NULL, &err);
	ocl_check(err, "create d_in1 failed");
	cl_mem d_out = clCreateBuffer(ctx, CL_MEM_READ_WRITE, memsize, NULL, &err);
	ocl_check(err, "create d_out failed");

	size_t preferred_rounding_init;
	size_t preferred_rounding_reduce;

	err = clGetKernelWorkGroupInfo(init_kernel, d, CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
		sizeof(preferred_rounding_init), &preferred_rounding_init, NULL);
	ocl_check(err, "get preferred work-group size multiple");
	err = clGetKernelWorkGroupInfo(reduce_kernel, d, CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
		sizeof(preferred_rounding_reduce), &preferred_rounding_reduce, NULL);
	ocl_check(err, "get preferred work-group size multiple");

	cl_event reduce_evt[log4+1];

	reduce_evt[0] = init_array(que, init_kernel, d_in, nels, preferred_rounding_init, lws);

	int nquarts = nels;
	for (int l = 0; l < log4; ++l) {
		nquarts /= 4;
		printf("%d: %d\n", l, nquarts);
		reduce_evt[l+1] =
			reduce(que, reduce_kernel, reduce_evt[l],
				d_out, d_in, nquarts, preferred_rounding_reduce, lws);
		cl_mem t = d_in;
		d_in = d_out;
		d_out = t;
	}

	cl_int r;

	cl_event read_evt;
	err = clEnqueueReadBuffer(que, d_in, CL_TRUE, 0, sizeof(cl_int),
		&r, 1, reduce_evt + log4, &read_evt);
	ocl_check(err, "read value");

	verify(r, nels);

	err = clFinish(que);
	ocl_check(err, "finish");

	double init_runtime = runtime_ms(reduce_evt[0]);
	double reduce_runtime = total_runtime_ms(reduce_evt[1], reduce_evt[log4]);
	double read_runtime = runtime_ms(read_evt);

	printf("init: %gms, %gGE/s, %gGB/s\n", init_runtime, nels/init_runtime/1.0e6, memsize/init_runtime/1.e6);
	nquarts = nels/4;
	for (int l = 0; l < log4; ++l) {
		double this_runtime = runtime_ms(reduce_evt[l+1]);
		printf("reduce[%d]: %gms, %gGE/s, %gGB/s\n", l,
			this_runtime, nquarts*4.0/this_runtime/1.0e6, (nquarts*4.0 + nquarts)*sizeof(cl_int)/this_runtime/1.0e6);
		nquarts /= 4;
	}
	printf("reduce: %gms, %gGE/s\n", reduce_runtime, nels/reduce_runtime/1.0e6);
	printf("read: %gms, %gGB/s\n", read_runtime, sizeof(cl_int)/read_runtime/1.0e6);

	clReleaseKernel(init_kernel);
	clReleaseMemObject(d_out);
	clReleaseMemObject(d_in);
	clReleaseProgram(prog);
	clReleaseCommandQueue(que);
	clReleaseContext(ctx);
}
