@@ -953,10 +953,9 @@ struct CorrVisitor {
953
953
954
954
const auto &idx = *idx_begin;
955
955
const size_type col_s =
956
- std::min (std::distance (idx_begin, idx_end),
957
- std::distance (column_begin1, column_end1));
958
-
959
- assert ((col_s == size_type (std::distance (column_begin2, column_end2))));
956
+ std::min ({ std::distance (idx_begin, idx_end),
957
+ std::distance (column_begin1, column_end1),
958
+ std::distance (column_begin2, column_end2) });
960
959
961
960
if (type_ == correlation_type::pearson) {
962
961
while (column_begin1 < column_end1 && column_begin2 < column_end2)
@@ -2532,29 +2531,44 @@ struct KthValueVisitor {
2532
2531
2533
2532
DEFINE_VISIT_BASIC_TYPES_2
2534
2533
2535
- template <typename U>
2536
- using vec_type = std::vector<U, typename allocator_declare<U, A>::type>;
2534
+ using vec_type = std::vector<T, typename allocator_declare<T, A>::type>;
2537
2535
2538
2536
template <forward_iterator K, forward_iterator H>
2539
2537
inline void
2540
2538
operator () (const K &, const K &,
2541
2539
const H &values_begin, const H &values_end) {
2542
2540
2543
- vec_type<value_type> aux (values_begin, values_end);
2541
+ vec_type aux;
2542
+ const size_type col_s = std::distance (values_begin, values_end);
2543
+
2544
+ if (skip_nan_) {
2545
+ aux.reserve (col_s);
2546
+ std::copy_if (values_begin, values_end,
2547
+ std::back_inserter (aux),
2548
+ [](T x) -> bool { return (! is_nan__ (x)); });
2549
+ }
2550
+ else
2551
+ aux.insert (aux.begin (), values_begin, values_end);
2552
+ compute_size_ = aux.size ();
2553
+
2554
+ const size_type kth =
2555
+ std::round (double (kth_element_ * compute_size_) / double (col_s));
2544
2556
2545
- result_ = find_kth_element_ (aux, 0 , aux. size () - 1 , kth_element_ );
2557
+ result_ = find_kth_element_ (aux, 0 , compute_size_ - 1 , kth );
2546
2558
}
2547
2559
2548
2560
inline void pre () { result_ = value_type (); }
2549
2561
inline void post () { }
2550
2562
inline result_type get_result () const { return (result_); }
2563
+ inline size_type get_compute_size () const { return (compute_size_); }
2551
2564
2552
2565
explicit KthValueVisitor (size_type ke, bool skipnan = true )
2553
2566
: kth_element_(ke), skip_nan_(skipnan) { }
2554
2567
2555
2568
private:
2556
2569
2557
2570
result_type result_ { };
2571
+ size_type compute_size_ { 0 };
2558
2572
const size_type kth_element_;
2559
2573
const bool skip_nan_;
2560
2574
@@ -2612,17 +2626,21 @@ struct MedianVisitor {
2612
2626
operator () (const K &idx_begin, const K &idx_end,
2613
2627
const H &column_begin, const H &column_end) {
2614
2628
2615
- GET_COL_SIZE2
2616
-
2617
- KthValueVisitor<value_type, index_type, A> kv_visitor ( col_s >> 1 ) ;
2618
-
2629
+ const std:: size_t col_s =
2630
+ std::distance (column_begin, column_end);
2631
+ const std:: size_t half = col_s >> 1 ;
2632
+ KthValueVisitor<value_type, index_type, A> kv_visitor (half + 1 );
2619
2633
2620
2634
kv_visitor.pre ();
2621
2635
kv_visitor (idx_begin, idx_end, column_begin, column_end);
2622
2636
kv_visitor.post ();
2623
2637
result_ = kv_visitor.get_result ();
2624
- if (! (col_s & 0x01 )) { // Even
2625
- KthValueVisitor<value_type, I, A> kv_visitor2 ((col_s >> 1 ) + 1 );
2638
+
2639
+ const size_type cs = kv_visitor.get_compute_size ();
2640
+
2641
+ if (! (cs & 0x01 )) { // Even
2642
+ KthValueVisitor<value_type, I, A> kv_visitor2 (
2643
+ cs < col_s ? half + 2 : half);
2626
2644
2627
2645
kv_visitor2.pre ();
2628
2646
kv_visitor2 (idx_begin, idx_end, column_begin, column_end);
0 commit comments