Skip to content

Commit

Permalink
Raise error if fit is not implemented in class inheriting KineticMode…
Browse files Browse the repository at this point in the history
…l; make mask the second argument of fit
  • Loading branch information
bilgelm committed Aug 8, 2024
1 parent 0beca66 commit 8c9397f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/dynamicpet/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,13 @@ def kineticmodel(
km.fit(mask=petmask_img_mat)
case "srtmlammertsma1996":
km = SRTMLammertsma1996(reftac, pet_img)
km.fit(weight_by=weight_by, mask=petmask_img_mat)
km.fit(mask=petmask_img_mat, weight_by=weight_by)
case "srtmzhou2003":
km = SRTMZhou2003(reftac, pet_img)
km.fit(
mask=petmask_img_mat,
integration_type=integration_type,
weight_by=weight_by,
mask=petmask_img_mat,
fwhm=fwhm,
)
case "kinfitr.srtm":
Expand Down
4 changes: 2 additions & 2 deletions src/dynamicpet/kineticmodel/kineticmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def __init__(
self.parameters: dict[str, NumpyRealNumberArray] = {}

@abstractmethod
def fit(self) -> None:
def fit(self, mask: NumpyRealNumberArray | None = None) -> None:
"""Estimate model parameters."""
# implementation should update self.parameters
pass
raise NotImplementedError

def get_parameter(self, param_name: str) -> SpatialImage | NumpyRealNumberArray:
"""Get a fitted parameter.
Expand Down
4 changes: 2 additions & 2 deletions src/dynamicpet/kineticmodel/srtm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def get_param_names(cls) -> list[str]:

def fit(
self,
weight_by: WEIGHT_OPTS | NumpyRealNumberArray | None = None,
mask: NumpyRealNumberArray | None = None,
weight_by: WEIGHT_OPTS | NumpyRealNumberArray | None = None,
) -> None:
"""Estimate model parameters.
Expand Down Expand Up @@ -168,9 +168,9 @@ def get_param_names(cls) -> list[str]:

def fit( # noqa: max-complexity: 12
self,
mask: NumpyRealNumberArray | None = None,
integration_type: INTEGRATION_TYPE_OPTS = "trapz",
weight_by: WEIGHT_OPTS | NumpyRealNumberArray | None = "frame_duration",
mask: NumpyRealNumberArray | None = None,
fwhm: RealNumber | list[RealNumber] | None = None,
) -> None:
"""Estimate model parameters.
Expand Down

0 comments on commit 8c9397f

Please sign in to comment.