Skip to content

Commit 0b573a0

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Remove more code paths that can create >1 module module groups.
Make HloModuleGroup::push_back() private. PiperOrigin-RevId: 814460024
1 parent 23c4c38 commit 0b573a0

File tree

6 files changed

+47
-289
lines changed

6 files changed

+47
-289
lines changed

xla/hlo/ir/hlo_module_group.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,9 @@ namespace xla {
3434
// concurrently across different devices.
3535
class HloModuleGroup {
3636
public:
37-
// Construct an empty module group.
38-
explicit HloModuleGroup(absl::string_view name) : name_(name) {}
39-
4037
// Construct a module group containing a single module.
4138
explicit HloModuleGroup(std::unique_ptr<HloModule> module);
4239

43-
// Construct a module group containing any number of modules.
44-
HloModuleGroup(absl::string_view name,
45-
absl::Span<std::unique_ptr<HloModule>> modules);
46-
HloModuleGroup(absl::string_view name,
47-
std::vector<std::unique_ptr<HloModule>>&& modules);
48-
4940
HloModuleGroup(const HloModuleGroup& other) = delete;
5041
HloModuleGroup(HloModuleGroup&& other) = default;
5142
HloModuleGroup& operator=(const HloModuleGroup& other) = delete;
@@ -57,9 +48,6 @@ class HloModuleGroup {
5748
// Returns a module at a particular index.
5849
HloModule& module(int index) const { return *module_ptrs_.at(index); }
5950

60-
// Add a module to the back of vector of modules in the group.
61-
void push_back(std::unique_ptr<HloModule> module);
62-
6351
// Replaces the existing module at the given index with the given module. The
6452
// existing module is discarded.
6553
void ReplaceModule(int index, std::unique_ptr<HloModule> module);
@@ -105,6 +93,19 @@ class HloModuleGroup {
10593
}
10694

10795
private:
96+
// Construct an empty module group.
97+
explicit HloModuleGroup(absl::string_view name) : name_(name) {}
98+
99+
// Construct a module group containing any number of modules.
100+
HloModuleGroup(absl::string_view name,
101+
absl::Span<std::unique_ptr<HloModule>> modules);
102+
HloModuleGroup(absl::string_view name,
103+
std::vector<std::unique_ptr<HloModule>>&& modules);
104+
105+
// Add a module to the back of vector of modules in the group. Private
106+
// because we no longer want to support > 1 module per group.
107+
void push_back(std::unique_ptr<HloModule> module);
108+
108109
std::string name_;
109110

110111
// Vector of modules as std::unique_ptrs.

xla/hlo/pass/hlo_pass_fix_test.cc

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -182,85 +182,5 @@ TEST_F(HloPassFixTest, RunModuleToNonDefaultEarlyExit) {
182182
EXPECT_EQ(root->literal().GetFirstElement<int32_t>(), 20);
183183
}
184184

185-
TEST_F(HloPassFixTest, RunModuleGroupToFixedPoint) {
186-
constexpr absl::string_view kModule0 = R"(
187-
HloModule First
188-
189-
ENTRY main {
190-
ROOT c = s32[] constant(5)
191-
}
192-
)";
193-
194-
constexpr absl::string_view kModule1 = R"(
195-
HloModule Second
196-
197-
ENTRY main {
198-
ROOT c = s32[] constant(3)
199-
}
200-
)";
201-
202-
constexpr absl::string_view kModule2 = R"(
203-
HloModule Second
204-
205-
ENTRY main {
206-
ROOT c = s32[] constant(0)
207-
}
208-
)";
209-
210-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module0,
211-
ParseAndReturnVerifiedModule(kModule0));
212-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module1,
213-
ParseAndReturnVerifiedModule(kModule1));
214-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module2,
215-
ParseAndReturnVerifiedModule(kModule2));
216-
HloModuleGroup module_group("group");
217-
module_group.push_back(std::move(module0));
218-
module_group.push_back(std::move(module1));
219-
module_group.push_back(std::move(module2));
220-
221-
HloPassFix<DecrementPositiveConstants> pass;
222-
TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.RunOnModuleGroup(&module_group));
223-
EXPECT_TRUE(changed);
224-
HloInstruction* root0 =
225-
module_group.module(0).entry_computation()->root_instruction();
226-
ASSERT_EQ(root0->opcode(), HloOpcode::kConstant);
227-
EXPECT_EQ(root0->literal().GetFirstElement<int32_t>(), 0);
228-
HloInstruction* root1 =
229-
module_group.module(1).entry_computation()->root_instruction();
230-
ASSERT_EQ(root1->opcode(), HloOpcode::kConstant);
231-
EXPECT_EQ(root1->literal().GetFirstElement<int32_t>(), 0);
232-
HloInstruction* root2 =
233-
module_group.module(2).entry_computation()->root_instruction();
234-
ASSERT_EQ(root2->opcode(), HloOpcode::kConstant);
235-
EXPECT_EQ(root2->literal().GetFirstElement<int32_t>(), 0);
236-
}
237-
238-
TEST_F(HloPassFixTest, OscillationsStillTerminate) {
239-
constexpr absl::string_view kModule = R"(
240-
HloModule Oscillating
241-
242-
ENTRY main {
243-
a = f32[4] parameter(0)
244-
b = f32[4] parameter(1)
245-
ROOT c = f32[4] add(a, b)
246-
}
247-
)";
248-
249-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
250-
ParseAndReturnVerifiedModule(kModule));
251-
HloPassFix<FlipAddSubtract> pass;
252-
253-
// We expect this to terminate and report that the module did not change.
254-
TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
255-
EXPECT_FALSE(changed);
256-
257-
// But don't lie when crash_on_hlo_pass_silent_hlo_change is set.
258-
module->mutable_config()
259-
.mutable_debug_options()
260-
.set_xla_unsupported_crash_on_hlo_pass_silent_hlo_change(true);
261-
TF_ASSERT_OK_AND_ASSIGN(changed, pass.Run(module.get()));
262-
EXPECT_TRUE(changed);
263-
}
264-
265185
} // namespace
266186
} // namespace xla

