Skip to content

Commit 63a4f3d

Browse files
committed
adding ability to run trinmfk without having to run nmfk first
1 parent 42b5e63 commit 63a4f3d

16 files changed

+125
-111
lines changed

Diff for: TELF/factorization/TriNMFk.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import numpy as np
2020
import warnings
2121
import scipy.sparse
22-
import numpy as np
2322
import os
23+
from pathlib import Path
2424

2525
try:
2626
import cupy as cp
@@ -87,6 +87,7 @@ class TriNMFk():
8787
def __init__(self,
8888
experiment_name="TriNMFk",
8989
nmfk_params={},
90+
save_path = "TriNMFk",
9091
nmf_verbose=False,
9192
use_gpu=False,
9293
n_jobs=-1,
@@ -108,6 +109,8 @@ def __init__(self,
108109
Name used for the experiment. Default is "TriNMFk".
109110
nmfk_params : str, optional
110111
Parameters for NMFk. See documentation for NMFk for the options.
112+
save_path : str, optional
113+
Used for save location when NMFk fit is not performed first, and TriNMFk fit is done.
111114
nmf_verbose : bool, optional
112115
If True, shows progress in each NMF operation. The default is False.
113116
use_gpu : bool, optional
@@ -151,8 +154,9 @@ def __init__(self,
151154
self.nmfk_fit = False
152155
self.pruned = pruned
153156
self.transpose = transpose
154-
self.save_path = "",
157+
self.save_path = save_path
155158
self.verbose = verbose
159+
self.save_path_full = ""
156160

157161
# organize n_jobs
158162
n_jobs, self.use_gpu = organize_n_jobs(use_gpu, n_jobs)
@@ -206,7 +210,7 @@ def fit_nmfk(self, X, Ks, note=""):
206210

207211
# Do NMFk
208212
nmfk_results = self.nmfk.fit(X, Ks, self.experiment_name, note)
209-
self.save_path = os.path.join(self.nmfk.save_path, self.nmfk.experiment_name)
213+
self.save_path_full = self.nmfk.save_path_full
210214

211215
# Do nmfk here
212216
self.nmfk_fit = True
@@ -215,7 +219,8 @@ def fit_nmfk(self, X, Ks, note=""):
215219

216220
def fit_tri_nmfk(self, X, k1k2:tuple):
217221
"""
218-
Factorize the input matrix ``X``, after applying ``fit_nmfk()`` to select the ``Wk`` and ``Hk``, to factorize the given matrix with ``k1k2=(Wk, Hk)``.
222+
Factorize the input matrix ``X``.\n
223+
after applying ``fit_nmfk()`` to select the ``Wk`` and ``Hk``, to factorize the given matrix with ``k1k2=(Wk, Hk)``.
219224
220225
Parameters
221226
----------
@@ -233,8 +238,20 @@ def fit_tri_nmfk(self, X, k1k2:tuple):
233238

234239

235240
if not self.nmfk_fit:
236-
warnings.warn("NMFk needs to be fit first. Use fit_nmfk function!")
237-
return
241+
name = (
242+
str(self.experiment_name)
243+
+ "_"
244+
+ str(self.n_iters)
245+
+ "iters_"
246+
+ str(self.n_inits)
247+
+ "inits"
248+
)
249+
self.save_path_full = os.path.join(self.save_path, name)
250+
try:
251+
if not Path(self.save_path_full).is_dir():
252+
Path(self.save_path_full).mkdir(parents=True)
253+
except Exception as e:
254+
print(e)
238255

239256
if self.transpose:
240257
if isinstance(X, np.ndarray):
@@ -306,7 +323,7 @@ def fit_tri_nmfk(self, X, k1k2:tuple):
306323

307324
# save the results
308325
np.savez_compressed(
309-
self.save_path
326+
self.save_path_full
310327
+ "/WSH"
311328
+ "_k="
312329
+ str(k1k2)

Diff for: docs/TELF.factorization.html

+4-2
Original file line numberDiff line numberDiff line change
@@ -760,14 +760,15 @@ <h2>Submodules<a class="headerlink" href="#submodules" title="Link to this headi
760760
others to do so.</p>
761761
<dl class="py class">
762762
<dt class="sig sig-object py" id="TELF.factorization.TriNMFk.TriNMFk">
763-
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">TELF.factorization.TriNMFk.</span></span><span class="sig-name descname"><span class="pre">TriNMFk</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">experiment_name</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'TriNMFk'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nmfk_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">{}</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nmf_verbose</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_gpu</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_jobs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_consensus_stopping</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">alpha</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">(0,</span> <span class="pre">0)</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_iters</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_inits</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">10</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pruned</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transpose</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">verbose</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/TELF/factorization/TriNMFk.html#TriNMFk"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#TELF.factorization.TriNMFk.TriNMFk" title="Link to this definition">#</a></dt>
763+
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">TELF.factorization.TriNMFk.</span></span><span class="sig-name descname"><span class="pre">TriNMFk</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">experiment_name</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'TriNMFk'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nmfk_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">{}</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">save_path</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'TriNMFk'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nmf_verbose</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_gpu</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_jobs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_consensus_stopping</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">alpha</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">(0,</span> <span class="pre">0)</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_iters</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_inits</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">10</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pruned</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transpose</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">verbose</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/TELF/factorization/TriNMFk.html#TriNMFk"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#TELF.factorization.TriNMFk.TriNMFk" title="Link to this definition">#</a></dt>
764764
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">object</span></code></p>
765765
<p>TriNMFk is a Non-negative Matrix Factorization module with the capability to do automatic model determination for both estimating the number of latent patterns (<code class="docutils literal notranslate"><span class="pre">Wk</span></code>) and clusters (<code class="docutils literal notranslate"><span class="pre">Hk</span></code>).</p>
766766
<dl class="field-list simple">
767767
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
768768
<dd class="field-odd"><ul class="simple">
769769
<li><p><strong>experiment_name</strong> (<em>str</em><em>, </em><em>optional</em>) – Name used for the experiment. Default is “TriNMFk”.</p></li>
770770
<li><p><strong>nmfk_params</strong> (<em>str</em><em>, </em><em>optional</em>) – Parameters for NMFk. See documentation for NMFk for the options.</p></li>
771+
<li><p><strong>save_path</strong> (<em>str</em><em>, </em><em>optional</em>) – Used for save location when NMFk fit is not performed first, and TriNMFk fit is done.</p></li>
771772
<li><p><strong>nmf_verbose</strong> (<em>bool</em><em>, </em><em>optional</em>) – If True, shows progress in each NMF operation. The default is False.</p></li>
772773
<li><p><strong>use_gpu</strong> (<em>bool</em><em>, </em><em>optional</em>) – If True, uses GPU for operations. The default is True.</p></li>
773774
<li><p><strong>n_jobs</strong> (<em>int</em><em>, </em><em>optional</em>) – Number of parallel jobs. Use -1 to use all available resources. The default is 1.</p></li>
@@ -819,7 +820,8 @@ <h2>Submodules<a class="headerlink" href="#submodules" title="Link to this headi
819820
<dl class="py method">
820821
<dt class="sig sig-object py" id="TELF.factorization.TriNMFk.TriNMFk.fit_tri_nmfk">
821822
<span class="sig-name descname"><span class="pre">fit_tri_nmfk</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">X</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">k1k2</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">tuple</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/TELF/factorization/TriNMFk.html#TriNMFk.fit_tri_nmfk"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#TELF.factorization.TriNMFk.TriNMFk.fit_tri_nmfk" title="Link to this definition">#</a></dt>
822-
<dd><p>Factorize the input matrix <code class="docutils literal notranslate"><span class="pre">X</span></code>, after applying <code class="docutils literal notranslate"><span class="pre">fit_nmfk()</span></code> to select the <code class="docutils literal notranslate"><span class="pre">Wk</span></code> and <code class="docutils literal notranslate"><span class="pre">Hk</span></code>, to factorize the given matrix with <code class="docutils literal notranslate"><span class="pre">k1k2=(Wk,</span> <span class="pre">Hk)</span></code>.</p>
823+
<dd><p>Factorize the input matrix <code class="docutils literal notranslate"><span class="pre">X</span></code>.</p>
824+
<p>after applying <code class="docutils literal notranslate"><span class="pre">fit_nmfk()</span></code> to select the <code class="docutils literal notranslate"><span class="pre">Wk</span></code> and <code class="docutils literal notranslate"><span class="pre">Hk</span></code>, to factorize the given matrix with <code class="docutils literal notranslate"><span class="pre">k1k2=(Wk,</span> <span class="pre">Hk)</span></code>.</p>
823825
<dl class="field-list simple">
824826
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
825827
<dd class="field-odd"><ul class="simple">

0 commit comments

Comments
 (0)