diff --git a/src/cuda/tile/_passes/token_order.py b/src/cuda/tile/_passes/token_order.py index e6ce856..a7abf22 100644 --- a/src/cuda/tile/_passes/token_order.py +++ b/src/cuda/tile/_passes/token_order.py @@ -10,7 +10,7 @@ from cuda.tile._ir.type import TupleTy, TokenTy from cuda.tile._memory_model import MemoryOrder -from cuda.tile._exception import Loc +from cuda.tile._exception import Loc, TileInternalError from cuda.tile._ir.ir import Function, Block, IRContext, Var, Operation from cuda.tile._ir.ops import ( Assign, Break, BuildTuple, CarriedVariables, Continue, EndBranch, IfElse, @@ -150,15 +150,16 @@ def get_memory_effects(cur_op): if isinstance(cur_op, LoadMemoryOperation): effect = MemoryEffect.LOAD - else: - assert isinstance(cur_op, StoreMemoryOperation) + elif isinstance(cur_op, StoreMemoryOperation): effect = MemoryEffect.STORE + else: + raise TileInternalError(f"Unexpected MemoryOperation type: {type(cur_op)}") has_acquire_order = False if isinstance(cur_op, (TileAtomicCAS, TileAtomicRMW)): - has_acquire_order = memory_order_has_acquire(op.memory_order) + has_acquire_order = memory_order_has_acquire(cur_op.memory_order) - return MemoryEffects({alias_result[_get_input_var(op).name]: effect}, has_acquire_order) + return MemoryEffects({alias_result[_get_input_var(cur_op).name]: effect}, has_acquire_order) blk_mem_effects = EMPTY_MEMORY_EFFECTS for op in block.operations: