Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/neuron_proofreader/proposal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def relabel_nodes(self):

# Reset graph
self.clear()
self.xyz_to_edge = dict()
for (i, j) in old_edge_ids:
edge_id = (int(old_to_new[i]), int(old_to_new[j]))
self._add_edge(edge_id, edge_attrs[(i, j)])
Expand Down
29 changes: 26 additions & 3 deletions src/neuron_proofreader/split_proofreading/split_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def add_graph(
gt_pointer,
pred_pointer,
img_path,
metadata_path=None,
segmentation_path=None
):
"""
Expand All @@ -97,12 +98,14 @@ def add_graph(
Path to predicted SWC files to be loaded.
img_path : str
Path to the raw image associated with the graph.
metadata_path : str
...
segmentation_path : str
Path to the segmentation mask associated with the graph.
"""
# Add graph
gt_graph = self.load_graph(gt_pointer, is_gt=True)
self.graphs[key] = self.load_graph(pred_pointer)
self.graphs[key] = self.load_graph(pred_pointer, metadata_path)
self.graphs[key].generate_proposals(
self.config.search_radius, gt_graph=gt_graph
)
Expand All @@ -117,7 +120,7 @@ def add_graph(
segmentation_path=segmentation_path
)

def load_graph(self, swc_pointer, is_gt=False):
def load_graph(self, swc_pointer, is_gt=False, metadata_path=None):
"""
Loads a graph by reading and processing SWC files specified by
"swc_pointer".
Expand All @@ -126,6 +129,8 @@ def load_graph(self, swc_pointer, is_gt=False):
----------
swc_pointer : str
Path to SWC files to be loaded.
metadata_path : str
...

Returns
-------
Expand All @@ -139,8 +144,9 @@ def load_graph(self, swc_pointer, is_gt=False):
)
graph.load(swc_pointer)

# Filter doubles (if applicable)
# Post process fragments
if not is_gt:
self.clip_fragments(graph, metadata_path)
geometry_util.remove_doubles(graph, 200)
return graph

Expand Down Expand Up @@ -198,6 +204,23 @@ def get_next_key(self, samplers):
return keys[0]

# --- Helpers ---
@staticmethod
def clip_fragments(graph, metadata_path):
# Extract bounding box
bucket_name, path = util.parse_cloud_path(metadata_path)
metadata = util.read_json_from_gcs(bucket_name, path)
origin = metadata["chunk_origin"][::-1]
shape = metadata["chunk_shape"][::-1]

# Clip graph
nodes = list()
for i in graph.nodes:
voxel = graph.get_voxel(i)
if not img_util.is_contained(voxel - origin, shape):
nodes.append(i)
graph.remove_nodes_from(nodes)
graph.relabel_nodes()

def n_proposals(self):
"""
Counts the number of proposals in the dataset.
Expand Down
72 changes: 69 additions & 3 deletions src/neuron_proofreader/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,28 @@ def read_zip(zip_file, path):
return f.read().decode("utf-8")


def update_txt(path, text, verbose=True):
"""
Appends the given text to a specified text file and prints the text.

Parameters
----------
path : str
Path to txt file where the text will be appended.
text : str
Text to be written to the file.
verbose : bool, optional
Indication of whether to printout text. Default is True.
"""
# Printout text (if applicable)
if verbose:
print(text)

# Update txt file
with open(path, "a") as file:
file.write(text + "\n")


def write_json(path, contents):
"""
Writes "contents" to a JSON file at "path".
Expand Down Expand Up @@ -342,6 +364,28 @@ def write_txt(path, contents):


# --- GCS utils ---
def check_gcs_file_exists(bucket_name, path):
"""
Checks if the given path exists.

Parameters
----------
bucket_name : str
Name of bucket to be checked.
path : str
Path to be checked.

Returns
-------
bool
Indication of whether the path exists.
"""
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(path)
return blob.exists()


def is_gcs_path(path):
"""
Checks if the path is a GCS path.
Expand All @@ -359,7 +403,7 @@ def is_gcs_path(path):
return path.startswith("gs://")


def list_gcs_filenames(bucket_name, prefix, extension):
def list_gcs_filenames(bucket_name, prefix, extension=""):
"""
Lists all files in a GCS bucket with the given extension.

Expand All @@ -369,8 +413,8 @@ def list_gcs_filenames(bucket_name, prefix, extension):
Name of bucket to be searched.
prefix : str
Path to location within bucket to be searched.
extension : str
File extension of filenames to be listed.
extension : str, optional
File extension of filenames to be listed. Default is an empty string.

Returns
-------
Expand Down Expand Up @@ -416,6 +460,28 @@ def list_gcs_subdirectories(bucket_name, prefix):
return subdirs


def read_json_from_gcs(bucket_name, blob_path):
"""
Reads JSON file stored in a GCS bucket.

Parameters
----------
bucket_name : str
Name of the GCS bucket containing the JSON file.
blob_path : str
Path to the JSON file within the GCS bucket.

Returns
-------
dict
Parsed JSON content as a Python dictionary.
"""
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_path)
return json.loads(blob.download_as_text())


# --- S3 utils ---
def is_s3_path(path):
"""
Expand Down
Loading