Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select a range of rows in a DataTable #3821

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/textual/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,15 @@ def style(self) -> Style:
def style(self, style: Style) -> None:
self._style = style

@property
def has_modifier(self) -> bool:
"""Check if any modifier keys are pressed.

Returns:
True if any modifier keys are pressed.
"""
return self.shift or self.meta or self.ctrl

def get_content_offset(self, widget: Widget) -> Offset | None:
"""Get offset within a widget's content area, or None if offset is not in content (i.e. padding or border).

Expand Down
165 changes: 116 additions & 49 deletions src/textual/widgets/_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from rich.segment import Segment
from rich.style import Style
from rich.text import Text, TextType
from typing_extensions import Literal, Self, TypeAlias
from typing_extensions import Literal, Self, TypeAlias, Union

from .. import events
from .._segment_tools import line_crop
Expand Down Expand Up @@ -335,6 +335,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
)
"""The coordinate of the `DataTable` that is being hovered."""

row_range_coordinates: Reactive[range[int] | None] = Reactive(
None, repaint=False, always_update=True
)
"""If a row range is selected, this is start and ending rows."""

class CellHighlighted(Message):
"""Posted when the cursor moves to highlight a new cell.

Expand Down Expand Up @@ -465,6 +470,35 @@ def control(self) -> DataTable:
"""Alias for the data table."""
return self.data_table

class RowsSelected(Message):
"""Posted when a range of rows are selected.

This message is only posted when the
`cursor_type` is set to `"row"`. Can be handled using
`on_data_table_rows_selected` in a subclass of `DataTable` or in a parent
widget in the DOM.
"""

def __init__(
self, data_table: DataTable, rows: range[int], row_keys: list[RowKey]
) -> None:
self.data_table = data_table
"""The data table."""
self.rows: range[int] = rows
"""The y-coordinate of the cursor that made the selection."""
self.row_keys: list[RowKey] = row_keys
"""The key of the row that was selected."""
super().__init__()

def __rich_repr__(self) -> rich.repr.Result:
yield "cursor_rows", self.cursor_rows
yield "row_keys", self.row_keys

@property
def control(self) -> DataTable:
"""Alias for the data table."""
return self.data_table

class ColumnHighlighted(Message):
"""Posted when a column is highlighted.

Expand Down Expand Up @@ -1096,10 +1130,19 @@ def watch_cursor_coordinate(
# scrolling because it may be animated.
self.call_after_refresh(self._scroll_cursor_into_view)

def watch_row_range_coordinates(
self, old_coordinates: range[int] | None, new_coordinates: range[int] | None
) -> None:
if old_coordinates != new_coordinates:
for row_index in old_coordinates or []:
self.refresh_row(row_index)
for row_index in new_coordinates or []:
self.refresh_row(row_index)

def move_cursor(
self,
*,
row: int | None = None,
row: int | range[int] | None = None,
column: int | None = None,
animate: bool = False,
) -> None:
Expand All @@ -1122,7 +1165,11 @@ def move_cursor(

cursor_row, cursor_column = self.cursor_coordinate
if row is not None:
cursor_row = row
if isinstance(row, range):
cursor_row = row.stop
self.row_range_coordinates = row
else:
cursor_row = row
if column is not None:
cursor_column = column
destination = Coordinate(cursor_row, cursor_column)
Expand Down Expand Up @@ -1308,10 +1355,6 @@ def _update_dimensions(self, new_rows: Iterable[RowKey]) -> None:
# so that we can cache this rendering for later.
if auto_height_rows:
render_cell = self._render_cell # This method renders & caches.
should_highlight = self._should_highlight
cursor_type = self.cursor_type
cursor_location = self.cursor_coordinate
hover_location = self.hover_coordinate
Comment on lines -1311 to -1314
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to be straight with you: this PR will need to be reviewed by Darren and/or Will but changes like this decrease the probability that it gets merged.

I see you deleted these assignments from here and then changed things like _render_cell to check the hover or cursor type there.
This is a small optimisation to reduce the number of attribute lookups: here, we're checking the cursor type once.
If you do this inside _render_cell, you'll access that attribute once for each new row.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay thanks for the heads up.
I isolated this refactor to 4354760, so it should be reasonably easy to test performance before/after. I'm curious how much of a different it will make.

To be sure: I think I'm doing the same number of comparisons, it's just e.g. self.cursor_type == "row" instead of type_of_cursor == "row".

base_style = self.rich_style
fixed_style = self.get_component_styles(
"datatable--fixed"
Expand All @@ -1328,19 +1371,10 @@ def _update_dimensions(self, new_rows: Iterable[RowKey]) -> None:
# that were rendered with the wrong height and append the missing padding.
rendered_cells: list[tuple[SegmentLines, int, int]] = []
for column_index, column in enumerate(ordered_columns):
style = fixed_style if column_index < fixed_columns else row_style
cell_location = Coordinate(row_index, column_index)
style = fixed_style if column_index < fixed_columns else row_style
rendered_cell = render_cell(
row_index,
column_index,
style,
column.get_render_width(self),
cursor=should_highlight(
cursor_location, cell_location, cursor_type
),
hover=should_highlight(
hover_location, cell_location, cursor_type
),
cell_location, style, column.get_render_width(self)
)
cell_height = len(rendered_cell)
rendered_cells.append(
Expand Down Expand Up @@ -1887,12 +1921,9 @@ def _get_row_renderables(self, row_index: int) -> RowRenderables:

def _render_cell(
self,
row_index: int,
column_index: int,
location: Coordinate,
base_style: Style,
width: int,
cursor: bool = False,
hover: bool = False,
) -> SegmentLines:
"""Render the given cell.