xla/hlo/pass/hlo_pass_pipeline_test.cc

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ limitations under the License.
2727
#include "absl/status/status.h"
2828
#include "absl/status/statusor.h"
2929
#include "absl/strings/string_view.h"
30-
#include "absl/types/span.h"
3130
#include "xla/hlo/ir/hlo_computation.h"
3231
#include "xla/hlo/ir/hlo_instruction.h"
3332
#include "xla/hlo/ir/hlo_module.h"
@@ -50,15 +49,10 @@ using ::testing::StrEq;
5049

5150
class HloPassPipelineTest : public HloHardwareIndependentTestBase {
5251
protected:
53-
absl::StatusOr<HloModuleGroup> ParseModuleGroup(
54-
absl::Span<const std::string> hlo_strings) {
55-
HloModuleGroup group(TestName());
56-
for (const std::string& hlo_string : hlo_strings) {
57-
TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
58-
ParseAndReturnVerifiedModule(hlo_string));
59-
group.push_back(std::move(module));
60-
}
61-
return group;
52+
absl::StatusOr<HloModuleGroup> ParseModuleGroup(std::string hlo_string) {
53+
TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
54+
ParseAndReturnVerifiedModule(hlo_string));
55+
return HloModuleGroup(std::move(module));
6256
}
6357
};
6458

@@ -276,36 +270,22 @@ ENTRY main {
276270
ROOT baz = f32[] multiply(a, b)
277271
}
278272
)";
279-
const std::string module_1_str = R"(
280-
HloModule MixedPipeline.0
281-
282-
ENTRY main {
283-
a = f32[] parameter(0)
284-
b = f32[] parameter(1)
285-
ROOT foo = f32[] multiply(a, b)
286-
}
287-
)";
288-
289273
TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
290-
ParseModuleGroup({module_0_str, module_1_str}));
274+
ParseModuleGroup(module_0_str));
291275

292276
HloPassPipeline pipeline(TestName());
293277
pipeline.AddPass<BazToQuxModuleGroupPass>();
294278
pipeline.AddPass<FooToBarModulePass>();
295279

296280
HloInstruction* root0 =
297281
module_group.module(0).entry_computation()->root_instruction();
298-
HloInstruction* root1 =
299-
module_group.module(1).entry_computation()->root_instruction();
300282
EXPECT_EQ(root0->name(), "baz");
301-
EXPECT_EQ(root1->name(), "foo");
302283

303284
TF_ASSERT_OK_AND_ASSIGN(bool changed,
304285
pipeline.RunOnModuleGroup(&module_group));
305286
EXPECT_TRUE(changed);
306287

307288
EXPECT_EQ(root0->name(), "qux");
308-
EXPECT_EQ(root1->name(), "bar");
309289
}
310290

311291
TEST_F(HloPassPipelineTest, InvariantChecker) {
@@ -383,15 +363,13 @@ ENTRY main {
383363

384364
// Test that metadata is set when a module group goes through a pass pipeline.
385365
TEST_F(HloPassPipelineTest, SetHloModuleMetadata) {
386-
HloModuleGroup module_group(TestName());
387-
module_group.push_back(CreateNewVerifiedModule());
388-
module_group.push_back(CreateNewVerifiedModule());
366+
HloModuleGroup module_group(CreateNewVerifiedModule());
389367

390368
HloPassPipeline pipeline(TestName());
391369
pipeline.AddPass<BazToQuxModuleGroupPass>();
392370
pipeline.AddPass<FooToBarModulePass>();
393371
TF_ASSERT_OK(pipeline.RunOnModuleGroup(&module_group).status());
394-
ASSERT_THAT(module_group.modules(), SizeIs(2));
372+
ASSERT_THAT(module_group.modules(), SizeIs(1));
395373

396374
std::vector<std::string> pass_names = {"pipeline-start", "baz2qux",
397375
"foo2bar"};
@@ -410,8 +388,7 @@ TEST_F(HloPassPipelineTest, SetHloModuleMetadata) {
410388
EXPECT_FALSE(pass_metadata.module_changed());
411389
EXPECT_EQ(pass_metadata.module_id(), module->unique_id());
412390
EXPECT_THAT(pass_metadata.module_group_module_ids(),
413-
ElementsAre(module_group.module(0).unique_id(),
414-
module_group.module(1).unique_id()));
391+
ElementsAre(module_group.module(0).unique_id()));
415392
EXPECT_GT(pass_metadata.start_timestamp_usec(), 0);
416393
EXPECT_LE(pass_metadata.start_timestamp_usec(),
417394
pass_metadata.end_timestamp_usec());

xla/service/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,10 +1838,7 @@ xla_cc_test(
18381838
xla_cc_test(
18391839
name = "hlo_module_group_test",
18401840
srcs = ["hlo_module_group_test.cc"],
1841-
# TODO(b/148211710) Test fails in OSS.
1842-
tags = ["no_oss"],
18431841
deps = [
1844-
":hlo_module_group_metadata",
18451842
":hlo_proto_cc",
18461843
"//xla/hlo/ir:hlo_module_group",
18471844
"//xla/hlo/testlib:hlo_hardware_independent_test_base",

0 commit comments

Comments
 (0)