Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make shape an optional attribute for constant components #1661

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/openPMD/RecordComponent.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ class RecordComponent : public BaseRecordComponent
static constexpr char const *const SCALAR = "\vScalar";

protected:
void flush(std::string const &, internal::FlushParams const &);
void
flush(std::string const &, internal::FlushParams const &, bool is_scalar);
void read(bool require_unit_si);

private:
Expand Down
27 changes: 27 additions & 0 deletions include/openPMD/backend/Attributable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
*/
#pragma once

#include "openPMD/Error.hpp"
#include "openPMD/IO/AbstractIOHandler.hpp"
#include "openPMD/ThrowError.hpp"
#include "openPMD/auxiliary/OutOfRangeMsg.hpp"
Expand All @@ -30,6 +31,7 @@
#include <exception>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <vector>
Expand Down Expand Up @@ -113,6 +115,31 @@ namespace internal
return res;
}

inline auto attributes() -> A_MAP &
{
return m_attributes;
}
[[nodiscard]] inline auto attributes() const -> A_MAP const &
{
return m_attributes;
}
[[nodiscard]] inline auto readAttribute(std::string const &name) const
-> Attribute const &
{
if (auto it = m_attributes.find(name); it != m_attributes.end())
{
return it->second;
}
else
{
throw error::ReadError(
error::AffectedObject::Attribute,
error::Reason::NotFound,
std::nullopt,
"Not found: '" + name + "'.");
}
}

