diff --git a/pixi.lock b/pixi.lock index 5c96fd3..3667f1f 100644 --- a/pixi.lock +++ b/pixi.lock @@ -4018,7 +4018,7 @@ packages: - pypi: ./ name: neuview version: 2.7.8 - sha256: ba8a929c3d6b1176a93c23591890057cad388f10af92896a470630c69b74ded9 + sha256: 735192b6759f2ca97e794e21ce95354ccb4d20c6dcc2b23e66a9fbecdeab5fec requires_dist: - click>=8.0.0 - jinja2>=3.0.0 diff --git a/pyproject.toml b/pyproject.toml index f985bb6..5fd4268 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ version = "neuview --version" setup-env = "cp .env.example .env && echo 'Created .env file. Please edit it and add your NEUPRINT_TOKEN.'" fill-all = "neuview fill-queue --all" create-list = "neuview create-list" +create-scatter = "neuview create-scatter" [tool.pytest.ini_options] testpaths = [ "test",] @@ -77,6 +78,9 @@ task = "pop-all" [[tool.pixi.tasks.create-all-pages.depends-on]] task = "create-list" +[[tool.pixi.tasks.create-all-pages.depends-on]] +task = "create-scatter" + [[tool.pixi.tasks.create-all-pages.depends-on]] task = "increment-version" @@ -104,6 +108,9 @@ args = [ "",] [[tool.pixi.tasks.subset-medium-no-index.depends-on]] task = "pop-all" +[[tool.pixi.tasks.subset-medium-no-index.depends-on]] +task = "create-scatter" + [tool.pixi.tasks.subset-medium] [[tool.pixi.tasks.subset-medium.depends-on]] task = "subset-medium-no-index" @@ -111,6 +118,9 @@ task = "subset-medium-no-index" [[tool.pixi.tasks.subset-medium.depends-on]] task = "create-list" +[[tool.pixi.tasks.subset-medium.depends-on]] +task = "create-scatter" + [tool.pixi.tasks.subset-small-no-index] [[tool.pixi.tasks.subset-small-no-index.depends-on]] task = "clean-output" @@ -122,6 +132,9 @@ args = [ "config.yaml", "subset-small",] [[tool.pixi.tasks.subset-small-no-index.depends-on]] task = "pop-all" +[[tool.pixi.tasks.subset-small-no-index.depends-on]] +task = "create-scatter" + [tool.pixi.tasks.subset-small] [[tool.pixi.tasks.subset-small.depends-on]] task = "subset-small-no-index" @@ -129,6 +142,9 @@ task = "subset-small-no-index" [[tool.pixi.tasks.subset-small.depends-on]] task = "create-list" +[[tool.pixi.tasks.subset-small.depends-on]] +task = "create-scatter" + [tool.pixi.tasks.increment-version] cmd = "python scripts/increment_version.py" diff --git a/src/neuview/cli.py b/src/neuview/cli.py index f5c1682..3de64b1 100644 --- a/src/neuview/cli.py +++ b/src/neuview/cli.py @@ -10,6 +10,7 @@ import sys from typing import Optional import logging +from pathlib import Path from .commands import ( GeneratePageCommand, @@ -17,6 +18,7 @@ FillQueueCommand, PopCommand, CreateListCommand, + CreateScatterCommand, ) from .services import ServiceContainer from .services.neuron_discovery_service import InspectNeuronTypeCommand @@ -375,5 +377,30 @@ async def run_create_list(): asyncio.run(run_create_list()) +@main.command("create-scatter") +@click.pass_context +def create_scatter(ctx): + """Generate three SVG scatterplots of spatial metrics for optic lobe types.""" + services = setup_services(ctx.obj["config_path"], ctx.obj["verbose"]) + + async def run_create_scatter(): + + await services.scatter_service.create_scatterplots() + + # Print the three scatterplot files that should have been created + scfg = services.scatter_service.scatter_config + scatter_dir = Path(scfg.scatter_dir) + fname = scfg.scatter_fname + + for region in ("ME", "LO", "LOP"): + file_path = scatter_dir / f"{region}_{fname}" + if file_path.exists(): + click.echo(f"✅ Created: {file_path}") + else: + click.echo(f"⚠️ Expected but not found: {file_path}", err=True) + + asyncio.run(run_create_scatter()) + + if __name__ == "__main__": main() diff --git a/src/neuview/commands.py b/src/neuview/commands.py index f058615..cf0e0f6 100644 --- a/src/neuview/commands.py +++ b/src/neuview/commands.py @@ -94,6 +94,17 @@ def __post_init__(self): self.requested_at = datetime.now() +@dataclass +class CreateScatterCommand: + """Command to create svg scatterplots of spatial metrics.""" + + requested_at: Optional[datetime] = None + + def __post_init__(self): + if self.requested_at is None: + self.requested_at = datetime.now() + + @dataclass class DatasetInfo: """Information about the dataset.""" diff --git a/src/neuview/services/cache_service.py b/src/neuview/services/cache_service.py index 1002866..64658f9 100644 --- a/src/neuview/services/cache_service.py +++ b/src/neuview/services/cache_service.py @@ -171,6 +171,9 @@ async def save_neuron_data_to_cache( # Filter ROIs by threshold and clean names (same logic as IndexService) threshold = self.threshold_service.get_roi_filtering_threshold() + threshold_high = self.threshold_service.get_roi_filtering_threshold( + profile_name="roi_filtering_strict" + ) cleaned_roi_summary = [] seen_names = set() @@ -189,8 +192,7 @@ async def save_neuron_data_to_cache( ) if clean_name not in seen_names: seen_names.add(clean_name) - cleaned_roi_summary.append( - { + entry = { "name": clean_name, "pre_percentage": roi["pre_percentage"], "post_percentage": roi["post_percentage"], @@ -198,7 +200,14 @@ async def save_neuron_data_to_cache( "pre_synapses": roi["pre"], "post_synapses": roi["post"], } - ) + if clean_name in ["ME", "LO", "LOP"]: + if (roi["pre_percentage"] >= threshold_high and (roi["pre"]+roi["post"])>50 + or roi["post_percentage"] >= threshold_high and (roi["pre"]+roi["post"])>50): + include_in_scatter = 1 + else: + include_in_scatter = 0 + entry["incl_scatter"] = include_in_scatter + cleaned_roi_summary.append(entry) roi_summary = cleaned_roi_summary @@ -228,8 +237,9 @@ async def save_neuron_data_to_cache( parent_rois = sorted(list(parent_rois_set)) - # Calculate spatial metrics for columns if column ROIs are present - # Currently these are calculated from both L and R instances + # Calculate spatial metrics for columns if column ROIs are present. + # These metrics are calculated using synapses within the ROI from + # both L and R instances. for side in ["L", "R"]: for region in ["ME", "LO", "LOP"]: str_pattern = f"{region}_{side}_col_" diff --git a/src/neuview/services/index_service.py b/src/neuview/services/index_service.py index 4732977..2bf529a 100644 --- a/src/neuview/services/index_service.py +++ b/src/neuview/services/index_service.py @@ -56,17 +56,17 @@ async def create_index(self, command) -> Result[str, str]: return Err(f"Output directory does not exist: {output_dir}") # Discover neuron types from cache or file scanning - neuron_types, scan_time = self._discover_neuron_types(output_dir) + neuron_types, scan_time = self.discover_neuron_types(output_dir) if not neuron_types: return Err("No neuron type HTML files found in output directory") # Initialize connector if needed for database lookups - connector = await self._initialize_connector_if_needed( + connector = await self.initialize_connector_if_needed( neuron_types, output_dir ) # Correct neuron names (convert filenames back to original names) - corrected_neuron_types, cache_performance = self._correct_neuron_names( + corrected_neuron_types, cache_performance = self.correct_neuron_names( neuron_types, connector ) @@ -95,7 +95,7 @@ async def create_index(self, command) -> Result[str, str]: logger.error(f"Failed to create optimized index: {e}") return Err(f"Failed to create index: {str(e)}") - def _discover_neuron_types(self, output_dir: Path) -> tuple: + def discover_neuron_types(self, output_dir: Path) -> tuple: """Discover neuron types from queue file to ensure all are included.""" neuron_types = defaultdict(set) @@ -209,7 +209,7 @@ def _discover_neuron_types(self, output_dir: Path) -> tuple: ) return neuron_types, 0.0 - async def _initialize_connector_if_needed(self, neuron_types, output_dir): + async def initialize_connector_if_needed(self, neuron_types, output_dir): """Initialize database connector only if needed for lookups.""" # Pre-load ROI hierarchy from cache (no database queries if cached) roi_hierarchy_loaded = False @@ -278,7 +278,7 @@ async def _initialize_connector_if_needed(self, neuron_types, output_dir): return connector - def _correct_neuron_names(self, neuron_types, connector): + def correct_neuron_names(self, neuron_types, connector): """Correct neuron names by converting filenames back to original names.""" cached_data_lazy = ( self.cache_manager.get_cached_data_lazy() if self.cache_manager else None diff --git a/src/neuview/services/scatterplot_service.py b/src/neuview/services/scatterplot_service.py new file mode 100644 index 0000000..29b2e79 --- /dev/null +++ b/src/neuview/services/scatterplot_service.py @@ -0,0 +1,525 @@ +""" +Interactive Scatterplot Service + +Simplified service that coordinates other specialized services to create +interactive scatterplot page with plots related to the spatial metrics per type. +""" + +import logging +from pathlib import Path +import pandas as pd +from math import ceil, floor, log10, isfinite +from ..result import Err +from jinja2 import Environment, FileSystemLoader, Template +from ..config import Config +from ..utils import get_templates_dir + +from .index_service import IndexService +from ..visualization.rendering.rendering_config import ScatterConfig + +logger = logging.getLogger(__name__) + + +class ScatterplotService: + """Service for creating scatterplots with markers for all available neuron types.""" + + def __init__(self): + + self.config = Config.load("config.yaml") + self.scatter_config = ScatterConfig() + + if isinstance(self.scatter_config.scatter_dir, str): + self.plot_output_dir = self.scatter_config.scatter_dir + plot_dir = Path(self.plot_output_dir) + plot_dir.mkdir(parents=True, exist_ok=True) + + # Initialize cache manager for neuron type data + self.cache_manager = None + if ( + self.config + and hasattr(self.config, "output") + and hasattr(self.config.output, "directory") + ): + self.output_dir = self.config.output.directory + from ..cache import create_cache_manager + + self.cache_manager = create_cache_manager(self.output_dir) + + async def create_scatterplots(self): + """Create scatterplots of spatial metrics for optic lobe neuron types.""" + + try: + page_generator = ( + None # or a tiny stub object if your constructors assume methods exist + ) + index = IndexService(self.config, page_generator) + + # 3) Use the instance properly + neuron_types, _ = index.discover_neuron_types(Path(self.output_dir)) + if not neuron_types: + return Err("No neuron type HTML files found in output directory") + + # Initialize connector if needed for database lookups + connector = await index.initialize_connector_if_needed( + neuron_types, self.output_dir + ) + + # Correct neuron names (convert filenames back to original names) + corrected_neuron_types, _ = index.correct_neuron_names( + neuron_types, connector + ) + + # Generate scatterplot data for corrected neuron types + plot_data = self._extract_plot_data(corrected_neuron_types) + + # Within loop for side + side = "both" + + for region in ["ME", "LO", "LOP"]: + points = self._extract_points(plot_data, side=side, region=region) + if not points: + raise SystemExit( + f"No points found: ensure values exist for types within {side} {region}." + ) + + ctx = self._prepare( + self.scatter_config, points, side=side, region=region + ) + + template_dir = get_templates_dir() + template_env = Environment(loader=FileSystemLoader(template_dir)) + template = template_env.get_template(self.scatter_config.template_name) + svg_content = template.render(**ctx) + + # Write the index file + svg_path = f"{self.plot_output_dir}/{region}_{self.scatter_config.scatter_fname}" + + # If saving the images - check if they exist first + with open(svg_path, "w", encoding="utf-8") as f: + f.write(svg_content) + + return + + except Exception as e: + logger.error(f"Failed to create scatterplots: {e}") + return Err(f"Failed to create scatterplots: {str(e)}") + + def _extract_plot_data(self, neuron_types): + """Generate plot data from list of neuron types.""" + + cached_data_lazy = ( + self.cache_manager.get_cached_data_lazy() if self.cache_manager else None + ) + + plot_data, cached_count, missing_cache_count = [], 0, 0 + + names = neuron_types.keys() if isinstance(neuron_types, dict) else neuron_types + + for neuron_name in names: + + cache_data = ( + cached_data_lazy.get(neuron_name) + if cached_data_lazy is not None + else None + ) + + entry = { + "name": neuron_name, + "total_count": 0, + "left_count": 0, + "right_count": 0, + "middle_count": 0, + "undefined_count": 0, + "has_undefined": False, + "spatial_metrics": {}, + } + + if cache_data is not None: + # ---- counts ---- + if ( + hasattr(cache_data, "total_count") + and cache_data.total_count is not None + ): + entry["total_count"] = cache_data.total_count + + ssc = {} + if ( + hasattr(cache_data, "soma_side_counts") + and cache_data.soma_side_counts + ): + ssc = cache_data.soma_side_counts + + if isinstance(ssc, dict): + if "left" in ssc and ssc["left"] is not None: + entry["left_count"] = ssc["left"] + if "right" in ssc and ssc["right"] is not None: + entry["right_count"] = ssc["right"] + if "middle" in ssc and ssc["middle"] is not None: + entry["middle_count"] = ssc["middle"] + + undefined_sum = 0 + if "unknown" in ssc and ssc["unknown"] is not None: + undefined_sum += ssc["unknown"] + if "undefined" in ssc and ssc["undefined"] is not None: + undefined_sum += ssc["undefined"] + entry["undefined_count"] = undefined_sum + entry["has_undefined"] = undefined_sum > 0 + + # ---- spatial metrics (raw) ---- + sm = {} + if ( + hasattr(cache_data, "spatial_metrics") + and cache_data.spatial_metrics + ): + sm = cache_data.spatial_metrics + + # ---- roi_summary source (for incl_scatter) ---- + roi_source = {} + if hasattr(cache_data, "roi_summary") and cache_data.roi_summary: + rs = cache_data.roi_summary + if isinstance(rs, dict): + roi_source = rs + elif isinstance(rs, list): + # turn [{'name': 'ME', ...}, ...] into {'ME': {...}, ...} if applicable + tmp = {} + for item in rs: + if isinstance(item, dict) and "name" in item: + nm = item["name"] + tmp[nm] = item + roi_source = tmp + + # ---- write incl_scatter into each side/region dict ---- + # If you only want it under "both", change sides_to_update = ("both",) + sides_to_update = ("both", "L", "R") + for region in ("ME", "LO", "LOP"): + incl_val = None + if isinstance(roi_source, dict) and region in roi_source: + region_src = roi_source[region] + if ( + isinstance(region_src, dict) + and "incl_scatter" in region_src + ): + incl_val = region_src["incl_scatter"] + + for side_key in sides_to_update: + if isinstance(sm, dict): + if side_key not in sm or sm[side_key] is None: + sm[side_key] = {} + side_dict = sm[side_key] + if region not in side_dict or side_dict[region] is None: + side_dict[region] = {} + region_dict = side_dict[region] + if isinstance(region_dict, dict): + region_dict["incl_scatter"] = incl_val + + # finally attach sm + entry["spatial_metrics"] = sm + + logger.debug(f"Used cached data for {neuron_name}") + cached_count += 1 + else: + logger.debug(f"No cached data available for {neuron_name}") + missing_cache_count += 1 + + plot_data.append(entry) + + plot_data.sort(key=lambda x: x["name"]) + + if missing_cache_count > 0: + logger.warning( + f"Plot data generation completed: {len(plot_data)} entries, " + f"{cached_count} with cache, {missing_cache_count} missing cache. " + f"Run 'quickpage generate' to populate cache." + ) + else: + logger.info( + f"Plot data generation completed: {len(plot_data)} entries, all with cached data" + ) + + return plot_data + + def _extract_points(self, plot_data, side, region): + """ + Collate the data points required to make the spatial + metric scatterplots. + """ + pts = [] + for rec in plot_data: + + incl = ( + rec.get("spatial_metrics", {}) + .get(side, {}) + .get(region, {}) + .get("incl_scatter") + ) + + # Only include types that have "incl_scatter" == 1. + # Pass threshold for syn % and syn #. + if incl == 1: + name = rec.get("name", "unknown") + x = rec.get("total_count") + y = ( + rec.get("spatial_metrics", {}) + .get(side, {}) + .get(region, {}) + .get("cell_size") + ) + c = ( + rec.get("spatial_metrics", {}) + .get(side, {}) + .get(region, {}) + .get("coverage") + ) + col_count = ( + rec.get("spatial_metrics", {}) + .get(side, {}) + .get(region, {}) + .get("cols_innervated") + ) + + # require x,y positive for log scales + if x is None or y is None or c is None: + continue + try: + x = float(x) + y = float(y) + c = float(c) + except Exception: + continue + if x <= 0 or y <= 0: + continue + + # Optional data quality filter from prior script + if col_count is not None: + try: + if float(col_count) <= 9: + continue + except Exception: + pass + + pts.append( + { + "name": name, + "x": x, + "y": y, + "coverage": c, + "col_count": ( + float(col_count) if col_count is not None else None + ), + } + ) + return pts + + def _prepare( + self, + config, + points, + side=None, + region=None, + ): + """Compute pixel positions for an SVG scatter plot (color by coverage).""" + + # Range depends on values of "points" + xmin = min(p["x"] for p in points) + xmax = max(p["x"] for p in points) + ymin = min(p["y"] for p in points) + ymax = max(p["y"] for p in points) + + # expand bounds slightly so dots don't sit on the frame (keep >0) + pad_x = xmin * 0.05 + pad_y = ymin * 0.08 + xmin = max(1e-12, xmin - pad_x) + ymin = max(1e-12, ymin - pad_y) + xmax *= 1.05 + ymax *= 1.08 + + lxmin, lxmax = log10(xmin), log10(xmax) + lymin, lymax = log10(ymin), log10(ymax) + dx = lxmax - lxmin + dy = lymax - lymin + + if dx > dy: + # expand Y range to match X span (around geometric center) + cy = (lymin + lymax) / 2.0 + lymin, lymax = cy - dx / 2.0, cy + dx / 2.0 + ymin, ymax = 10**lymin, 10**lymax + elif dy > dx: + # expand X range to match Y span (around geometric center) + cx = (lxmin + lxmax) / 2.0 + lxmin, lxmax = cx - dy / 2.0, cx + dy / 2.0 + xmin, xmax = 10**lxmin, 10**lxmax + + # coverage color scaling with 98th percentile clipping + coverages = [p["coverage"] for p in points] + cmin = min(coverages) + cmax = self._percentile(coverages, 98.0) or max(coverages) + crng = (cmax - cmin) if isfinite(cmax - cmin) and (cmax - cmin) > 0 else 1.0 + + # Inner drawing range to create a visible gap to axes + inner_x0, inner_x1 = config.axis_gap_px, max( + config.axis_gap_px, config.plot_w - config.axis_gap_px + ) + inner_y0, inner_y1 = ( + config.plot_h - config.axis_gap_px, + config.axis_gap_px, + ) # inverted + + def sx(v): + return self._scale_log10(v, xmin, xmax, inner_x0, inner_x1) + + def sy(v): + return self._scale_log10(v, ymin, ymax, inner_y0, inner_y1) + + for p in points: + p["sx"] = sx(p["x"]) + p["sy"] = sy(p["y"]) # SVG y grows downward + # color by coverage (clipped at cmax) + t_raw = (min(p["coverage"], cmax) - cmin) / crng + t = max(0.0, min(1.0, t_raw)) + p["color"] = self._cov_to_rgb(t) + p["r"] = config.marker_size + p["line_width"] = config.marker_line_width + p["type"] = f"{p['name']}" + p["tooltip"] = ( + f"{p['name']} - {region}({side}):\n" + f" {int(p['x'])} cells:\n" + f" cell_size: {p['y']:.2f}\n" + f" coverage: {p['coverage']:.2f}" + ) + + # Reference (anti-diagonal) guide lines under points + col_counts = [p["col_count"] for p in points if p.get("col_count")] + if col_counts: + n_cols_region = max(col_counts) + else: + n_cols_region = 10 ** ((log10(xmin * ymin) + log10(xmax * ymax)) / 4) + + # Add guide lines to scatter plot + multipliers = [0.2, 0.5, 1, 2, 5] + + def guide_width(m): + if m < 0.5 or m > 2: + return 0.25 + elif m != 1: + return 0.4 + else: + return 0.8 + + guide_lines = [] + for m in multipliers: + k = n_cols_region * m # x*y = k + x0_clip = max(xmin, k / ymax) + x1_clip = min(xmax, k / ymin) + if x0_clip >= x1_clip: + continue # out of view + y0 = k / x0_clip + y1 = k / x1_clip + guide_lines.append( + { + "x1": sx(x0_clip), + "y1": sy(y0), + "x2": sx(x1_clip), + "y2": sy(y1), + "w": guide_width(m), + } + ) + + # Precompute pixel tick positions for Jinja (avoid math inside template) + def log_pos_x(t): + return self._scale_log10(t, xmin, xmax, inner_x0, inner_x1) + + def log_pos_y(t): + return self._scale_log10(t, ymin, ymax, inner_y0, inner_y1) + + xtick_data = [{"t": t, "px": log_pos_x(t)} for t in config.xticks] + ytick_data = [{"t": t, "py": log_pos_y(t)} for t in config.yticks] + + ctx = self._prepare_template_variables( + points, guide_lines, config, region, xtick_data, ytick_data, cmin, cmax + ) + + return ctx + + def _prepare_template_variables( + self, points, guide_lines, config, region, xtick_data, ytick_data, cmin, cmax + ): + """Prepare variables for template rendering. + Args: + points: Processed scatter points + guide_lines: Points to draw plot guidelines + config: Scatter configuration + region: Optic lobe region for which to generate plot. ME, LO or LOP. + Returns: + Dictionary of template variables + """ + template_vars = { + "width": config.width, + "height": config.height, + "margin_top": config.margin_top, + "margin_right": config.margin_right, + "margin_bottom": config.margin_bottom, + "margin_left": config.margin_left, + "plot_w": config.plot_w, + "plot_h": config.plot_h, + "cmin": cmin, + "cmax": cmax, + "points": points, + "xtick_data": xtick_data, + "ytick_data": ytick_data, + "guide_lines": guide_lines, + "title": region, + "xlabel": config.xlabel, + "ylabel": config.ylabel, + "legend_label": config.legend_label, + "legend_w": config.legend_w, + } + + return template_vars + + def _log_ticks(self, vmin, vmax): + """Ticks for a log10 axis between (vmin, vmax), inclusive.""" + if vmin <= 0 or vmax <= 0 or vmin >= vmax: + return [] + lo = floor(log10(vmin)) + hi = ceil(log10(vmax)) + return [10**e for e in range(lo, hi + 1)] + + def _scale_log10(self, v, vmin, vmax, a, b): + """Log10 scaling to pixels.""" + lv = log10(v) + lmin = log10(vmin) + lmax = log10(vmax) + if lmax == lmin: + return (a + b) / 2.0 + return a + (lv - lmin) * (b - a) / (lmax - lmin) + + def _lerp(self, a, b, t): + return a + (b - a) * t + + def _cov_to_rgb(self, t): + """ + Map t in [0,1] to a white→dark red gradient. + start = white (255,255,255), end = dark red (~180,0,0) + """ + r0, g0, b0 = 255, 255, 255 + r1, g1, b1 = 180, 0, 0 + r = int(round(self._lerp(r0, r1, t))) + g = int(round(self._lerp(g0, g1, t))) + b = int(round(self._lerp(b0, b1, t))) + return f"rgb({r},{g},{b})" + + def _percentile(self, values, p): + """ + p in [0, 100]. Returns None on no finite data. + Uses pandas.Series.quantile with the right keyword for the installed version. + """ + s = pd.Series(values, dtype="float64").dropna() + if s.empty: + return None + + q = p / 100 + # Prefer the 2.x API if available; fall back to 1.5.x + try: + return float(s.quantile(q, method="linear")) # pandas 2.x + except TypeError: + return float(s.quantile(q, interpolation="linear")) # pandas 1.5.x diff --git a/src/neuview/services/service_container.py b/src/neuview/services/service_container.py index df49aa2..3ce2566 100644 --- a/src/neuview/services/service_container.py +++ b/src/neuview/services/service_container.py @@ -40,6 +40,7 @@ def __init__(self, config, copy_mode: str = "check_exists"): self._cache_service = None self._soma_detection_service = None self._neuron_statistics_service = None + self._scatter_service = None # Phase 3 managers self._template_manager = None @@ -300,6 +301,17 @@ def create(): return self._get_or_create_service("index_service", create) + @property + def scatter_service(self): + """Get or create scatterplot service.""" + + def create(): + from .scatterplot_service import ScatterplotService + + return ScatterplotService() + + return self._get_or_create_service("scatter_service", create) + def cleanup(self): """Clean up services and resources.""" # Close any connections or clean up resources diff --git a/src/neuview/visualization/rendering/rendering_config.py b/src/neuview/visualization/rendering/rendering_config.py index 9f18356..007a4c7 100644 --- a/src/neuview/visualization/rendering/rendering_config.py +++ b/src/neuview/visualization/rendering/rendering_config.py @@ -173,3 +173,90 @@ def to_dict(self) -> Dict[str, Any]: "thresholds": self.thresholds, "layer_thresholds": self.layer_thresholds, } + + +@dataclass +class ScatterConfig: + """ + Configuration for scatterplot rendering. + """ + + # Output configuration + output_format: str = "svg" + save_to_files: bool = True + + # File management + scatter_dir: Optional[Path] = "output/scatter" + scatter_fname = "scatter.svg" + + # Layout configuration + margins: list = (60, 72, 64, 72) + axis_gap_px: int = 10 + + # Marker features + marker_size: int = 4 + marker_line_width: float = 0.5 + + # SVG-specific configuration + template_name: str = "scatterplot.svg.jinja" + + # Content configuration + title: str = "" + xlabel: str = "Population size (no. cells per type)" + ylabel: str = "Cell size (no. columns per cell)" + legend_label: str = "Coverage factor (cells per column)" + + # Data configuration + min_max_data: Optional[Dict[str, Any]] = None + thresholds: Optional[Dict[str, Any]] = None + + top, right, bottom, left = margins + margin_top = top + margin_right = right + margin_bottom = bottom + margin_left = left + width = 460 + height = 460 + plot_w = width - left - right + plot_h = height - top - bottom + + side_px = min(plot_w, plot_h) + plot_w = side_px + plot_h = side_px + + xticks = [1, 10, 100, 1000] + yticks = [1, 10, 100, 1000] + + legend_w = 12 + + def get_template_path(self) -> Optional[Path]: + """Get the full path to the template file.""" + # Templates are now loaded from the built-in templates directory + return get_templates_dir() / self.template_name + + def to_dict(self) -> Dict[str, Any]: + """Convert layout config to dictionary for template rendering.""" + return { + "width": self.width, + "height": self.height, + "xticks": self.xticks, + "yticks": self.yticks, + "marker_size": self.marker_size, + "margin_top": self.top, + "margin_right": self.right, + "margin_bottom": self.bottom, + "margin_left": self.left, + "legend_w": self.legend_w, + "xlabel": self.xlabel, + "ylabel": self.ylabel, + "legend_label": self.legend_label, + "axis_gap_px": self.axis_gap_px, + "plot_h": self.plot_h, + "plot_w": self.plot_w, + } + + def copy(self, **overrides) -> "ScatterConfig": + """Create a copy of this config with optional overrides.""" + from dataclasses import replace + + return replace(self, **overrides) diff --git a/static/js/neuron-page.js b/static/js/neuron-page.js index 0adb10f..f0355c5 100644 --- a/static/js/neuron-page.js +++ b/static/js/neuron-page.js @@ -565,5 +565,141 @@ function initializeAllTooltips() { }, 100); } +function highlightInSvgDocument(doc, neuronType) { + const needleName = String(neuronType || "").trim().toLowerCase(); + if (!needleName) return 0; + console.log(`Needle: ${needleName}.`); + + let candidates = Array.from(doc.querySelectorAll('g.marker')); + if (candidates.length === 0) { + candidates = Array.from(doc.querySelectorAll('circle.dot')) + .map(c => c.closest('g.marker') || c.parentNode) + .filter(Boolean); + } + if (candidates.length === 0) return 0; + const seenSvgs = new WeakSet(); + let hitCount = 0; + + for (const g of candidates) { + const svgEl = g.ownerSVGElement || doc.querySelector('svg'); + if (!svgEl || seenSvgs.has(svgEl)) continue; + + const circle = g.querySelector('circle') || g; + if (!circle) continue; + + const haystack = (circle.getAttribute('data-type') || '').toLowerCase(); + if (!haystack) continue; + console.log(`haystack: ${haystack}.`); + + // Require an exact, case-insensitive name match + if (haystack !== needleName) continue; + + const win = doc.defaultView; + const rect = circle.getBoundingClientRect(); + const evtLike = { + currentTarget: g, + clientX: rect.left + rect.width / 2, + clientY: rect.top + rect.height / 2 + }; + + let usedShowTip = false; + try { + if (win && typeof win.showTip === 'function') { + win.showTip(evtLike); + usedShowTip = true; + } + } catch (_) {} + + if (!usedShowTip) { + try { + g.parentNode && g.parentNode.appendChild(g); + + const baseR = parseFloat(circle.getAttribute('data-base-r') || '4'); + const baseSW = parseFloat(circle.getAttribute('data-base-sw') || + (doc.defaultView?.getComputedStyle(circle).strokeWidth || '0.5')); + circle.setAttribute('r', String(baseR * 3)); + circle.setAttribute('stroke-width', String(baseSW * 3)); + + const tip = doc.getElementById('tooltip'); + const tg = doc.getElementById('tooltip-text-group'); + const bg = doc.getElementById('tooltip-bg'); + if (tip && tg && bg) { + while (tg.firstChild) tg.removeChild(tg.firstChild); + const lines = (circle.getAttribute('data-title') || '') + .split('\n').filter(s => s.trim().length); + const pad = 6, lh = 14; + lines.forEach((line, i) => { + const t = doc.createElementNS('http://www.w3.org/2000/svg', 'text'); + t.setAttribute('x', pad); + t.setAttribute('y', pad + lh + i * lh); + t.setAttribute('class', 'tooltip-text'); + t.textContent = line; + tg.appendChild(t); + }); + const boxW = 350; + const boxH = lines.length * lh + pad * 2; + bg.setAttribute('width', boxW); + bg.setAttribute('height', boxH); + + const svgRect = svgEl.getBoundingClientRect(); + let x = rect.left - svgRect.left + 10; + let y = rect.top - svgRect.top - boxH - 10; + const vbW = svgEl.viewBox?.baseVal?.width || svgRect.width; + if (x + boxW > vbW) x = vbW - boxW - 5; + if (y < 0) y = rect.top - svgRect.top + 10; + + tip.setAttribute('transform', `translate(${x},${y})`); + tip.setAttribute('opacity', '1'); + + const tEl = g.querySelector('title'); + if (tEl) tEl.textContent = ''; + } + } catch (_) {} + } + + seenSvgs.add(svgEl); + hitCount++; + } + + return hitCount; +} + +// Discover SVGs in the page and highlight all of them. +function highlightNeuronAllPlots(neuronType) { + const needle = String(neuronType || '').trim(); + if (!needle) return; + + let total = 0; + + const objects = Array.from(document.querySelectorAll('object[type="image/svg+xml"]')); + for (const obj of objects) { + const run = () => { + try { + const doc = obj.contentDocument; + if (doc) { + const added = highlightInSvgDocument(doc, needle); + total += added; + if (added === 0) { + // helpful debug + console.warn('No match in SVG:', obj.data); + } + } + } catch (e) { + console.warn('Cannot access (likely cross-origin):', obj.data); + } + }; + if (obj.contentDocument && obj.contentDocument.readyState !== 'loading') { + run(); + } else { + obj.addEventListener('load', run, { once: true }); + } + } + + setTimeout(() => { + console.log(`Highlighted ${total} plot(s) for neuron "${needle}".`); + }, 0); +} + + // Initialize responsive navigation initializeResponsiveNavigation(); diff --git a/templates/neuron_page.html.jinja b/templates/neuron_page.html.jinja index 2d7418d..ce368dd 100644 --- a/templates/neuron_page.html.jinja +++ b/templates/neuron_page.html.jinja @@ -17,7 +17,6 @@ {% include "sections/layer_analysis.html.jinja" %} - {% include "sections/eyemaps.html.jinja" %} {% include "sections/neuroglancer.html.jinja" %} diff --git a/templates/scatterplot.svg.jinja b/templates/scatterplot.svg.jinja new file mode 100644 index 0000000..526d1f6 --- /dev/null +++ b/templates/scatterplot.svg.jinja @@ -0,0 +1,225 @@ + + + + + + + + + + + + + + + + {% for tick in xtick_data %} + + + {{ tick.t }} + + {% endfor %} + + + {% for tick in ytick_data %} + + + {{ tick.t }} + + {% endfor %} + + + {% for g in guide_lines %} + + {% endfor %} + + + {% for p in points %} + + + {{ p.tooltip }} + + + {% endfor %} + + + +{{ xlabel }} +{{ ylabel }} + + + {{ title }} + + + + + + + + + + + + {{legend_label}} + + + + >{{ '%.0f' % cmax }} + {{ '%.0f' % cmin }} + + + + + + + + \ No newline at end of file diff --git a/templates/sections/eyemaps.html.jinja b/templates/sections/eyemaps.html.jinja index 3da29bf..faec5f7 100644 --- a/templates/sections/eyemaps.html.jinja +++ b/templates/sections/eyemaps.html.jinja @@ -7,6 +7,7 @@ {%- if soma_side == 'combined' -%} {% include "sections/eyemaps_both.html.jinja" %} + {% include "sections/scatterplots.html.jinja" %} {%- else -%} {% include "sections/eyemaps_single.html.jinja" %} {%- endif -%} diff --git a/templates/sections/neuron_page_scripts.html.jinja b/templates/sections/neuron_page_scripts.html.jinja index df1681b..b085868 100644 --- a/templates/sections/neuron_page_scripts.html.jinja +++ b/templates/sections/neuron_page_scripts.html.jinja @@ -2,7 +2,7 @@ {% include "sections/global_scripts.html.jinja" %} {# -- Load external static JavaScript functions -- #} - +