Skip to content

Commit 3d2d039

Browse files
added stat params inside MapperComplex
1 parent 62439da commit 3d2d039

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

example/ex_mapper.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66

77
X = np.loadtxt("inputs/human")
88

9-
109
print("Mapper computation with point cloud")
1110
mapper = MapperComplex(inp="point cloud", filters=X[:,[2,0]], filter_bnds=np.array([[np.nan,np.nan],[np.nan,np.nan]]), resolutions=np.array([np.nan,np.nan]), gains=np.array([0.33,0.33]), colors=X[:,2:3]).fit(X)
12-
print(mapper.mapper_.get_filtration())
13-
11+
print(list(mapper.mapper_.get_filtration()))
1412

1513
print("Mapper computation with pairwise distances only")
1614
X = pairwise_distances(X)
1715
mapper = MapperComplex(inp="distance matrix", filters=X[:,[2,0]], filter_bnds=np.array([[np.nan,np.nan],[np.nan,np.nan]]), resolutions=np.array([np.nan,np.nan]), gains=np.array([0.33,0.33]), colors=np.max(X, axis=1)[:,np.newaxis]).fit(X)
18-
print(mapper.mapper_.get_filtration())
16+
print(list(mapper.mapper_.get_filtration()))

sklearn_tda/clustering.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class MapperComplex(BaseEstimator, TransformerMixin):
5757
"""
5858
This is a class for computing Mapper simplicial complexes on point clouds or distance matrices.
5959
"""
60-
def __init__(self, filters, filter_bnds, colors, resolutions, gains, inp="point cloud", clustering=DBSCAN(), mask=0):
60+
def __init__(self, filters, filter_bnds, colors, resolutions, gains, inp="point cloud", clustering=DBSCAN(), mask=0, N=100, beta=0., C=10.):
6161
"""
6262
Constructor for the MapperComplex class.
6363
@@ -70,12 +70,15 @@ def __init__(self, filters, filter_bnds, colors, resolutions, gains, inp="point
7070
gains (numpy array of shape num_filters containing doubles in [0,1]): gain of each filter, ie overlap percentage of the intervals covering each filter image.
7171
clustering (class): clustering class (default sklearn.cluster.DBSCAN()). Common clustering classes can be found in the scikit-learn library (such as AgglomerativeClustering for instance).
7272
mask (int): threshold on the size of the Mapper nodes (default 0). Any node associated to a subpopulation with less than **mask** points will be removed.
73+
N (int): subsampling iterations (default 100). See http://www.jmlr.org/papers/volume19/17-291/17-291.pdf for details.
74+
beta (double): exponent parameter (default 0.). See http://www.jmlr.org/papers/volume19/17-291/17-291.pdf for details.
75+
C (double): constant parameter (default 10.). See http://www.jmlr.org/papers/volume19/17-291/17-291.pdf for details.
7376
7477
mapper_ (gudhi SimplexTree): Mapper simplicial complex computed after calling the fit() method
7578
node_info_ (dictionary): various information associated to the nodes of the Mapper.
7679
"""
7780
self.filters, self.filter_bnds, self.resolutions, self.gains, self.colors, self.clustering = filters, filter_bnds, resolutions, gains, colors, clustering
78-
self.input, self.mask = inp, mask
81+
self.input, self.mask, self.N, self.beta, self.C = inp, mask, N, beta, C
7982

8083
def get_optimal_parameters_for_agglomerative_clustering(self, X, beta=0., C=10., N=100):
8184
"""
@@ -121,7 +124,7 @@ def fit(self, X, y=None):
121124

122125
# If some resolutions are not specified, automatically compute them
123126
if np.any(np.isnan(self.resolutions)) or self.clustering is None:
124-
delta, resolutions = self.get_optimal_parameters_for_agglomerative_clustering(X=X, beta=0., C=10, N=100)
127+
delta, resolutions = self.get_optimal_parameters_for_agglomerative_clustering(X=X, beta=self.beta, C=self.C, N=self.N)
125128
if self.clustering is None:
126129
if self.input == "point cloud":
127130
self.clustering = AgglomerativeClustering(n_clusters=None, linkage="single", distance_threshold=delta, affinity="euclidean")

0 commit comments

Comments
 (0)