Skip to content

Commit

Permalink
Support custom datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
franzpoeschel committed May 11, 2023
1 parent 191d05b commit 6c87958
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 4 deletions.
29 changes: 26 additions & 3 deletions include/openPMD/CustomHierarchy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#pragma once

#include "openPMD/IO/AbstractIOHandler.hpp"
#include "openPMD/RecordComponent.hpp"
#include "openPMD/backend/Container.hpp"

#include <iostream>
Expand All @@ -38,7 +39,13 @@ namespace internal
std::set<std::string> paths;
[[nodiscard]] bool ignore(std::string const &name) const;
};
using CustomHierarchyData = ContainerData<CustomHierarchy>;

struct CustomHierarchyData : ContainerData<CustomHierarchy>
{
explicit CustomHierarchyData();

Container<RecordComponent> m_embeddedDatasets;
};
} // namespace internal

class CustomHierarchy : public Container<CustomHierarchy>
Expand All @@ -48,15 +55,29 @@ class CustomHierarchy : public Container<CustomHierarchy>

private:
using Container_t = Container<CustomHierarchy>;
using Data_t = typename Container_t::ContainerData;
static_assert(std::is_same_v<Data_t, internal::CustomHierarchyData>);
using Data_t = internal::CustomHierarchyData;
static_assert(std::is_base_of_v<Container_t::ContainerData, Data_t>);

std::shared_ptr<Data_t> m_customHierarchyData;

void init();

[[nodiscard]] Data_t &get()
{
return *m_customHierarchyData;
}
[[nodiscard]] Data_t const &get() const
{
return *m_customHierarchyData;
}

protected:
CustomHierarchy();
CustomHierarchy(NoInit);

inline void setData(std::shared_ptr<Data_t> data)
{
m_customHierarchyData = data;
Container_t::setData(std::move(data));
}

Expand All @@ -70,5 +91,7 @@ class CustomHierarchy : public Container<CustomHierarchy>

CustomHierarchy &operator=(CustomHierarchy const &) = default;
CustomHierarchy &operator=(CustomHierarchy &&) = default;

Container<RecordComponent> datasets();
};
} // namespace openPMD
1 change: 1 addition & 0 deletions include/openPMD/RecordComponent.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class RecordComponent : public BaseRecordComponent
friend class DynamicMemoryView;
friend class internal::RecordComponentData;
friend class MeshRecordComponent;
friend class CustomHierarchy;

public:
enum class Allocation
Expand Down
3 changes: 3 additions & 0 deletions include/openPMD/backend/Container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ namespace internal
class SeriesData;
template <typename>
class EraseStaleEntries;
struct CustomHierarchyData;

template <
typename T,
Expand Down Expand Up @@ -109,6 +110,8 @@ class Container : virtual public Attributable
template <typename>
friend class internal::EraseStaleEntries;
friend class SeriesIterator;
friend struct internal::CustomHierarchyData;
friend class CustomHierarchy;

protected:
using ContainerData = internal::ContainerData<T, T_key, T_container>;
Expand Down
9 changes: 9 additions & 0 deletions include/openPMD/backend/EmbeddedDataset.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include "openPMD/RecordComponent.hpp"

namespace openPMD
{
class EmbeddedDataset : public RecordComponent
{};
} // namespace openPMD
69 changes: 68 additions & 1 deletion src/CustomHierarchy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@
#include "openPMD/IO/AbstractIOHandler.hpp"
#include "openPMD/IO/Access.hpp"
#include "openPMD/IO/IOTask.hpp"
#include "openPMD/RecordComponent.hpp"
#include "openPMD/backend/Attributable.hpp"

#include <deque>
#include <memory>

namespace openPMD
{
namespace internal
Expand All @@ -33,9 +37,25 @@ namespace internal
{
return paths.find(name) != paths.end();
}

CustomHierarchyData::CustomHierarchyData()
{
/*
* m_embeddeddatasets should point to the same instance of Attributable
* Can only use a non-owning pointer in here in order to avoid shared
* pointer cycles.
* When handing this object out to users, we create a copy that has a
* proper owning pointer (see CustomHierarchy::datasets()).
*/
m_embeddedDatasets.Attributable::setData(
std::shared_ptr<AttributableData>(this, [](auto const *) {}));
}
} // namespace internal

CustomHierarchy::CustomHierarchy() = default;
CustomHierarchy::CustomHierarchy()
{
setData(std::make_shared<Data_t>());
}
CustomHierarchy::CustomHierarchy(NoInit) : Container_t(NoInit())
{}

