From 466da3eb5c5b21d33603aba15b53e4c39b2da3ef Mon Sep 17 00:00:00 2001 From: Ashwin Patil Date: Thu, 4 Aug 2022 19:24:17 +0000 Subject: [PATCH 1/5] initial version of heatmap --- .../queries/mssentinel/kql_sent_heatmap.yaml | 76 +++++++ msticpy/vis/heatmap.py | 196 ++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 msticpy/data/queries/mssentinel/kql_sent_heatmap.yaml create mode 100644 msticpy/vis/heatmap.py diff --git a/msticpy/data/queries/mssentinel/kql_sent_heatmap.yaml b/msticpy/data/queries/mssentinel/kql_sent_heatmap.yaml new file mode 100644 index 000000000..d276a40c5 --- /dev/null +++ b/msticpy/data/queries/mssentinel/kql_sent_heatmap.yaml @@ -0,0 +1,76 @@ +metadata: + version: 1 + description: Heatmap - Query to return weekly data of various data sources which can be used to plot heatmap + data_environments: [LogAnalytics] + data_families: ['MultiDataSource'] + tags: ['Heatmap','host','ip','alert','network'] +defaults: + parameters: + table: + description: Table name + type: str + end: + description: Query end time + type: datetime + where_clause: + description: Optional additional filter clauses + type: str + default: '' + add_query_items: + description: Additional query clauses + type: str + default: '' + timestampcolumn: + description: Timestamp field to use from source dataset + type: str + default: 'TimeGenerated' + aggregatefunction: + description: Aggregation functions to use - count(), sum(), avg() etc + type: str + default: 'count()' +sources: + get_weekly_heatmap_unpivot: + description: Retrieves data across the week of a given table in 4 column unpivot view which can also be used to plot heatmap + args: + query: ' + let end = datetime({end}); + let start = end - 7d; + {table} + {where_clause} + | where {timestampcolumn} >= startofday(start) + | where {timestampcolumn} <= startofday(end) + | extend HourOfLogin = hourofday({timestampcolumn}), DayNumberofWeek = dayofweek({timestampcolumn}) , Date = format_datetime(TimeGenerated, "yyyy-MM-dd") + | extend DayofWeek = case( + DayNumberofWeek == "00:00:00", "Sunday", + DayNumberofWeek == "1.00:00:00", "Monday", + DayNumberofWeek == "2.00:00:00", "Tuesday", + DayNumberofWeek == "3.00:00:00", "Wednesday", + DayNumberofWeek == "4.00:00:00", "Thursday", + DayNumberofWeek == "5.00:00:00", "Friday", + DayNumberofWeek == "6.00:00:00", "Saturday","InvalidTimeStamp") + | summarize Total={aggregatefunction} by Date, DayofWeek, HourOfLogin + | sort by Date asc, HourOfLogin asc + {add_query_items}' + get_weekly_heatmap_pivot: + description: Retrieves data across the week of a given table in Pivot view which can be used to plot heatmap + args: + query: ' + let end = datetime({end}); + let start = end - 7d; + {table} + {where_clause} + | where {timestampcolumn} >= startofday(start) + | where {timestampcolumn} <= startofday(end) + | extend HourOfLogin = toint(hourofday({timestampcolumn})), DayNumberofWeek = dayofweek({timestampcolumn}) , Date = format_datetime(TimeGenerated, "yyyy-MM-dd") + | extend DayofWeek = case( + DayNumberofWeek == "00:00:00", "Sunday", + DayNumberofWeek == "1.00:00:00", "Monday", + DayNumberofWeek == "2.00:00:00", "Tuesday", + DayNumberofWeek == "3.00:00:00", "Wednesday", + DayNumberofWeek == "4.00:00:00", "Thursday", + DayNumberofWeek == "5.00:00:00", "Friday", + DayNumberofWeek == "6.00:00:00", "Saturday","InvalidTimeStamp") + | evaluate pivot(HourOfLogin, {aggregatefunction}, DayofWeek, Date) + | project-reorder Date, DayofWeek, * granny-asc + | sort by Date asc + {add_query_items}' diff --git a/msticpy/vis/heatmap.py b/msticpy/vis/heatmap.py new file mode 100644 index 000000000..b92ec8685 --- /dev/null +++ b/msticpy/vis/heatmap.py @@ -0,0 +1,196 @@ +"""Bokeh heatmap plot.""" +from math import pi +from typing import List, Optional, Union + +import attr +import numpy as np +import pandas as pd +from bokeh.io import output_notebook, reset_output, show +from bokeh.layouts import row +from bokeh.models import HoverTool, LayoutDOM, BasicTicker, ColorBar, LinearColorMapper, PrintfTickFormatter +from bokeh.plotting import figure + +from .._version import VERSION +from ..common.utility import check_kwargs + +__version__ = VERSION +__author__ = "Ashwin Patil" + +@attr.s(auto_attribs=True) +class PlotParams: + """Plot params for heatmap.""" + + title: Optional[str] = "Heatmap" + x: Optional[str] = None + x_col: Optional[str] = None + y: Optional[str] = None + y_col: Optional[str] = None + height: int =400 + width: int = 800 + color_pallette: Optional[List]= ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"] + value_col: Optional[str] = 'Total' + sort: Optional[Union[str, bool]] = None + sort_x: Optional[Union[str, bool]] = None + sort_y: Optional[Union[str, bool]] = None + hide: bool = False + font_size: Optional[int] = None + max_label_font_size: int = 11 + major_label_text_font_size: str = "7px" + + @property + def x_column(self) -> Optional[str]: + """Return the current x column value.""" + return self.x or self.x_col + + @property + def y_column(self) -> Optional[str]: + """Return the current y column value.""" + return self.y or self.y_col + + @classmethod + def field_list(cls) -> List[str]: + """Return field names as a list.""" + return list(attr.fields_dict(cls).keys()) + +def plot_heatmap(data: pd.DataFrame, **kwargs) -> LayoutDOM: + """ + Plot data as a heatmap. + + Parameters + ---------- + data : pd.DataFrame + The data to plot. + x : str + Column to plot on the x (horizontal) axis + x_col : str + Alias for 'x' + y : str + Column to plot on the y (vertical) axis + y_col : str + Alias for 'y' + title : str, optional + Custom title, default is 'Intersection plot' + value_col : str, optional + Column from the DataFrame used to categorize heatmap. Default is Total. + height : int, optional + The plot height. Default is 700 + width : int + The plot width. Default is 900 + color_pallette : List, optional + The color pallette of the heatmap, default is custom list ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"] + sort : Union[str, bool], optional + Sorts the labels of both axes, default is None. + Acceptable values are: + 'asc' (or string starting with 'asc') - Sort ascending + 'desc' (or string starting with 'asc') - Sort descending + False or None (no sort) + True - Sort ascending + sort_x : str, optional + Sorts the labels of the x axis (takes precedence over `sort`), + default is None. + Acceptable values are: + 'asc' (or string starting with 'asc') - Sort ascending + 'desc' (or string starting with 'asc') - Sort descending + False or None (no sort) + True - Sort ascending + sort_y : str, optional + Sorts the labels of the y axis (takes precedence over `sort`), + default is None. + Acceptable values are: + 'asc' (or string starting with 'asc') - Sort ascending + 'desc' (or string starting with 'asc') - Sort descending + False or None (no sort) + True - Sort ascending + hide : bool, optional + Creates and returns but does not display the plot, default + is False. + font_size : int, optional + Manually specify the font size for axis labels, in points, + the default is to automatically calculate a size based on the + number of items in each axis. + max_label_font_size : int, optional + The maximum size, in points, of the X and Y labels, default is 11. + + + Returns + ------- + LayoutDOM + The Bokeh plot + + """ + # Process/extract parameters + check_kwargs(kwargs, PlotParams.field_list()) + param = PlotParams(**kwargs) + + if not param.x_column and not param.y_column: + raise ValueError("Must supply `x` and `y` column parameters.") + + reset_output() + output_notebook() + + x_range, y_range = _sort_days_hours(data, param.x_column, param.y_column) + + plot = figure( + title=param.title, + x_axis_location="above", + x_range=x_range, + y_range=y_range, + plot_width=param.width, + plot_height=param.height, + tools=["wheel_zoom", "box_zoom", "pan", "reset", "save"], + toolbar_location="above", + ) + + tool_tips = [ + (param.x_column, f"@{param.x_column} @{param.y_column}:00"), + (param.value_col, f"@{param.value_col}") + ] + plot.add_tools(HoverTool(tooltips=tool_tips)) + + mapper, color_bar = _create_colorbar(data, param) + + plot.rect(x=param.y_column, y=param.x_column, width=1, height=1, + source=data, + fill_color={'field':param.value_col, 'transform': mapper}, + line_color=None) + + plot.add_layout(color_bar, 'right') + + _set_plot_params(plot) + + if not param.hide: + show(plot) + return plot + + +def _set_plot_params(plot): + plot.title.text_font_size = "15pt" + plot.outline_line_color = None + plot.xgrid.visible = True + plot.ygrid.visible = True + plot.grid.grid_line_color = None + plot.grid.grid_line_alpha = 0.1 + plot.axis.axis_line_color = None + plot.axis.major_tick_line_color = None + plot.axis.major_label_standoff = 0 + plot.xaxis.major_label_orientation = pi / 3 + +def _sort_days_hours(data: pd.DataFrame, week_column: str, hour_column: str): + """Sort the Week days and hour of day if required.""" + dayofweek = list(data[week_column].unique()) + hourofday = list(data[hour_column].astype(str).unique()) + correct_days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday'] + correct_hours = hours = [f"{hr}" for hr in range(0, 24)] + days = {name:val for val, name in enumerate(correct_days)} + hours = {name:val for val, name in enumerate(correct_hours)} + sorted_days = sorted(dayofweek, key=days.get, reverse=True) + sorted_hours = sorted(hourofday, key=hours.get) + return sorted_hours, sorted_days + +def _create_colorbar(data: pd.DataFrame, param: PlotParams): + mapper = LinearColorMapper(palette=param.color_pallette, low=data[param.value_col].min(), high=data[param.value_col].max()) + color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size=param.major_label_text_font_size, + ticker=BasicTicker(desired_num_ticks=len(param.color_pallette)), + formatter=PrintfTickFormatter(format="%d"), + label_standoff=6, border_line_color=None) + return mapper, color_bar \ No newline at end of file From b827095912f768b8dc9707cfde4e58468f7739bb Mon Sep 17 00:00:00 2001 From: Ashwin Patil Date: Tue, 9 Aug 2022 20:55:59 +0000 Subject: [PATCH 2/5] pandas accessor --- msticpy/vis/mp_pandas_plot.py | 71 +++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/msticpy/vis/mp_pandas_plot.py b/msticpy/vis/mp_pandas_plot.py index ef3cd0ef9..06fb77ffe 100644 --- a/msticpy/vis/mp_pandas_plot.py +++ b/msticpy/vis/mp_pandas_plot.py @@ -14,6 +14,7 @@ from ..common.exceptions import MsticpyUserError from ..transform.network import GraphType, df_to_networkx from ..vis.network_plot import plot_nx_graph +from ..vis.heatmap import plot_heatmap from .entity_graph_tools import EntityGraph, req_alert_cols, req_inc_cols from .foliummap import plot_map from .matrix_plot import plot_matrix @@ -616,3 +617,73 @@ def network( edge_attrs=edge_attrs, **kwargs, ) + + # pylint: disable=too-many-arguments + def heatmap(data: pd.DataFrame, **kwargs) -> LayoutDOM: + """ + Plot data as a heatmap. + + Parameters + ---------- + data : pd.DataFrame + The data to plot. + x : str + Column to plot on the x (horizontal) axis + x_col : str + Alias for 'x' + y : str + Column to plot on the y (vertical) axis + y_col : str + Alias for 'y' + title : str, optional + Custom title, default is 'Intersection plot' + value_col : str, optional + Column from the DataFrame used to categorize heatmap. Default is Total. + height : int, optional + The plot height. Default is 700 + width : int + The plot width. Default is 900 + color_pallette : List, optional + The color pallette of the heatmap, default is custom list ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"] + sort : Union[str, bool], optional + Sorts the labels of both axes, default is None. + Acceptable values are: + 'asc' (or string starting with 'asc') - Sort ascending + 'desc' (or string starting with 'asc') - Sort descending + False or None (no sort) + True - Sort ascending + sort_x : str, optional + Sorts the labels of the x axis (takes precedence over `sort`), + default is None. + Acceptable values are: + 'asc' (or string starting with 'asc') - Sort ascending + 'desc' (or string starting with 'asc') - Sort descending + False or None (no sort) + True - Sort ascending + sort_y : str, optional + Sorts the labels of the y axis (takes precedence over `sort`), + default is None. + Acceptable values are: + 'asc' (or string starting with 'asc') - Sort ascending + 'desc' (or string starting with 'asc') - Sort descending + False or None (no sort) + True - Sort ascending + hide : bool, optional + Creates and returns but does not display the plot, default + is False. + font_size : int, optional + Manually specify the font size for axis labels, in points, + the default is to automatically calculate a size based on the + number of items in each axis. + max_label_font_size : int, optional + The maximum size, in points, of the X and Y labels, default is 11. + + + Returns + ------- + LayoutDOM + The Bokeh plot + + """ + + return plot_heatmap(data=self._df, **kwargs) \ No newline at end of file From 35a5f3431c03025d94e094634aee17b5573faa7f Mon Sep 17 00:00:00 2001 From: Ashwin Patil Date: Thu, 17 Nov 2022 16:57:52 +0000 Subject: [PATCH 3/5] correcting function definition --- msticpy/vis/mp_pandas_plot.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/msticpy/vis/mp_pandas_plot.py b/msticpy/vis/mp_pandas_plot.py index 06fb77ffe..1a686cff2 100644 --- a/msticpy/vis/mp_pandas_plot.py +++ b/msticpy/vis/mp_pandas_plot.py @@ -619,14 +619,12 @@ def network( ) # pylint: disable=too-many-arguments - def heatmap(data: pd.DataFrame, **kwargs) -> LayoutDOM: + def heatmap(self, **kwargs) -> LayoutDOM: """ Plot data as a heatmap. Parameters ---------- - data : pd.DataFrame - The data to plot. x : str Column to plot on the x (horizontal) axis x_col : str From d00f0948d79dc91641861c2794ea3a000dcd5aa8 Mon Sep 17 00:00:00 2001 From: Ashwin Patil Date: Mon, 27 Mar 2023 05:32:19 +0000 Subject: [PATCH 4/5] fixing linting errors --- .pre-commit-config.yaml | 2 +- msticpy/vis/heatmap.py | 98 +++++++++++++++++++++++++---------- msticpy/vis/mp_pandas_plot.py | 9 ++-- 3 files changed, 76 insertions(+), 33 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index db1d0bc35..fc8833bc5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: - --max-line-length=90 - --exclude=tests,test*.py - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort (python) diff --git a/msticpy/vis/heatmap.py b/msticpy/vis/heatmap.py index b92ec8685..33de394a8 100644 --- a/msticpy/vis/heatmap.py +++ b/msticpy/vis/heatmap.py @@ -3,11 +3,16 @@ from typing import List, Optional, Union import attr -import numpy as np import pandas as pd from bokeh.io import output_notebook, reset_output, show -from bokeh.layouts import row -from bokeh.models import HoverTool, LayoutDOM, BasicTicker, ColorBar, LinearColorMapper, PrintfTickFormatter +from bokeh.models import ( + BasicTicker, + ColorBar, + HoverTool, + LayoutDOM, + LinearColorMapper, + PrintfTickFormatter, +) from bokeh.plotting import figure from .._version import VERSION @@ -16,6 +21,7 @@ __version__ = VERSION __author__ = "Ashwin Patil" + @attr.s(auto_attribs=True) class PlotParams: """Plot params for heatmap.""" @@ -25,10 +31,20 @@ class PlotParams: x_col: Optional[str] = None y: Optional[str] = None y_col: Optional[str] = None - height: int =400 + height: int = 400 width: int = 800 - color_pallette: Optional[List]= ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"] - value_col: Optional[str] = 'Total' + color_pallette: Optional[List] = [ + "#75968f", + "#a5bab7", + "#c9d9d3", + "#e2e2e2", + "#dfccce", + "#ddb7b1", + "#cc7878", + "#933b41", + "#550b1d", + ] + value_col: Optional[str] = "Total" sort: Optional[Union[str, bool]] = None sort_x: Optional[Union[str, bool]] = None sort_y: Optional[Union[str, bool]] = None @@ -52,6 +68,7 @@ def field_list(cls) -> List[str]: """Return field names as a list.""" return list(attr.fields_dict(cls).keys()) + def plot_heatmap(data: pd.DataFrame, **kwargs) -> LayoutDOM: """ Plot data as a heatmap. @@ -77,7 +94,9 @@ def plot_heatmap(data: pd.DataFrame, **kwargs) -> LayoutDOM: width : int The plot width. Default is 900 color_pallette : List, optional - The color pallette of the heatmap, default is custom list ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"] + The color pallette of the heatmap, default is custom list + ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", + "#cc7878", "#933b41", "#550b1d"] sort : Union[str, bool], optional Sorts the labels of both axes, default is None. Acceptable values are: @@ -128,7 +147,7 @@ def plot_heatmap(data: pd.DataFrame, **kwargs) -> LayoutDOM: reset_output() output_notebook() - x_range, y_range = _sort_days_hours(data, param.x_column, param.y_column) + x_range, y_range = _sort_days_hours(data, param.x_column, param.y_column) plot = figure( title=param.title, @@ -143,18 +162,23 @@ def plot_heatmap(data: pd.DataFrame, **kwargs) -> LayoutDOM: tool_tips = [ (param.x_column, f"@{param.x_column} @{param.y_column}:00"), - (param.value_col, f"@{param.value_col}") + (param.value_col, f"@{param.value_col}"), ] plot.add_tools(HoverTool(tooltips=tool_tips)) mapper, color_bar = _create_colorbar(data, param) - plot.rect(x=param.y_column, y=param.x_column, width=1, height=1, + plot.rect( + x=param.y_column, + y=param.x_column, + width=1, + height=1, source=data, - fill_color={'field':param.value_col, 'transform': mapper}, - line_color=None) - - plot.add_layout(color_bar, 'right') + fill_color={"field": param.value_col, "transform": mapper}, + line_color=None, + ) + + plot.add_layout(color_bar, "right") _set_plot_params(plot) @@ -175,22 +199,40 @@ def _set_plot_params(plot): plot.axis.major_label_standoff = 0 plot.xaxis.major_label_orientation = pi / 3 -def _sort_days_hours(data: pd.DataFrame, week_column: str, hour_column: str): + +def _sort_days_hours(data: pd.DataFrame, day_column: str, hour_column: str): """Sort the Week days and hour of day if required.""" - dayofweek = list(data[week_column].unique()) + dayofweek = list(data[day_column].unique()) hourofday = list(data[hour_column].astype(str).unique()) - correct_days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday'] + correct_days = [ + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + "Saturday", + "Sunday", + ] correct_hours = hours = [f"{hr}" for hr in range(0, 24)] - days = {name:val for val, name in enumerate(correct_days)} - hours = {name:val for val, name in enumerate(correct_hours)} + days = {name: val for val, name in enumerate(correct_days)} + hours = {name: val for val, name in enumerate(correct_hours)} sorted_days = sorted(dayofweek, key=days.get, reverse=True) sorted_hours = sorted(hourofday, key=hours.get) - return sorted_hours, sorted_days - -def _create_colorbar(data: pd.DataFrame, param: PlotParams): - mapper = LinearColorMapper(palette=param.color_pallette, low=data[param.value_col].min(), high=data[param.value_col].max()) - color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size=param.major_label_text_font_size, - ticker=BasicTicker(desired_num_ticks=len(param.color_pallette)), - formatter=PrintfTickFormatter(format="%d"), - label_standoff=6, border_line_color=None) - return mapper, color_bar \ No newline at end of file + return sorted_hours, sorted_days + + +def _create_colorbar(data: pd.DataFrame, param: PlotParams): + mapper = LinearColorMapper( + palette=param.color_pallette, + low=data[param.value_col].min(), + high=data[param.value_col].max(), + ) + color_bar = ColorBar( + color_mapper=mapper, + major_label_text_font_size=param.major_label_text_font_size, + ticker=BasicTicker(desired_num_ticks=len(param.color_pallette)), + formatter=PrintfTickFormatter(format="%d"), + label_standoff=6, + border_line_color=None, + ) + return mapper, color_bar diff --git a/msticpy/vis/mp_pandas_plot.py b/msticpy/vis/mp_pandas_plot.py index 99733df2b..5a28c4b28 100644 --- a/msticpy/vis/mp_pandas_plot.py +++ b/msticpy/vis/mp_pandas_plot.py @@ -13,8 +13,8 @@ from .._version import VERSION from ..common.exceptions import MsticpyUserError from ..transform.network import GraphType, df_to_networkx -from ..vis.network_plot import plot_nx_graph from ..vis.heatmap import plot_heatmap +from ..vis.network_plot import plot_nx_graph from .entity_graph_tools import EntityGraph, req_alert_cols, req_inc_cols from .foliummap import plot_map from .matrix_plot import plot_matrix @@ -642,7 +642,9 @@ def heatmap(self, **kwargs) -> LayoutDOM: width : int The plot width. Default is 900 color_pallette : List, optional - The color pallette of the heatmap, default is custom list ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"] + The color pallette of the heatmap, default is custom list + ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", + "#ddb7b1", "#cc7878", "#933b41", "#550b1d"] sort : Union[str, bool], optional Sorts the labels of both axes, default is None. Acceptable values are: @@ -683,5 +685,4 @@ def heatmap(self, **kwargs) -> LayoutDOM: The Bokeh plot """ - - return plot_heatmap(data=self._df, **kwargs) \ No newline at end of file + return plot_heatmap(data=self._df, **kwargs) From d9e6f061877328e23cf31c885646baeb654950ea Mon Sep 17 00:00:00 2001 From: Ashwin Patil Date: Mon, 27 Mar 2023 21:24:52 +0000 Subject: [PATCH 5/5] fixing mypy errors --- msticpy/vis/heatmap.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/msticpy/vis/heatmap.py b/msticpy/vis/heatmap.py index 33de394a8..4eca53c4c 100644 --- a/msticpy/vis/heatmap.py +++ b/msticpy/vis/heatmap.py @@ -33,7 +33,7 @@ class PlotParams: y_col: Optional[str] = None height: int = 400 width: int = 800 - color_pallette: Optional[List] = [ + color_pallette: List[str] = [ "#75968f", "#a5bab7", "#c9d9d3", @@ -54,14 +54,20 @@ class PlotParams: major_label_text_font_size: str = "7px" @property - def x_column(self) -> Optional[str]: + def x_column(self) -> str: """Return the current x column value.""" - return self.x or self.x_col + x_column = self.x or self.x_col + if x_column is None: + raise TypeError("Please supply value for x_column") + return x_column @property - def y_column(self) -> Optional[str]: + def y_column(self) -> str: """Return the current y column value.""" - return self.y or self.y_col + y_column = self.y or self.y_col + if y_column is None: + raise TypeError("Please supply value for y_column") + return y_column @classmethod def field_list(cls) -> List[str]: @@ -213,11 +219,11 @@ def _sort_days_hours(data: pd.DataFrame, day_column: str, hour_column: str): "Saturday", "Sunday", ] - correct_hours = hours = [f"{hr}" for hr in range(0, 24)] + correct_hours = [f"{hr}" for hr in range(0, 24)] days = {name: val for val, name in enumerate(correct_days)} hours = {name: val for val, name in enumerate(correct_hours)} - sorted_days = sorted(dayofweek, key=days.get, reverse=True) - sorted_hours = sorted(hourofday, key=hours.get) + sorted_days = sorted(dayofweek, key=days.get, reverse=True) # type: ignore[arg-type] + sorted_hours = sorted(hourofday, key=hours.get) # type: ignore[arg-type] return sorted_hours, sorted_days