Skip to content

Commit 74704a4

Browse files
committed
Updated the form of function
1 parent ee1d837 commit 74704a4

File tree

2 files changed

+23
-30
lines changed

2 files changed

+23
-30
lines changed

mlir/include/QEC/Transforms/Patterns.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023 Xanadu Quantum Technologies Inc.
1+
// Copyright 2025 Xanadu Quantum Technologies Inc.
22

33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

mlir/lib/QEC/Transforms/QECLowering.cpp

+22-29
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include "QEC/IR/QECDialect.h"
2929
#include "QEC/Transforms/Patterns.h"
3030
#include "Quantum/IR/QuantumOps.h"
31-
#include "mlir/Support/LogicalResult.h"
3231

3332
using namespace mlir;
3433
using namespace catalyst;
@@ -52,9 +51,9 @@ const llvm::StringMap<GateConversion> gateMap = {{"H", {{"Z", "X", "Z"}, M_PI /
5251
{"CNOT", {{"Z", "X"}, M_PI / 4}}};
5352

5453
// Get the Pauli operators and theta for a given gate
55-
template <typename OriginOp> GateConversion getPauliOperators(OriginOp *op)
54+
GateConversion getPauliOperators(CustomOp op)
5655
{
57-
mlir::StringRef opName = op->getGateName();
56+
mlir::StringRef opName = op.getGateName();
5857
auto gateConversion = gateMap.find(opName);
5958

6059
if (gateConversion == gateMap.end()) {
@@ -65,41 +64,39 @@ template <typename OriginOp> GateConversion getPauliOperators(OriginOp *op)
6564
return gateConversion->second;
6665
}
6766

68-
template <typename OriginOp, typename LoweredQECOp>
69-
LoweredQECOp convertCustomOpToPPR(OriginOp *op, ConversionPatternRewriter &rewriter)
67+
// Convert a CustomOp to a PPRotationOp
68+
PPRotationOp convertCustomOpToPPR(CustomOp op, ConversionPatternRewriter &rewriter)
7069
{
71-
auto loc = op->getLoc();
70+
auto loc = op.getLoc();
7271
auto gateConversion = getPauliOperators(op);
7372

7473
if (gateConversion.pauliOperators.empty()) {
75-
return PPRotationOp();
74+
return nullptr;
7675
}
7776

7877
ArrayAttr pauliProduct = rewriter.getStrArrayAttr(gateConversion.pauliOperators);
79-
ValueRange inQubits = op->getInQubits();
80-
TypeRange outQubitsTypes = op->getOutQubits().getTypes();
78+
ValueRange inQubits = op.getInQubits();
79+
TypeRange outQubitsTypes = op.getOutQubits().getType();
8180
Value thetaValue = rewriter.create<arith::ConstantOp>(
8281
loc, rewriter.getF64Type(), rewriter.getF64FloatAttr(gateConversion.theta));
8382

84-
// Create a new PPRotationOp with the Pauli operators and theta
8583
auto pprOp =
86-
rewriter.create<LoweredQECOp>(loc, outQubitsTypes, pauliProduct, thetaValue, inQubits);
84+
rewriter.create<PPRotationOp>(loc, outQubitsTypes, pauliProduct, thetaValue, inQubits);
8785

88-
// Replace the original operation with the new PPRotationOp
8986
return pprOp;
9087
}
9188

92-
template <typename OriginOp, typename LoweredQECOp>
93-
LoweredQECOp convertMeasureOpToPPM(OriginOp *op, ConversionPatternRewriter &rewriter)
89+
// Convert a MeasureOp to a PPMeasurementOp
90+
PPMeasurementOp convertMeasureOpToPPM(MeasureOp op, ConversionPatternRewriter &rewriter)
9491
{
95-
auto loc = op->getLoc();
92+
auto loc = op.getLoc();
9693

97-
// pauli product is always Z
94+
// Pauli product is always Z
9895
ArrayAttr pauliProduct = rewriter.getStrArrayAttr({"Z"});
99-
ValueRange inQubits = op->getInQubit();
100-
TypeRange outQubitsTypes = op->getResults().getType();
96+
ValueRange inQubits = op.getInQubit();
97+
TypeRange outQubitsTypes = op.getOutQubit().getType();
10198

102-
auto ppmOp = rewriter.create<LoweredQECOp>(loc, outQubitsTypes, pauliProduct, inQubits);
99+
auto ppmOp = rewriter.create<PPMeasurementOp>(loc, outQubitsTypes, pauliProduct, inQubits);
103100

104101
return ppmOp;
105102
}
@@ -117,13 +114,12 @@ struct QECOpLowering : public ConversionPattern {
117114
Operation *loweredOp = nullptr;
118115

119116
// cast to OriginOp
120-
if (llvm::isa<quantum::CustomOp>(op)) {
121-
auto originOp = llvm::cast<quantum::CustomOp>(op);
122-
loweredOp = convertCustomOpToPPR<CustomOp, PPRotationOp>(&originOp, rewriter);
123-
124-
}else if (llvm::isa<quantum::MeasureOp>(op)) {
125-
auto originOp = llvm::cast<quantum::MeasureOp>(op);
126-
loweredOp = convertMeasureOpToPPM<MeasureOp, PPMeasurementOp>(&originOp, rewriter);
117+
if (isa<quantum::CustomOp>(op)) {
118+
auto originOp = cast<quantum::CustomOp>(op);
119+
loweredOp = convertCustomOpToPPR(originOp, rewriter);
120+
}else if (isa<quantum::MeasureOp>(op)) {
121+
auto originOp = cast<quantum::MeasureOp>(op);
122+
loweredOp = convertMeasureOpToPPM(originOp, rewriter);
127123
}
128124

129125
if (!loweredOp) {
@@ -132,12 +128,9 @@ struct QECOpLowering : public ConversionPattern {
132128

133129
rewriter.replaceOp(op, loweredOp);
134130
return success();
135-
136-
return failure();
137131
}
138132
};
139133

140-
// TODO: add more lowering patterns here. e.g. StaticCustomOp, UnitaryCustomOp, etc.
141134
using CustomOpLowering = QECOpLowering<quantum::CustomOp, qec::PPRotationOp>;
142135
using MeasureOpLowering = QECOpLowering<quantum::MeasureOp, qec::PPMeasurementOp>;
143136

0 commit comments

Comments
 (0)