Skip to content

distributed: guard against pooled lag transforms in DistributedMLForecast#680

Open
simonez-tuidi wants to merge 4 commits into
Nixtla:mainfrom
simonez-tuidi:feature/guard-distributed-pooled-transforms
Open

distributed: guard against pooled lag transforms in DistributedMLForecast#680
simonez-tuidi wants to merge 4 commits into
Nixtla:mainfrom
simonez-tuidi:feature/guard-distributed-pooled-transforms

Conversation

@simonez-tuidi

Copy link
Copy Markdown
Contributor

Summary

Raise a clear NotImplementedError when a pooled lag transform (global_, groupby, or partition_by) is configured on DistributedMLForecast, instead of silently returning incorrect results.

Problem

DistributedMLForecast shards data by the id column and runs a fully independent, deep-copied TimeSeries.fit_transform/predict on each partition:

  • _preprocess_partitions sets partition = dict(by=id_col) for every backend (Spark/Dask/Ray).
  • _preprocess_partition does copy.deepcopy(base_ts) then ts.fit_transform(part, ...), and _predict calls ts.predict(...) per partition.

Plain/local lag transforms survive this because they're computed per series, and a series' rows always stay within one partition. Pooled transforms don't: their aggregates are defined over a population the sharding splits across workers —

  • global_ aggregates over the whole dataset,
  • groupby=[...] aggregates over a whole group (whose series the id-sharding scatters),
  • partition_by rides the same cross-series parent-calendar machinery, which assumes one TimeSeries owns every series.

Each worker only ever sees its own slice, so the feature is computed over the wrong set of series. Crucially, there was no guard — this ran to completion and returned plausible-looking but wrong numbers.

Change

In DistributedMLForecast.__init__, right after _base_ts is built, detect pooled transforms via TimeSeries._get_pooled_tfms() and raise NotImplementedError if any are present. The guard:

  • fires at construction, before any engine is touched, so it fails immediately regardless of backend and needs no cluster;
  • names the offending feature(s) in the message and points users to the local (non-distributed) MLForecast.

Tests

  • test_distributed_rejects_pooled_transforms — parametrized over global_, groupby, and partition_by; asserts each raises NotImplementedError.
  • test_distributed_allows_local_transforms — guards against over-rejection: a plain per-series RollingMean must still construct fine.

Notes

The guard rejects all pooled modes, including local-only partition_by (which is mathematically per-series). I chose the conservative boundary — _get_pooled_tfms() is the codebase's definition of "pooled," and that mode still relies on the parent-calendar + predict-time machinery that has never been validated under per-partition execution. Allowing it later is a one-line change once there's a distributed test backing it.

simonez-tuidi and others added 4 commits June 30, 2026 17:20
…cast

The distributed engines shard data by the id column and run a fully
independent TimeSeries.fit_transform/predict per partition. Pooled
transforms (global_/groupby/partition_by) need cross-series aggregation
over a population the sharding scatters across workers, so each partition
would only see its own slice and silently produce incorrect results.

Raise NotImplementedError at construction (before any engine is touched)
when any pooled transform is configured, naming the offending features,
so it fails loudly instead of returning wrong numbers.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@codspeed-hq

codspeed-hq Bot commented Jun 30, 2026

Copy link
Copy Markdown

Merging this PR will not alter performance

✅ 12 untouched benchmarks


Comparing simonez-tuidi:feature/guard-distributed-pooled-transforms (d0295e7) with main (8c314af)

Open in CodSpeed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant