-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Saving uboost BDT with tf/keras base estimators #63
Comments
Poor picklability of keras is a long-known issue (you can google keras with the same mistake error). You may be fortunate to have some of variables being passed e.g. in lambda not through calls, But otherwise I'm not sure there will be a simple solution |
Yes, but of course keras has its own save_model functionality. So do you think it is simply not possible to save a uBoost model which is based on a keras model? |
It all goes down do pickle-ability of items. Option 1. Ask keras maintainers why your model is not pickle-able estimators = clf.estimators_ # list of keras models
# TODO save estimators somehow using keras tools
# delete estimators
clf.estimators_ = None
with open('uboost.pkl', 'wb') as f:
joblib.dump(clf, f)
# loading
with open('uboost.pkl', 'rb') as f:
clf = joblib.load(f)
estimators = .... # load estimators
clf.estimators_ = estimators |
there is option 3 as well - find truly sklearn-compatible NN package =) |
Ok I will try to implement option 2, thanks! |
Hi,
I am trying to use a uBoost BDT to achieve uniform signal efficiency. My base estimator is a Keras model (Tensorflow 2.2), which I have written as a scikit-learn BaseEstimator subclass using tensorflow.keras.wrappers.scikit_learn.KerasClassifier. The training and everything seems to work fine, but I am encountering an error when I try to save the uboost classifier with pickle/joblib. The error is
TypeError: can't pickle _thread.RLock objects
(full error at bottom - it is mostly a long thread of calls to pickle )
From trying to look it up it seems the error is usually to do with the way tensorflow is run, but I'm only creating a simple model and fitting and all the session handling should be taken care of in this version of tf/keras. Maybe this answer is related keras-team/keras#8343 (comment)
ie. perhaps there is a call to something from the model that leaves an unserializable tensor object? As I am using the BDT not the classifier, I assume it is not to do with any parallel processes either?
Please let me know if you know what is causing the issue or if there is some way I can work around it.
Thanks!
The text was updated successfully, but these errors were encountered: