Skip to content

Commit

Permalink
Run code formatters.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandervsokol committed Sep 25, 2024
1 parent bb39e23 commit dc96687
Show file tree
Hide file tree
Showing 23 changed files with 157 additions and 116 deletions.
1 change: 0 additions & 1 deletion cl/runtime/backend/core/ui_app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from dataclasses import dataclass
from typing import List
from typing import Optional

from cl.runtime import Context
from cl.runtime.backend.core.app_theme import AppTheme
from cl.runtime.backend.core.tab_info import TabInfo
Expand Down
1 change: 0 additions & 1 deletion cl/runtime/file/csv_file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any
from typing import Dict
from typing import Type

from cl.runtime import Context
from cl.runtime.file.reader import Reader
from cl.runtime.records.protocols import RecordProtocol
Expand Down
7 changes: 3 additions & 4 deletions cl/runtime/plots/confusion_matrix_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

from cl.runtime import Context
from cl.runtime.plots.confusion_matrix_plot_style import ConfusionMatrixPlotStyle
from cl.runtime.plots.confusion_matrix_plot_style_key import ConfusionMatrixPlotStyleKey
Expand Down Expand Up @@ -58,15 +57,15 @@ def create_figure(self) -> plt.Figure:
# TODO: consider moving
data, annotation_text = self._create_confusion_matrix()

theme = 'dark_background' if style.dark_theme else 'default'
theme = "dark_background" if style.dark_theme else "default"

with plt.style.context(theme):
fig, axes = plt.subplots()

cmap = LinearSegmentedColormap.from_list('rg', ["g", "y", "r"], N=256)
cmap = LinearSegmentedColormap.from_list("rg", ["g", "y", "r"], N=256)

im = MatplotlibUtil.heatmap(data.values, data.index.tolist(), data.columns.tolist(), ax=axes, cmap=cmap)
MatplotlibUtil.annotate_heatmap(im, labels=annotation_text, textcolors='black', size=style.label_font_size)
MatplotlibUtil.annotate_heatmap(im, labels=annotation_text, textcolors="black", size=style.label_font_size)

# Set figure and axes labels
axes.set_xlabel(self.x_label)
Expand Down
6 changes: 2 additions & 4 deletions cl/runtime/plots/group_bar_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@

from dataclasses import dataclass
from typing import List

import numpy as np
from matplotlib import pyplot as plt

from cl.runtime import Context
from cl.runtime.plots.group_bar_plot_style import GroupBarPlotStyle
from cl.runtime.plots.group_bar_plot_style_key import GroupBarPlotStyleKey
Expand Down Expand Up @@ -57,7 +55,7 @@ def create_figure(self) -> plt.Figure:
# Load style object
style = Context.current().load_one(GroupBarPlotStyle, self.style)

theme = 'dark_background' if style.dark_theme else 'default'
theme = "dark_background" if style.dark_theme else "default"

with plt.style.context(theme):
fig = plt.figure()
Expand All @@ -80,7 +78,7 @@ def create_figure(self) -> plt.Figure:
space = 1 / (len(self.bar_labels) + 1)

for i, (bar_label, bar_shift) in enumerate(zip(self.bar_labels, bar_shifts)):
data = self.values[i * len(self.group_labels): (i + 1) * len(self.group_labels)]
data = self.values[i * len(self.group_labels) : (i + 1) * len(self.group_labels)]
axes.bar(x_ticks + space * bar_shift, data, space, label=bar_label)

axes.set_xticks(x_ticks, self.group_labels)
Expand Down
1 change: 0 additions & 1 deletion cl/runtime/plots/group_bar_plot_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from abc import ABC
from dataclasses import dataclass

from cl.runtime.plots.group_bar_plot_style_key import GroupBarPlotStyleKey
from cl.runtime.records.record_mixin import RecordMixin

Expand Down
9 changes: 4 additions & 5 deletions cl/runtime/plots/heat_map_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

from cl.runtime import Context
from cl.runtime.plots.heat_map_plot_style import HeatMapPlotStyle
from cl.runtime.plots.heat_map_plot_style_key import HeatMapPlotStyleKey
Expand Down Expand Up @@ -59,19 +58,19 @@ def create_figure(self) -> plt.Figure:
# Load style object
style = Context.current().load_one(HeatMapPlotStyle, self.style)

theme = 'dark_background' if style.dark_theme else 'default'
theme = "dark_background" if style.dark_theme else "default"

with plt.style.context(theme):
fig, axes = plt.subplots()

shape = (len(self.row_labels), len(self.col_labels))

data = np.abs(
np.reshape(np.asarray(self.received_values), shape) - np.reshape(np.asarray(self.expected_values),
shape)
np.reshape(np.asarray(self.received_values), shape)
- np.reshape(np.asarray(self.expected_values), shape)
)

cmap = LinearSegmentedColormap.from_list('rg', ["g", "y", "r"], N=256)
cmap = LinearSegmentedColormap.from_list("rg", ["g", "y", "r"], N=256)

im = MatplotlibUtil.heatmap(data, self.row_labels, self.col_labels, ax=axes, cmap=cmap)

Expand Down
36 changes: 18 additions & 18 deletions cl/runtime/plots/matplotlib_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np

from matplotlib import pyplot as plt
from matplotlib.image import AxesImage

Expand Down Expand Up @@ -55,30 +55,28 @@ def heatmap(data: np.ndarray, row_labels: List[str], col_labels: List[str], ax=N
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)

# Let the horizontal axes labeling appear on top.
ax.tick_params(top=True, bottom=False,
labeltop=True, labelbottom=False)
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
rotation_mode="anchor")
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")

# Turn spines off and create white grid.
ax.spines[:].set_visible(False)

ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True)
ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True)
ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)

return im

@staticmethod
def annotate_heatmap(
im: AxesImage,
labels: List[List[str]],
textcolors: Union[str, Tuple[str]] = ("black", "white"),
threshold: Optional[float] = None,
**textkw
im: AxesImage,
labels: List[List[str]],
textcolors: Union[str, Tuple[str]] = ("black", "white"),
threshold: Optional[float] = None,
**textkw,
):
"""
A function to annotate a heatmap.
Expand Down Expand Up @@ -109,8 +107,7 @@ def annotate_heatmap(

# Set default alignment to center, but allow it to be
# overwritten by textkw.
kw = dict(horizontalalignment="center",
verticalalignment="center")
kw = dict(horizontalalignment="center", verticalalignment="center")
kw.update(textkw)

# Loop over the data and create a `Text` for each "pixel".
Expand All @@ -119,8 +116,11 @@ def annotate_heatmap(
for i in range(data.shape[0]):
for j in range(data.shape[1]):
kw.update(
color=textcolors[int(im.norm(data[i, j]) < threshold)] if isinstance(textcolors, tuple)
else textcolors,
color=(
textcolors[int(im.norm(data[i, j]) < threshold)]
if isinstance(textcolors, tuple)
else textcolors
),
)
text = im.axes.text(j, i, labels[i][j], **kw)
texts.append(text)
Expand Down
6 changes: 4 additions & 2 deletions cl/runtime/primitive/ordered_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ def validate(cls, value: UUID) -> None:

# Check type
if (value_type_name := type(value).__name__) != "UUID":
raise RuntimeError(f"Method 'OrderedUuid.datetime_of' received object of '{value_type_name}' "
f"type while 'UUID' was expected.")
raise RuntimeError(
f"Method 'OrderedUuid.datetime_of' received object of '{value_type_name}' "
f"type while 'UUID' was expected."
)

# Check version
if value.version != 7:
Expand Down
18 changes: 11 additions & 7 deletions cl/runtime/routers/storage/record_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,17 @@ def get_record(cls, request: RecordRequest) -> RecordResponse:
record = data_source.load_one(record_type, deserialized_key)

# Get type declarations based on the actual record type
type_decl_dict = TypeResponseUtil.get_type(
TypeRequest(
name=type(record).__name__,
module=request.module,
user="root",
),
) if record is not None else dict() # Handle not found record
type_decl_dict = (
TypeResponseUtil.get_type(
TypeRequest(
name=type(record).__name__,
module=request.module,
user="root",
),
)
if record is not None
else dict()
) # Handle not found record

# TODO: Optimize speed using dacite or similar library

Expand Down
1 change: 0 additions & 1 deletion cl/runtime/serialization/dict_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Type
from typing import cast
from inflection import camelize

from cl.runtime.backend.core.base_type_info import BaseTypeInfo
from cl.runtime.backend.core.tab_info import TabInfo
from cl.runtime.records.protocols import is_key
Expand Down
22 changes: 12 additions & 10 deletions cl/runtime/serialization/string_value_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import re
from enum import Enum
from enum import IntEnum
from typing import Any, Dict, Final
from typing import Any
from typing import Dict
from typing import Final
from cl.runtime.records.protocols import is_key


Expand Down Expand Up @@ -62,9 +64,7 @@ class StringValueCustomType(IntEnum):
"""Key type."""


CUSTOM_TYPE_VALUE_TO_NAME: Final[Dict[StringValueCustomType, str]] = {
StringValueCustomType.dict: "json"
}
CUSTOM_TYPE_VALUE_TO_NAME: Final[Dict[StringValueCustomType, str]] = {StringValueCustomType.dict: "json"}
"""Enum value to name mapping."""

CUSTOM_TYPE_NAME_TO_VALUE: Final[Dict[str, StringValueCustomType]] = {
Expand All @@ -84,9 +84,9 @@ def add_type_prefix(cls, value: str, type_: StringValueCustomType | None) -> str
return value

# Check type name in alias mapping
type_name = type_name_alias if (
(type_name_alias := CUSTOM_TYPE_VALUE_TO_NAME.get(type_)) is not None
) else type_.name
type_name = (
type_name_alias if ((type_name_alias := CUSTOM_TYPE_VALUE_TO_NAME.get(type_)) is not None) else type_.name
)

type_prefix = f"```{type_name} "
return type_prefix + value
Expand Down Expand Up @@ -115,9 +115,11 @@ def parse(cls, value: str) -> (str, StringValueCustomType | None):
value_without_prefix = value.removeprefix(f"```{value_custom_type} ")

# Check custom type in alias mapping
value_custom_type = custom_type if (
(custom_type := CUSTOM_TYPE_NAME_TO_VALUE.get(value_custom_type)) is not None
) else StringValueCustomType[value_custom_type]
value_custom_type = (
custom_type
if ((custom_type := CUSTOM_TYPE_NAME_TO_VALUE.get(value_custom_type)) is not None)
else StringValueCustomType[value_custom_type]
)

return value_without_prefix, value_custom_type
else:
Expand Down
6 changes: 4 additions & 2 deletions cl/runtime/settings/plotly_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def __post_init__(self):
self.max_test_plots = int(self.show_limit)

if not isinstance(self.show_limit, int):
raise RuntimeError(f"{type(self).__name__} field 'show_limit' must be"
f"an int or a string that can be converted to an int.")
raise RuntimeError(
f"{type(self).__name__} field 'show_limit' must be"
f"an int or a string that can be converted to an int."
)

@classmethod
def get_prefix(cls) -> str:
Expand Down
18 changes: 10 additions & 8 deletions cl/runtime/testing/regression_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import inspect
import os
from dataclasses import dataclass
from typing import Any, Literal
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import Iterable
from typing import Literal
from typing import cast
import inflection
import yaml
Expand Down Expand Up @@ -82,11 +83,11 @@ class RegressionGuard:
"""Output file extension (format), defaults to '.txt'"""

def __init__(
self,
*,
ext: str = None,
channel: str | None = None,
test_function_pattern: str | None = None,
self,
*,
ext: str = None,
channel: str | None = None,
test_function_pattern: str | None = None,
):
"""
Initialize the regression guard, optionally specifying channel.
Expand Down Expand Up @@ -239,8 +240,9 @@ def verify(self, *, silent: bool = True) -> bool:
diff_path = self._get_file_path("diff")

