Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/nbla/cuda/cudnn/function/max_pooling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ class MaxPoolingCudaCudnn
: BasePoolingCudaCudnn<typename MaxPooling<T>::base_pooling_type>(
ctx, kernel, stride, ignore_border, pad, channel_last) {}
string name() override { return "MaxPoolingCudaCudnn"; }
// NOTE: With an unknown reason, creating this class derived from
// `BasePoolingCudaCudnn<MaxPooling<T>>` gave a compile error. So I decided to
// derive it from BasePooling class which seems to succeed, but the problem is
// that it doesn't implement `copy()` function. I copy & paste the copy
// function found in the MaxPooling class although it's ugly.
shared_ptr<Function> copy() const override {
return create_MaxPooling(this->ctx_, this->kernel_, this->stride_,
this->ignore_border_, this->pad_,
this->channel_last_);
}
cudnnPoolingMode_t mode() const override { return CUDNN_POOLING_MAX; }
};
}
Expand Down