General-purpose programming on GPU

First steps in CUDA

Giuseppe Bilotta, Eugenio Rustico, Alexis Hérault

DMI — Università di Catania
Sezione di Catania — INGV

Some vector examples

Adding two vectors

#include <stdio.h>
#include <malloc.h>
#include <sys/time.h>

#include <cuda_runtime_api.h>


__global__ void initVector(float *vec, int size, int mul) {
    int idx = threadIdx.x + blockDim.x*blockIdx.x;
    if (idx >= size)
        return;
    vec[idx] = mul & 1 ? idx*(mul+1) : idx/float(mul);
}

__global__ void vecSum(float *sum, const float *vec1, const float *vec2, int size) {
    int idx = threadIdx.x + blockDim.x*blockIdx.x;
    if (idx >= size)
        return;
    sum[idx] = vec1[idx] + vec2[idx];
}

#define vecsize 20240000
#define bsize 512

int main (int argc, char **argv) {
    int deviceCount = -1; // number of devices
    int dev = 0;

    /* Host and device vector pointers */
    float *hVec1, *hVec2, *hVec3, *hVec3cpu;
    float *dVec1, *dVec2, *dVec3;

    cudaGetDeviceCount(&deviceCount);

    if (deviceCount == 0) {
        fprintf(stderr, "No CUDA devices found\n");
        return 1;
    }

    cudaError_t error = cudaSetDevice(dev);
    if (error != cudaSuccess) {
        fprintf(stderr, "Error setting device to %d: %s\n",
            dev, cudaGetErrorString(error));
        return 1;
    }

    /* Allocate host vector */
    hVec1 = (float *)malloc(vecsize*sizeof(float));
    hVec2 = (float *)malloc(vecsize*sizeof(float));
    hVec3 = (float *)malloc(vecsize*sizeof(float));
    hVec3cpu = (float *)malloc(vecsize*sizeof(float));
    if (!hVec1 || !hVec2 || !hVec3 || !hVec3cpu) {
        fprintf(stderr, "Unable to allocate host vector\n");
        return 1;
    }

    /* Allocate device vector */
    error = cudaMalloc(&dVec1, vecsize*sizeof(float));
    error = cudaMalloc(&dVec2, vecsize*sizeof(float));
    error = cudaMalloc(&dVec3, vecsize*sizeof(float));
    if (error != cudaSuccess) {
        fprintf(stderr, "Unable to allocate device memory: %s\n",
            cudaGetErrorString(error));
        return 1;
    }

    /* Zero device vector */
    cudaMemset(dVec1, 0, vecsize*sizeof(float));
    cudaMemset(dVec2, 0, vecsize*sizeof(float));
    cudaMemset(dVec3, 0, vecsize*sizeof(float));

    int gridsize = vecsize/bsize;
    if (bsize*gridsize < vecsize)
        gridsize += 1;

    /* Launch kernel: the first parameter is the number of blocks, the second the block size */
    initVector<<<gridsize, bsize>>>(dVec1, vecsize, 1);
    initVector<<<gridsize, bsize>>>(dVec2, vecsize, 2);

    /* Wait for kernel execution to finish */
    error = cudaThreadSynchronize();
    if (error != cudaSuccess) {
        fprintf(stderr, "initVector failed: %s\n",
            cudaGetErrorString(error));
        return 1;
    }

    /* Copy data from device to host */
    error = cudaMemcpy(hVec1, dVec1, vecsize*sizeof(float), cudaMemcpyDeviceToHost);
    error = cudaMemcpy(hVec2, dVec2, vecsize*sizeof(float), cudaMemcpyDeviceToHost);
    if (error != cudaSuccess) {
        fprintf(stderr, "cudaMemcpy failed: %s\n",
            cudaGetErrorString(error));
        return 1;
    }

    printf("Initialized 2 vectors with %u elements each\n", vecsize);


    cudaEvent_t gpu_start, gpu_stop;
    float gpu_runtime;

    cudaEventCreate(&gpu_start);
    cudaEventCreate(&gpu_stop);
    cudaEventRecord(gpu_start, 0);

    vecSum<<<gridsize, bsize>>>(dVec3, dVec1, dVec2, vecsize);

    cudaEventRecord(gpu_stop, 0);
    cudaEventSynchronize(gpu_stop);
    cudaEventElapsedTime(&gpu_runtime, gpu_start, gpu_stop);
    printf("CUDA runtime: %gms\n", gpu_runtime);

    /* Copy data from device to host */
    error = cudaMemcpy(hVec3, dVec3, vecsize*sizeof(float), cudaMemcpyDeviceToHost);
    if (error != cudaSuccess) {
        fprintf(stderr, "cudaMemcpy failed: %s\n",
            cudaGetErrorString(error));
        return 1;
    }

    struct timeval cpu_start, cpu_stop;
    gettimeofday(&cpu_start, NULL);
    for (int i = 0; i < vecsize; ++i) {
        hVec3cpu[i] = hVec1[i] + hVec2[i];
    }
    gettimeofday(&cpu_stop, NULL);

    long int seconds = cpu_stop.tv_sec - cpu_start.tv_sec;
    long int useconds = cpu_stop.tv_usec - cpu_start.tv_usec;
    double cpu_runtime = (double(seconds) * 1000 + double(useconds)/1000);

    printf("CPU runtime: %lgms\n", cpu_runtime);
    printf("Speedup (els %u, bs %u): %.2lg\n",
                 vecsize, bsize, cpu_runtime/gpu_runtime);

    for (int i = 0; i < vecsize; ++i) {
        if (hVec3[i] != hVec3cpu[i]) {
            printf("ERROR @ %u: CPU %g, GPU %g\n",
                   i, hVec3[i], hVec3cpu[i]);
            break;
        }
    }


    return 0;

}

