diff --git a/README.rst b/README.rst index ddc0a9f..e731cfe 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,6 @@ -=========== +================================================ cryoSBI - Simulation-based Inference for Cryo-EM -=========== +================================================ .. start-badges @@ -134,7 +134,20 @@ The training config file should be a json file with the following structure: "THETA_SHIFT": 25, "THETA_SCALE": 25, "BATCH_SIZE": 256 - } + } + +Loading the posterior after training +------------------------------------ +After training the estimator, loading it in Python can be done with the load_estimator in the estimator_utils module. + +.. code:: python + + import cryo_sbi.utils.estimator_utils as est_utils + posterior = est_utils.load_estimator( + config_file_path="path_to_config_file", + estimator_path="path_to_estimator_file", + device="cuda" + ) Inference --------- @@ -161,3 +174,47 @@ We can quickly generate a histogram with 50 bins with the following piece of cod plt.hist(samples[:, idx_image].flatten(), np.linspace(0, simulator.max_index, 50)) In this case the x-axis is just the index of the structures in increasing order. + +Latent space +------------ + +Computing the latent features for simulated or experimental particles can be done using the compute_latent_repr function in the estimator_utils module. The function needs a trained posterior estimator and images and computes the latent representation for each image. + +.. code:: python + + import cryo_sbi.utils.estimator_utils as est_utils + latent_vecs = est_utils.compute_latent_repr( + compute_latent_repr( + estimator=posterior, + images=images, + batch_size=100, + device="cuda", + ) + +After we computed the latent representation for the images, one possible way to visualize the latent space is to use `UMAP `_ . UMAP generates a two-dimensional representation of the latent space, which should allow us to analyze its important features. + +.. code:: python + + import umap + reducer = umap.UMAP(metric="euclidian", n_components=2, n_neighbors=50) + embedding = reducer.fit_transform(latent_vecs.numpy()) + +We can quickly visualize the 2d latent space with matplotlib. + +.. code:: python + + import matplotlib.pyplot as plt + plt.scatter( + embedding[:, 0], + embedding[:, 1], + ) + + + + + + + + + +