Skip to content

Commit

Permalink
plan: use vector instead of array for length/stride
Browse files Browse the repository at this point in the history
* plan: use vector instead of array for length/stride

This way, we don't also need to pass plan rank along with length/stride.
  • Loading branch information
evetsso authored Aug 31, 2023
1 parent 87e628d commit da2d66b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 41 deletions.
17 changes: 8 additions & 9 deletions library/src/include/plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ struct rocfft_plan_description_t
rocfft_array_type inArrayType = rocfft_array_type_unset;
rocfft_array_type outArrayType = rocfft_array_type_unset;

std::array<size_t, 3> inStrides = {0, 0, 0};
std::array<size_t, 3> outStrides = {0, 0, 0};
std::vector<size_t> inStrides;
std::vector<size_t> outStrides;

size_t inDist = 0;
size_t outDist = 0;
Expand All @@ -72,17 +72,16 @@ struct rocfft_plan_description_t
// type of transform it will be for. Once that's known, we can
// initialize default values for in/out type, stride, dist if they're
// unspecified.
void init_defaults(rocfft_transform_type transformType,
rocfft_result_placement placement,
size_t rank,
const std::array<size_t, 3>& lengths);
void init_defaults(rocfft_transform_type transformType,
rocfft_result_placement placement,
const std::vector<size_t>& lengths);
};

struct rocfft_plan_t
{
size_t rank = 1;
std::array<size_t, 3> lengths = {1, 1, 1};
size_t batch = 1;
size_t rank = 1;
std::vector<size_t> lengths;
size_t batch = 1;

rocfft_result_placement placement = rocfft_placement_inplace;
rocfft_transform_type transformType = rocfft_transform_type_complex_forward;
Expand Down
57 changes: 25 additions & 32 deletions library/src/plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ static size_t offset_count(rocfft_array_type type)
: 1;
}

