Skip to content

Commit 269cdef

Browse files
committed
Fixed a bug in medain and kth_element visitors related to handling nans
1 parent cd449f3 commit 269cdef

File tree

3 files changed

+60
-22
lines changed

3 files changed

+60
-22
lines changed

include/DataFrame/DataFrameStatsVisitors.h

+32-14
Original file line numberDiff line numberDiff line change
@@ -953,10 +953,9 @@ struct CorrVisitor {
953953

954954
const auto &idx = *idx_begin;
955955
const size_type col_s =
956-
std::min(std::distance(idx_begin, idx_end),
957-
std::distance(column_begin1, column_end1));
958-
959-
assert((col_s == size_type(std::distance(column_begin2, column_end2))));
956+
std::min ({ std::distance(idx_begin, idx_end),
957+
std::distance(column_begin1, column_end1),
958+
std::distance(column_begin2, column_end2) });
960959

961960
if (type_ == correlation_type::pearson) {
962961
while (column_begin1 < column_end1 && column_begin2 < column_end2)
@@ -2532,29 +2531,44 @@ struct KthValueVisitor {
25322531

25332532
DEFINE_VISIT_BASIC_TYPES_2
25342533

2535-
template<typename U>
2536-
using vec_type = std::vector<U, typename allocator_declare<U, A>::type>;
2534+
using vec_type = std::vector<T, typename allocator_declare<T, A>::type>;
25372535

25382536
template <forward_iterator K, forward_iterator H>
25392537
inline void
25402538
operator() (const K &, const K &,
25412539
const H &values_begin, const H &values_end) {
25422540

2543-
vec_type<value_type> aux (values_begin, values_end);
2541+
vec_type aux;
2542+
const size_type col_s = std::distance(values_begin, values_end);
2543+
2544+
if (skip_nan_) {
2545+
aux.reserve(col_s);
2546+
std::copy_if(values_begin, values_end,
2547+
std::back_inserter(aux),
2548+
[](T x) -> bool { return (! is_nan__(x)); });
2549+
}
2550+
else
2551+
aux.insert(aux.begin(), values_begin, values_end);
2552+
compute_size_ = aux.size();
2553+
2554+
const size_type kth =
2555+
std::round(double(kth_element_ * compute_size_) / double(col_s));
25442556

2545-
result_ = find_kth_element_(aux, 0, aux.size() - 1, kth_element_);
2557+
result_ = find_kth_element_(aux, 0, compute_size_ - 1, kth);
25462558
}
25472559

25482560
inline void pre () { result_ = value_type(); }
25492561
inline void post () { }
25502562
inline result_type get_result () const { return (result_); }
2563+
inline size_type get_compute_size() const { return (compute_size_); }
25512564

25522565
explicit KthValueVisitor (size_type ke, bool skipnan = true)
25532566
: kth_element_(ke), skip_nan_(skipnan) { }
25542567

25552568
private:
25562569

25572570
result_type result_ { };
2571+
size_type compute_size_ { 0 };
25582572
const size_type kth_element_;
25592573
const bool skip_nan_;
25602574

@@ -2612,17 +2626,21 @@ struct MedianVisitor {
26122626
operator() (const K &idx_begin, const K &idx_end,
26132627
const H &column_begin, const H &column_end) {
26142628

2615-
GET_COL_SIZE2
2616-
2617-
KthValueVisitor<value_type, index_type, A> kv_visitor (col_s >> 1);
2618-
2629+
const std::size_t col_s =
2630+
std::distance(column_begin, column_end);
2631+
const std::size_t half = col_s >> 1;
2632+
KthValueVisitor<value_type, index_type, A> kv_visitor (half + 1);
26192633

26202634
kv_visitor.pre();
26212635
kv_visitor(idx_begin, idx_end, column_begin, column_end);
26222636
kv_visitor.post();
26232637
result_ = kv_visitor.get_result();
2624-
if (! (col_s & 0x01)) { // Even
2625-
KthValueVisitor<value_type, I, A> kv_visitor2 ((col_s >> 1) + 1);
2638+
2639+
const size_type cs = kv_visitor.get_compute_size();
2640+
2641+
if (! (cs & 0x01)) { // Even
2642+
KthValueVisitor<value_type, I, A> kv_visitor2 (
2643+
cs < col_s ? half + 2 : half);
26262644

26272645
kv_visitor2.pre();
26282646
kv_visitor2(idx_begin, idx_end, column_begin, column_end);

test/dataframe_tester.cc

+22-2
Original file line numberDiff line numberDiff line change
@@ -2131,7 +2131,7 @@ static void test_median() {
21312131
double result =
21322132
df.single_act_visit<double>("dblcol_1", med_visit, true).get_result();
21332133

2134-
assert(result == 10.0);
2134+
assert(result == 11.0);
21352135

21362136
result = df.single_act_visit<double>("dblcol_2", med_visit).get_result();
21372137
assert(result == 10.50);
@@ -2140,10 +2140,30 @@ static void test_median() {
21402140
int result2 =
21412141
df.single_act_visit<int>("intcol_1", med_visit2).get_result();
21422142

2143-
assert(result2 == 10);
2143+
assert(result2 == 11);
21442144

21452145
result2 = df.single_act_visit<int>("intcol_2", med_visit2).get_result();
21462146
assert(result2 == 10);
2147+
2148+
using TestDF = StdDataFrame<std::string>;
2149+
2150+
std::vector<std::string> syms = { "AAPL", "IBM", "TSLA", "MSFT", "CSCO" };
2151+
std::vector<double> c1 = { 1.0, 2.0, 3.0, 4.0 };
2152+
std::vector<double> c2 = { 0.01, 0.02, 0.03 };
2153+
std::vector<double> c3 =
2154+
{ 0.0, std::numeric_limits<double>::quiet_NaN(), 0.1 };
2155+
TestDF testDF ;
2156+
2157+
testDF.load_data(std::move(syms),
2158+
std::make_pair("c1", c1),
2159+
std::make_pair("c2", c2),
2160+
std::make_pair("c3", c3));
2161+
2162+
MedianVisitor<double, std::string> md;
2163+
2164+
assert((testDF.single_act_visit<double>("c1", md).get_result() == 2.5));
2165+
assert((testDF.single_act_visit<double>("c2", md).get_result() == 0.02));
2166+
assert((testDF.single_act_visit<double>("c3", md).get_result() == 0.05));
21472167
}
21482168

21492169
// -----------------------------------------------------------------------------

test/dataframe_tester_2.cc

+6-6
Original file line numberDiff line numberDiff line change
@@ -2388,18 +2388,18 @@ static void test_LowessVisitor() {
23882388
df.single_act_visit<double, double>("dep_var", "indep_var", l_v);
23892389

23902390
auto actual_yfit = StlVecType<double> {
2391-
68.1432, 119.432, 122.75, 135.633, 142.724, 165.905, 169.447, 185.617,
2392-
186.017, 191.865, 198.03, 202.234, 206.178, 215.053, 216.586, 220.408,
2393-
226.671, 229.052, 229.185, 230.023, 231.657,
2391+
67.988, 119.351, 122.673, 135.574, 142.677, 165.901, 169.442, 185.5469,
2392+
185.946, 191.751, 197.912, 202.10997, 206.052, 214.933, 216.473,
2393+
220.319, 226.653, 229.068, 229.203, 230.054, 231.714,
23942394
};
23952395

23962396
for (size_t idx = 0; idx < actual_yfit.size(); ++idx)
23972397
assert(fabs(l_v.get_result()[idx] - actual_yfit[idx]) < 0.001);
23982398

23992399
auto actual_weights = StlVecType<double> {
2400-
0.641773, 0.653544, 0.940738, 0.865302, 0.990575, 0.971522, 0.92929,
2401-
0.902444, 0.918228, 0.924041, 0.855054, 0.824388, 0.586045, 0.945216,
2402-
0.94831, 0.998031, 0.999834, 0.991263, 0.993165, 0.972067, 0.990308,
2400+
0.665908, 0.674181, 0.945216, 0.873828, 0.991117, 0.973495, 0.934069,
2401+
0.909536, 0.923308, 0.928475, 0.863709, 0.837148, 0.612762, 0.948307,
2402+
0.951239, 0.998073, 0.99984, 0.991830, 0.993602, 0.974109, 0.990844,
24032403
};
24042404

24052405
for (size_t idx = 0; idx < actual_weights.size(); ++idx)

0 commit comments

Comments
 (0)