#define CL_TARGET_OPENCL_VERSION 120
#include "ocl_boiler.h"
#include <stdio.h>
#include <stdlib.h>

void error(const char *err)
{
	fprintf(stderr, "%s\n", err);
	exit(1);
}

cl_event init_array(cl_command_queue que, cl_kernel init_kernel, cl_mem d_array,
	cl_int nrows, cl_int ncols, cl_int pitch_el, int lws_arg)
{
	size_t lws[] = { lws_arg, lws_arg };
	size_t gws[] = { round_mul_up(ncols, lws[0]), round_mul_up(nrows, lws[1]) };

	cl_int err;
	cl_event ret;

	cl_int i = 0;
	err = clSetKernelArg(init_kernel, i, sizeof(d_array), &d_array);
	ocl_check(err, "set init_array arg %d", i++);
	err = clSetKernelArg(init_kernel, i, sizeof(nrows), &nrows);
	ocl_check(err, "set init_array arg %d", i++);
	err = clSetKernelArg(init_kernel, i, sizeof(ncols), &ncols);
	ocl_check(err, "set init_array arg %d", i++);
	err = clSetKernelArg(init_kernel, i, sizeof(pitch_el), &pitch_el);
	ocl_check(err, "set init_array arg %d", i++);

	err = clEnqueueNDRangeKernel(que, init_kernel, 2,
		NULL, gws, lws,
		0, NULL,  &ret);
	ocl_check(err, "enqueue init");

	return ret;
}

void verify(const cl_int *array, int nrows, int ncols, int pitch_el)
{
	for (int r = 0; r < nrows; ++r) {
		for (int c = 0; c < ncols; ++c) {
			int a = array[r*pitch_el+c];
			int expected = r - c;
#if 0
			if (r < 8 && c < 8)
				printf("%d\t", a);
#endif
			if (a != expected)
				fprintf(stderr, "mismatch @ %d, %d: %d != %d\n",
					r, c, a, expected);
		}
		// if (r < 8) printf("\n");
	}
}

int main(int argc, char *argv[])
{
	if (argc < 4) error("matinit rows cols lws");

	int nrows = atoi(argv[1]);
	int ncols = atoi(argv[2]);
	int lws = atoi(argv[3]);

	if (nrows <= 0 || ncols <= 0 || lws <= 0) error("please specify a positive integer");

	cl_platform_id p = select_platform();
	cl_device_id d = select_device(p);
	cl_context ctx = create_context(p, d);
	cl_command_queue que = create_queue(ctx, d);
	cl_program prog = create_program("matinit.ocl", ctx, d);

	cl_int err;
	cl_kernel init_kernel = clCreateKernel(prog, "init_array_pitch_el", &err);
	ocl_check(err, "create init_kernel");

	cl_uint pitch_align;
	err = clGetDeviceInfo(d, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(pitch_align), &pitch_align, NULL);
	ocl_check(err, "get pitch align");

	const size_t pitch_byte = round_mul_up(ncols*sizeof(cl_int), pitch_align/8);
	const size_t pitch_el = pitch_byte/sizeof(cl_int);

	printf("pitch: %d => %zu\n", ncols, pitch_el);

	int nels = nrows*ncols;
	size_t natural_memsize = nels*sizeof(cl_int);
	size_t memsize = pitch_byte*nrows;

	printf("%g%% extra memory\n", memsize*100.0/natural_memsize - 100.0);

	cl_mem d_array = clCreateBuffer(ctx, CL_MEM_WRITE_ONLY | CL_MEM_ALLOC_HOST_PTR,
		memsize, NULL, &err);
	ocl_check(err, "create d_buffer failed");

	cl_event init_evt = init_array(que, init_kernel, d_array, nrows, ncols, pitch_el, lws);

	cl_event map_evt, unmap_evt;

	int *h_array = clEnqueueMapBuffer(que, d_array, CL_TRUE,
		CL_MAP_READ, 0, memsize,
		1, &init_evt, &map_evt, &err);
	ocl_check(err, "map buffer");

	verify(h_array, nrows, ncols, pitch_el);

	err = clEnqueueUnmapMemObject(que, d_array, h_array,
		0, NULL, &unmap_evt);
	ocl_check(err, "unmap buffer");

	err = clFinish(que);
	ocl_check(err, "finish");

	double init_runtime = runtime_ms(init_evt);
	double map_runtime = runtime_ms(map_evt);
	double unmap_runtime = runtime_ms(unmap_evt);

	printf("init: %gms, %gGB/s\n", init_runtime, natural_memsize/init_runtime/1.0e6);
	printf("map: %gms, %gGB/s\n", map_runtime, natural_memsize/map_runtime/1.0e6);
	printf("unmap: %gms, %gGB/s\n", unmap_runtime, natural_memsize/unmap_runtime/1.0e6);

	clReleaseKernel(init_kernel);
	clReleaseMemObject(d_array);
	clReleaseProgram(prog);
	clReleaseCommandQueue(que);
	clReleaseContext(ctx);
}
