Skip to content

Commit 4300120

Browse files
committed
Implement Sparse view (using std::map for now)
1 parent e4103a3 commit 4300120

File tree

3 files changed

+184
-0
lines changed

3 files changed

+184
-0
lines changed

Diff for: src/sparse_list.hpp

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#ifndef CPPVIEWS_SRC_SPARSE_LIST_HPP_
2+
#define CPPVIEWS_SRC_SPARSE_LIST_HPP_
3+
4+
#include "list.hpp"
5+
#include "util/intseq.hpp"
6+
#include "util/iterator.hpp"
7+
8+
#include <array>
9+
#include <cstdint>
10+
#include <map>
11+
12+
namespace v {
13+
14+
// For the MakeList overload, we need to somehow encode the block sizes.
15+
template<unsigned dim_count>
16+
struct SparseListTag {};
17+
18+
template<typename T, unsigned dims>
19+
using SparseHashList = List<void, dims, kListOpMutable, void, T>;
20+
21+
template<class MapType>
22+
class SparseListForwardIter
23+
: public DefaultIterator<SparseListForwardIter<MapType>,
24+
std::forward_iterator_tag,
25+
typename MapType::mapped_type>,
26+
public View<typename MapType::mapped_type>::IteratorBase {
27+
28+
V_DEFAULT_ITERATOR_DERIVED_HEAD(SparseListForwardIter);
29+
template<typename> friend class SparseListForwardIter;
30+
31+
using Enabler = typename SparseListForwardIter::Enabler;
32+
using DataType = typename MapType::mapped_type;
33+
34+
public:
35+
template<typename DataType2>
36+
SparseListForwardIter(const SparseListForwardIter<DataType2>& other,
37+
EnableIfIterConvertible<DataType2, DataType,
38+
Enabler> = Enabler()) {}
39+
40+
SparseListForwardIter(typename MapType::iterator it)
41+
: it_(std::move(it)) {}
42+
43+
protected:
44+
V_DEF_VIEW_ITER_IS_EQUAL(DataType, SparseListForwardIter);
45+
46+
template<typename DataType2>
47+
bool IsEqual(const SparseListForwardIter<DataType2>& other) const {
48+
return true; // conforms to C++ spec. (this->it_ == other.it_ is no better)
49+
}
50+
51+
void Increment() override { ++it_; }
52+
53+
DataType& ref() const override { return it_->second; }
54+
55+
private:
56+
typename MapType::iterator it_;
57+
};
58+
59+
template<typename T, unsigned dims>
60+
class List<void, dims, kListOpMutable, void, T>
61+
: public ListBase<T, dims> {
62+
63+
using ListBaseType = ListBase<T, dims>;
64+
using typename ListBaseType::SizeArray;
65+
using ListBaseType::kDims;
66+
67+
// template<typename E>
68+
// struct Get2ndTransformer {
69+
// auto operator=(const E& e) -> decltype(e.second) { return e.second; }
70+
// };
71+
72+
using MapType = std::map<std::array<size_t, dims>, T>;
73+
74+
public:
75+
using ForwardIterator = SparseListForwardIter<MapType>;
76+
77+
// TODO use SFINAE to enable only if dims == sizeof...(Sizes)
78+
template<typename... Sizes>
79+
explicit List(SparseListTag<sizeof...(Sizes)>,
80+
T* default_value,
81+
Sizes&&... sizes)
82+
: ListBaseType(std::forward<Sizes>(sizes)...) {
83+
}
84+
85+
friend List MakeList(List&& list) { return std::forward<List>(list); }
86+
87+
template<typename... Indexes>
88+
const T& operator()(Indexes&&... indexes) const {
89+
// TODO this does not type-check if used as mutable
90+
return static_cast<const decltype(this)>(this)->get(
91+
std::forward<Indexes>(indexes)...);
92+
}
93+
94+
template<typename... Indexes>
95+
const T& get(Indexes&&... indexes) const {
96+
try {
97+
return map_.at(SizeArray{{std::forward<Indexes>(indexes)...}});
98+
} catch (const std::out_of_range&) {
99+
return *default_val_;
100+
}
101+
}
102+
103+
T& get(SizeArray&& indexes) const override { return map_[indexes]; }
104+
105+
const List& values() const override { return *this; }
106+
107+
size_t nondefault_count() const { return map_.size(); }
108+
109+
// non-polymorphic iterators (no dynamic allocation)
110+
111+
ForwardIterator begin() const { return ForwardIterator(map_.begin()); }
112+
ForwardIterator end() const { return ForwardIterator(map_.end()); }
113+
114+
protected:
115+
116+
// polymorphic iterators
117+
118+
typename View<T>::Iterator iterator_begin() const override {
119+
return new ForwardIterator(map_.begin());
120+
}
121+
122+
// overrides View<DataType>::Iterator to provide O(1) time complexity
123+
typename View<T>::Iterator iterator_end() const override {
124+
return new ForwardIterator(map_.end());
125+
}
126+
127+
private:
128+
mutable MapType map_;
129+
T* default_val_;
130+
};
131+
132+
template<typename T, unsigned dims, typename... Sizes>
133+
auto MakeList(SparseListTag<dims>, T* default_value, Sizes... sizes)
134+
#define V_LIST_TYPE SparseHashList<T, sizeof...(sizes)>
135+
-> V_LIST_TYPE {
136+
return V_LIST_TYPE(SparseListTag<dims>(), default_value, sizes...); }
137+
#undef V_LIST_TYPE
138+
139+
} // namespace v
140+
141+
#endif /* CPPVIEWS_SRC_SPARSE_LIST_HPP_ */

Diff for: test/Makefile.am

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ unittest_all_SOURCES = test.cpp \
1616
diag_test.cpp \
1717
diag_chain_test.cpp \
1818
list_test.cpp \
19+
sparse_list_test.cpp \
1920
bench/util/sparse_matrix_test.cpp
2021

2122
speedtest_all_SOURCES = test.cpp \

Diff for: test/sparse_list_test.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include "../src/sparse_list.hpp"
2+
#include "test.hpp"
3+
4+
// typedef ::testing::Types<
5+
// std::integral_constant<unsigned, 3>,
6+
// std::integral_constant<size_t, 4>,
7+
// std::integral_constant<unsigned, 1> > Rhs;
8+
// TYPED_TEST_CASE(SparseListTest, Rhs);
9+
10+
11+
TEST(SparseListTest, 1D) {
12+
// auto l = MakeList(SparseListTag<1>(), 5);
13+
int zero;
14+
SparseHashList<int, 1> sl(SparseListTag<1>(), &zero, 5);
15+
16+
EXPECT_EQ((std::array<size_t, 1>{{5}}), sl.sizes());
17+
18+
sl.get({3}) = 30;
19+
EXPECT_EQ(30, sl.get(3));
20+
EXPECT_EQ(1, sl.nondefault_count());
21+
22+
sl.get(2);
23+
sl(1);
24+
EXPECT_EQ(1, sl.nondefault_count());
25+
26+
auto it = sl.begin();
27+
EXPECT_EQ(sl.begin(), it);
28+
EXPECT_EQ(30, *it);
29+
EXPECT_EQ(sl.end(), ++it);
30+
31+
// TODO complete
32+
}
33+
34+
TEST(SparseListTest, 2D) {
35+
double zero;
36+
auto sm = MakeList(SparseListTag<2>(), &zero, 2, 3);
37+
38+
sm.get({2, 3}) = 4.5;
39+
EXPECT_EQ(4.5, sm(2, 3));
40+
41+
// TODO complete
42+
}

0 commit comments

Comments
 (0)