Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 44 additions & 73 deletions src/pypulseq/Sequence/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pypulseq.compress_shape import compress_shape
from pypulseq.decompress_shape import decompress_shape
from pypulseq.event_lib import EventLibrary
from pypulseq.Sequence.grad_check import grad_check
from pypulseq.supported_labels_rf_use import get_supported_labels
from pypulseq.utils.tracing import trace_enabled

Expand Down Expand Up @@ -52,6 +53,7 @@ def set_block(self, block_index: int, *args: SimpleNamespace) -> None:
} # Key-value mapping of index and pairs of gradients/times
extensions = []

rot_event = None
for event in events:
if not isinstance(event, float): # If event is not a block duration
if event.type == 'rf':
Expand Down Expand Up @@ -139,6 +141,18 @@ def set_block(self, block_index: int, *args: SimpleNamespace) -> None:
ext = {'type': self.get_extension_type_ID('TRIGGERS'), 'ref': event_id}
extensions.append(ext)
duration = max(duration, event.delay + event.duration)
elif event.type == 'rotation':
rot_event = event
if hasattr(event, 'id'):
event_id = event.id
else:
event_id = register_rotation_event(self, event)

ext = {
'type': self.get_extension_type_ID('ROTATIONS'),
'ref': event_id
}
extensions.append(ext)
elif event.type in ['labelset', 'labelinc']:
if hasattr(event, 'id'):
label_id = event.id
Expand Down Expand Up @@ -190,79 +204,8 @@ def set_block(self, block_index: int, *args: SimpleNamespace) -> None:
# =========
# PERFORM GRADIENT CHECKS
# =========
for grad_to_check in check_g.values():
if abs(grad_to_check.start[1]) > self.system.max_slew * self.system.grad_raster_time: # noqa: SIM102
if grad_to_check.start[0] > eps:
raise RuntimeError('No delay allowed for gradients which start with a non-zero amplitude')

# Check whether any blocks exist in the sequence
if self.next_free_block_ID > 1:
# Look up the previous block (and the next block in case of a set_block call)
if block_index == self.next_free_block_ID:
# New block inserted
prev_block_index = next(reversed(self.block_events))
next_block_index = None
else:
blocks = list(self.block_events)
try:
# Existing block overwritten
idx = blocks.index(block_index)
prev_block_index = blocks[idx - 1] if idx > 0 else None
next_block_index = blocks[idx + 1] if idx < len(blocks) - 1 else None
except ValueError:
# Inserting a new block with non-contiguous numbering
prev_block_index = next(reversed(self.block_events))
next_block_index = None

# Look up the last gradient value in the previous block
last = 0
if prev_block_index is not None:
prev_id = self.block_events[prev_block_index][grad_to_check.idx]
if prev_id != 0:
prev_lib = self.grad_library.get(prev_id)
prev_type = prev_lib['type']

if prev_type == 't':
last = 0
elif prev_type == 'g':
last = prev_lib['data'][5]

# Check whether the difference between the last gradient value and
# the first value of the new gradient is achievable with the
# specified slew rate.
if abs(last - grad_to_check.start[1]) > self.system.max_slew * self.system.grad_raster_time:
raise RuntimeError('Two consecutive gradients need to have the same amplitude at the connection point')

# Look up the first gradient value in the next block
# (this only happens when using set_block to patch a block)
if next_block_index is not None:
next_id = self.block_events[next_block_index][grad_to_check.idx]
if next_id != 0:
next_lib = self.grad_library.get(next_id)
next_type = next_lib['type']

if next_type == 't':
first = 0
elif next_type == 'g':
first = next_lib['data'][4]
else:
first = 0

# Check whether the difference between the first gradient value
# in the next block and the last value of the new gradient is
# achievable with the specified slew rate.
if abs(first - grad_to_check.stop[1]) > self.system.max_slew * self.system.grad_raster_time:
raise RuntimeError(
'Two consecutive gradients need to have the same amplitude at the connection point'
)
elif abs(grad_to_check.start[1]) > self.system.max_slew * self.system.grad_raster_time:
raise RuntimeError('First gradient in the the first block has to start at 0.')

if (
grad_to_check.stop[1] > self.system.max_slew * self.system.grad_raster_time
and abs(grad_to_check.stop[0] - duration) > 1e-7
):
raise RuntimeError("A gradient that doesn't end at zero needs to be aligned to the block boundary.")
grad_check(self, block_index, check_g, duration, rot_event)


self.block_events[block_index] = new_block
self.block_durations[block_index] = float(duration)
Expand Down Expand Up @@ -410,6 +353,12 @@ def get_block(self, block_index: int) -> SimpleNamespace:
block.trigger[len(block.trigger)] = trigger
else:
block.trigger = {0: trigger}
elif ext_type == "ROTATIONS":
data = self.rotation_library.data[ext_data[1]]
rotation = SimpleNamespace()
rotation.type = "rotation"
rotation.rot_matrix = np.asarray(data).reshape(3, 3)
block.rotation = rotation
elif ext_type in ['LABELSET', 'LABELINC']:
label = SimpleNamespace()
label.type = ext_type.lower()
Expand Down Expand Up @@ -687,3 +636,25 @@ def register_rf_event(self, event: SimpleNamespace) -> Tuple[int, List[int]]:
rf_id = self.rf_library.insert(key_id=0, new_data=data, data_type=use)

return rf_id, shape_IDs


def register_rotation_event(self, event: EventLibrary) -> int:
"""
Parameters
----------
event : SimpleNamespace
Rotation event to be registered.
Returns
-------
int
ID of registered rotation event.
"""
data = tuple(event.rot_matrix.ravel().tolist())
rotation_id, found = self.rotation_library.find_or_insert(new_data=data)

# Clear block cache because Rotation was overwritten
# TODO: Could find only the blocks that are affected by the changes
if self.use_block_cache and found:
self.block_cache.clear()

return rotation_id
Loading
Loading