#define CL_TARGET_OPENCL_VERSION 120
#include "ocl_boiler.h"
#include <stdio.h>
#include <stdlib.h>

void error(const char *err)
{
	fprintf(stderr, "%s\n", err);
	exit(1);
}

cl_event init_arrays(cl_command_queue que, cl_kernel init_kernel,
	cl_mem d_in1, cl_mem d_in2,
	int nels, size_t preferred_rounding_init, int lws_arg)
{
	size_t lws[] = { lws_arg > 0 ? (size_t)lws_arg : preferred_rounding_init };
	size_t gws[] = { round_mul_up(nels, lws[0]) };

	printf("init: %u | %zu = %zu\n", nels, lws[0], gws[0]);
	cl_int err;
	cl_event ret;

	int arg = 0;
	err = clSetKernelArg(init_kernel, arg, sizeof(d_in1), &d_in1);
	ocl_check(err, "set init_array arg %d", arg++);
	err = clSetKernelArg(init_kernel, arg, sizeof(d_in2), &d_in2);
	ocl_check(err, "set init_array arg %d", arg++);
	err = clSetKernelArg(init_kernel, arg, sizeof(nels), &nels);
	ocl_check(err, "set init_array arg %d", arg++);

	err = clEnqueueNDRangeKernel(que, init_kernel, 1,
		NULL, gws, (lws_arg > 0 ? lws : NULL),
		0, NULL,  &ret);
	ocl_check(err, "enqueue init");

	return ret;
}

cl_event sum_arrays(cl_command_queue que, cl_kernel sum_kernel, cl_event init_evt,
	cl_mem d_out, cl_mem d_in1, cl_mem d_in2,
	int nels,
	size_t preferred_rounding_sum, int lws_arg)
{
	size_t lws[] = { lws_arg > 0 ? (size_t)lws_arg : preferred_rounding_sum };
	size_t gws[] = { round_mul_up(nels, lws[0]) };

	printf("sum: %u | %zu = %zu\n", nels, lws[0], gws[0]);
	cl_int err;
	cl_event ret;

	int arg = 0;
	err = clSetKernelArg(sum_kernel, arg, sizeof(d_out), &d_out);
	ocl_check(err, "set sum_array arg %d", arg++);
	err = clSetKernelArg(sum_kernel, arg, sizeof(d_in1), &d_in1);
	ocl_check(err, "set sum_array arg %d", arg++);
	err = clSetKernelArg(sum_kernel, arg, sizeof(d_in2), &d_in2);
	ocl_check(err, "set sum_array arg %d", arg++);
	err = clSetKernelArg(sum_kernel, arg, sizeof(nels), &nels);
	ocl_check(err, "set sum_array arg %d", arg++);

	err = clEnqueueNDRangeKernel(que, sum_kernel, 1,
		NULL, gws, (lws_arg > 0 ? lws : NULL),
		1, &init_evt,  &ret);
	ocl_check(err, "enqueue sum");

	return ret;
}


void verify(const cl_int *array, int nels)
{
	for (int i = 0; i < nels; ++i) {
		if (array[i] != nels) {
			fprintf(stderr, "mismatch @ %d: %d != %d\n",
				i, array[i], nels);
		}
	}
}

int main(int argc, char *argv[])
{
	if (argc < 2) error("please specify number of elements");

	int nels = atoi(argv[1]);

	if (nels <= 0) error("please specify a positive integer");

	int lws = 0;
	if (argc == 3) {
		lws = atoi(argv[2]);
	}

	cl_platform_id p = select_platform();
	cl_device_id d = select_device(p);
	cl_context ctx = create_context(p, d);
	cl_command_queue que = create_queue(ctx, d);
	cl_program prog = create_program("vecsum.ocl", ctx, d);

	cl_int err;
	cl_kernel init_kernel = clCreateKernel(prog, "init_kernel", &err);
	ocl_check(err, "create init_kernel");
	cl_kernel sum_kernel = clCreateKernel(prog, "sum_kernel_v1", &err);
	ocl_check(err, "create sum_kernel");

	size_t memsize = nels*sizeof(cl_int);

	cl_mem d_in1 = clCreateBuffer(ctx, CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
		memsize, NULL, &err);
	ocl_check(err, "create d_in1 failed");
	cl_mem d_in2 = clCreateBuffer(ctx, CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
		memsize, NULL, &err);
	ocl_check(err, "create d_in2 failed");
	cl_mem d_out = clCreateBuffer(ctx, CL_MEM_WRITE_ONLY | CL_MEM_ALLOC_HOST_PTR,
		memsize, NULL, &err);
	ocl_check(err, "create d_out failed");

	size_t preferred_rounding_init;
	size_t preferred_rounding_sum;

	err = clGetKernelWorkGroupInfo(init_kernel, d, CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
		sizeof(preferred_rounding_init), &preferred_rounding_init, NULL);
	ocl_check(err, "get preferred work-group size multiple");
	err = clGetKernelWorkGroupInfo(sum_kernel, d, CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
		sizeof(preferred_rounding_sum), &preferred_rounding_sum, NULL);
	ocl_check(err, "get preferred work-group size multiple");

	cl_event init_evt = init_arrays(que, init_kernel, d_in1, d_in2, nels, preferred_rounding_init, lws);
	cl_event sum_evt = sum_arrays(que, sum_kernel, init_evt, d_out, d_in1, d_in2, nels, preferred_rounding_sum, lws);

	cl_event map_evt, unmap_evt;

	int *h_array = clEnqueueMapBuffer(que, d_out, CL_TRUE,
		CL_MAP_READ, 0, memsize,
		1, &sum_evt, &map_evt, &err);
	ocl_check(err, "map buffer");

	verify(h_array, nels);

	err = clEnqueueUnmapMemObject(que, d_out, h_array,
		0, NULL, &unmap_evt);
	ocl_check(err, "unmap buffer");

	err = clFinish(que);
	ocl_check(err, "finish");

	double init_runtime = runtime_ms(init_evt);
	double sum_runtime = runtime_ms(sum_evt);
	double map_runtime = runtime_ms(map_evt);
	double unmap_runtime = runtime_ms(unmap_evt);

	printf("init: %gms, %gGE/s, %gGB/s\n", init_runtime, 2*nels/init_runtime/1.0e6, 2*memsize/init_runtime/1.e6);
	printf("sum: %gms, %gGE/s, %gGB/s\n", sum_runtime, nels/sum_runtime/1.0e6, 3*memsize/sum_runtime/1.0e6);
	printf("map: %gms, %gGE/s, %gGB/s\n", map_runtime, nels/map_runtime/1.0e6, memsize/map_runtime/1.0e6);
	printf("unmap: %gms\n", unmap_runtime);

	clReleaseKernel(init_kernel);
	clReleaseMemObject(d_out);
	clReleaseMemObject(d_in1);
	clReleaseMemObject(d_in2);
	clReleaseProgram(prog);
	clReleaseCommandQueue(que);
	clReleaseContext(ctx);
}
