34 #include "../../types.h"
39 template <
typename InputIterator,
typename Function,
typename index_type>
40 __global__
void for_each_kernel(InputIterator input,
size_t length, Function func)
44 for(index = blockIdx.x * blockDim.x + threadIdx.x;
46 index += blockDim.x * gridDim.x)
52 template <
typename InputIterator,
typename Function>
53 void for_each(InputIterator input,
size_t length, Function func, int2 launch_params = { 0, 0 })
55 if (launch_params.x == 0 &&
58 int2 params_64 =
launch_parameters(for_each_kernel<InputIterator, Function, uint64>, length);
59 int2 params_32 =
launch_parameters(for_each_kernel<InputIterator, Function, uint32>, length);
62 if (
uint64(length) + params_32.x * params_32.y >=
uint64(1 << 31))
65 for_each_kernel<InputIterator, Function, uint64> <<<params_64.x, params_64.y>>>(input, length, func);
68 for_each_kernel<InputIterator, Function, uint32> <<<params_32.x, params_32.y>>>(input, length, func);
70 }
else if (launch_params.x == 0) {
71 launch_params.x = int((length + launch_params.y - 1) / launch_params.y);
73 if (
uint64(length) + launch_params.x * launch_params.y >=
uint64(1 << 31))
75 for_each_kernel<InputIterator, Function, uint64> <<<launch_params.x, launch_params.y>>>(input, length, func);
77 for_each_kernel<InputIterator, Function, uint32> <<<launch_params.x, launch_params.y>>>(input, length, func);
81 int max_blocks = int((length + launch_params.y - 1) / launch_params.y);
83 if (launch_params.x >
int((length + launch_params.y - 1) / launch_params.y))
85 fprintf(stderr,
"WARNING: for_each call overcommitted, reducing block size to %d\n", max_blocks);
86 launch_params.x = max_blocks;
90 if (
uint64(length) + launch_params.x * launch_params.y >=
uint64(1 << 31))
92 for_each_kernel<InputIterator, Function, uint64> <<<launch_params.x, launch_params.y>>>(input, length, func);
95 for_each_kernel<InputIterator, Function, uint32> <<<launch_params.x, launch_params.y>>>(input, length, func);
100 template <
typename InputIterator,
typename Function>
103 int2 params_64 =
launch_parameters(for_each_kernel<InputIterator, Function, uint64>, length);
104 int2 params_32 =
launch_parameters(for_each_kernel<InputIterator, Function, uint32>, length);
106 if (
uint64(length) + params_32.x * params_32.y >=
uint64(1 << 31))
__global__ void for_each_kernel(InputIterator input, size_t length, Function func)
int2 for_each_launch_parameters(InputIterator input, size_t length, Function func)
int2 launch_parameters(T kernel, size_t elements, int dynamic_smem_size=0)
void for_each(InputIterator input, size_t length, Function func, int2 launch_params={0, 0})