void rocfft_plan_description_t::init_defaults(rocfft_transform_type transformType,
rocfft_result_placement placement,
size_t rank,
const std::array<size_t, 3>& lengths)
void rocfft_plan_description_t::init_defaults(rocfft_transform_type transformType,
rocfft_result_placement placement,
const std::vector<size_t>& lengths)
{
const size_t rank = lengths.size();

// assume interleaved data
if(inArrayType == rocfft_array_type_unset)
{
Expand Down Expand Up @@ -117,9 +118,9 @@ void rocfft_plan_description_t::init_defaults(rocfft_transform_type trans
}

// Set inStrides, if not specified
if(inStrides[0] == 0)
if(inStrides.empty())
{
inStrides[0] = 1;
inStrides.push_back(1);

if((transformType == rocfft_transform_type_real_forward)
&& (placement == rocfft_placement_inplace))
Expand All @@ -129,7 +130,7 @@ void rocfft_plan_description_t::init_defaults(rocfft_transform_type trans

for(size_t i = 1; i < rank; i++)
{
inStrides[i] = dist;
inStrides.push_back(dist);
dist *= lengths[i];
}

Expand All @@ -143,7 +144,7 @@ void rocfft_plan_description_t::init_defaults(rocfft_transform_type trans

for(size_t i = 1; i < rank; i++)
{
inStrides[i] = dist;
inStrides.push_back(dist);
dist *= lengths[i];
}

Expand All @@ -155,14 +156,14 @@ void rocfft_plan_description_t::init_defaults(rocfft_transform_type trans
{
// Set the inStrides to deal with contiguous data
for(size_t i = 1; i < rank; i++)
inStrides[i] = lengths[i - 1] * inStrides[i - 1];
inStrides.push_back(lengths[i - 1] * inStrides[i - 1]);
}
}

// Set outStrides, if not specified
if(outStrides[0] == 0)
if(outStrides.empty())
{
outStrides[0] = 1;
outStrides.push_back(1);

if((transformType == rocfft_transform_type_real_inverse)
&& (placement == rocfft_placement_inplace))
Expand All @@ -172,7 +173,7 @@ void rocfft_plan_description_t::init_defaults(rocfft_transform_type trans

for(size_t i = 1; i < rank; i++)
{
outStrides[i] = dist;
outStrides.push_back(dist);
dist *= lengths[i];
}

Expand All @@ -186,7 +187,7 @@ void rocfft_plan_description_t::init_defaults(rocfft_transform_type trans

for(size_t i = 1; i < rank; i++)
{
outStrides[i] = dist;
outStrides.push_back(dist);
dist *= lengths[i];
}

Expand All @@ -197,7 +198,7 @@ void rocfft_plan_description_t::init_defaults(rocfft_transform_type trans
{
// Set the outStrides to deal with contiguous data
for(size_t i = 1; i < rank; i++)
outStrides[i] = lengths[i - 1] * outStrides[i - 1];
outStrides.push_back(lengths[i - 1] * outStrides[i - 1]);
}
}

Expand Down Expand Up @@ -325,17 +326,18 @@ rocfft_status rocfft_plan_description_set_data_layout(rocfft_plan_description de

if(in_strides != nullptr)
{
for(size_t i = 0; i < std::min((size_t)3, in_strides_size); i++)
description->inStrides[i] = in_strides[i];
std::copy(
in_strides, in_strides + in_strides_size, std::back_inserter(description->inStrides));
}

if(in_distance != 0)
description->inDist = in_distance;

if(out_strides != nullptr)
{
for(size_t i = 0; i < std::min((size_t)3, out_strides_size); i++)
description->outStrides[i] = out_strides[i];
std::copy(out_strides,
out_strides + out_strides_size,
std::back_inserter(description->outStrides));
}

if(out_distance != 0)
Expand Down Expand Up @@ -365,7 +367,7 @@ std::string rocfft_bench_command(rocfft_plan plan)
std::stringstream bench;
bench << "rocfft-bench --length ";
std::ostream_iterator<size_t> bench_iter(bench, " ");
std::copy(plan->lengths.rbegin() + (3 - plan->rank), plan->lengths.rend(), bench_iter);
std::copy(plan->lengths.rbegin(), plan->lengths.rend(), bench_iter);
bench << "-b " << plan->batch << " ";

if(plan->placement == rocfft_placement_notinplace)
Expand All @@ -378,12 +380,9 @@ std::string rocfft_bench_command(rocfft_plan plan)
bench << "--itype " << plan->desc.inArrayType << " ";
bench << "--otype " << plan->desc.outArrayType << " ";
bench << "--istride ";
std::copy(
plan->desc.inStrides.rbegin() + (3 - plan->rank), plan->desc.inStrides.rend(), bench_iter);
std::copy(plan->desc.inStrides.rbegin(), plan->desc.inStrides.rend(), bench_iter);
bench << "--ostride ";
std::copy(plan->desc.outStrides.rbegin() + (3 - plan->rank),
plan->desc.outStrides.rend(),
bench_iter);
std::copy(plan->desc.outStrides.rbegin(), plan->desc.outStrides.rend(), bench_iter);
bench << "--idist " << plan->desc.inDist << " ";
bench << "--odist " << plan->desc.outDist << " ";
bench << "--ioffset ";
Expand Down Expand Up @@ -521,13 +520,7 @@ rocfft_status rocfft_plan_create_internal(rocfft_plan plan,

rocfft_plan p = plan;
p->rank = dimensions;
p->lengths[0] = 1;
p->lengths[1] = 1;
p->lengths[2] = 1;
for(size_t ilength = 0; ilength < dimensions; ++ilength)
{
p->lengths[ilength] = lengths[ilength];
}
std::copy(lengths, lengths + dimensions, std::back_inserter(p->lengths));
p->batch = number_of_transforms;
p->placement = placement;
p->precision = precision;
Expand All @@ -538,7 +531,7 @@ rocfft_status rocfft_plan_create_internal(rocfft_plan plan,
{
p->desc = *description;
}
p->desc.init_defaults(p->transformType, p->placement, p->rank, p->lengths);
p->desc.init_defaults(p->transformType, p->placement, p->lengths);

// Check plan validity
switch(transform_type)
Expand Down

0 comments on commit da2d66b

Please sign in to comment.