Skip to content

Commit

Permalink
reduce duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Smith committed Jan 23, 2024
1 parent 26f34a4 commit e6ec701
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 74 deletions.
12 changes: 6 additions & 6 deletions six/modules/c++/six.sicd/include/six/sicd/NearestNeighbors.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ struct NearestNeighbors final
uint8_t find_nearest(six::zfloat phase_direction, six::zfloat v) const;
uint8_t getPhase(six::zfloat) const;

#if SIX_sicd_ComplexToAMP8IPHS8I_unseq
#if SIX_sicd_ComplexToAMP8IPHS8I_unseq
void nearest_neighbors_(execution_policy, std::span<const zfloat> inputs, std::span<AMP8I_PHS8I_t> results) const;

template<typename ZFloatV>
auto nearest_neighbors_unseq_T(std::span<const zfloat> p) const; // TODO: std::span<T, N> ... ?
template<typename ZFloatV, int elements_per_iteration>
void nearest_neighbors_unseq_(std::span<const zfloat> inputs, std::span<AMP8I_PHS8I_t> results) const;
auto unseq_nearest_neighbors(std::span<const zfloat> p) const; // TODO: std::span<T, N> ... ?

template<typename ZFloatV, int elements_per_iteration>
void nearest_neighbors_par_unseq_T(std::span<const zfloat> inputs, std::span<AMP8I_PHS8I_t> results) const;
#endif
void nearest_neighbors_T(execution_policy, std::span<const zfloat> inputs, std::span<AMP8I_PHS8I_t> results) const;
#endif

};
}
Expand Down
98 changes: 30 additions & 68 deletions six/modules/c++/six.sicd/source/NearestNeighbors_unseq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ template<typename ZFloatV>
using IntV = decltype(::getPhase(ZFloatV{}, 0.0f));

