28
28
#include " QEC/IR/QECDialect.h"
29
29
#include " QEC/Transforms/Patterns.h"
30
30
#include " Quantum/IR/QuantumOps.h"
31
- #include " mlir/Support/LogicalResult.h"
32
31
33
32
using namespace mlir ;
34
33
using namespace catalyst ;
@@ -52,9 +51,9 @@ const llvm::StringMap<GateConversion> gateMap = {{"H", {{"Z", "X", "Z"}, M_PI /
52
51
{" CNOT" , {{" Z" , " X" }, M_PI / 4 }}};
53
52
54
53
// Get the Pauli operators and theta for a given gate
55
- template < typename OriginOp> GateConversion getPauliOperators (OriginOp * op)
54
+ GateConversion getPauliOperators (CustomOp op)
56
55
{
57
- mlir::StringRef opName = op-> getGateName ();
56
+ mlir::StringRef opName = op. getGateName ();
58
57
auto gateConversion = gateMap.find (opName);
59
58
60
59
if (gateConversion == gateMap.end ()) {
@@ -65,41 +64,39 @@ template <typename OriginOp> GateConversion getPauliOperators(OriginOp *op)
65
64
return gateConversion->second ;
66
65
}
67
66
68
- template < typename OriginOp, typename LoweredQECOp>
69
- LoweredQECOp convertCustomOpToPPR (OriginOp * op, ConversionPatternRewriter &rewriter)
67
+ // Convert a CustomOp to a PPRotationOp
68
+ PPRotationOp convertCustomOpToPPR (CustomOp op, ConversionPatternRewriter &rewriter)
70
69
{
71
- auto loc = op-> getLoc ();
70
+ auto loc = op. getLoc ();
72
71
auto gateConversion = getPauliOperators (op);
73
72
74
73
if (gateConversion.pauliOperators .empty ()) {
75
- return PPRotationOp () ;
74
+ return nullptr ;
76
75
}
77
76
78
77
ArrayAttr pauliProduct = rewriter.getStrArrayAttr (gateConversion.pauliOperators );
79
- ValueRange inQubits = op-> getInQubits ();
80
- TypeRange outQubitsTypes = op-> getOutQubits ().getTypes ();
78
+ ValueRange inQubits = op. getInQubits ();
79
+ TypeRange outQubitsTypes = op. getOutQubits ().getType ();
81
80
Value thetaValue = rewriter.create <arith::ConstantOp>(
82
81
loc, rewriter.getF64Type (), rewriter.getF64FloatAttr (gateConversion.theta ));
83
82
84
- // Create a new PPRotationOp with the Pauli operators and theta
85
83
auto pprOp =
86
- rewriter.create <LoweredQECOp >(loc, outQubitsTypes, pauliProduct, thetaValue, inQubits);
84
+ rewriter.create <PPRotationOp >(loc, outQubitsTypes, pauliProduct, thetaValue, inQubits);
87
85
88
- // Replace the original operation with the new PPRotationOp
89
86
return pprOp;
90
87
}
91
88
92
- template < typename OriginOp, typename LoweredQECOp>
93
- LoweredQECOp convertMeasureOpToPPM (OriginOp * op, ConversionPatternRewriter &rewriter)
89
+ // Convert a MeasureOp to a PPMeasurementOp
90
+ PPMeasurementOp convertMeasureOpToPPM (MeasureOp op, ConversionPatternRewriter &rewriter)
94
91
{
95
- auto loc = op-> getLoc ();
92
+ auto loc = op. getLoc ();
96
93
97
- // pauli product is always Z
94
+ // Pauli product is always Z
98
95
ArrayAttr pauliProduct = rewriter.getStrArrayAttr ({" Z" });
99
- ValueRange inQubits = op-> getInQubit ();
100
- TypeRange outQubitsTypes = op-> getResults ().getType ();
96
+ ValueRange inQubits = op. getInQubit ();
97
+ TypeRange outQubitsTypes = op. getOutQubit ().getType ();
101
98
102
- auto ppmOp = rewriter.create <LoweredQECOp >(loc, outQubitsTypes, pauliProduct, inQubits);
99
+ auto ppmOp = rewriter.create <PPMeasurementOp >(loc, outQubitsTypes, pauliProduct, inQubits);
103
100
104
101
return ppmOp;
105
102
}
@@ -117,13 +114,12 @@ struct QECOpLowering : public ConversionPattern {
117
114
Operation *loweredOp = nullptr ;
118
115
119
116
// cast to OriginOp
120
- if (llvm::isa<quantum::CustomOp>(op)) {
121
- auto originOp = llvm::cast<quantum::CustomOp>(op);
122
- loweredOp = convertCustomOpToPPR<CustomOp, PPRotationOp>(&originOp, rewriter);
123
-
124
- }else if (llvm::isa<quantum::MeasureOp>(op)) {
125
- auto originOp = llvm::cast<quantum::MeasureOp>(op);
126
- loweredOp = convertMeasureOpToPPM<MeasureOp, PPMeasurementOp>(&originOp, rewriter);
117
+ if (isa<quantum::CustomOp>(op)) {
118
+ auto originOp = cast<quantum::CustomOp>(op);
119
+ loweredOp = convertCustomOpToPPR (originOp, rewriter);
120
+ }else if (isa<quantum::MeasureOp>(op)) {
121
+ auto originOp = cast<quantum::MeasureOp>(op);
122
+ loweredOp = convertMeasureOpToPPM (originOp, rewriter);
127
123
}
128
124
129
125
if (!loweredOp) {
@@ -132,12 +128,9 @@ struct QECOpLowering : public ConversionPattern {
132
128
133
129
rewriter.replaceOp (op, loweredOp);
134
130
return success ();
135
-
136
- return failure ();
137
131
}
138
132
};
139
133
140
- // TODO: add more lowering patterns here. e.g. StaticCustomOp, UnitaryCustomOp, etc.
141
134
using CustomOpLowering = QECOpLowering<quantum::CustomOp, qec::PPRotationOp>;
142
135
using MeasureOpLowering = QECOpLowering<quantum::MeasureOp, qec::PPMeasurementOp>;
143
136
0 commit comments