Minimization: a parallel reduction example

A reduction is done when a single value is obtained from a large number of values. Typical examples: finding the minimum/maximum of an array, calculating the dot product of two vectors.

Parallel reduction requires the use of shared memory to allow exchange of information between threads in the same block (CTA = Cooperative Thread Array).

Shared memory is allocated inside a kernel with the __shared__ specifier. Each block is assigned the given amount of shared memory, accessible for read/write by all the threads in the block.

When a thread needs to read data written by some other thread in the same block, it must ensure that the write has completed. This is achieved by waiting on the other threads with the __syncthreads() primitive.


Shared memory arrays can be declared statically in a kernel (size defined at compile time), or dynamically (size declared when the kernel is called).

Static size example:

#define SHMEM_SIZE 32

__global__ void someKernel(...) {
    __shared__ float shMem[SHMEM_SIZE];
    ....
}

...

someKernel<<<gridSize, blockSize>>>(...)

Runtime size example:

__global__ void someKernel(...) {
    extern __shared__ float shMem[];
    ....
}

...

someKernel<<<gridSize, blockSize, shMemSize>>>(...)

Reduction is an intrinsically parallel process. To parallelize it, we decompose it in a number of indepent, parallel mini-reductions. In shared memory, this is done with a loop like the following:

uint active = blockDim.x >> 1;
do {
    __syncthreads();
    if (threadIdx.x < active)
        shMem[threadIdx.x] = fmin(shMem[threadIdx.x], shMem[threadIdx.x+active]);
    active >>= 1;
} while (active > 0);

At each pass, we halve the number of active threads, with each active thread comparing their own value with the corresponding one in the second half of the shared memory: 16 threads reduce 32 values to 16 values, then 8 threads reduce 16 values to 8 values, and so on, until a single value remains.


Each block can thus reduce a number of elements (in shared memory) equal to the number of threads in the block to a single element. A single kernel invocation therefore reduces the array to as many values a there were blocks invoked. If more than one block was launched, additional reduction steps have to be invoked, until a single block is sufficient to.


When loading data from global to shared memory, an initial reduction can already be done, by letting each thread read more than one value from global memory, and only storing in shared memory its final result. With a code such as:

uint gix = threadIdx.x + blockDim.x*blockIdx.x;
float acc = CUDART_NAN_F;
while (gix < dim) {
    acc = fmin(acc, dSrc[gix]);
    gix += blockDim.x*gridDim.x;
}
shMem[tid] = acc;

