#define CL_TARGET_OPENCL_VERSION 120
#include "ocl_boiler.h"
#include <stdio.h>
#include <stdlib.h>

void error(const char *err)
{
	fprintf(stderr, "%s\n", err);
	exit(1);
}

cl_event init_array(cl_command_queue que, cl_kernel init_kernel,
	cl_mem d_in, int nels, size_t preferred_rounding_init, int lws_arg)
{
	size_t lws[] = { lws_arg > 0 ? (size_t)lws_arg : preferred_rounding_init };
	size_t gws[] = { round_mul_up(nels, lws[0]) };

	printf("init: %u | %zu = %zu\n", nels, lws[0], gws[0]);
	cl_int err;
	cl_event ret;

	int arg = 0;
	err = clSetKernelArg(init_kernel, arg, sizeof(d_in), &d_in);
	ocl_check(err, "set init_array arg %d", arg++);
	err = clSetKernelArg(init_kernel, arg, sizeof(nels), &nels);
	ocl_check(err, "set init_array arg %d", arg++);

	err = clEnqueueNDRangeKernel(que, init_kernel, 1,
		NULL, gws, (lws_arg > 0 ? lws : NULL),
		0, NULL,  &ret);
	ocl_check(err, "enqueue init");

	return ret;
}

size_t reduce_groups(int nels, int lws_arg)
{
	size_t ngroups = round_div_up(nels, lws_arg);
	if (ngroups > 1)
		ngroups = round_mul_up(ngroups, 4);
	return ngroups;
}

cl_event reduce(cl_command_queue que, cl_kernel reduce_kernel, cl_event init_evt,
	cl_mem d_out, cl_mem d_in, int nels, int lws_arg, int *ngroups_out)
{
	size_t lws[] = { lws_arg };
	size_t ngroups = reduce_groups(nels, lws_arg);
	size_t gws[] = { ngroups*lws[0] };

	printf("reduce: %u | %zu * %zu = %zu\n", nels, lws[0], ngroups, gws[0]);

	cl_int err;
	cl_event ret;

	int arg = 0;
	err = clSetKernelArg(reduce_kernel, arg, sizeof(d_out), &d_out);
	ocl_check(err, "set reduce_array arg %d", arg++);
	err = clSetKernelArg(reduce_kernel, arg, sizeof(d_in), &d_in);
	ocl_check(err, "set reduce_array arg %d", arg++);
	err = clSetKernelArg(reduce_kernel, arg, lws[0]*sizeof(cl_int), NULL);
	ocl_check(err, "set reduce_array arg %d", arg++);
	err = clSetKernelArg(reduce_kernel, arg, sizeof(nels), &nels);
	ocl_check(err, "set reduce_array arg %d", arg++);

	err = clEnqueueNDRangeKernel(que, reduce_kernel, 1,
		NULL, gws, lws,
		1, &init_evt,  &ret);
	ocl_check(err, "enqueue reduce");

	*ngroups_out = ngroups;
	return ret;
}


void verify(const cl_int sum, int nels)
{
	int expected = (nels - 1)*(nels/2);
	if (expected != sum)
		fprintf(stderr, "mismatch: %d != %d\n", sum, expected);
}

int main(int argc, char *argv[])
{
	if (argc < 3) error("please specify number of elements");

	int nels = atoi(argv[1]);

	if (nels <= 0) error("please specify a positive integer");
	if (nels & 3) error("please specify a multiple of 4");

	int lws = atoi(argv[2]);
	if (lws <= 0) error("lws must be > 0");
	if (lws & (lws - 1)) error("lws must be a power of 2");

	int reduction_steps = 1;
	int nquarts = nels/4;
	while (1) {
		int ngroups = reduce_groups(nquarts, lws);
		if (ngroups == 1) break;
		++reduction_steps;
		nquarts = ngroups/4;
	}
	nquarts = nels/4;

	printf("%u => %u\n", nels, reduction_steps);

	cl_platform_id p = select_platform();
	cl_device_id d = select_device(p);
	cl_context ctx = create_context(p, d);
	cl_command_queue que = create_queue(ctx, d);
	cl_program prog = create_program("reduce.ocl", ctx, d);

	cl_int err;
	cl_kernel init_kernel = clCreateKernel(prog, "init_kernel", &err);
	ocl_check(err, "create init_kernel");
	cl_kernel reduce_kernel = clCreateKernel(prog, "reduce_lmem", &err);
	ocl_check(err, "create reduce_kernel");

	size_t memsize = nels*sizeof(cl_int);

	cl_mem d_in = clCreateBuffer(ctx, CL_MEM_READ_WRITE, memsize, NULL, &err);
	ocl_check(err, "create d_in1 failed");
	cl_mem d_out = clCreateBuffer(ctx, CL_MEM_READ_WRITE, reduce_groups(nquarts, lws)*sizeof(cl_int),
		NULL, &err);
	ocl_check(err, "create d_out failed");

	size_t preferred_rounding_init;

	err = clGetKernelWorkGroupInfo(init_kernel, d, CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
		sizeof(preferred_rounding_init), &preferred_rounding_init, NULL);
	ocl_check(err, "get preferred work-group size multiple");

	cl_event reduce_evt[reduction_steps+1];

	reduce_evt[0] = init_array(que, init_kernel, d_in, nels, preferred_rounding_init, lws);

	for (int l = 0; l < reduction_steps; ++l) {
		printf("%d: %d\n", l, nquarts);
		int ngroups;
		reduce_evt[l+1] =
			reduce(que, reduce_kernel, reduce_evt[l], d_out, d_in, nquarts, lws, &ngroups);
		cl_mem t = d_in;
		d_in = d_out;
		d_out = t;
		nquarts = ngroups/4;
	}

	cl_int r;

	cl_event read_evt;
	err = clEnqueueReadBuffer(que, d_in, CL_TRUE, 0, sizeof(cl_int),
		&r, 1, reduce_evt + reduction_steps, &read_evt);
	ocl_check(err, "read value");

	verify(r, nels);

	err = clFinish(que);
	ocl_check(err, "finish");

	double init_runtime = runtime_ms(reduce_evt[0]);
	double reduce_runtime = total_runtime_ms(reduce_evt[1], reduce_evt[reduction_steps]);
	double read_runtime = runtime_ms(read_evt);

	printf("init: %gms, %gGE/s, %gGB/s\n", init_runtime, nels/init_runtime/1.0e6, memsize/init_runtime/1.e6);
	nquarts = nels/4;
	for (int l = 0; l < reduction_steps; ++l) {
		double this_runtime = runtime_ms(reduce_evt[l+1]);
		int ngroups = reduce_groups(nquarts, lws);
		printf("reduce[%d]: %gms, %gGE/s, %gGB/s\n", l,
			this_runtime, nquarts*4.0/this_runtime/1.0e6, (nquarts*4.0 + ngroups)*sizeof(cl_int)/this_runtime/1.0e6);
		nquarts = ngroups/4;
	}
	printf("reduce: %gms, %gGE/s\n", reduce_runtime, nels/reduce_runtime/1.0e6);
	printf("read: %gms, %gGB/s\n", read_runtime, sizeof(cl_int)/read_runtime/1.0e6);

	clReleaseKernel(init_kernel);
	clReleaseMemObject(d_out);
	clReleaseMemObject(d_in);
	clReleaseProgram(prog);
	clReleaseCommandQueue(que);
	clReleaseContext(ctx);
}
