/* -----------------------------------------------------------------------------
 *
 * Module      : Shape
 * Copyright   : [2008..2011] Manuel M T Chakravarty, Gabriele Keller, Sean Lee, Trevor L. McDonell
 * License     : BSD3
 *
 * Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
 * Stability   : experimental
 *
 * ---------------------------------------------------------------------------*/

#ifndef __ACCELERATE_CUDA_SHAPE_H__
#define __ACCELERATE_CUDA_SHAPE_H__

#include <stdint.h>
#include <cuda_runtime.h>

typedef int32_t                                   Ix;
typedef void*                                     DIM0;
typedef Ix                                        DIM1;
typedef struct { Ix a1,a0; }                      DIM2;
typedef struct { Ix a2,a1,a0; }                   DIM3;
typedef struct { Ix a3,a2,a1,a0; }                DIM4;
typedef struct { Ix a4,a3,a2,a1,a0; }             DIM5;
typedef struct { Ix a5,a4,a3,a2,a1,a0; }          DIM6;
typedef struct { Ix a6,a5,a4,a3,a2,a1,a0; }       DIM7;
typedef struct { Ix a7,a6,a5,a4,a3,a2,a1,a0; }    DIM8;
typedef struct { Ix a8,a7,a6,a5,a4,a3,a2,a1,a0; } DIM9;

#ifdef __cplusplus

/* -----------------------------------------------------------------------------
 * Shape construction and destruction
 */

/*
 * Convert the individual dimensions of a linear array into a shape
 */
static __inline__ __device__ DIM0 shape()
{
    return NULL;
}

static __inline__ __device__ DIM1 shape(const Ix a)
{
    return a;
}

static __inline__ __device__ DIM2 shape(const Ix b, const Ix a)
{
    DIM2 sh = { b, a };
    return sh;
}

static __inline__ __device__ DIM3 shape(const Ix c, const Ix b, const Ix a)
{
    DIM3 sh = { c, b, a };
    return sh;
}

static __inline__ __device__ DIM4 shape(const Ix d, const Ix c, const Ix b, const Ix a)
{
    DIM4 sh = { d, c, b, a };
    return sh;
}

static __inline__ __device__ DIM5 shape(const Ix e, const Ix d, const Ix c, const Ix b, const Ix a)
{
    DIM5 sh = { e, d, c, b, a };
    return sh;
}

static __inline__ __device__ DIM6 shape(const Ix f, const Ix e, const Ix d, const Ix c, const Ix b, const Ix a)
{
    DIM6 sh = { f, e, d, c, b, a };
    return sh;
}

static __inline__ __device__ DIM7 shape(const Ix g, const Ix f, const Ix e, const Ix d, const Ix c, const Ix b, const Ix a)
{
    DIM7 sh = { g, f, e, d, c, b, a };
    return sh;
}

static __inline__ __device__ DIM8 shape(const Ix h, const Ix g, const Ix f, const Ix e, const Ix d, const Ix c, const Ix b, const Ix a)
{
    DIM8 sh = { h, g, f, e, d, c, b, a };
    return sh;
}

static __inline__ __device__ DIM9 shape(const Ix i, const Ix h, const Ix g, const Ix f, const Ix e, const Ix d, const Ix c, const Ix b, const Ix a)
{
    DIM9 sh = { i, h, g, f, e, d, c, b, a };
    return sh;
}

/*
 * Yield the inner-most dimension of a shape only
 */
template <typename Shape>
static __inline__ __device__ Ix indexHead(const Shape ix)
{
    return ix.a0;
}

template <>
static __inline__ __device__ Ix indexHead(const DIM0 ix)
{
    return 0;
}

template <>
static __inline__ __device__ Ix indexHead(const DIM1 ix)
{
    return ix;
}


/*
 * Yield all but the inner-most dimension of a shape
 */
static __inline__ __device__ DIM0 indexTail(const DIM1 ix)
{
    return 0;
}

static __inline__ __device__ DIM1 indexTail(const DIM2 ix)
{
    return ix.a1;
}

static __inline__ __device__ DIM2 indexTail(const DIM3 ix)
{
    return shape(ix.a2, ix.a1);
}

static __inline__ __device__ DIM3 indexTail(const DIM4 ix)
{
    return shape(ix.a3, ix.a2, ix.a1);
}

static __inline__ __device__ DIM4 indexTail(const DIM5 ix)
{
    return shape(ix.a4, ix.a3, ix.a2, ix.a1);
}

static __inline__ __device__ DIM5 indexTail(const DIM6 ix)
{
    return shape(ix.a5, ix.a4, ix.a3, ix.a2, ix.a1);
}

static __inline__ __device__ DIM6 indexTail(const DIM7 ix)
{
    return shape(ix.a6, ix.a5, ix.a4, ix.a3, ix.a2, ix.a1);
}

