Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 40359ce

Browse files
Automatic Layout Management (#20718)
* Automatic Layout Management Originally authored by Dawid Tracz <[email protected]> * Fix clang-format * Fix clang-format in mshadow * Print layout name instead of a number * Generalize NHWC target layout to other dimensions * Change layout optimization API * Add layout optimization tests * Add backward check to tests * Generalize tests to 1..3 spatial dims * Add NWC layout to ConvolutionParams * Enable layout optimization tests only with cuDNN Co-authored-by: Vladimir Cherepanov <[email protected]>
1 parent f60c1d2 commit 40359ce

File tree

23 files changed

+737
-6
lines changed

23 files changed

+737
-6
lines changed

3rdparty/mshadow/mshadow/base.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,8 @@ const int index_type_flag = DataType<lapack_index_t>::kFlag;
496496

497497
/*! layout flag */
498498
enum LayoutFlag {
499+
kUNKNOWN = -1,
500+
499501
kNCHW = 0,
500502
kNHWC,
501503
kCHWN,
@@ -509,6 +511,64 @@ enum LayoutFlag {
509511
kCDHWN
510512
};
511513

514+
inline LayoutFlag layoutFlag(std::string layoutstr) {
515+
switch (layoutstr.length()) {
516+
case 4:
517+
if (layoutstr == "NHWC")
518+
return kNHWC;
519+
if (layoutstr == "NCHW")
520+
return kNCHW;
521+
if (layoutstr == "CHWN")
522+
return kCHWN;
523+
return kUNKNOWN;
524+
case 3:
525+
if (layoutstr == "NWC")
526+
return kNWC;
527+
if (layoutstr == "NCW")
528+
return kNCW;
529+
if (layoutstr == "CWN")
530+
return kCWN;
531+
return kUNKNOWN;
532+
case 5:
533+
if (layoutstr == "NDHWC")
534+
return kNDHWC;
535+
if (layoutstr == "NCDHW")
536+
return kNCDHW;
537+
if (layoutstr == "CDHWN")
538+
return kCDHWN;
539+
return kUNKNOWN;
540+
default:
541+
return kUNKNOWN;
542+
}
543+
}
544+
545+
inline std::string toString(LayoutFlag layout) {
546+
switch (layout) {
547+
case kUNKNOWN:
548+
return "";
549+
case kNCHW:
550+
return "NCHW";
551+
case kNHWC:
552+
return "NHWC";
553+
case kCHWN:
554+
return "CHWN";
555+
case kNCW:
556+
return "NCW";
557+
case kNWC:
558+
return "NWC";
559+
case kCWN:
560+
return "CWN";
561+
case kNCDHW:
562+
return "NCDHW";
563+
case kNDHWC:
564+
return "NDHWC";
565+
case kCDHWN:
566+
return "CDHWN";
567+
default:
568+
return "";
569+
}
570+
}
571+
512572
template<int layout>
513573
struct LayoutType;
514574

3rdparty/mshadow/mshadow/tensor.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,97 @@ inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layou
390390
return dst2;
391391
}
392392

393+
/*!
394+
* \brief returns axes of transpose operation
395+
* that needs to be performed between src layout and dst
396+
* \param src_layout input layout
397+
* \param dst_layout output layout
398+
* \return vector of required type describing axes of a transpose operation
399+
*/
400+
template <typename dim_t>
401+
inline std::vector<dim_t> getTranspAxes(const LayoutFlag src_layout, const LayoutFlag dst_layout) {
402+
auto apply = [](const std::vector<dim_t>& v, const std::vector<dim_t>& op) {
403+
CHECK_EQ(v.size(), op.size()) << "Layout ndims does not match";
404+
std::vector<dim_t> ret(v.size());
405+
for (size_t i = 0; i < v.size(); i++) {
406+
ret[i] = v[op[i]];
407+
}
408+
return ret;
409+
};
410+
std::vector<dim_t> axes;
411+
// transpose from `case` to ND?H?WC
412+
switch (src_layout) {
413+
case kUNKNOWN:
414+
LOG(FATAL) << "Unknown source layout";
415+
break;
416+
case kNHWC:
417+
axes = std::vector<dim_t>({0, 1, 2, 3});
418+
break;
419+
case kNCHW:
420+
axes = std::vector<dim_t>({0, 2, 3, 1});
421+
break;
422+
case kCHWN:
423+
axes = std::vector<dim_t>({3, 1, 2, 0});
424+
break;
425+
case kNWC:
426+
axes = std::vector<dim_t>({0, 1, 2});
427+
break;
428+
case kNCW:
429+
axes = std::vector<dim_t>({0, 2, 1});
430+
break;
431+
case kCWN:
432+
axes = std::vector<dim_t>({2, 1, 0});
433+
break;
434+
case kNDHWC:
435+
axes = std::vector<dim_t>({0, 1, 2, 3, 4});
436+
break;
437+
case kNCDHW:
438+
axes = std::vector<dim_t>({0, 2, 3, 4, 1});
439+
break;
440+
case kCDHWN:
441+
axes = std::vector<dim_t>({4, 1, 2, 3, 0});
442+
break;
443+
default:
444+
LOG(FATAL) << "Invalid source layout " << src_layout;
445+
}
446+
// transpose from ND?H?WC to `case`
447+
switch (dst_layout) {
448+
case kUNKNOWN:
449+
LOG(FATAL) << "Unknown destination layout";
450+
break;
451+
case kNHWC:
452+
axes = apply(axes, {0, 1, 2, 3});
453+
break;
454+
case kNCHW:
455+
axes = apply(axes, {0, 3, 1, 2});
456+
break;
457+
case kCHWN:
458+
axes = apply(axes, {3, 1, 2, 0});
459+
break;
460+
case kNWC:
461+
axes = apply(axes, {0, 1, 2});
462+
break;
463+
case kNCW:
464+
axes = apply(axes, {0, 2, 1});
465+
break;
466+
case kCWN:
467+
axes = apply(axes, {2, 1, 0});
468+
break;
469+
case kNDHWC:
470+
axes = apply(axes, {0, 1, 2, 3, 4});
471+
break;
472+
case kNCDHW:
473+
axes = apply(axes, {0, 4, 1, 2, 3});
474+
break;
475+
case kCDHWN:
476+
axes = apply(axes, {4, 1, 2, 3, 0});
477+
break;
478+
default:
479+
LOG(FATAL) << "Invalid destination layout " << src_layout;
480+
}
481+
return axes;
482+
}
483+
393484
/*!
394485
* \brief computaion stream structure, used for asynchronous computations
395486
*/

include/mxnet/c_api.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3161,6 +3161,16 @@ MXNET_DLL int MXCUDAProfilerStart();
31613161
*/
31623162
MXNET_DLL int MXCUDAProfilerStop();
31633163

3164+
/*!
3165+
* \brief Turns on or off Layout Optimization
3166+
*/
3167+
MXNET_DLL int MXSetOptimizeLayout(bool val);
3168+
3169+
/*!
3170+
* \brief Get current Layout Optimization status
3171+
*/
3172+
MXNET_DLL int MXGetOptimizeLayout(bool* val);
3173+
31643174
#ifdef __cplusplus
31653175
}
31663176
#endif // __cplusplus

