Skip to content

Commit

Permalink
Support all the methods in go.py and add one more dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed May 19, 2024
1 parent 0d4250d commit 0e45a48
Showing 1 changed file with 34 additions and 33 deletions.
67 changes: 34 additions & 33 deletions examples/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,60 @@
import sys
import os
import subprocess
import re

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

supported_methods = {
'lle': 'Locally Linear Embedding',
'ltsa': 'Local Tangent Space Alignment',
'isomap': 'Isomap',
'mds': 'Multidimensional Scaling',
'pca': 'Principal Component Analysis',
'kpca': 'Kernel Principal Component Analysis',
't-sne': 't-distributed Stochastic Neighborhood Embedding',
'dm': 'Diffusion Map',
}

def generate_data(type, N=1000):
def generate_data(type, N=1000, random_state=None):
rng = np.random.RandomState(random_state)
if type=='swissroll':
tt = np.array((3*np.pi/2)*(1+2*np.random.rand(N)))
height = np.array((np.random.rand(N)-0.5))
tt = np.array((3*np.pi/2)*(1+2*rng.rand(N)))
height = np.array((rng.rand(N)-0.5))
X = np.array([tt*np.cos(tt), 10*height, tt*np.sin(tt)])
return X,tt
return X, tt
if type=='scurve':
tt = np.array((3*np.pi*(np.random.rand(N)-0.5)))
height = np.array((np.random.rand(N)-0.5))
tt = np.array((3*np.pi*(rng.rand(N)-0.5)))
height = np.array((rng.rand(N)-0.5))
X = np.array([np.sin(tt), 10*height, np.sign(tt)*(np.cos(tt)-1)])
return X,tt
return X, tt
if type=='helix':
tt = np.linspace(1,N,N).T / N
tt = tt*2*np.pi
X = np.r_[[(2+np.cos(8*tt))*np.cos(tt)],
[(2+np.cos(8*tt))*np.sin(tt)],
[np.sin(8*tt)]]
return X,tt
[(2+np.cos(8*tt))*np.sin(tt)],
[np.sin(8*tt)]]
return X, tt
if type=='twinpeaks':
X = rng.uniform(-1, 1, size=(N, 2))
tt = np.sin(np.pi * X[:, 0]) * np.tanh(X[:, 1])
tt += 0.1 * rng.normal(size=tt.shape)
X = np.vstack([X.T, tt])
return X, tt

raise Exception('Dataset is not supported')


def embed(data,method):
if method not in supported_methods:
raise Exception('Method is not supported by this script')

input_file = 'tapkee_input_data'
output_file = 'tapkee_output_data'
np.savetxt(input_file, data.T,delimiter=',')
tapkee_binary = 'bin/tapkee'
runner_string = '%s -i %s -o %s -m %s -k 20 --precompute --verbose --transpose-output --benchmark' % (tapkee_binary, input_file, output_file, method)
runner_string = '%s -i %s -o %s -m %s -k 20 --precompute --debug --verbose --transpose-output --benchmark' % (tapkee_binary, input_file, output_file, method)
print('-- To reproduce this use the following command', runner_string)
output = subprocess.check_output(runner_string, shell=True)
process = subprocess.run(runner_string, shell=True, capture_output=True, text=True)

if match := re.search(r'Parameter dimension reduction method = \[([a-zA-Z0-9() ]+)\]', process.stderr):
used_method = match.group(1)
else:
used_method = ''


embedded_data = np.loadtxt(output_file, delimiter=',')
os.remove(input_file)
os.remove(output_file)
return embedded_data
return embedded_data, used_method

def plot(data, embedded_data, colors='m', method=None):
fig = plt.figure()
Expand All @@ -72,22 +73,22 @@ def plot(data, embedded_data, colors='m', method=None):
ax.scatter(embedded_data[0], embedded_data[1], c=colors, cmap=plt.cm.Spectral, s=5)
plt.axis('tight')
plt.axis('off')
plt.title('Embedding' + (' with ' + method) if method else '', fontsize=9)
plt.title('Embedding' + (' with ' + method) if method else '', fontsize=9, wrap=True)

plt.show()

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Graphical example of dimension reduction with Tapkee.')
parser.add_argument('dataset', type=str, nargs=1, help='A dataset to embed. One of the following: %s' % str(['swissroll', 'scurve', 'helix']))
parser.add_argument('method', type=str, nargs=1, help='A method to use. One of the following %s' % str(list(supported_methods.keys())))
parser.add_argument('dataset', type=str, nargs=1, help='A dataset to embed. One of the following: %s' % str(['swissroll', 'scurve', 'helix', 'twinpeaks']))
parser.add_argument('method', type=str, nargs=1, help='A method to use. Any of the methods supported by Tapkee')
args = parser.parse_args()

dataset = args.dataset[0]
method = args.method[0]
print('-- Loading %s data' % dataset)
data, colors = generate_data(dataset)
print('-- Embedding %s data with %s' % (dataset,method))
embedded_data = embed(data, method)
print('-- Embedding %s data with %s' % (dataset, method))
embedded_data, used_method = embed(data, method)
print('-- Plotting embedded data')
plot(data, embedded_data, colors, supported_methods[method])
plot(data, embedded_data, colors, used_method)

0 comments on commit 0e45a48

Please sign in to comment.