Expand All @@ -49,7 +69,10 @@ void CustomHierarchy::read(internal::MeshesParticlesPath const &mpp)
Attributable::readAttributes(ReadMode::FullyReread);
Parameter<Operation::LIST_PATHS> pList;
IOHandler()->enqueue(IOTask(this, pList));
Parameter<Operation::LIST_DATASETS> dList;
IOHandler()->enqueue(IOTask(this, dList));
IOHandler()->flush(internal::defaultFlushParams);
std::deque<std::string> constantComponentsPushback;
for (auto const &path : *pList.paths)
{
if (mpp.ignore(path))
Expand All @@ -61,6 +84,39 @@ void CustomHierarchy::read(internal::MeshesParticlesPath const &mpp)
auto subpath = this->operator[](path);
IOHandler()->enqueue(IOTask(&subpath, pOpen));
subpath.read(mpp);
if (subpath.size() == 0 && subpath.containsAttribute("shape") &&
subpath.containsAttribute("value"))
{
// This is not a group, but a constant record component
// Writable::~Writable() will deal with removing this from the
// backend again.
std::cout << "IS CONSTANT COMPONENT: " << path << std::endl;
constantComponentsPushback.push_back(path);
container().erase(path);
}
}
auto &data = get();
for (auto const &path : *dList.datasets)
{
auto &rc = data.m_embeddedDatasets[path];
Parameter<Operation::OPEN_DATASET> dOpen;
dOpen.name = path;
IOHandler()->enqueue(IOTask(&rc, dOpen));
IOHandler()->flush(internal::defaultFlushParams);
rc.written() = false;
rc.resetDataset(Dataset(*dOpen.dtype, *dOpen.extent));
rc.written() = true;
rc.read();
}

for (auto const &path : constantComponentsPushback)
{
auto &rc = data.m_embeddedDatasets[path];
Parameter<Operation::OPEN_PATH> pOpen;
pOpen.path = path;
IOHandler()->enqueue(IOTask(&rc, pOpen));
rc.get().m_isConstant = true;
rc.read();
}
}

Expand All @@ -82,6 +138,17 @@ void CustomHierarchy::flush(
}
subpath.flush(name, flushParams);
}
for (auto &[name, dataset] : get().m_embeddedDatasets)
{
dataset.flush(name, flushParams);
}
flushAttributes(flushParams);
}

Container<RecordComponent> CustomHierarchy::datasets()
{
Container<RecordComponent> res = get().m_embeddedDatasets;
res.Attributable::setData(m_customHierarchyData);
return res;
}
} // namespace openPMD
52 changes: 52 additions & 0 deletions test/CoreTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ TEST_CASE("custom_hierarchies", "[core]")
REQUIRE(read.iterations[0].size() == 2);
REQUIRE(read.iterations[0].count("custom") == 1);
REQUIRE(read.iterations[0].count("no_attributes") == 1);
REQUIRE(read.iterations[0]["custom"].size() == 1);
REQUIRE(read.iterations[0]["custom"].count("hierarchy") == 1);
REQUIRE(read.iterations[0]["custom"]["hierarchy"].size() == 0);
REQUIRE(read.iterations[0]["no_attributes"].size() == 0);
REQUIRE(
read.iterations[0]["custom"]
.getAttribute("string")
Expand All @@ -190,6 +194,54 @@ TEST_CASE("custom_hierarchies", "[core]")
.getAttribute("number")
.get<int>() == 3);
read.close();

write = Series(filePath, Access::READ_WRITE);
{
write.iterations[0]["custom"]["hierarchy"];
write.iterations[0]["custom"].datasets()["emptyDataset"].makeEmpty(
Datatype::FLOAT, 3);
write.iterations[0]["custom"]["hierarchy"].setAttribute("number", 3);
write.iterations[0]["no_attributes"];
auto iteration_level_ds =
write.iterations[0].datasets()["iteration_level_dataset"];
iteration_level_ds.resetDataset({Datatype::INT, {10}});
std::vector<int> data(10, 5);
iteration_level_ds.storeChunk(data);
write.close();
}

read = Series(filePath, Access::READ_ONLY);
{
REQUIRE(read.iterations[0].size() == 2);
REQUIRE(read.iterations[0].count("custom") == 1);
REQUIRE(read.iterations[0].count("no_attributes") == 1);
REQUIRE(read.iterations[0]["custom"].size() == 1);
REQUIRE(read.iterations[0]["custom"].count("hierarchy") == 1);
REQUIRE(read.iterations[0]["custom"]["hierarchy"].size() == 0);
REQUIRE(read.iterations[0]["no_attributes"].size() == 0);

REQUIRE(read.iterations[0].datasets().size() == 1);
REQUIRE(read.iterations[0]["custom"].datasets().size() == 1);
REQUIRE(
read.iterations[0]["custom"]["hierarchy"].datasets().size() == 0);
REQUIRE(read.iterations[0]["no_attributes"].datasets().size() == 0);

auto iteration_level_ds =
read.iterations[0].datasets()["iteration_level_dataset"];
REQUIRE(iteration_level_ds.getDatatype() == Datatype::INT);
REQUIRE(iteration_level_ds.getExtent() == Extent{10});
auto loaded_chunk = iteration_level_ds.loadChunk<int>();
iteration_level_ds.seriesFlush();
for (size_t i = 0; i < 10; ++i)
{
REQUIRE(loaded_chunk.get()[i] == 5);
}

auto constant_dataset =
read.iterations[0]["custom"].datasets()["emptyDataset"];
REQUIRE(constant_dataset.getDatatype() == Datatype::FLOAT);
REQUIRE(constant_dataset.getExtent() == Extent{0, 0, 0});
}
}

TEST_CASE("myPath", "[core]")
Expand Down

0 comments on commit 6c87958

Please sign in to comment.