if not os.path.exists(received_path):
raise RuntimeError(f"Regression guard error, cannot verify because "
f"received file {received_path} does not yet exist.")
raise RuntimeError(
f"Regression guard error, cannot verify because " f"received file {received_path} does not yet exist."
)

if os.path.exists(expected_path):
# Expected file exists, compare
Expand Down
16 changes: 6 additions & 10 deletions cl/runtime/testing/stack_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def is_inside_test(cls, *, test_module_pattern: str | None = None) -> bool:

@classmethod
def get_base_dir(
cls,
*,
allow_missing: bool = False,
test_function_pattern: str | None = None,
cls,
*,
allow_missing: bool = False,
test_function_pattern: str | None = None,
) -> str:
"""
Return module_dir/test_module/test_function or module_dir/test_module/test_class/test_method,
Expand All @@ -61,9 +61,7 @@ def get_base_dir(
defaults to 'test_*'
"""
return cls._get_base_dir_or_path(
dot_delimited=False,
allow_missing=allow_missing,
test_function_pattern=test_function_pattern
dot_delimited=False, allow_missing=allow_missing, test_function_pattern=test_function_pattern
)

@classmethod
Expand All @@ -86,9 +84,7 @@ def get_base_path(
defaults to 'test_*'
"""
return cls._get_base_dir_or_path(
dot_delimited=True,
allow_missing=allow_missing,
test_function_pattern=test_function_pattern
dot_delimited=True, allow_missing=allow_missing, test_function_pattern=test_function_pattern
)

@classmethod
Expand Down
4 changes: 0 additions & 4 deletions stubs/cl/runtime/config/stub_runtime_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from dataclasses import dataclass

from cl.runtime import handler
from cl.runtime.backend.core.ui_app_state import UiAppState
from cl.runtime.backend.core.user_key import UserKey
Expand Down Expand Up @@ -107,6 +106,3 @@ def configure_plots(self) -> None:
bar_plot.bar_labels = ["Bar 1", "Bar 2"]
bar_plot.values = [85.5, 92]
Context.current().save_one(bar_plot)



1 change: 0 additions & 1 deletion stubs/cl/runtime/views/stub_viewers_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,3 @@ def pdf_viewer(self) -> BinaryContent:
content = file.read()
pdf_content = BinaryContent(content=content, content_type=BinaryContentTypeEnum.Pdf)
return pdf_content

Loading

0 comments on commit dc96687

Please sign in to comment.