distributed: guard against pooled lag transforms in DistributedMLForecast#680
Open
simonez-tuidi wants to merge 4 commits into
Open
distributed: guard against pooled lag transforms in DistributedMLForecast#680simonez-tuidi wants to merge 4 commits into
simonez-tuidi wants to merge 4 commits into
Conversation
…cast The distributed engines shard data by the id column and run a fully independent TimeSeries.fit_transform/predict per partition. Pooled transforms (global_/groupby/partition_by) need cross-series aggregation over a population the sharding scatters across workers, so each partition would only see its own slice and silently produce incorrect results. Raise NotImplementedError at construction (before any engine is touched) when any pooled transform is configured, naming the offending features, so it fails loudly instead of returning wrong numbers. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Raise a clear
NotImplementedErrorwhen a pooled lag transform (global_,groupby, orpartition_by) is configured onDistributedMLForecast, instead of silently returning incorrect results.Problem
DistributedMLForecastshards data by the id column and runs a fully independent, deep-copiedTimeSeries.fit_transform/predicton each partition:_preprocess_partitionssetspartition = dict(by=id_col)for every backend (Spark/Dask/Ray)._preprocess_partitiondoescopy.deepcopy(base_ts)thents.fit_transform(part, ...), and_predictcallsts.predict(...)per partition.Plain/local lag transforms survive this because they're computed per series, and a series' rows always stay within one partition. Pooled transforms don't: their aggregates are defined over a population the sharding splits across workers —
global_aggregates over the whole dataset,groupby=[...]aggregates over a whole group (whose series the id-sharding scatters),partition_byrides the same cross-series parent-calendar machinery, which assumes oneTimeSeriesowns every series.Each worker only ever sees its own slice, so the feature is computed over the wrong set of series. Crucially, there was no guard — this ran to completion and returned plausible-looking but wrong numbers.
Change
In
DistributedMLForecast.__init__, right after_base_tsis built, detect pooled transforms viaTimeSeries._get_pooled_tfms()and raiseNotImplementedErrorif any are present. The guard:MLForecast.Tests
test_distributed_rejects_pooled_transforms— parametrized overglobal_,groupby, andpartition_by; asserts each raisesNotImplementedError.test_distributed_allows_local_transforms— guards against over-rejection: a plain per-seriesRollingMeanmust still construct fine.Notes
The guard rejects all pooled modes, including local-only
partition_by(which is mathematically per-series). I chose the conservative boundary —_get_pooled_tfms()is the codebase's definition of "pooled," and that mode still relies on the parent-calendar + predict-time machinery that has never been validated under per-partition execution. Allowing it later is a one-line change once there's a distributed test backing it.