we can choose how many blocks we want, and they will read as many elements from global memory as necessary to cover the whole array. This means that a single kernel launch can reduce the whole array: however, this is not efficient because most of the GPU is not being used.

The most efficient approach uses two kernel launches, one with the amount of blocks necessary to saturate the hardware, the other with a single block to ‘finish up’ the reduction.

Example code is shown next, with the key parameters being the BLOCK_SIZE and the number of blocks nblocks launched in the first step. Try to find the optimal choice for your hardware.


/* usual C/C++ includes */
#include <stdio.h>
#include <string.h>
#include <stdlib.h> // for rand()
#include <errno.h>

#include <cuda_runtime_api.h>
#include <math_constants.h>

#define N 1024*1024*64

/* host buffer */
float *data;
/* device buffers */
float *dSrc, *dDst;

void check_error(cudaError error, const char *message) {
    if (error != cudaSuccess) {
        fprintf(stderr, "%s (%s)\n", message,
            cudaGetErrorString(error));
        if (dSrc)
            cudaFree(dSrc);
        if (dDst)
            cudaFree(dDst);
        exit(1);
    }
}

#define WARP_SIZE 32
#define BLOCK_SIZE (12*WARP_SIZE)

__global__ void findMin(float *dDst, const float *dSrc, uint dim)
{
    __shared__ float cache[BLOCK_SIZE];

    uint gix = threadIdx.x + blockDim.x*blockIdx.x;

#define tid threadIdx.x

    float acc = CUDART_NAN_F;

    while (gix < dim) {
        acc = fmin(acc, dSrc[gix]);
        gix += blockDim.x*gridDim.x;
    }

    cache[tid] = acc;

    uint active = blockDim.x >> 1;

    do {
        __syncthreads();
        if (tid < active)
            cache[tid] = fmin(cache[tid], cache[tid+active]);
        active >>= 1;
    } while (active > 0);

    if (tid == 0)
        dDst[blockIdx.x] = cache[0];
}

int main(int argc, char **argv) {
    data = (float*) calloc(N, sizeof(float));
    size_t data_size = N * sizeof(float);
    float min = nan(""), d_min = nan("");

    for (size_t i = 0; i < N; ++i) {
        data[i] = N*float(rand())/RAND_MAX;
        min = fmin(min, data[i]);
    }
    printf("%u elements generated, min %g, data size %zu (%zuMB)\n",
            N, min, data_size, data_size>>20);

    cudaError_t err;

    err = cudaMalloc(&dSrc, data_size);
    check_error(err, "allocating array");

    err = cudaMemcpy(dSrc, data, data_size, cudaMemcpyHostToDevice);
    check_error(err, "copy UP");

    uint nblocks = 8;

    err = cudaMalloc(&dDst, nblocks*sizeof(*dDst));
    check_error(err, "allocating Dst array");

    cudaEvent_t start, stop;
    float runtime;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);

    cudaEventRecord(start, 0);
    findMin<<<nblocks,BLOCK_SIZE>>>(dDst, dSrc, N);
    findMin<<<1,BLOCK_SIZE>>>(dDst, dDst, nblocks);
    cudaEventRecord(stop, 0);

    cudaEventSynchronize(stop);
    cudaEventElapsedTime(&runtime, start, stop);

    /* Giga-elements per second */
    printf("%u elements processed in %gms: %gGE/s\n",
        N, runtime, (N/runtime)/(1000000));

    /* Actual bandwith in GB/s */
    uint total_els = N + nblocks;
    float sizeMB = float(total_els)*sizeof(float)/(1024*1024);
    printf("Bandwidth: %u elements (%gMB) read in two steps. "
        "Runtime: %gms (%gGB/s)\n",
        total_els, sizeMB, runtime, sizeMB/runtime);

    err = cudaMemcpy(&d_min, dDst, sizeof(d_min), cudaMemcpyDeviceToHost);
    check_error(err, "copy DOWN");

    cudaFree(dSrc); dSrc = NULL;
    cudaFree(dDst); dDst = NULL;
    free(data);

    printf("Parallel min: %g vs %g\n", d_min, min);
}