Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3270,15 +3270,15 @@ std::shared_ptr<DataType> run_end_encoded(std::shared_ptr<DataType> run_end_type
std::shared_ptr<DataType> sparse_union(FieldVector child_fields,
std::vector<int8_t> type_codes) {
if (type_codes.empty()) {
type_codes = internal::Iota(static_cast<int8_t>(child_fields.size()));
type_codes = internal::Iota<int8_t>(0, child_fields.size());
}
return std::make_shared<SparseUnionType>(std::move(child_fields),
std::move(type_codes));
}
std::shared_ptr<DataType> dense_union(FieldVector child_fields,
std::vector<int8_t> type_codes) {
if (type_codes.empty()) {
type_codes = internal::Iota(static_cast<int8_t>(child_fields.size()));
type_codes = internal::Iota<int8_t>(0, child_fields.size());
}
return std::make_shared<DenseUnionType>(std::move(child_fields), std::move(type_codes));
}
Expand Down Expand Up @@ -3310,7 +3310,7 @@ std::shared_ptr<DataType> sparse_union(const ArrayVector& children,
std::vector<std::string> field_names,
std::vector<int8_t> type_codes) {
if (type_codes.empty()) {
type_codes = internal::Iota(static_cast<int8_t>(children.size()));
type_codes = internal::Iota<int8_t>(0, children.size());
}
auto fields = FieldsFromArraysAndNames(std::move(field_names), children);
return sparse_union(std::move(fields), std::move(type_codes));
Expand All @@ -3320,7 +3320,7 @@ std::shared_ptr<DataType> dense_union(const ArrayVector& children,
std::vector<std::string> field_names,
std::vector<int8_t> type_codes) {
if (type_codes.empty()) {
type_codes = internal::Iota(static_cast<int8_t>(children.size()));
type_codes = internal::Iota<int8_t>(0, children.size());
}
auto fields = FieldsFromArraysAndNames(std::move(field_names), children);
return dense_union(std::move(fields), std::move(type_codes));
Expand Down
30 changes: 30 additions & 0 deletions cpp/src/arrow/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2191,6 +2191,36 @@ TEST(TestUnionType, Basics) {
ASSERT_EQ(ty6->child_ids(), child_ids2);
}

TEST(TestUnionType, MaxTypeCode) {
std::vector<std::shared_ptr<Field>> fields;
for (int32_t i = 0; i <= UnionType::kMaxTypeCode; i++) {
fields.push_back(field(std::to_string(i), int32()));
}

std::vector<int8_t> type_codes(fields.size());
std::iota(type_codes.begin(), type_codes.end(), 0);

auto t1 = checked_pointer_cast<UnionType>(dense_union(fields, type_codes));
ASSERT_EQ(t1->type_codes().size(), UnionType::kMaxTypeCode + 1);
ASSERT_EQ(t1->child_ids().size(), UnionType::kMaxTypeCode + 1);

auto t2 = checked_pointer_cast<UnionType>(dense_union(fields));
ASSERT_EQ(t2->type_codes().size(), UnionType::kMaxTypeCode + 1);
ASSERT_EQ(t2->child_ids().size(), UnionType::kMaxTypeCode + 1);

AssertTypeEqual(*t1, *t2);

auto t3 = checked_pointer_cast<UnionType>(sparse_union(fields, type_codes));
ASSERT_EQ(t3->type_codes().size(), UnionType::kMaxTypeCode + 1);
ASSERT_EQ(t3->child_ids().size(), UnionType::kMaxTypeCode + 1);

auto t4 = checked_pointer_cast<UnionType>(sparse_union(fields));
ASSERT_EQ(t4->type_codes().size(), UnionType::kMaxTypeCode + 1);
ASSERT_EQ(t4->child_ids().size(), UnionType::kMaxTypeCode + 1);

AssertTypeEqual(*t3, *t4);
}

TEST(TestDictionaryType, Basics) {
auto value_type = int32();

Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/util/range.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ std::vector<T> Iota(T length) {
return Iota(static_cast<T>(0), length);
}

/// Create a vector containing the values from start with length elements
template <typename T>
std::vector<T> Iota(T start, size_t length) {
std::vector<T> result(length);
std::iota(result.begin(), result.end(), start);
return result;
}

/// Create a range from a callable which takes a single index parameter
/// and returns the value of iterator on each call and a length.
/// Only iterators obtained from the same range should be compared, the
Expand Down
Loading