Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/substrait/common/Io.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ enum class PlanFileFormat {
* amount of memory that it consumed on disk.
*
* \param input_filename The filename containing the plan to convert.
* \param force_binary If true, the plan will be opened as a binary file.
* Required on Windows to avoid text mode line-ending translation.
* \return If loading was successful, returns a plan. If loading was not
* successful this is a status containing a list of parse errors in the status's
* message.
*/
absl::StatusOr<::substrait::proto::Plan> loadPlan(
std::string_view input_filename);
std::string_view input_filename,
bool force_binary = false);

/*
* \brief Writes the provided plan to disk.
Expand Down
6 changes: 4 additions & 2 deletions src/substrait/common/Io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ PlanFileFormat detectFormat(std::string_view content) {
} // namespace

absl::StatusOr<::substrait::proto::Plan> loadPlan(
std::string_view input_filename) {
auto contentOrError = textplan::readFromFile(input_filename.data());
std::string_view input_filename,
bool forceBinary) {
auto contentOrError =
textplan::readFromFile(input_filename.data(), forceBinary);
if (!contentOrError.ok()) {
return contentOrError.status();
}
Expand Down
11 changes: 10 additions & 1 deletion src/substrait/common/tests/IoTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class SaveAndLoadTestFixture : public ::testing::TestWithParam<PlanFileFormat> {
std::filesystem::path("my_temp_dir"))
.string();

if (!std::filesystem::create_directory(testFileDirectory_)) {
std::filesystem::create_directory(testFileDirectory_);
if (!std::filesystem::exists(testFileDirectory_)) {
ASSERT_TRUE(false) << "Failed to create temporary directory.";
testFileDirectory_.clear();
}
Expand Down Expand Up @@ -87,7 +88,15 @@ TEST_P(SaveAndLoadTestFixture, SaveAndLoad) {
auto status = ::io::substrait::savePlan(plan, tempFilename, encoding);
ASSERT_TRUE(status.ok()) << "Save failed.\n" << status;

#ifdef _WIN32
// Windows cannot rely on io::substrait::loadPlan to detect the file format,
// since it needs to a-priori specify how the file should be loaded.
bool forceBinary = encoding == PlanFileFormat::kBinary;
auto result = ::io::substrait::loadPlan(tempFilename, forceBinary);
#else
auto result = ::io::substrait::loadPlan(tempFilename);
#endif

ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status();
ASSERT_THAT(
*result,
Expand Down
15 changes: 11 additions & 4 deletions src/substrait/textplan/converter/LoadBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,24 @@ class StringErrorCollector : public google::protobuf::io::ErrorCollector {

} // namespace

absl::StatusOr<std::string> readFromFile(std::string_view msgPath) {
std::ifstream textFile(std::string{msgPath});
if (textFile.fail()) {
absl::StatusOr<std::string> readFromFile(
std::string_view msgPath,
bool forceBinary) {
std::ifstream file;
if (forceBinary)
file.open(std::string{msgPath}, std::ios::binary);
else
file.open(std::string{msgPath}, std::ios::in);

if (file.fail()) {
auto currDir = std::filesystem::current_path().string();
return absl::ErrnoToStatus(
errno,
fmt::format(
"Failed to open file {} when running in {}", msgPath, currDir));
}
std::stringstream buffer;
buffer << textFile.rdbuf();
buffer << file.rdbuf();
return buffer.str();
}

Expand Down
7 changes: 5 additions & 2 deletions src/substrait/textplan/converter/LoadBinary.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ class Plan;

namespace io::substrait::textplan {

// Read the contents of a file from disk.
absl::StatusOr<std::string> readFromFile(std::string_view msgPath);
// Read the contents of a file from disk. 'forceBinary' enables file reading in
// binary mode.
absl::StatusOr<std::string> readFromFile(
std::string_view msgPath,
bool forceBinary = false);

// Reads a plan from a json-encoded text proto.
// Returns a list of errors if the file cannot be parsed.
Expand Down
18 changes: 6 additions & 12 deletions src/substrait/textplan/converter/SaveBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,18 @@ namespace io::substrait::textplan {
absl::Status savePlanToBinary(
const ::substrait::proto::Plan& plan,
std::string_view output_filename) {
int outputFileDescriptor =
creat(std::string{output_filename}.c_str(), S_IREAD | S_IWRITE);
if (outputFileDescriptor == -1) {
return absl::ErrnoToStatus(
errno,
// Open file in binary mode and get its file descriptor
std::ofstream of(std::string{output_filename}, std::ios::binary);
if (!of) {
return absl::InternalError(
fmt::format("Failed to open file {} for writing", output_filename));
}
auto stream =
new google::protobuf::io::FileOutputStream(outputFileDescriptor);

if (!plan.SerializeToZeroCopyStream(stream)) {
if (!plan.SerializeToOstream(&of)) {
return ::absl::UnknownError("Failed to write plan to stream.");
}

if (!stream->Close()) {
return absl::AbortedError("Failed to close file descriptor.");
}
delete stream;
of.close();
return absl::OkStatus();
}

Expand Down
35 changes: 19 additions & 16 deletions src/substrait/textplan/tests/ParseResultMatchers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,31 +149,34 @@ class HasSymbolsMatcher {

bool MatchAndExplain(const ParseResult& result, std::ostream* listener)
const {
auto actualSymbols = symbolNames(result.getSymbolTable().getSymbols());
// Note: Need set or sorted vector for set_difference.
auto actualSymbolsSorted =
symbolNames(result.getSymbolTable().getSymbols());
std::sort(actualSymbolsSorted.begin(), actualSymbolsSorted.end());
std::vector<std::string> extraSymbols;
auto expectedSymbolsSorted = expectedSymbols_;
std::sort(expectedSymbolsSorted.begin(), expectedSymbolsSorted.end());
if (listener != nullptr) {
std::vector<std::string> extraSymbols(actualSymbols.size());
auto end = std::set_difference(
actualSymbols.begin(),
actualSymbols.end(),
expectedSymbols_.begin(),
expectedSymbols_.end(),
extraSymbols.begin());
extraSymbols.resize(end - extraSymbols.begin());
actualSymbolsSorted.begin(),
actualSymbolsSorted.end(),
expectedSymbolsSorted.begin(),
expectedSymbolsSorted.end(),
std::back_inserter(extraSymbols));
if (!extraSymbols.empty()) {
*listener << std::endl << " with extra symbols: ";
for (const auto& symbol : extraSymbols) {
*listener << " \"" << symbol << "\"";
}
}

std::vector<std::string> missingSymbols(expectedSymbols_.size());
std::vector<std::string> missingSymbols;
end = std::set_difference(
expectedSymbols_.begin(),
expectedSymbols_.end(),
actualSymbols.begin(),
actualSymbols.end(),
missingSymbols.begin());
missingSymbols.resize(end - missingSymbols.begin());
expectedSymbolsSorted.begin(),
expectedSymbolsSorted.end(),
actualSymbolsSorted.begin(),
actualSymbolsSorted.end(),
std::back_inserter(missingSymbols));
if (!missingSymbols.empty()) {
if (!extraSymbols.empty()) {
*listener << ", and missing symbols: ";
Expand All @@ -185,7 +188,7 @@ class HasSymbolsMatcher {
}
}
}
return actualSymbols == expectedSymbols_;
return actualSymbolsSorted == expectedSymbolsSorted;
}

void DescribeTo(std::ostream* os) const {
Expand Down
Loading