-
Notifications
You must be signed in to change notification settings - Fork 0
Refactor FunctionalCPD to use dedicated Tabular and LinearGaussian adapters #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import numpy as np | ||
|
|
||
| from pgmpy.factors.continuous import LinearGaussianCPD | ||
|
|
||
|
|
||
| class LinearGaussianAdapter: | ||
| """ | ||
| Adapter that fits data into a `LinearGaussianCPD`. | ||
| """ | ||
|
|
||
| def __init__(self, variable, estimator=None, parents=None): | ||
| self.variable = variable | ||
| self.estimator = estimator | ||
| self.parents = parents if parents is not None else [] | ||
|
|
||
| def fit(self, data): | ||
| if self.estimator not in ("MLE", "OLS", None): | ||
| raise ValueError(f"For linear tag, MLE/OLS is supported. Got {self.estimator}") | ||
|
|
||
| target_data = data[self.variable].values | ||
|
|
||
| if not self.parents: | ||
| mean = np.mean(target_data) | ||
| std = np.std(target_data) | ||
| beta = [mean] | ||
| else: | ||
| evidence_data = data[self.parents].values | ||
| X = np.c_[np.ones(evidence_data.shape[0]), evidence_data] | ||
|
|
||
| beta, residuals, rank, s = np.linalg.lstsq(X, target_data, rcond=None) | ||
| if len(residuals) > 0: | ||
| variance = residuals[0] / len(target_data) | ||
| else: | ||
| predictions = X @ beta | ||
| variance = np.mean((target_data - predictions) ** 2) | ||
| std = np.sqrt(variance) | ||
|
|
||
| self.fitted_cpd_ = LinearGaussianCPD(variable=self.variable, beta=beta, std=std, evidence=self.parents) | ||
| return self | ||
|
|
||
| def __repr__(self): | ||
| if not hasattr(self, "fitted_cpd_"): | ||
| return f"<LinearGaussianAdapter(variable='{self.variable}', status='unfitted') at {hex(id(self))}>" | ||
|
|
||
| cpd = self.fitted_cpd_ | ||
| beta_str = f"{cpd.beta[0]:.3f}" | ||
| for i, parent in enumerate(cpd.evidence): | ||
| beta_str += f" + {cpd.beta[i + 1]:.3f}*{parent}" | ||
|
|
||
| return ( | ||
| f"<LinearGaussianAdapter representing P({cpd.variable} | {', '.join(cpd.evidence)}) " | ||
| f"~ N({beta_str}, std={cpd.std:.3f}) at {hex(id(self))}>" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,60 @@ | ||||||||
| from pgmpy.estimators import MaximumLikelihoodEstimator as MLE | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The file uses numpy features (e.g.,
Suggested change
|
||||||||
| from pgmpy.factors.discrete import TabularCPD | ||||||||
|
|
||||||||
|
|
||||||||
| class TabularAdapter: | ||||||||
| """ | ||||||||
| Adapter that fits data into a `TabularCPD`. | ||||||||
| """ | ||||||||
|
|
||||||||
| def __init__(self, variable, estimator, parents=None): | ||||||||
| self.variable = variable | ||||||||
| self.estimator = estimator | ||||||||
| self.parents = parents if parents is not None else [] | ||||||||
|
|
||||||||
| def fit(self, data): | ||||||||
| if self.estimator not in ("MLE", MLE): | ||||||||
| raise ValueError("For tabular tag, only MLE estimator is currently supported.") | ||||||||
|
|
||||||||
| variable_states = sorted(data[self.variable].dropna().unique()) | ||||||||
| if not self.parents: | ||||||||
| counts = data[self.variable].value_counts().reindex(variable_states, fill_value=0) | ||||||||
| probs = counts.values / counts.values.sum() | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
Suggested change
|
||||||||
| values = [[prob] for prob in probs] | ||||||||
| self.fitted_cpd_ = TabularCPD(variable=self.variable, variable_card=len(variable_states), values=values) | ||||||||
| return self | ||||||||
|
|
||||||||
| parent_states = [sorted(data[parent].dropna().unique()) for parent in self.parents] | ||||||||
| grouped = ( | ||||||||
| data.groupby(self.parents + [self.variable], dropna=False) | ||||||||
| .size() | ||||||||
| .unstack(self.variable, fill_value=0) | ||||||||
| .reindex(columns=variable_states, fill_value=0) | ||||||||
| ) | ||||||||
| grouped = grouped.T | ||||||||
| grouped = grouped / grouped.sum(axis=0).replace(0, 1) | ||||||||
| self.fitted_cpd_ = TabularCPD( | ||||||||
| variable=self.variable, | ||||||||
| variable_card=len(variable_states), | ||||||||
| values=grouped.values, | ||||||||
| evidence=self.parents, | ||||||||
| evidence_card=[len(states) for states in parent_states], | ||||||||
| ) | ||||||||
| return self | ||||||||
|
|
||||||||
| def __repr__(self): | ||||||||
| if not hasattr(self, "fitted_cpd_"): | ||||||||
| return f"<TabularAdapter(variable='{self.variable}', status='unfitted') at {hex(id(self))}>" | ||||||||
|
|
||||||||
| cpd = self.fitted_cpd_ | ||||||||
| var_str = f"<TabularAdapter representing P({cpd.variable}:{cpd.variable_card}" | ||||||||
|
|
||||||||
| evidence = cpd.variables[1:] | ||||||||
| evidence_card = cpd.cardinality[1:] | ||||||||
|
|
||||||||
| if evidence: | ||||||||
| evidence_str = " | " + ", ".join([f"{var}:{card}" for var, card in zip(evidence, evidence_card)]) | ||||||||
| else: | ||||||||
| evidence_str = "" | ||||||||
|
|
||||||||
| return var_str + evidence_str + f") at {hex(id(self))}>" | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current
__repr__format for a fitted model without evidence variables results inP(variable | ), which is slightly awkward. It would be cleaner to omit the|when there are no evidence variables, similar to how it's handled inTabularAdapter.