static __inline__ __device__ DIM7 indexTail(const DIM8 ix)
{
    return shape(ix.a7, ix.a6, ix.a5, ix.a4, ix.a3, ix.a2, ix.a1);
}

static __inline__ __device__ DIM8 indexTail(const DIM9 ix)
{
    return shape(ix.a8, ix.a7, ix.a6, ix.a5, ix.a4, ix.a3, ix.a2, ix.a1);
}


/* -----------------------------------------------------------------------------
 * Shape methods
 */

/*
 * Number of dimensions of a shape
 */
template <typename Shape>
static __inline__ __device__ int dim(const Shape sh)
{
    return dim(indexTail(sh)) + 1;
}

template <>
static __inline__ __device__ int dim(const DIM0 sh)
{
    return 0;
}


/*
 * Yield the total number of elements in a shape
 */
template <typename Shape>
static __inline__ __device__ int size(const Shape sh)
{
    return size(indexTail(sh)) * indexHead(sh);
}

template <>
static __inline__ __device__ int size(const DIM0 sh)
{
    return 1;
}


/*
 * Add an index to the head of a shape
 */
static __inline__ __device__ DIM1 indexCons(const DIM0 sh, const Ix ix)
{
    return shape(ix);
}

static __inline__ __device__ DIM2 indexCons(const DIM1 sh, const Ix ix)
{
    return shape(sh, ix);
}

static __inline__ __device__ DIM3 indexCons(const DIM2 sh, const Ix ix)
{
    return shape(sh.a1, sh.a0, ix);
}

static __inline__ __device__ DIM4 indexCons(const DIM3 sh, const Ix ix)
{
    return shape(sh.a2, sh.a1, sh.a0, ix);
}

static __inline__ __device__ DIM5 indexCons(const DIM4 sh, const Ix ix)
{
    return shape(sh.a3, sh.a2, sh.a1, sh.a0, ix);
}

static __inline__ __device__ DIM6 indexCons(const DIM5 sh, const Ix ix)
{
    return shape(sh.a4, sh.a3, sh.a2, sh.a1, sh.a0, ix);
}

static __inline__ __device__ DIM7 indexCons(const DIM6 sh, const Ix ix)
{
    return shape(sh.a5, sh.a4, sh.a3, sh.a2, sh.a1, sh.a0, ix);
}

static __inline__ __device__ DIM8 indexCons(const DIM7 sh, const Ix ix)
{
    return shape(sh.a6, sh.a5, sh.a4, sh.a3, sh.a2, sh.a1, sh.a0, ix);
}

static __inline__ __device__ DIM9 indexCons(const DIM8 sh, const Ix ix)
{
    return shape(sh.a7, sh.a6, sh.a5, sh.a4, sh.a3, sh.a2, sh.a1, sh.a0, ix);
}


/*
 * Yield the index position in a linear, row-major representation of the array.
 * First argument is the shape of the array, the second the index
 */
template <typename Shape>
static __inline__ __device__ Ix toIndex(const Shape sh, const Shape ix)
{
    return toIndex(indexTail(sh), indexTail(ix)) * indexHead(sh) + indexHead(ix);
}

template <>
static __inline__ __device__ Ix toIndex(const DIM0 sh, const DIM0 ix)
{
    return 0;
}

template <>
static __inline__ __device__ Ix toIndex(const DIM1 sh, const DIM1 ix)
{
    return ix;
}


/*
 * Inverse of 'toIndex'
 */
template <typename Shape>
static __inline__ __device__ Shape fromIndex(const Shape sh, const Ix ix)
{
    const Ix d = indexHead(sh);
    return indexCons(fromIndex(indexTail(sh), ix / d), ix % d);
}

template <>
static __inline__ __device__ DIM0 fromIndex(const DIM0 sh, const Ix ix)
{
    return 0;
}

template <>
static __inline__ __device__ DIM1 fromIndex(const DIM1 sh, const Ix ix)
{
    return ix;
}


/*
 * Test for the magic index `ignore`
 */
template <typename Shape>
static __inline__ __device__ int ignore(const Shape ix)
{
    return indexHead(ix) == -1 && ignore(indexTail(ix));
}

template <>
static __inline__ __device__ int ignore(const DIM0 ix)
{
    return 1;
}


#else

static __inline__ __device__ int dim(const Ix sh);
static __inline__ __device__ int size(const Ix sh);
static __inline__ __device__ int shape(const Ix sh);
static __inline__ __device__ int toIndex(const Ix sh, const Ix ix);
static __inline__ __device__ int fromIndex(const Ix sh, const Ix ix);
static __inline__ __device__ int indexHead(const Ix ix);
static __inline__ __device__ int indexTail(const Ix ix);
static __inline__ __device__ int indexCons(const Ix sh, const Ix ix);

#endif  // __cplusplus
#endif  // __ACCELERATE_CUDA_SHAPE_H__