private:
/**
* The attributes defined by this Attributable.
Expand Down
3 changes: 2 additions & 1 deletion include/openPMD/backend/MeshRecordComponent.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class MeshRecordComponent : public RecordComponent
MeshRecordComponent();
MeshRecordComponent(NoInit);
void read();
void flush(std::string const &, internal::FlushParams const &);
void
flush(std::string const &, internal::FlushParams const &, bool is_scalar);

public:
~MeshRecordComponent() override = default;
Expand Down
3 changes: 3 additions & 0 deletions src/Iteration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,9 @@ void Iteration::readMeshes(std::string const &meshesPath)
IOHandler()->enqueue(IOTask(&m, aList));
IOHandler()->flush(internal::defaultFlushParams);

// Find constant scalar meshes. shape generally required for meshes,
// shape also required for scalars.
// https://github.com/openPMD/openPMD-standard/pull/289
auto att_begin = aList.attributes->begin();
auto att_end = aList.attributes->end();
auto value = std::find(att_begin, att_end, "value");
Expand Down
17 changes: 11 additions & 6 deletions src/Mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,14 @@ void Mesh::flush_impl(
auto &m = get();
if (m.m_datasetDefined)
{
T_RecordComponent::flush(SCALAR, flushParams);
T_RecordComponent::flush(
SCALAR, flushParams, /* is_scalar = */ true);
}
else
{
for (auto &comp : *this)
comp.second.flush(comp.first, flushParams);
comp.second.flush(
comp.first, flushParams, /* is_scalar = */ false);
}
}
else
Expand All @@ -237,7 +239,7 @@ void Mesh::flush_impl(
if (scalar())
{
MeshRecordComponent &mrc = *this;
mrc.flush(name, flushParams);
mrc.flush(name, flushParams, /* is_scalar = */ true);
}
else
{
Expand All @@ -247,20 +249,23 @@ void Mesh::flush_impl(
for (auto &comp : *this)
{
comp.second.parent() = &this->writable();
comp.second.flush(comp.first, flushParams);
comp.second.flush(
comp.first, flushParams, /* is_scalar = */ false);
}
}
}
else
{
if (scalar())
{
T_RecordComponent::flush(name, flushParams);
T_RecordComponent::flush(
name, flushParams, /* is_scalar = */ true);
}
else
{
for (auto &comp : *this)
comp.second.flush(comp.first, flushParams);
comp.second.flush(
comp.first, flushParams, /* is_scalar = */ false);
}
}
flushAttributes(flushParams);
Expand Down
5 changes: 3 additions & 2 deletions src/ParticleSpecies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ void ParticleSpecies::read()
auto att_begin = aList.attributes->begin();
auto att_end = aList.attributes->end();
auto value = std::find(att_begin, att_end, "value");
auto shape = std::find(att_begin, att_end, "shape");
if (value != att_end && shape != att_end)
// @todo see this comment:
// https://github.com/openPMD/openPMD-standard/pull/289#issuecomment-2407263974
if (value != att_end)
{
RecordComponent &rc = r;
IOHandler()->enqueue(IOTask(&rc, pOpen));
Expand Down
17 changes: 11 additions & 6 deletions src/Record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ void Record::flush_impl(
{
if (scalar())
{
T_RecordComponent::flush(SCALAR, flushParams);
T_RecordComponent::flush(
SCALAR, flushParams, /* is_scalar = */ true);
}
else
{
for (auto &comp : *this)
comp.second.flush(comp.first, flushParams);
comp.second.flush(
comp.first, flushParams, /* is_scalar = */ false);
}
}
else
Expand All @@ -65,7 +67,7 @@ void Record::flush_impl(
if (scalar())
{
RecordComponent &rc = *this;
rc.flush(name, flushParams);
rc.flush(name, flushParams, /* is_scalar = */ true);
}
else
{
Expand All @@ -75,7 +77,8 @@ void Record::flush_impl(
for (auto &comp : *this)
{
comp.second.parent() = getWritable(this);
comp.second.flush(comp.first, flushParams);
comp.second.flush(
comp.first, flushParams, /* is_scalar = */ false);
}
}
}
Expand All @@ -84,12 +87,14 @@ void Record::flush_impl(

if (scalar())
{
T_RecordComponent::flush(name, flushParams);
T_RecordComponent::flush(
name, flushParams, /* is_scalar = */ true);
}
else
{
for (auto &comp : *this)
comp.second.flush(comp.first, flushParams);
comp.second.flush(
comp.first, flushParams, /* is_scalar = */ false);
}
}

Expand Down
116 changes: 86 additions & 30 deletions src/RecordComponent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,19 @@ RecordComponent &RecordComponent::resetDataset(Dataset d)
throw error::WrongAPIUsage(
"[RecordComponent] Must set specific datatype.");
}
// if( d.extent.empty() )
// throw std::runtime_error("Dataset extent must be at least 1D.");
if (d.extent.empty())
throw std::runtime_error("Dataset extent must be at least 1D.");
if (d.empty())
{
if (d.extent.empty())
{
throw error::Internal(
"A zero-dimensional dataset is not to be considered empty, but "
"undefined. This is an internal safeguard against future "
"changes that might not consider this.");
}
return makeEmpty(std::move(d));
}

rc.m_isEmpty = false;
if (written())
Expand Down Expand Up @@ -232,7 +241,9 @@ bool RecordComponent::empty() const
}

void RecordComponent::flush(
std::string const &name, internal::FlushParams const &flushParams)
std::string const &name,
internal::FlushParams const &flushParams,
bool is_scalar)
{
auto &rc = get();
if (flushParams.flushLevel == FlushLevel::SkeletonOnly)
Expand Down Expand Up @@ -275,6 +286,28 @@ void RecordComponent::flush(
{
setUnitSI(1);
}
auto constant_component_write_shape = [&]() {
if (is_scalar)
{
// Must write shape in any case:
// 1. Non-scalar constant components can be distinguished from
// normal components by checking if the backend reports a
// group or a dataset. This does not work for scalar constant
// components, so the parser needs to check if the attributes
// value and shape are there. If they're not, the group is
// not considered as a constant component.
// 2. Scalar constant components are required to write the shape
// by standard anyway since the standard requires that at
// least one component in a record have a shape. For scalars,
// there is only one component, so it must have a shape.
return true;
}
auto extent = getExtent();
return !extent.empty() &&
std::none_of(extent.begin(), extent.end(), [](auto val) {
return val == Dataset::JOINED_DIMENSION;
});
};
if (!written())
{
if (constant())
Expand All @@ -294,16 +327,20 @@ void RecordComponent::flush(
Operation::WRITE_ATT>::ChangesOverSteps::IfPossible;
}
IOHandler()->enqueue(IOTask(this, aWrite));
aWrite.name = "shape";
Attribute a(getExtent());
aWrite.dtype = a.dtype;
aWrite.resource = a.getResource();
if (isVBased)
if (constant_component_write_shape())
{
aWrite.changesOverSteps = Parameter<
Operation::WRITE_ATT>::ChangesOverSteps::IfPossible;

aWrite.name = "shape";
Attribute a(getExtent());
aWrite.dtype = a.dtype;
aWrite.resource = a.getResource();
if (isVBased)
{
aWrite.changesOverSteps = Parameter<
Operation::WRITE_ATT>::ChangesOverSteps::IfPossible;
}
IOHandler()->enqueue(IOTask(this, aWrite));
}
IOHandler()->enqueue(IOTask(this, aWrite));
}
else
{
Expand All @@ -321,6 +358,13 @@ void RecordComponent::flush(
{
if (constant())
{
if (!constant_component_write_shape())
{
throw error::WrongAPIUsage(
"Extended constant component from a previous shape to "
"one that cannot be written (empty or with joined "
"dimension).");
}
bool isVBased = retrieveSeries().iterationEncoding() ==
IterationEncoding::variableBased;
Parameter<Operation::WRITE_ATT> aWrite;
Expand Down Expand Up @@ -385,28 +429,35 @@ namespace
};
} // namespace

inline void breakpoint()
{}

void RecordComponent::readBase(bool require_unit_si)
{
using DT = Datatype;
// auto & rc = get();
Parameter<Operation::READ_ATT> aRead;
auto &rc = get();

if (constant() && !empty())
{
aRead.name = "value";
IOHandler()->enqueue(IOTask(this, aRead));
IOHandler()->flush(internal::defaultFlushParams);
readAttributes(ReadMode::FullyReread);

Attribute a(*aRead.resource);
DT dtype = *aRead.dtype;
auto read_constant =
[&]() // comment for forcing clang-format into putting a newline here
{
Attribute a = rc.readAttribute("value");
DT dtype = a.dtype;
setWritten(false, Attributable::EnqueueAsynchronously::No);
switchNonVectorType<MakeConstant>(dtype, *this, a);
setWritten(true, Attributable::EnqueueAsynchronously::No);

aRead.name = "shape";
IOHandler()->enqueue(IOTask(this, aRead));
IOHandler()->flush(internal::defaultFlushParams);
a = Attribute(*aRead.resource);
if (!containsAttribute("shape"))
{
setWritten(false, Attributable::EnqueueAsynchronously::No);
resetDataset(Dataset(dtype, {}));
setWritten(true, Attributable::EnqueueAsynchronously::No);

return;
}

a = rc.attributes().at("shape");
Extent e;

// uint64_t check
Expand All @@ -416,7 +467,7 @@ void RecordComponent::readBase(bool require_unit_si)
else
{
std::ostringstream oss;
oss << "Unexpected datatype (" << *aRead.dtype
oss << "Unexpected datatype (" << a.dtype
<< ") for attribute 'shape' (" << determineDatatype<uint64_t>()
<< " aka uint64_t)";
throw error::ReadError(
Expand All @@ -429,9 +480,13 @@ void RecordComponent::readBase(bool require_unit_si)
setWritten(false, Attributable::EnqueueAsynchronously::No);
resetDataset(Dataset(dtype, e));
setWritten(true, Attributable::EnqueueAsynchronously::No);
}
};

readAttributes(ReadMode::FullyReread);
if (constant() && !empty())
{
breakpoint();
read_constant();
}

if (require_unit_si)
{
Expand All @@ -445,16 +500,17 @@ void RecordComponent::readBase(bool require_unit_si)
"'" +
myPath().openPMDPath() + "'.");
}
if (!getAttribute("unitSI").getOptional<double>().has_value())
if (auto attr = getAttribute("unitSI");
!attr.getOptional<double>().has_value())
{
throw error::ReadError(
error::AffectedObject::Attribute,
error::Reason::UnexpectedContent,
{},
"Unexpected Attribute datatype for 'unitSI' (expected double, "
"found " +
datatypeToString(Attribute(*aRead.resource).dtype) +
") in '" + myPath().openPMDPath() + "'.");
datatypeToString(attr.dtype) + ") in '" +
myPath().openPMDPath() + "'.");
}
}
}
Expand Down
Loading
Loading