template<typename ZFloatV>
auto six::sicd::NearestNeighbors::nearest_neighbors_unseq_T(std::span<const zfloat> p) const // TODO: std::span<T, N> ... ? The compiler can sometimes do better optimization with fixed-size structures.
auto six::sicd::NearestNeighbors::unseq_nearest_neighbors(std::span<const zfloat> p) const // TODO: std::span<T, N> ... ? The compiler can sometimes do better optimization with fixed-size structures.
{
ZFloatV v;
assert(p.size() == size(v));
Expand Down Expand Up @@ -931,7 +931,8 @@ static void finish_nearest_neighbors_unseq(const six::sicd::NearestNeighbors& im
}

template<typename ZFloatV, int elements_per_iteration>
void six::sicd::NearestNeighbors::nearest_neighbors_unseq_(std::span<const zfloat> inputs, std::span<AMP8I_PHS8I> results) const
void six::sicd::NearestNeighbors::nearest_neighbors_T(execution_policy policy,
std::span<const zfloat> inputs, std::span<AMP8I_PHS8I> results) const
{
// View the data as chunks of *elements_per_iteration*. This allows iterating
// to go *elements_per_iteration* at a time; and each chunk can be processed
Expand All @@ -949,38 +950,22 @@ void six::sicd::NearestNeighbors::nearest_neighbors_unseq_(std::span<const zfloa

const auto func = [&](const auto& v)
{
return nearest_neighbors_unseq_T<ZFloatV>(v);
return unseq_nearest_neighbors<ZFloatV>(v);
};
std::transform(/*std::execution::unseq,*/ b, e, d, func);

// Then finish off anything left
finish_nearest_neighbors_unseq<elements_per_iteration>(*this, inputs, results);
}

template<typename ZFloatV, int elements_per_iteration>
void six::sicd::NearestNeighbors::nearest_neighbors_par_unseq_T(std::span<const zfloat> inputs, std::span<AMP8I_PHS8I> results) const
{
// View the data as chunks of *elements_per_iteration*. This allows iterating
// to go *elements_per_iteration* at a time; and each chunk can be processed
// using `nearest_neighbors_unseq_T()`, above.
using extents_t = coda_oss::dextents<size_t, 2>; // two dimensions: M×N
const extents_t extents{ inputs.size() / elements_per_iteration, elements_per_iteration };
const coda_oss::mdspan<const zfloat, extents_t> md_inputs(inputs.data(), extents);
assert(md_inputs.size() <= inputs.size());
auto const b = cbegin(md_inputs);
auto const e = cend(md_inputs);

const coda_oss::mdspan<AMP8I_PHS8I, extents_t> md_results(results.data(), extents);
assert(md_results.size() <= results.size());
auto const d = begin<mdspan_iterator_value>(md_results);

const auto func = [&](const auto& v)
if (policy == execution_policy::unseq)
{
return nearest_neighbors_unseq_T<ZFloatV>(v);
};
//std::transform(std::execution::par_unseq, b, e, d, func);
mt::Transform_par(b, e, d, func);

std::transform(/*std::execution::unseq,*/ b, e, d, func);
}
else if (policy == execution_policy::par_unseq)
{
//std::transform(std::execution::par_unseq, b, e, d, func);
mt::Transform_par(b, e, d, func);
}
else
{
throw std::logic_error("Unsupported execution_policy");
}

// Then finish off anything left
finish_nearest_neighbors_unseq<elements_per_iteration>(*this, inputs, results);
Expand Down Expand Up @@ -1019,74 +1004,51 @@ std::string SIX_SICD_API six_sicd_set_nearest_neighbors_unseq(std::string unseq)
return retval;
}

void six::sicd::NearestNeighbors::nearest_neighbors_unseq(std::span<const zfloat> inputs, std::span<AMP8I_PHS8I> results) const
void six::sicd::NearestNeighbors::nearest_neighbors_(execution_policy policy,
std::span<const zfloat> inputs, std::span<AMP8I_PHS8I_t> results) const
{
// TODO: there could be more complicated logic here to determine which UNSEQ
// implementation to use.


// This is very simple as it's only used for unit-testing
const auto& unseq = ::nearest_neighbors_unseq_;
#if SIX_sicd_has_simd
if (unseq == unseq_simd)
{
return nearest_neighbors_unseq_<simd_zfloatv, simd_elements_per_iteration>(inputs, results);
return nearest_neighbors_T<simd_zfloatv, simd_elements_per_iteration>(policy, inputs, results);
}
#endif
#if SIX_sicd_has_VCL
if (unseq == unseq_vcl)
{
return nearest_neighbors_unseq_<vcl_zfloatv, vcl_elements_per_iteration>(inputs, results);
return nearest_neighbors_T<vcl_zfloatv, vcl_elements_per_iteration>(policy, inputs, results);
}
#endif
#if SIX_sicd_has_valarray
if (unseq == unseq_valarray)
{
return nearest_neighbors_unseq_<valarray_zfloatv, valarray_elements_per_iteration>(inputs, results);
return nearest_neighbors_T<valarray_zfloatv, valarray_elements_per_iteration>(policy, inputs, results);
}
#endif
#if SIX_sicd_has_ximd
if (unseq == unseq_ximd)
{
return nearest_neighbors_unseq_<ximd_zfloatv, ximd_elements_per_iteration>(inputs, results);
return nearest_neighbors_T<ximd_zfloatv, ximd_elements_per_iteration>(policy, inputs, results);
}
#endif

throw std::logic_error("Don't know how to implement nearest_neighbors_unseq() for unseq=" + unseq);
throw std::logic_error("Don't know how to implement nearest_neighbors_() for unseq=" + unseq);
}
void six::sicd::NearestNeighbors::nearest_neighbors_unseq(std::span<const zfloat> inputs, std::span<AMP8I_PHS8I> results) const
{
// TODO: there could be more complicated logic here to determine which UNSEQ
// implementation to use.
nearest_neighbors_(execution_policy::unseq, inputs, results);
}

void six::sicd::NearestNeighbors::nearest_neighbors_par_unseq(std::span<const zfloat> inputs, std::span<AMP8I_PHS8I> results) const
{
// TODO: there could be more complicated logic here to determine which UNSEQ
// implementation to use.

// This is very simple as it's only used for unit-testing
const auto& unseq = ::nearest_neighbors_unseq_;
#if SIX_sicd_has_simd
if (unseq == unseq_simd)
{
return nearest_neighbors_par_unseq_T<simd_zfloatv, simd_elements_per_iteration>(inputs, results);
}
#endif
#if SIX_sicd_has_VCL
if (unseq == unseq_vcl)
{
return nearest_neighbors_par_unseq_T<vcl_zfloatv, vcl_elements_per_iteration>(inputs, results);
}
#endif
#if SIX_sicd_has_valarray
if (unseq == unseq_valarray)
{
return nearest_neighbors_par_unseq_T<valarray_zfloatv, valarray_elements_per_iteration>(inputs, results);
}
#endif
#if SIX_sicd_has_ximd
if (unseq == unseq_ximd)
{
return nearest_neighbors_par_unseq_T<ximd_zfloatv, ximd_elements_per_iteration>(inputs, results);
}
#endif

throw std::logic_error("Don't know how to implement nearest_neighbors_par_unseq() for unseq=" + unseq);
nearest_neighbors_(execution_policy::par_unseq, inputs, results);
}
#endif // SIX_sicd_ComplexToAMP8IPHS8I_unseq

0 comments on commit e6ec701

Please sign in to comment.