Expand All @@ -1907,6 +1938,8 @@ def _render_cell(
Returns:
A list of segments per line.
"""
row_index, column_index = location

is_header_cell = row_index == -1
is_row_label_cell = column_index == -1

Expand All @@ -1921,6 +1954,9 @@ def _render_cell(
else:
row_key = self._row_locations.get_key(row_index)

cursor = self._should_highlight("cursor", location)
hover = self._should_highlight("hover", location)

column_key = self._column_locations.get_key(column_index)
cell_cache_key: CellCacheKey = (
row_key,
Expand Down Expand Up @@ -2080,7 +2116,6 @@ def _render_line_in_row(
if cache_key in self._row_render_cache:
return self._row_render_cache[cache_key]

should_highlight = self._should_highlight
render_cell = self._render_cell
header_style = self.get_component_styles("datatable--header").rich_style

Expand All @@ -2096,12 +2131,9 @@ def _render_line_in_row(
# The width of the row label is updated again on idle
cell_location = Coordinate(row_index, -1)
label_cell_lines = render_cell(
row_index,
-1,
cell_location,
header_style,
width=self._row_label_column_width,
cursor=should_highlight(cursor_location, cell_location, cursor_type),
hover=should_highlight(hover_location, cell_location, cursor_type),
)[line_no]
fixed_row.append(label_cell_lines)

Expand All @@ -2116,14 +2148,9 @@ def _render_line_in_row(
):
cell_location = Coordinate(row_index, column_index)
fixed_cell_lines = render_cell(
row_index,
column_index,
cell_location,
fixed_style,
column.get_render_width(self),
cursor=should_highlight(
cursor_location, cell_location, cursor_type
),
hover=should_highlight(hover_location, cell_location, cursor_type),
)[line_no]
fixed_row.append(fixed_cell_lines)

Expand All @@ -2133,12 +2160,9 @@ def _render_line_in_row(
for column_index, column in enumerate(self.ordered_columns):
cell_location = Coordinate(row_index, column_index)
cell_lines = render_cell(
row_index,
column_index,
cell_location,
row_style,
column.get_render_width(self),
cursor=should_highlight(cursor_location, cell_location, cursor_type),
hover=should_highlight(hover_location, cell_location, cursor_type),
)[line_no]
scrollable_row.append(cell_lines)

Expand Down Expand Up @@ -2269,9 +2293,8 @@ def render_line(self, y: int) -> Strip:

def _should_highlight(
self,
cursor: Coordinate,
what: Union[Literal["cursor"], Literal["hover"]],
target_cell: Coordinate,
type_of_cursor: CursorType,
) -> bool:
"""Determine if the given cell should be highlighted because of the cursor.

Expand All @@ -2286,13 +2309,19 @@ def _should_highlight(
Returns:
Whether or not the given cell should be highlighted.
"""
if type_of_cursor == "cell":
cursor = self.cursor_coordinate if what == "cursor" else self.hover_coordinate

if self.cursor_type == "cell":
return cursor == target_cell
elif type_of_cursor == "row":
cursor_row, _ = cursor
elif self.cursor_type == "row":
row_range_coordinates = self.row_range_coordinates
cell_row, _ = target_cell
return cursor_row == cell_row
elif type_of_cursor == "column":
if row_range_coordinates is None:
cursor_row, _ = cursor
return cursor_row == cell_row
else:
return cell_row in row_range_coordinates
elif self.cursor_type == "column":
_, cursor_column = cursor
_, cell_column = target_cell
return cursor_column == cell_column
Expand Down Expand Up @@ -2406,8 +2435,22 @@ def _scroll_cursor_into_view(self, animate: bool = False) -> None:
top, _, _, left = fixed_offset

if self.cursor_type == "row":
x, y, width, height = self._get_row_region(self.cursor_row)
region = Region(int(self.scroll_x) + left, y, width - left, height)
if self.row_range_coordinates is None:
x, y, width, height = self._get_row_region(self.cursor_row)
region = Region(int(self.scroll_x) + left, y, width - left, height)
else:
x, y, width, height = self._get_row_region(self.cursor_row)
start_row = self.row_range_coordinates.start
stop_row = self.row_range_coordinates.stop
x_start, y_start, width_start, width_end = self._get_row_region(
start_row
)
x_stop, y_stop, width_stop, height_stop = self._get_row_region(stop_row)
height = y_stop + height_stop - y_start
region = Region(
int(self.scroll_x) + left, y_start, width - left, height
)

elif self.cursor_type == "column":
x, y, width, height = self._get_column_region(self.cursor_column)
region = Region(x, int(self.scroll_y) + top, width, height - top)
Expand Down Expand Up @@ -2458,6 +2501,7 @@ async def _on_click(self, event: events.Click) -> None:
)
self.post_message(message)
elif self.show_cursor and self.cursor_type != "none":
self._update_row_range_coordinates(event, row_index)
# Only post selection events if there is a visible row/col/cell cursor.
self.cursor_coordinate = Coordinate(row_index, column_index)
self._post_selected_message()
Expand Down Expand Up @@ -2586,10 +2630,33 @@ def _post_selected_message(self):
)
)
elif cursor_type == "row":
row_index, _ = cursor_coordinate
row_key, _ = cell_key
self.post_message(DataTable.RowSelected(self, row_index, row_key))
if self.row_range_coordinates is None:
row_index, _ = cursor_coordinate
row_key, _ = cell_key
self.post_message(DataTable.RowSelected(self, row_index, row_key))
else:
row_keys: list[RowKey] = []
for row_index in self.row_range_coordinates:
row_key = self._row_locations.get_key(row_index)
row_keys.append(row_key)
self.post_message(DataTable.RowSelected(self, row_index, row_key))

self.post_message(
DataTable.RowsSelected(self, self.row_range_coordinates, row_keys)
)

elif cursor_type == "column":
_, column_index = cursor_coordinate
_, column_key = cell_key
self.post_message(DataTable.ColumnSelected(self, column_index, column_key))

def _update_row_range_coordinates(
self, event: events.Click, new_row_index: int
) -> None:
if self.cursor_type != "row" or not event.has_modifier:
self.row_range_coordinates = None
return

first = self.cursor_coordinate.row
second = new_row_index
self.row_range_coordinates = range(min(first, second), max(first, second) + 1)
Loading
Loading