#define CL_TARGET_OPENCL_VERSION 120
#include "ocl_boiler.h"
#include <stdio.h>
#include <stdlib.h>

void error(const char *err)
{
	fprintf(stderr, "%s\n", err);
	exit(1);
}

cl_event init_array(cl_command_queue que, cl_kernel init_kernel, cl_mem d_array,
	cl_int nrows, cl_int ncols, int lws_arg)
{
	size_t lws[] = { lws_arg, lws_arg };
	size_t gws[] = { round_mul_up(ncols, lws[0]), round_mul_up(nrows, lws[1]) };

	cl_int err;
	cl_event ret;

	cl_int i = 0;
	err = clSetKernelArg(init_kernel, i, sizeof(d_array), &d_array);
	ocl_check(err, "set init_array arg %d", i++);
	err = clSetKernelArg(init_kernel, i, sizeof(nrows), &nrows);
	ocl_check(err, "set init_array arg %d", i++);
	err = clSetKernelArg(init_kernel, i, sizeof(ncols), &ncols);
	ocl_check(err, "set init_array arg %d", i++);

	err = clEnqueueNDRangeKernel(que, init_kernel, 2,
		NULL, gws, lws,
		0, NULL,  &ret);
	ocl_check(err, "enqueue init");

	return ret;
}

cl_event transpose_array(cl_command_queue que, cl_kernel transpose_kernel,
	cl_mem d_out, cl_mem d_in,
	cl_int nrows_in, cl_int ncols_in, int lws_arg)
{
	size_t lws[] = { lws_arg, lws_arg };
	size_t gws[] = { round_mul_up(ncols_in, lws[0]), round_mul_up(nrows_in, lws[1]) };

	cl_int err;
	cl_event ret;

	cl_int i = 0;
	err = clSetKernelArg(transpose_kernel, i, sizeof(d_out), &d_out);
	ocl_check(err, "set transpose_array arg %d", i++);
	err = clSetKernelArg(transpose_kernel, i, sizeof(d_in), &d_in);
	ocl_check(err, "set transpose_array arg %d", i++);
	err = clSetKernelArg(transpose_kernel, i, lws[0]*lws[1]*sizeof(cl_int), NULL);
	ocl_check(err, "set transpose_array arg %d", i++);
	err = clSetKernelArg(transpose_kernel, i, sizeof(nrows_in), &nrows_in);
	ocl_check(err, "set transpose_array arg %d", i++);
	err = clSetKernelArg(transpose_kernel, i, sizeof(ncols_in), &ncols_in);
	ocl_check(err, "set transpose_array arg %d", i++);

	err = clEnqueueNDRangeKernel(que, transpose_kernel, 2,
		NULL, gws, lws,
		0, NULL,  &ret);
	ocl_check(err, "enqueue transpose");

	return ret;
}


void verify(const cl_int *array, int nrows_out, int ncols_out)
{
	for (int r = 0; r < nrows_out; ++r) {
		for (int c = 0; c < ncols_out; ++c) {
			int a = array[r*ncols_out+c];
			int expected = c - r;
#if 0
			if (r < 8 && c < 8)
				printf("%d\t", a);
#endif
			if (a != expected)
				fprintf(stderr, "mismatch @ %d, %d: %d != %d\n",
					r, c, a, expected);
		}
		//if (r < 8) printf("\n");
	}
}

int main(int argc, char *argv[])
{
	if (argc < 4) error("matinit rows cols lws");

	int nrows = atoi(argv[1]);
	int ncols = atoi(argv[2]);
	int lws = atoi(argv[3]);

	int rot = 0;
	if (argc == 5) rot = atoi(argv[4]);

	if (nrows <= 0 || ncols <= 0 || lws <= 0) error("please specify a positive integer");

	cl_platform_id p = select_platform();
	cl_device_id d = select_device(p);
	cl_context ctx = create_context(p, d);
	cl_command_queue que = create_queue(ctx, d);
	cl_program prog = create_program("transpose.ocl", ctx, d);

	cl_int err;
	cl_kernel init_kernel = clCreateKernel(prog, "init_array", &err);
	ocl_check(err, "create init_kernel");
	cl_kernel transpose_kernel = clCreateKernel(prog,
		(rot ? "transpose_lmem_rot" : "transpose_lmem"), &err);
	ocl_check(err, "create transpose_kernel");

	int nels = nrows*ncols;
	size_t memsize = nels*sizeof(cl_int);

	cl_mem d_in = clCreateBuffer(ctx, CL_MEM_READ_WRITE, memsize, NULL, &err);
	ocl_check(err, "create d_in failed");
	cl_mem d_out = clCreateBuffer(ctx, CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, memsize, NULL, &err);
	ocl_check(err, "create d_out failed");

	cl_event init_evt = init_array(que, init_kernel, d_in, nrows, ncols, lws);
	cl_event transpose_evt = transpose_array(que, transpose_kernel, d_out, d_in, nrows, ncols, lws);

	cl_event map_evt, unmap_evt;

	int *h_array = clEnqueueMapBuffer(que, d_out, CL_TRUE,
		CL_MAP_READ, 0, memsize,
		1, &transpose_evt, &map_evt, &err);
	ocl_check(err, "map buffer");

	verify(h_array, ncols, nrows);

	err = clEnqueueUnmapMemObject(que, d_out, h_array,
		0, NULL, &unmap_evt);
	ocl_check(err, "unmap buffer");

	err = clFinish(que);
	ocl_check(err, "finish");

	double init_runtime = runtime_ms(init_evt);
	double transpose_runtime = runtime_ms(transpose_evt);
	double map_runtime = runtime_ms(map_evt);
	double unmap_runtime = runtime_ms(unmap_evt);

	printf("init: %gms, %gGB/s\n", init_runtime, memsize/init_runtime/1.0e6);
	printf("transpose: %gms, %gGB/s\n", transpose_runtime, 2*memsize/transpose_runtime/1.0e6);
	printf("map: %gms, %gGB/s\n", map_runtime, memsize/map_runtime/1.0e6);
	printf("unmap: %gms, %gGB/s\n", unmap_runtime, memsize/unmap_runtime/1.0e6);

	clReleaseKernel(init_kernel);
	clReleaseKernel(transpose_kernel);
	clReleaseMemObject(d_out);
	clReleaseMemObject(d_in);
	clReleaseProgram(prog);
	clReleaseCommandQueue(que);
	clReleaseContext(ctx);
}
