Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 9f21cdd

Browse files
leezusxjscience
authored andcommitted
RNNOp only call cuda/cudnn if GPU ctx is requested (#16632)
1 parent c130cc9 commit 9f21cdd

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/operator/rnn-inl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,8 @@ class RNNOp {
422422
init_mem_ = false;
423423
reserve_mem_size_ = 0;
424424
#endif
425+
426+
if (ctx_.dev_type == kGPU) {
425427
#if MXNET_USE_CUDNN == 1
426428
init_cudnn_ = false;
427429
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
@@ -505,6 +507,7 @@ class RNNOp {
505507
LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
506508
}
507509
#endif // MXNET_USE_CUDNN == 1
510+
}
508511

509512
if (ctx_.dev_type == kCPU) {
510513
this->init_space_ = false;
@@ -523,6 +526,7 @@ class RNNOp {
523526
}
524527

525528
~RNNOp() {
529+
if (ctx_.dev_type == kGPU) {
526530
#if MXNET_USE_CUDNN == 1
527531
CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
528532
CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
@@ -557,6 +561,7 @@ class RNNOp {
557561
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_));
558562
#endif // MXNET_USE_CUDNN_GE_7200
559563
#endif // MXNET_USE_CUDNN
564+
}
560565
}
561566

562567
void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,

0 commit comments

Comments
 (0)