-
Notifications
You must be signed in to change notification settings - Fork 92
New tabpfn model #240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
New tabpfn model #240
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds new TabPFN (Prior-data Fitted Networks) model functionality to Driverless AI, introducing both an outlier detection transformer and a supervised model for classification and regression tasks. The implementation includes GPU-accelerated inference, SHAP-based explanations, and support for datasets with up to 10,000 samples.
Key changes:
- Implements TabPFN-based unsupervised outlier detection with density estimation via permutation sampling
- Adds TabPFN supervised model with support for regression, binary, and multi-class classification (including many-class scenarios via codebook-based decomposition)
- Includes SAGE/Shapley value computation for global feature importance and local explanations
Reviewed changes
Copilot reviewed 1 out of 1 changed files in this pull request and generated 20 comments.
| File | Description |
|---|---|
transformers/outliers/tabpfn_outlier.py |
Implements TabPFN-based outlier detection transformer that calculates negative log probability scores using conditional probability estimation across feature permutations |
models/algorithms/tabpfn_model.py |
Implements TabPFN supervised model with regression/classification support, many-class handling via codebook decomposition, and SHAP-based explanations using SAGE simulation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

https://github.com/h2oai/h2oai/issues/34827
Due to the complexity of TabPFN, automation regression is skipped, manual testing results show down below