python/mxnet/amp/amp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def warn_if_model_exists():
307307
return
308308

309309
def init(target_dtype='float16', target_precision_ops=None,
310-
conditional_fp32_ops=None, fp32_ops=None):
310+
conditional_fp32_ops=None, fp32_ops=None, layout_optimization=False):
311311
"""Initialize AMP (automatic mixed precision).
312312
313313
This needs to be done before model creation.
@@ -333,7 +333,11 @@ def init(target_dtype='float16', target_precision_ops=None,
333333
assert target_dtype in ['float16', np.float16, 'bfloat16', bfloat16], \
334334
"AMP currently supports only float16 or bfloat16 as a target_dtype"
335335
_amp_initialized = True
336-
logging.info("Using AMP")
336+
log_msg = "Using AMP"
337+
if layout_optimization:
338+
log_msg += "\n - layout optimization: enabled"
339+
check_call(_LIB.MXSetOptimizeLayout(ctypes.c_bool(True)))
340+
logging.info(log_msg)
337341
if target_dtype == "bfloat16":
338342
target_dtype = bfloat16
339343
else:

src/c_api/c_api.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include "../operator/tvmop/op_module.h"
5656
#include "../operator/subgraph/partitioner/custom_subgraph_property.h"
5757
#include "../operator/subgraph/subgraph_property.h"
58+
#include "../common/alm.h"
5859
#include "../common/utils.h"
5960
#include "../profiler/profiler.h"
6061
#include "../serialization/cnpy.h"
@@ -4004,3 +4005,15 @@ int MXCUDAProfilerStop() {
40044005
#endif
40054006
API_END();
40064007
}
4008+
4009+
int MXSetOptimizeLayout(bool val) {
4010+
API_BEGIN();
4011+
mxnet::alm::ALMParams::get().optimize = val;
4012+
API_END();
4013+
}
4014+
4015+
int MXGetOptimizeLayout(bool* val) {
4016+
API_BEGIN();
4017+
*val = mxnet::alm::ALMParams::get().optimize;
4018+
API_END();
4019+
}

0 commit comments

Comments
 (0)