diff --git a/src/pypulseq/Sequence/block.py b/src/pypulseq/Sequence/block.py index ec4d761e..e9f4cb38 100644 --- a/src/pypulseq/Sequence/block.py +++ b/src/pypulseq/Sequence/block.py @@ -9,6 +9,7 @@ from pypulseq.compress_shape import compress_shape from pypulseq.decompress_shape import decompress_shape from pypulseq.event_lib import EventLibrary +from pypulseq.Sequence.grad_check import grad_check from pypulseq.supported_labels_rf_use import get_supported_labels from pypulseq.utils.tracing import trace_enabled @@ -52,6 +53,7 @@ def set_block(self, block_index: int, *args: SimpleNamespace) -> None: } # Key-value mapping of index and pairs of gradients/times extensions = [] + rot_event = None for event in events: if not isinstance(event, float): # If event is not a block duration if event.type == 'rf': @@ -139,6 +141,18 @@ def set_block(self, block_index: int, *args: SimpleNamespace) -> None: ext = {'type': self.get_extension_type_ID('TRIGGERS'), 'ref': event_id} extensions.append(ext) duration = max(duration, event.delay + event.duration) + elif event.type == 'rotation': + rot_event = event + if hasattr(event, 'id'): + event_id = event.id + else: + event_id = register_rotation_event(self, event) + + ext = { + 'type': self.get_extension_type_ID('ROTATIONS'), + 'ref': event_id + } + extensions.append(ext) elif event.type in ['labelset', 'labelinc']: if hasattr(event, 'id'): label_id = event.id @@ -190,79 +204,8 @@ def set_block(self, block_index: int, *args: SimpleNamespace) -> None: # ========= # PERFORM GRADIENT CHECKS # ========= - for grad_to_check in check_g.values(): - if abs(grad_to_check.start[1]) > self.system.max_slew * self.system.grad_raster_time: # noqa: SIM102 - if grad_to_check.start[0] > eps: - raise RuntimeError('No delay allowed for gradients which start with a non-zero amplitude') - - # Check whether any blocks exist in the sequence - if self.next_free_block_ID > 1: - # Look up the previous block (and the next block in case of a set_block call) - if block_index == self.next_free_block_ID: - # New block inserted - prev_block_index = next(reversed(self.block_events)) - next_block_index = None - else: - blocks = list(self.block_events) - try: - # Existing block overwritten - idx = blocks.index(block_index) - prev_block_index = blocks[idx - 1] if idx > 0 else None - next_block_index = blocks[idx + 1] if idx < len(blocks) - 1 else None - except ValueError: - # Inserting a new block with non-contiguous numbering - prev_block_index = next(reversed(self.block_events)) - next_block_index = None - - # Look up the last gradient value in the previous block - last = 0 - if prev_block_index is not None: - prev_id = self.block_events[prev_block_index][grad_to_check.idx] - if prev_id != 0: - prev_lib = self.grad_library.get(prev_id) - prev_type = prev_lib['type'] - - if prev_type == 't': - last = 0 - elif prev_type == 'g': - last = prev_lib['data'][5] - - # Check whether the difference between the last gradient value and - # the first value of the new gradient is achievable with the - # specified slew rate. - if abs(last - grad_to_check.start[1]) > self.system.max_slew * self.system.grad_raster_time: - raise RuntimeError('Two consecutive gradients need to have the same amplitude at the connection point') - - # Look up the first gradient value in the next block - # (this only happens when using set_block to patch a block) - if next_block_index is not None: - next_id = self.block_events[next_block_index][grad_to_check.idx] - if next_id != 0: - next_lib = self.grad_library.get(next_id) - next_type = next_lib['type'] - - if next_type == 't': - first = 0 - elif next_type == 'g': - first = next_lib['data'][4] - else: - first = 0 - - # Check whether the difference between the first gradient value - # in the next block and the last value of the new gradient is - # achievable with the specified slew rate. - if abs(first - grad_to_check.stop[1]) > self.system.max_slew * self.system.grad_raster_time: - raise RuntimeError( - 'Two consecutive gradients need to have the same amplitude at the connection point' - ) - elif abs(grad_to_check.start[1]) > self.system.max_slew * self.system.grad_raster_time: - raise RuntimeError('First gradient in the the first block has to start at 0.') - - if ( - grad_to_check.stop[1] > self.system.max_slew * self.system.grad_raster_time - and abs(grad_to_check.stop[0] - duration) > 1e-7 - ): - raise RuntimeError("A gradient that doesn't end at zero needs to be aligned to the block boundary.") + grad_check(self, block_index, check_g, duration, rot_event) + self.block_events[block_index] = new_block self.block_durations[block_index] = float(duration) @@ -410,6 +353,12 @@ def get_block(self, block_index: int) -> SimpleNamespace: block.trigger[len(block.trigger)] = trigger else: block.trigger = {0: trigger} + elif ext_type == "ROTATIONS": + data = self.rotation_library.data[ext_data[1]] + rotation = SimpleNamespace() + rotation.type = "rotation" + rotation.rot_matrix = np.asarray(data).reshape(3, 3) + block.rotation = rotation elif ext_type in ['LABELSET', 'LABELINC']: label = SimpleNamespace() label.type = ext_type.lower() @@ -687,3 +636,25 @@ def register_rf_event(self, event: SimpleNamespace) -> Tuple[int, List[int]]: rf_id = self.rf_library.insert(key_id=0, new_data=data, data_type=use) return rf_id, shape_IDs + + +def register_rotation_event(self, event: EventLibrary) -> int: + """ + Parameters + ---------- + event : SimpleNamespace + Rotation event to be registered. + Returns + ------- + int + ID of registered rotation event. + """ + data = tuple(event.rot_matrix.ravel().tolist()) + rotation_id, found = self.rotation_library.find_or_insert(new_data=data) + + # Clear block cache because Rotation was overwritten + # TODO: Could find only the blocks that are affected by the changes + if self.use_block_cache and found: + self.block_cache.clear() + + return rotation_id diff --git a/src/pypulseq/Sequence/grad_check.py b/src/pypulseq/Sequence/grad_check.py new file mode 100644 index 00000000..89c7ac49 --- /dev/null +++ b/src/pypulseq/Sequence/grad_check.py @@ -0,0 +1,356 @@ +import numpy as np + +from pypulseq import eps + + +def grad_check(self, block_index, check_g, duration, rot_event): + """ + Check if connection to the previous block is correct. + + Parameters + ---------- + block_index : int + Current block index. + check_g : SimpleNamespace + Structure containing current gradient start and end (t, g) values for each + axis. + duration : float + Current block duration. + rot_event : SimpleNamespace + Current block rotation event. + + Raises + ------ + RuntimeError + If either 1) initial block start with non-zero amplitude; + 2) gradients starting with non-zero amplitude have a delay; + 3) gradients starting with non-zero amplitude have different initial + amplitude value than the previous block at connecting point; + 4) gradients ending with non-zero amplitude are not aligned with block raster; + 4) gradients ending with non-zero amplitude are not aligned have different initial + amplitude value than the next block at connecting point. + + """ + if not self.rotation_library.data: + grad_check_norot(self, block_index, check_g, duration) + else: + grad_check_rot(self, block_index, check_g, duration, rot_event) + + +def grad_check_norot(self, block_index, check_g, duration): + """ + Check continuity of adjacent gradient events in absence of rotations. + + Parameters + ---------- + block_index : int + Current block index. + check_g : SimpleNamespace + Structure containing current gradient start and end (t, g) values for each + axis. + duration : float + Current block duration. + + Raises + ------ + RuntimeError + If either 1) initial block start with non-zero amplitude; + 2) gradients starting with non-zero amplitude have a delay; + 3) gradients starting with non-zero amplitude have different initial + amplitude value than the previous block at connecting point; + 4) gradients ending with non-zero amplitude are not aligned with block raster; + 4) gradients ending with non-zero amplitude are not aligned have different initial + amplitude value than the next block at connecting point. + + """ + for grad_to_check in check_g.values(): + if ( + abs(grad_to_check.start[1]) + > self.system.max_slew * self.system.grad_raster_time + ): # noqa: SIM102 + if grad_to_check.start[0] > eps: + raise RuntimeError( + "No delay allowed for gradients which start with a non-zero amplitude" + ) + + # Check whether any blocks exist in the sequence + if self.next_free_block_ID > 1: + # Look up the previous block (and the next block in case of a set_block call) + if block_index == self.next_free_block_ID: + # New block inserted + prev_block_index = next(reversed(self.block_events)) + next_block_index = None + else: + blocks = list(self.block_events) + try: + # Existing block overwritten + idx = blocks.index(block_index) + prev_block_index = blocks[idx - 1] if idx > 0 else None + next_block_index = ( + blocks[idx + 1] if idx < len(blocks) - 1 else None + ) + except ValueError: + # Inserting a new block with non-contiguous numbering + prev_block_index = next(reversed(self.block_events)) + next_block_index = None + + # Look up the last gradient value in the previous block + last = 0 + if prev_block_index is not None: + prev_id = self.block_events[prev_block_index][grad_to_check.idx] + if prev_id != 0: + prev_lib = self.grad_library.get(prev_id) + prev_type = prev_lib["type"] + + if prev_type == "t": + last = 0 + elif prev_type == "g": + last = prev_lib["data"][5] + + # Check whether the difference between the last gradient value and + # the first value of the new gradient is achievable with the + # specified slew rate. + if ( + abs(last - grad_to_check.start[1]) + > self.system.max_slew * self.system.grad_raster_time + ): + raise RuntimeError( + "Two consecutive gradients need to have the same amplitude at the connection point" + ) + + # Look up the first gradient value in the next block + # (this only happens when using set_block to patch a block) + if next_block_index is not None: + next_id = self.block_events[next_block_index][grad_to_check.idx] + if next_id != 0: + next_lib = self.grad_library.get(next_id) + next_type = next_lib["type"] + + if next_type == "t": + first = 0 + elif next_type == "g": + first = next_lib["data"][4] + else: + first = 0 + + # Check whether the difference between the first gradient value + # in the next block and the last value of the new gradient is + # achievable with the specified slew rate. + if ( + abs(first - grad_to_check.stop[1]) + > self.system.max_slew * self.system.grad_raster_time + ): + raise RuntimeError( + "Two consecutive gradients need to have the same amplitude at the connection point" + ) + elif ( + abs(grad_to_check.start[1]) + > self.system.max_slew * self.system.grad_raster_time + ): + raise RuntimeError( + "First gradient in the the first block has to start at 0." + ) + + # Check if gradients, which do not end at 0, are as long as the block itself. + if ( + abs(grad_to_check.stop[1]) + > self.system.max_slew * self.system.grad_raster_time + and abs(grad_to_check.stop[0] - duration) > 1e-7 + ): + raise RuntimeError( + "A gradient that doesn't end at zero needs to be aligned to the block boundary." + ) + + +def grad_check_rot(self, block_index, check_g, duration, rot_event): + """ + Check continuity of adjacent gradient events in presence of rotations. + + Parameters + ---------- + block_index : int + Current block index. + check_g : SimpleNamespace + Structure containing current gradient start and end (t, g) values for each + axis. + duration : float + Current block duration. + rot_event : SimpleNamespace + Current block rotation event. + + Raises + ------ + RuntimeError + If either 1) initial block start with non-zero amplitude; + 2) gradients starting with non-zero amplitude have a delay; + 3) gradients starting with non-zero amplitude have different initial + amplitude value than the previous block at connecting point; + 4) gradients ending with non-zero amplitude are not aligned with block raster; + 4) gradients ending with non-zero amplitude are not aligned have different initial + amplitude value than the next block at connecting point. + + """ + for grad_to_check in check_g.values(): + # Check beginning of gradient event + if ( + abs(grad_to_check.start[1]) + > self.system.max_slew * self.system.grad_raster_time + ): # noqa: SIM102 + if grad_to_check.start[0] > eps: + raise RuntimeError( + "No delay allowed for gradients which start with a non-zero amplitude" + ) + + # Check whether any blocks exist in the sequence + if self.next_free_block_ID > 1: + current_has_rot = rot_event is not None + previous_has_rot = False + next_has_rot = False + + # Rotation extension ID + rot_type_id = self.get_extension_type_ID("ROTATIONS") + + # Get indexes of previous and next blocks + if block_index == self.next_free_block_ID: + # New block inserted + prev_block_index = next(reversed(self.block_events)) + next_block_index = None + else: + blocks = list(self.block_events) + try: + # Existing block overwritten + idx = blocks.index(block_index) + prev_block_index = blocks[idx - 1] if idx > 0 else None + next_block_index = blocks[idx + 1] if idx < len(blocks) - 1 else None + except ValueError: + # Inserting a new block with non-contiguous numbering + prev_block_index = next(reversed(self.block_events)) + next_block_index = None + + # 1) Comparison with previous block + prev_grad_last = np.zeros(3, dtype=np.float32) + if prev_block_index is not None: + # Look up the last gradient value in the previous block + for grad_to_check in check_g.values(): + prev_id = self.block_events[prev_block_index][grad_to_check.idx] + if prev_id != 0: + prev_lib = self.grad_library.get(prev_id) + prev_type = prev_lib["type"] + + # Trapezoids end with zeros, + # so we cannot have the same amplitude + # as the initial value of current block + if prev_type == "g": + prev_grad_last[grad_to_check.idx - 2] = prev_lib["data"][5] + + # Get previous block rotation matrix + ext_id = self.block_events[prev_block_index][-1] + while ext_id and not previous_has_rot: + try: + ext = self.extensions_library.data.get(ext_id) + if ext[0] == rot_type_id: + previous_has_rot = True + previous_rotmat = np.asarray( + self.rotation_library.data.get(ext[1]) + ).reshape((3, 3)) + else: + ext_id = ext[-1] + except KeyError: + ext_id = 0 + + # Rotate previous gradient + if previous_has_rot: + prev_grad_last = previous_rotmat @ prev_grad_last + + # Look up the first gradient value in current block + curr_grad_first = np.zeros(3, dtype=np.float32) + for grad_to_check in check_g.values(): + curr_grad_first[grad_to_check.idx - 2] = grad_to_check.start[1] + + # Rotate current gradient + if current_has_rot: + curr_grad_first = rot_event.rot_matrix @ curr_grad_first + + # Compare current block with previous + if any( + abs(curr_grad_first - prev_grad_last) + > self.system.max_slew * self.system.grad_raster_time + ): + raise RuntimeError( + f"Error in block {block_index}: Two consecutive gradients need to have the same amplitude at the connection point." + ) + + # 2) Comparison with next block + if next_block_index is not None: + + # Look up the last gradient value in the previous block + next_grad_first = np.zeros(3, dtype=np.float32) + for grad_to_check in check_g.values(): + next_id = self.block_events[next_block_index][grad_to_check.idx] + if next_id != 0: + next_lib = self.grad_library.get(next_id) + next_type = next_lib["type"] + + # Trapezoids start with zeros, + # so we cannot have the same amplitude + # as the final value of current block + if next_type == "g": + next_grad_first[grad_to_check.idx - 2] = next_lib["data"][4] + + # Get next block rotation matrix + ext_id = self.block_events[next_block_index][-1] + while ext_id and not next_has_rot: + try: + ext = self.extensions_library.data.get(ext_id) + if ext[0] == rot_type_id: + next_has_rot = True + next_rotmat = np.asarray( + self.rotation_library.data.get(ext[1]) + ).reshape((3, 3)) + else: + ext_id = ext[-1] + except KeyError: + ext_id = 0 + + # Rotate next gradient + if next_has_rot: + next_grad_first = next_rotmat @ next_grad_first + + # Look up the last gradient value in current block + curr_grad_last = np.zeros(3, dtype=np.float32) + for grad_to_check in check_g.values(): + curr_grad_last[grad_to_check.idx - 2] = grad_to_check.stop[1] + + # Rotate current gradient + if current_has_rot: + curr_grad_last = rot_event.rot_matrix @ curr_grad_last + + # Compare current block with next + if any( + abs(curr_grad_last - next_grad_first) + > self.system.max_slew * self.system.grad_raster_time + ): + raise RuntimeError( + f"Error in block {block_index}: Two consecutive gradients need to have the same amplitude at the connection point." + ) + else: + for grad_to_check in check_g.values(): + # Check beginning of gradient event + if ( + abs(grad_to_check.start[1]) + > self.system.max_slew * self.system.grad_raster_time + ): + raise RuntimeError( + "First gradient in the the first block has to start at 0." + ) + + # Check if gradients, which do not end at 0, are as long as the block itself. + for grad_to_check in check_g.values(): + if ( + abs(grad_to_check.stop[1]) + > self.system.max_slew * self.system.grad_raster_time + ): + if abs(grad_to_check.stop[0] - duration) > 1e-7: + raise RuntimeError( + "A gradient that doesn't end at zero needs to be aligned to the block boundary." + ) diff --git a/src/pypulseq/Sequence/read_seq.py b/src/pypulseq/Sequence/read_seq.py index ae313d1f..5de43708 100644 --- a/src/pypulseq/Sequence/read_seq.py +++ b/src/pypulseq/Sequence/read_seq.py @@ -50,6 +50,7 @@ def read(self, path: str, detect_rf_use: bool = False, remove_duplicates: bool = self.rf_library = EventLibrary() self.shape_library = EventLibrary() self.trigger_library = EventLibrary() + self.rotation_library = EventLibrary() # Raster times self.grad_raster_time = self.system.grad_raster_time @@ -166,6 +167,12 @@ def read(self, path: str, detect_rf_use: bool = False, remove_duplicates: bool = extension_id = int(section[18:]) self.set_extension_string_ID('TRIGGERS', extension_id) self.trigger_library = __read_events(input_file, (1, 1, 1e-6, 1e-6), event_library=self.trigger_library) + elif section[:19] == "extension ROTATIONS": + extension_id = int(section[19:]) + self.set_extension_string_ID("ROTATIONS", extension_id) + self.rotation_library = __read_events( + input_file, (1, 1, 1, 1, 1, 1, 1, 1, 1), event_library=self.rotation_library + ) elif section[:18] == 'extension LABELSET': extension_id = int(section[18:]) self.set_extension_string_ID('LABELSET', extension_id) diff --git a/src/pypulseq/Sequence/sequence.py b/src/pypulseq/Sequence/sequence.py index 0d1d249a..213ce268 100644 --- a/src/pypulseq/Sequence/sequence.py +++ b/src/pypulseq/Sequence/sequence.py @@ -34,6 +34,7 @@ from pypulseq.Sequence.write_seq import write as write_seq from pypulseq.supported_labels_rf_use import get_supported_labels from pypulseq.utils.cumsum import cumsum +from pypulseq.utils.rotate_ndarray import rotate_ndarray from pypulseq.utils.tracing import format_trace, trace, trace_enabled major, minor, revision = __version__.split('.')[:3] @@ -68,6 +69,7 @@ def __init__(self, system: Union[Opts, None] = None, use_block_cache: bool = Tru self.rf_library = EventLibrary() self.shape_library = EventLibrary(numpy_data=True) self.trigger_library = EventLibrary() + self.rotation_library = EventLibrary() # Library of rotation events # ========= # OTHER @@ -1025,6 +1027,7 @@ def plot( ) grad_channels = ['gx', 'gy', 'gz'] + waveform = {} for x in range(len(grad_channels)): # Gradients if getattr(block, grad_channels[x], None) is not None: grad = getattr(block, grad_channels[x]) @@ -1032,7 +1035,9 @@ def plot( # We extend the shape by adding the first and the last points in an effort of making the # display a bit less confusing... time = grad.delay + np.array([0, *grad.tt, grad.shape_dur]) - waveform = g_factor * np.array((grad.first, *grad.waveform, grad.last)) + waveform[grad_channels[x]] = g_factor * np.array( + (grad.first, *grad.waveform, grad.last) + ) else: time = np.array( cumsum( @@ -1043,8 +1048,18 @@ def plot( grad.fall_time, ) ) - waveform = g_factor * grad.amplitude * np.array([0, 0, 1, 1, 0]) - fig2_subplots[x].plot(t_factor * (t0 + time), waveform) + waveform[grad_channels[x]] = ( + g_factor * grad.amplitude * np.array([0, 0, 1, 1, 0]) + ) + + # rotate current block gradients + if hasattr(block, "rotation"): + waveform = rotate_ndarray(waveform, block.rotation.rot_matrix) + + for x in range(len(grad_channels)): # Gradients + if grad_channels[x] in waveform: + fig2_subplots[x].plot(t_factor * (t0 + time), waveform[grad_channels[x]]) + t0 += self.block_durations[block_counter] grad_plot_labels = ['x', 'y', 'z'] @@ -1105,6 +1120,9 @@ def register_label_event(self, event: SimpleNamespace) -> int: def register_rf_event(self, event: SimpleNamespace) -> Tuple[int, List[int]]: return block.register_rf_event(self, event) + + def register_rotation_event(self, event: SimpleNamespace) -> Tuple[int, List[int]]: + return block.rotation(self, event) def remove_duplicates(self, in_place: bool = False) -> Self: """ @@ -1430,6 +1448,7 @@ def waveforms(self, append_RF: bool = False, time_range: Union[List[float], None for block_counter in blocks: block = self.get_block(block_counter) + shape_tmp = {} for j in range(len(grad_channels)): grad = getattr(block, grad_channels[j]) @@ -1449,8 +1468,7 @@ def waveforms(self, append_RF: bool = False, time_range: Union[List[float], None # https://github.com/pulseq/pulseq/blob/master/matlab/%2Bmr/restoreAdditionalShapeSamples.m out_len[j] += len(grad.tt) + 2 - shape_pieces[j].append( - np.array( + shape_tmp[grad_channels[j]] = np.array( [ curr_dur + grad.delay @@ -1458,17 +1476,14 @@ def waveforms(self, append_RF: bool = False, time_range: Union[List[float], None np.concatenate(([grad.first], grad.waveform, [grad.last])), ] ) - ) else: # Extended trapezoid out_len[j] += len(grad.tt) - shape_pieces[j].append( - np.array( + shape_tmp[grad_channels[j]] = np.array( [ curr_dur + grad.delay + grad.tt, grad.waveform, ] ) - ) else: if abs(grad.flat_time) > eps: out_len[j] += 4 @@ -1483,7 +1498,7 @@ def waveforms(self, append_RF: bool = False, time_range: Union[List[float], None grad.amplitude * np.array([0, 1, 1, 0]), ) ) - shape_pieces[j].append(_temp) + shape_tmp[grad_channels[j]] = _temp else: if abs(grad.rise_time) > eps and abs(grad.fall_time) > eps: out_len[j] += 3 @@ -1493,7 +1508,7 @@ def waveforms(self, append_RF: bool = False, time_range: Union[List[float], None grad.amplitude * np.array([0, 1, 0]), ) ) - shape_pieces[j].append(_temp) + shape_tmp[grad_channels[j]] = _temp else: if abs(grad.amplitude) > eps: print( @@ -1501,7 +1516,18 @@ def waveforms(self, append_RF: bool = False, time_range: Union[List[float], None block_counter ) ) + + # rotate current block gradients + if hasattr(block, "rotation"): + time_tmp = [shape_tmp[k][0] for k in shape_tmp.keys()][0] + grad_tmp = {k : shape_tmp[k][1] for k in shape_tmp.keys()} + grad_tmp = rotate_ndarray(grad_tmp, block.rotation.rot_matrix) + shape_tmp = {k : np.vstack((time_tmp, grad_tmp[k])) for k in grad_tmp.keys()} + for j in range(len(grad_channels)): + if grad_channels[j] in shape_tmp: + shape_pieces[j].append(shape_tmp[grad_channels[j]]) + if block.rf is not None: # RF rf = block.rf if append_RF: @@ -1671,13 +1697,15 @@ def waveforms_export(self, time_range=(0, np.inf)) -> dict: rf_signal_centers = np.concatenate((rf_signal_centers, [rf[ic]])) grad_channels = ['gx', 'gy', 'gz'] + g_t = {} + g = {} for x in range(len(grad_channels)): # Check each gradient channel: x, y, and z if getattr(block, grad_channels[x]) is not None: # If this channel is on in current block grad = getattr(block, grad_channels[x]) if grad.type == 'grad': # Arbitrary gradient option # In place unpacking of grad.t with the starred expression - g_t = ( + g_t[grad_channels[x]] = ( t0 + grad.delay + [ @@ -1686,26 +1714,34 @@ def waveforms_export(self, time_range=(0, np.inf)) -> dict: grad.t[-1] + grad.t[1] - grad.t[0], ] ) - g = 1e-3 * np.array((grad.first, *grad.waveform, grad.last)) + g[grad_channels[x]] = 1e-3 * np.array((grad.first, *grad.waveform, grad.last)) else: # Trapezoid gradient option - g_t = cumsum( + g_t[grad_channels[x]] = cumsum( t0, grad.delay, grad.rise_time, grad.flat_time, grad.fall_time, ) - g = 1e-3 * grad.amplitude * np.array([0, 0, 1, 1, 0]) - - if grad.channel == 'x': - gx_t_all = np.concatenate((gx_t_all, g_t)) - gx_all = np.concatenate((gx_all, g)) - elif grad.channel == 'y': - gy_t_all = np.concatenate((gy_t_all, g_t)) - gy_all = np.concatenate((gy_all, g)) - elif grad.channel == 'z': - gz_t_all = np.concatenate((gz_t_all, g_t)) - gz_all = np.concatenate((gz_all, g)) + g[grad_channels[x]] = 1e-3 * grad.amplitude * np.array([0, 0, 1, 1, 0]) + + # rotate current block gradients + if hasattr(block, "rotation"): + t_tmp = list(g_t.values())[0] + g = rotate_ndarray(g, block.rotation.rot_matrix) + g_t = {k : t_tmp for k in g.keys()} + + for ch in grad_channels: + if ch in g: + if ch == "gx": + gx_t_all = np.concatenate((gx_t_all, g_t[ch])) + gx_all = np.concatenate((gx_all, g[ch])) + elif ch == "gy": + gy_t_all = np.concatenate((gy_t_all, g_t[ch])) + gy_all = np.concatenate((gy_all, g[ch])) + elif ch == "gz": + gz_t_all = np.concatenate((gz_t_all, g_t[ch])) + gz_all = np.concatenate((gz_all, g[ch])) t0 += self.block_durations[block_counter] # "Current time" gets updated to end of block just examined diff --git a/src/pypulseq/Sequence/write_seq.py b/src/pypulseq/Sequence/write_seq.py index ce9fc911..95931725 100644 --- a/src/pypulseq/Sequence/write_seq.py +++ b/src/pypulseq/Sequence/write_seq.py @@ -186,6 +186,22 @@ def write(self, file_name: Union[str, Path], create_signature, remove_duplicates s = id_format_str.format(k, *np.round(self.trigger_library.data[k] * np.array([1, 1, 1e6, 1e6]))) output_file.write(s) output_file.write('\n') + + if len(self.rotation_library.data) != 0: + output_file.write( + "# Extension specification for rotation events:\n" + ) + output_file.write("# id RotMat[0][0] RotMat[0][1] RotMat[0][2] RotMat[1][0] RotMat[1][1] RotMat[1][2] RotMat[2][0] RotMat[2][1] RotMat[2][2]\n") + output_file.write( + f'extension ROTATIONS {self.get_extension_type_ID("ROTATIONS")}\n' + ) + id_format_str = "{:.0f} {:12g} {:12g} {:12g} {:12g} {:12g} {:12g} {:12g} {:12g} {:12g}\n" # Refer lines 20-21 + for k in self.rotation_library.data: + s = id_format_str.format( + k, *self.rotation_library.data[k] + ) + output_file.write(s) + output_file.write("\n") if len(self.label_set_library.data) != 0: labels = get_supported_labels() diff --git a/src/pypulseq/__init__.py b/src/pypulseq/__init__.py index 58587d08..841a009d 100644 --- a/src/pypulseq/__init__.py +++ b/src/pypulseq/__init__.py @@ -55,6 +55,7 @@ def round_half_up(n, decimals=0): from pypulseq.make_trapezoid import make_trapezoid from pypulseq.sigpy_pulse_opts import SigpyPulseOpts from pypulseq.make_trigger import make_trigger +from pypulseq.make_rotation import make_rotation from pypulseq.opts import Opts from pypulseq.points_to_waveform import points_to_waveform from pypulseq.rotate import rotate diff --git a/src/pypulseq/make_rotation.py b/src/pypulseq/make_rotation.py new file mode 100644 index 00000000..c258f5c8 --- /dev/null +++ b/src/pypulseq/make_rotation.py @@ -0,0 +1,27 @@ +from types import SimpleNamespace + +import numpy as np + +def make_rotation(rot_matrix: np.ndarray) -> SimpleNamespace: + """ + Create a rotation event to instruct the interpreter to rotate + the gx, gy and gz waveforms according to the given rotation matrix. + + See also `pypulseq.Sequence.sequence.Sequence.add_block()`. + + Parameters + ---------- + rot_matrix : np.ndarray + Rotation matrix of shape (3, 3). + + Returns + ------- + rotation : SimpleNamespace + Rotation event. + + """ + rotation = SimpleNamespace() + rotation.type = "rotation" + rotation.rot_matrix = rot_matrix + + return rotation \ No newline at end of file diff --git a/src/pypulseq/utils/rotate_ndarray.py b/src/pypulseq/utils/rotate_ndarray.py new file mode 100644 index 00000000..6970f643 --- /dev/null +++ b/src/pypulseq/utils/rotate_ndarray.py @@ -0,0 +1,44 @@ +# Rotate gradient waveforms according to provided rotation matrix + +import numpy as np +import copy + +def rotate_ndarray(grad, rot_matrix): + grad_channels = ["gx", "gy", "gz"] + grad = copy.deepcopy(grad) + + # get length of gradient waveforms + wave_length = [] + for ch in grad_channels: + if ch in grad: + wave_length.append(len(grad[ch])) + + assert (np.unique(wave_length) != 0).sum() == 1, "All the waveform along different channels must have the same length" + + wave_length = np.unique(wave_length) + wave_length = wave_length[wave_length != 0].item() + + # create zero-filled waveforms for empty gradient channels + for ch in grad_channels: + if ch in grad: + grad[ch] = grad[ch].squeeze() + else: + grad[ch] = np.zeros(wave_length) + + # stack matrix + grad_mat = np.stack((grad["gx"], grad["gy"], grad["gz"]), axis=0) # (3, wave_length) + + # apply rotation + grad_mat = rot_matrix @ grad_mat + + # put back in dictionary + for j in range(3): + ch = grad_channels[j] + grad[ch] = grad_mat[j] + + # remove all zero waveforms + for ch in grad_channels: + if np.allclose(grad[ch], 0.0): + grad.pop(ch) + + return grad \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 96cc9fa5..a392433f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ from pathlib import Path +from types import SimpleNamespace import numpy as np import pytest +from _pytest.python_api import ApproxBase # this is currently not used, but might be useful in the future @@ -38,3 +40,68 @@ def compare(file1, file2): assert line1 == line2 return compare + + +class Approx(ApproxBase): + """ + Extension of pytest.approx that also handles approximate equality + recursively within dicts, lists, tuples, and SimpleNamespace + """ + + def __repr__(self): + return str(self.expected) + + def __eq__(self, actual): + # if type(actual) != type(self.expected): + # return False + if isinstance(self.expected, dict): + if set(self.expected.keys()) != set(actual.keys()): + return False + + for k in self.expected: + if actual[k] != Approx(self.expected[k], rel=self.rel, abs=self.abs, nan_ok=self.nan_ok): + return False + return True + elif isinstance(self.expected, (list, tuple)): + if len(self.expected) != len(actual): + return False + + for e, a in zip(self.expected, actual): + if a != Approx(e, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok): + return False + return True + elif isinstance(self.expected, SimpleNamespace): + return actual.__dict__ == Approx(self.expected.__dict__, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) + else: + return actual == pytest.approx(self.expected, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) + + def _repr_compare(self, actual): + # if type(actual) != type(self.expected): + # return [f'Actual and expected types do not match: {type(actual)} != {type(self.expected)}'] + if isinstance(self.expected, dict): + if set(self.expected.keys()) != set(actual.keys()): + return [f'Actual and expected keys do not match: {set(actual.keys())} != {set(self.expected.keys())}'] + + r = [] + for k in self.expected: + approx_obj = Approx(self.expected[k], rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) + if actual[k] != approx_obj: + r += [f'{k} does not match:'] + r += [f' {x}' for x in approx_obj._repr_compare(actual[k])] + return r + elif isinstance(self.expected, (list, tuple)): + if len(self.expected) != len(actual): + return [f'Actual and expected lengths do not match: {len(actual)} != {len(self.expected)}'] + r = [] + for i, (e, a) in enumerate(zip(self.expected, actual)): + approx_obj = Approx(e, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) + if a != approx_obj: + r += [f'Index {i} does not match:'] + r += [f' {x}' for x in approx_obj._repr_compare(a)] + return r + elif isinstance(self.expected, SimpleNamespace): + return Approx(self.expected.__dict__, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)._repr_compare( + actual.__dict__ + ) + else: + return pytest.approx(self.expected, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)._repr_compare(actual) diff --git a/tests/expected_output/seq_make_radial.seq b/tests/expected_output/seq_make_radial.seq new file mode 100644 index 00000000..f0f3a89c --- /dev/null +++ b/tests/expected_output/seq_make_radial.seq @@ -0,0 +1,88 @@ +# Pulseq sequence file +# Created by PyPulseq + +[VERSION] +major 1 +minor 4 +revision 2 + +[DEFINITIONS] +AdcRasterTime 1e-07 +BlockDurationRaster 1e-05 +GradientRasterTime 1e-05 +RadiofrequencyRasterTime 1e-06 +TotalDuration 0.01098 + +# Format of blocks: +# NUM DUR RF GX GY GZ ADC EXT +[BLOCKS] + 1 100 1 0 0 0 0 0 + 2 83 0 1 0 0 0 1 + 3 100 1 0 0 0 0 0 + 4 83 0 1 0 0 0 2 + 5 100 1 0 0 0 0 0 + 6 83 0 1 0 0 0 3 + 7 100 1 0 0 0 0 0 + 8 83 0 1 0 0 0 4 + 9 100 1 0 0 0 0 0 +10 83 0 1 0 0 0 5 +11 100 1 0 0 0 0 0 +12 83 0 1 0 0 0 1 + +# Format of RF events: +# id amplitude mag_id phase_id time_shape_id delay freq phase +# .. Hz .... .... .... us Hz rad +[RF] +1 250 1 2 3 0 0 0 + +# Format of trapezoid gradients: +# id amplitude rise flat fall delay +# .. Hz/m us us us us +[TRAP] + 1 1.69492e+06 240 350 240 0 + +# Format of extension lists: +# id type ref next_id +# next_id of 0 terminates the list +# Extension list is followed by extension specifications +[EXTENSIONS] +1 1 1 0 +2 1 2 0 +3 1 3 0 +4 1 4 0 +5 1 5 0 + +# Extension specification for rotation events: +# id RotMat[0][0] RotMat[0][1] RotMat[0][2] RotMat[1][0] RotMat[1][1] RotMat[1][2] RotMat[2][0] RotMat[2][1] RotMat[2][2] +extension ROTATIONS 1 +1 1 -0 0 0 1 0 0 0 1 +2 0.866025 -0.5 0 0.5 0.866025 0 0 0 1 +3 0.707107 -0.707107 0 0.707107 0.707107 0 0 0 1 +4 0.5 -0.866025 0 0.866025 0.5 0 0 0 1 +5 6.12323e-17 -1 0 1 6.12323e-17 0 0 0 1 + +# Sequence Shapes +[SHAPES] + +shape_id 1 +num_samples 2 +1 +1 + +shape_id 2 +num_samples 2 +0 +0 + +shape_id 3 +num_samples 2 +0 +1000 + + +[SIGNATURE] +# This is the hash of the Pulseq file, calculated right before the [SIGNATURE] section was added +# It can be reproduced/verified with md5sum if the file trimmed to the position right above [SIGNATURE] +# The new line character preceding [SIGNATURE] BELONGS to the signature (and needs to be stripped away for recalculating/verification) +Type md5 +Hash f876672d1be9882ad32aa6ff5441d8ba diff --git a/tests/test_block.py b/tests/test_block.py index 98865feb..b321160f 100644 --- a/tests/test_block.py +++ b/tests/test_block.py @@ -1,6 +1,8 @@ import pypulseq as pp import pytest +from scipy.spatial.transform import Rotation as R + # Gradient definitions used in tests gx_trap = pp.make_trapezoid('x', area=1000, duration=1e-3) gx_extended = pp.make_extended_trapezoid('x', amplitudes=[0, 100000, 0], times=[0, 1e-4, 2e-4]) @@ -11,6 +13,10 @@ gx_allhigh = pp.make_extended_trapezoid('x', amplitudes=[100000, 100000, 100000], times=[0, 1e-4, 2e-4]) delay = pp.make_delay(1e-3) +# Rotations +rotmat = pp.make_rotation(R.from_euler('z', 90.0, degrees=True).as_matrix()) +eye = pp.make_rotation(R.from_euler('z', 0.0, degrees=True).as_matrix()) + ## Test gradient continuity checks in add_block @@ -171,6 +177,245 @@ def test_gradient_continuity_setblock7(): seq.set_block(7, gx_startshigh) assert list(seq.block_events.keys()) == [10, 5, 7] - + # TODO: Add other block functionality tests + +# Rotations +def test_gradient_continuity_rot1(): + # Trap followed by extended gradient: No error + seq = pp.Sequence() + seq.add_block(gx_trap, eye) + seq.add_block(gx_extended, eye) + seq.add_block(gx_trap, eye) + + +def test_gradient_continuity_rot2(): + # Trap followed by non-zero gradient + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_trap, eye) + seq.add_block(gx_startshigh, eye) # raises + + +def test_gradient_continuity_rot3(): + # Gradient starts at non-zero in first block + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_startshigh, eye) # raises + + +def test_gradient_continuity_rot4(): + # Gradient starts and ends at non-zero + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(delay) + seq.add_block(gx_allhigh, eye) + + +def test_gradient_continuity_rot5(): + # Gradient starts at zero and has a delay: No error + seq = pp.Sequence() + seq.add_block(gx_extended_delay, eye) + + +def test_gradient_continuity_rot6(): + # Gradient starts at non-zero in other blocks + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(delay) + seq.add_block(gx_startshigh, eye) # raises + + +def test_gradient_continuity_rot7(): + # Gradient ends high and is followed by empty block + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(delay) # raises + + +def test_gradient_continuity_rot8(): + # Gradient ends high and is followed by trapezoid + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_trap) # raises + + +def test_gradient_continuity_rot9(): + # Gradient ends high and is followed by connecting gradient: No error + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_startshigh, eye) + + +def test_gradient_continuity_rot10(): + # Gradient in last block ends high: No error, this is caught by seq.write() + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + + +def test_gradient_continuity_rot11(): + # Non-zero, but non-connecting gradients + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_startshigh2, eye) + +def test_gradient_continuity_rot12(): + # Non-zero, both grad are rotated by the same angle: No error + seq = pp.Sequence() + seq.add_block(gx_endshigh, rotmat) + seq.add_block(gx_startshigh, rotmat) + +def test_gradient_continuity_rot13(): + # Non-zero, new grad has different rotation from previous + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_startshigh, rotmat) + +def test_gradient_continuity_rot14(): + # Non-zero, new grad has different rotation from previous + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, rotmat) + seq.add_block(gx_startshigh, eye) + +## Test gradient continuity checks in set_block + + +def test_gradient_continuity_setblock_rot1(): + # Use set_block to insert gradient + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(delay) + seq.add_block(delay) + seq.add_block(delay) + + seq.set_block(1, gx_startshigh, eye) + + +def test_gradient_continuity_setblock_rot2(): + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(delay) + seq.add_block(delay) + seq.add_block(delay) + + seq.set_block(2, gx_startshigh, eye) + + +def test_gradient_continuity_setblock_rot3(): + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(delay) + seq.add_block(delay) + seq.add_block(delay) + + seq.set_block(3, gx_startshigh, eye) + + +def test_gradient_continuity_setblock_rot4(): + # Overwrite valid gradient with empty block + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_allhigh, eye) + seq.add_block(gx_startshigh, eye) + + seq.set_block(2, delay) + + +def test_gradient_continuity_setblock_rot5(): + # Overwrite valid gradient with gradient that is valid on one side + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_allhigh, eye) + seq.add_block(gx_startshigh, eye) + + seq.set_block(2, gx_startshigh, eye) + + +def test_gradient_continuity_setblock_rot6(): + # Add new gradient with non-contiguous block numbering + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_allhigh, eye) + seq.add_block(gx_startshigh, eye) + + seq.set_block(6, gx_startshigh, eye) + + +def test_gradient_continuity_setblock_rot7(): + # Valid sequence with non-contiguous block numbering + seq = pp.Sequence() + seq.set_block(10, gx_endshigh, eye) + seq.set_block(5, gx_allhigh, eye) + seq.set_block(7, gx_startshigh, eye) + + assert list(seq.block_events.keys()) == [10, 5, 7] + + +def test_gradient_continuity_setblock_rot8(): + # Overwrite valid gradient with rotated gradient + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_allhigh, eye) + seq.add_block(gx_startshigh, eye) + + seq.set_block(1, gx_endshigh, rotmat) + +def test_gradient_continuity_setblock_rot9(): + # Overwrite valid gradient with rotated gradient + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_allhigh, eye) + seq.add_block(gx_startshigh, eye) + + seq.set_block(2, gx_allhigh, rotmat) + +def test_gradient_continuity_setblock_rot10(): + # Overwrite valid gradient with rotated gradient + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh, eye) + seq.add_block(gx_allhigh, eye) + seq.add_block(gx_startshigh, eye) + + seq.set_block(3, gx_startshigh, rotmat) + +def test_gradient_continuity_setblock_rot11(): + # Overwrite valid non-rotated gradient with rotated gradient + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh) + seq.add_block(gx_allhigh) + seq.add_block(gx_startshigh) + + seq.set_block(1, gx_endshigh, rotmat) + +def test_gradient_continuity_setblock_rot12(): + # Overwrite valid gradient with rotated gradient + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh) + seq.add_block(gx_allhigh) + seq.add_block(gx_startshigh) + + seq.set_block(2, gx_allhigh, rotmat) + +def test_gradient_continuity_setblock_rot13(): + # Overwrite valid gradient with rotated gradient + with pytest.raises(RuntimeError): + seq = pp.Sequence() + seq.add_block(gx_endshigh) + seq.add_block(gx_allhigh) + seq.add_block(gx_startshigh) + + seq.set_block(3, gx_startshigh, rotmat) diff --git a/tests/test_rotation_extension.py b/tests/test_rotation_extension.py new file mode 100644 index 00000000..f29845c9 --- /dev/null +++ b/tests/test_rotation_extension.py @@ -0,0 +1,384 @@ +import os +import math +from pathlib import Path + +import numpy as np +import numpy.testing as npt + +import pytest +from unittest.mock import patch + +from pypulseq import Sequence +import pypulseq as pp + +from conftest import Approx + + +expected_output_path = Path(__file__).parent / "expected_output" + + +# Rotation Matrix creation routine +def rotation_matrix(angle): + # angle in degrees + theta = np.deg2rad(angle) + + # R[0] = (R[0][0], R[0][1], R[0][2]) + R0 = np.asarray((np.cos(theta), -np.sin(theta), 0.0)) + + # R[1] = (R[1][0], R[1][1], R[1][2]) + R1 = np.asarray((np.sin(theta), np.cos(theta), 0.0)) + + # R[2] = (R[2][0], R[2][1], R[2][2]) + R2 = np.asarray((0.0, 0.0, 1.0)) + + return np.stack((R0, R1, R2), axis=0) + + +# Basic sequence with 0, 30, 45, 60, 90 deg +def seq_make_radial(): + # init sequence + seq = Sequence() + + # init rf pulse + rf = pp.make_block_pulse(math.pi / 2, duration=1e-3) + + # init readout + gread = pp.make_trapezoid("x", area=1000) + + # init angle list + theta = np.asarray((0.0, 30.0, 45.0, 60.0, 90.0)) + rot = [rotation_matrix(th) for th in theta] + + # build sequence + for n in range(len(theta)): + seq.add_block(rf) + + # add readouts + seq.add_block(gread, pp.make_rotation(rot[n])) + + # add 0 again + seq.add_block(rf) + + # add readouts + seq.add_block(gread, pp.make_rotation(rot[0])) + + return seq + + +def seq_make_radial_norotext(): + # init sequence + seq = Sequence() + + # init rf pulse + rf = pp.make_block_pulse(math.pi / 2, duration=1e-3) + + # init readout + gread = pp.make_trapezoid("x", area=1000) + + # init angle list + theta = np.asarray((0.0, 30.0, 45.0, 60.0, 90.0)) + theta = np.deg2rad(theta) + + # build sequence + for n in range(len(theta)): + seq.add_block(rf) + + # add readouts + seq.add_block(*pp.rotate(gread, angle=theta[n], axis="z")) + + # add 0 again + seq.add_block(rf) + + # add readouts + seq.add_block(*pp.rotate(gread, angle=theta[0], axis="z")) + + return seq + + +# Test utils.rotate actually performs rotation (for plotting) +def test_rotate_utils(): + wavein = {"gx": np.linspace(-10, 10, 5)} + + # rotate by 0 deg + R = rotation_matrix(0.0) + waveout = pp.utils.rotate_ndarray.rotate_ndarray(wavein, R) + npt.assert_allclose(wavein["gx"], waveout["gx"]) + assert "gy" not in waveout + assert "gz" not in waveout + + # rotate by 45 deg + R = rotation_matrix(45.0) + waveout = pp.utils.rotate_ndarray.rotate_ndarray(wavein, R) + npt.assert_allclose(wavein["gx"] * (2**0.5 / 2), waveout["gx"]) + npt.assert_allclose(wavein["gx"] * (2**0.5 / 2), waveout["gy"]) + assert "gz" not in waveout + + # rotate by 90 deg + R = rotation_matrix(90.0) + waveout = pp.utils.rotate_ndarray.rotate_ndarray(wavein, R) + assert "gx" not in waveout + npt.assert_allclose(wavein["gx"], waveout["gy"]) + assert "gz" not in waveout + + +# Test sequence +def test_sequence(): + seq = seq_make_radial() + blocks = np.stack(list(seq.block_events.values()), axis=1) + + # check rf + rf = ( + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + ) # alternate between no rf and gaussian pulse + npt.assert_allclose(blocks[1], rf) + + # check gradients + g = ( + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + ) # alternate between no readout and radial spoke + npt.assert_allclose(blocks[2], g) + + # check extension (we have 0=no rot, 1=0.0deg, 2=30.0deg, 3=45.0deg, 4=60.0deg, 5=90.0deg) + ext = ( + 0, + 1, + 0, + 2, + 0, + 3, + 0, + 4, + 0, + 5, + 0, + 1, + ) # last event re-use ROTATION[1] = 90deg rotation about z + npt.assert_allclose(blocks[6], ext) + + # verify that the only extension is ROTATIONS (id=1, string=[ROTATIONS]) + assert len(seq.extension_numeric_idx) == 1 + assert seq.extension_numeric_idx[0] == 1 + assert len(seq.extension_string_idx) == 1 + assert seq.extension_string_idx[0] == "ROTATIONS" + + assert len(seq.extensions_library.data) == 5 + npt.assert_allclose(seq.extensions_library.data[1], (1, 1, 0)) + npt.assert_allclose(seq.extensions_library.data[2], (1, 2, 0)) + npt.assert_allclose(seq.extensions_library.data[3], (1, 3, 0)) + npt.assert_allclose(seq.extensions_library.data[4], (1, 4, 0)) + npt.assert_allclose(seq.extensions_library.data[5], (1, 5, 0)) + + # verify that rotation_events 1-5 contains the correct matrix + for n in (1, 3, 5, 7, 9, 11): + b = seq.get_block(n) + assert b.rf is not None + assert b.gx is None + assert b.gy is None + assert b.gz is None + assert hasattr(b, "rotation") is False + + for n in (2, 4, 6, 8, 10, 12): + b = seq.get_block(n) + assert b.rf is None + assert b.gx is not None + assert b.gy is None + assert b.gz is None + assert hasattr(b, "rotation") is True + + npt.assert_allclose(seq.get_block(2).rotation.rot_matrix, rotation_matrix(0.0)) + npt.assert_allclose(seq.get_block(4).rotation.rot_matrix, rotation_matrix(30.0)) + npt.assert_allclose(seq.get_block(6).rotation.rot_matrix, rotation_matrix(45.0)) + npt.assert_allclose(seq.get_block(8).rotation.rot_matrix, rotation_matrix(60.0)) + npt.assert_allclose(seq.get_block(10).rotation.rot_matrix, rotation_matrix(90.0)) + npt.assert_allclose(seq.get_block(12).rotation.rot_matrix, rotation_matrix(0.0)) + + +# Test again explicit gradient rotation +def test_vs_rotate(): + seq = seq_make_radial() + seq2 = seq_make_radial_norotext() + + # test waveforms() + waveforms = seq.waveforms() + waveforms2 = seq2.waveforms() + + assert len(waveforms) == len(waveforms2) + for n in range(len(waveforms)): + npt.assert_allclose(waveforms[n], waveforms2[n]) + + # test waveforms_and_times() + waveforms_and_times = seq.waveforms_and_times() + waveforms_and_times2 = seq2.waveforms_and_times() + + assert len(waveforms_and_times) == len(waveforms_and_times2) + for n in range(len(waveforms_and_times)): + assert len(waveforms_and_times[n]) == len(waveforms_and_times[n]) + for m in range(len(waveforms_and_times[n])): + npt.assert_allclose(waveforms_and_times[n][m], waveforms_and_times2[n][m]) + + # test waveform_export + waveforms_export = seq.waveforms_export() + waveforms_export2 = seq2.waveforms_export() + + for k in waveforms_export.keys(): + if type(waveforms_export[k]) == np.ndarray: + npt.assert_allclose(waveforms_export[k], waveforms_export2[k]) + else: + assert waveforms_export[k] == waveforms_export2[k] + + # test k-space + kspace = seq.calculate_kspace() + kspace2 = seq2.calculate_kspace() + + for n in range(len(kspace)): + npt.assert_allclose(kspace[n], kspace2[n]) + + +# This "test" rewrites the expected .seq output files when SAVE_EXPECTED is +# set in the environment variables. +# E.g. in a unix-based system, run: SAVE_EXPECTED=1 pytest test_sequence.py +@pytest.mark.skipif( + not os.environ.get("SAVE_EXPECTED"), + reason="Only save sequence files when requested", +) +def test_sequence_save_expected(): + + # Generate sequence and write to file + seq = seq_make_radial() + seq.write(expected_output_path / "seq_make_radial.seq") + + +# Test whether a sequence can be plotted. +@patch("matplotlib.pyplot.show") +def test_plot(mock_show): + seq = seq_make_radial() + + seq.plot() + seq.plot(show_blocks=True) + + +# Test whether the sequence is the approximately the same after writing a .seq +# file and reading it back in. +def test_sequence_writeread(tmp_path, compare_seq_file): + output_filename = tmp_path / "seq_make_radial.seq" + + # Generate sequence + seq = seq_make_radial() + + # Write sequence to file + seq.write(output_filename) + + # Check if written sequence file matches expected sequence file + compare_seq_file(output_filename, expected_output_path / "seq_make_radial.seq") + + # Read written sequence file back in + seq2 = pp.Sequence(system=seq.system) + seq2.read(output_filename) + + # Clean up written sequence file + output_filename.unlink() + + # Test for approximate equality of all blocks + assert set(seq2.block_events.keys()) == set(seq.block_events.keys()) + for block_counter in seq.block_events: + assert seq2.get_block(block_counter) == Approx( + seq.get_block(block_counter), abs=1e-6, rel=1e-5 + ), f"Block {block_counter} does not match" + + # Test for approximate equality of all gradient waveforms + for a, b in zip(seq2.get_gradients(), seq.get_gradients()): + if a == None and b == None: + continue + if a == None or b == None: + assert False + + assert a.x == Approx(b.x, abs=1e-3, rel=1e-3) + assert a.c == Approx(b.c, abs=1e-3, rel=1e-3) + + # Test for approximate equality of kspace calculation + assert seq2.calculate_kspace() == Approx( + seq.calculate_kspace(), abs=1e-2, nan_ok=True + ) + + # Test whether labels are the same + labels_seq = seq.evaluate_labels(evolution="blocks") + labels_seq2 = seq2.evaluate_labels(evolution="blocks") + + assert ( + labels_seq.keys() == labels_seq2.keys() + ), "Sequences do not contain the same set of labels" + + for label in labels_seq: + assert ( + labels_seq[label] == labels_seq2[label] + ).all(), f"Label {label} does not match" + + +# Test whether the sequence is approximately the same after recreating it by +# getting all blocks with get_block and inserting them into a new sequence +# with add_block. +def test_sequence_recreate(tmp_path): + # Generate sequence + seq = seq_make_radial() + + # Insert blocks from sequence into a new sequence + seq2 = pp.Sequence(system=seq.system) + for b in seq.block_events: + seq2.add_block(seq.get_block(b)) + + # Test for approximate equality of all blocks + assert set(seq2.block_events.keys()) == set(seq.block_events.keys()) + for block_counter in seq.block_events: + assert seq2.get_block(block_counter) == Approx( + seq.get_block(block_counter), abs=1e-6, rel=1e-5 + ), f"Block {block_counter} does not match" + + # Test for approximate equality of all gradient waveforms + for a, b in zip(seq2.get_gradients(), seq.get_gradients()): + if a == None and b == None: + continue + if a == None or b == None: + assert False + + assert a.x == Approx(b.x, abs=1e-4, rel=1e-4) + assert a.c == Approx(b.c, abs=1e-4, rel=1e-4) + + # Test for approximate equality of kspace calculation + assert seq2.calculate_kspace() == Approx( + seq.calculate_kspace(), abs=1e-6, nan_ok=True + ) + + # Test whether labels are the same + labels_seq = seq.evaluate_labels(evolution="blocks") + labels_seq2 = seq2.evaluate_labels(evolution="blocks") + + assert ( + labels_seq.keys() == labels_seq2.keys() + ), "Sequences do not contain the same set of labels" + + for label in labels_seq: + assert ( + labels_seq[label] == labels_seq2[label] + ).all(), f"Label {label} does not match" diff --git a/tests/test_sequence.py b/tests/test_sequence.py index f75b7414..2fe763df 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,81 +2,17 @@ import math import os from pathlib import Path -from types import SimpleNamespace from unittest.mock import patch import matplotlib.pyplot as plt import pypulseq as pp import pytest -from _pytest.python_api import ApproxBase -from pypulseq import Sequence -expected_output_path = Path(__file__).parent / 'expected_output' +from pypulseq import Sequence +from conftest import Approx -class Approx(ApproxBase): - """ - Extension of pytest.approx that also handles approximate equality - recursively within dicts, lists, tuples, and SimpleNamespace - """ - - def __repr__(self): - return str(self.expected) - - def __eq__(self, actual): - # if type(actual) != type(self.expected): - # return False - if isinstance(self.expected, dict): - if set(self.expected.keys()) != set(actual.keys()): - return False - - for k in self.expected: - if actual[k] != Approx(self.expected[k], rel=self.rel, abs=self.abs, nan_ok=self.nan_ok): - return False - return True - elif isinstance(self.expected, (list, tuple)): - if len(self.expected) != len(actual): - return False - - for e, a in zip(self.expected, actual): - if a != Approx(e, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok): - return False - return True - elif isinstance(self.expected, SimpleNamespace): - return actual.__dict__ == Approx(self.expected.__dict__, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) - else: - return actual == pytest.approx(self.expected, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) - - def _repr_compare(self, actual): - # if type(actual) != type(self.expected): - # return [f'Actual and expected types do not match: {type(actual)} != {type(self.expected)}'] - if isinstance(self.expected, dict): - if set(self.expected.keys()) != set(actual.keys()): - return [f'Actual and expected keys do not match: {set(actual.keys())} != {set(self.expected.keys())}'] - - r = [] - for k in self.expected: - approx_obj = Approx(self.expected[k], rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) - if actual[k] != approx_obj: - r += [f'{k} does not match:'] - r += [f' {x}' for x in approx_obj._repr_compare(actual[k])] - return r - elif isinstance(self.expected, (list, tuple)): - if len(self.expected) != len(actual): - return [f'Actual and expected lengths do not match: {len(actual)} != {len(self.expected)}'] - r = [] - for i, (e, a) in enumerate(zip(self.expected, actual)): - approx_obj = Approx(e, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) - if a != approx_obj: - r += [f'Index {i} does not match:'] - r += [f' {x}' for x in approx_obj._repr_compare(a)] - return r - elif isinstance(self.expected, SimpleNamespace): - return Approx(self.expected.__dict__, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)._repr_compare( - actual.__dict__ - ) - else: - return pytest.approx(self.expected, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)._repr_compare(actual) +expected_output_path = Path(__file__).parent / 'expected_output' # Dummy sequence which contains only gaussian pulses with different parameters