Skip to content

Commit a0d4056

Browse files
authored
Merge pull request #230 from ROCmSoftwarePlatform/revert-229-af/mioconv_fixes
Revert "[Caffe2] MIOpen dims change check"
2 parents 1f7bb65 + 1639014 commit a0d4056

File tree

1 file changed

+93
-129
lines changed

1 file changed

+93
-129
lines changed

caffe2/operators/hip/conv_op_miopen.cc

Lines changed: 93 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,19 @@ class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> {
6666
dilation_h() == 1 && dilation_w() == 1,
6767
"MIOpen convolution does not support dilation for groups > 1.");
6868
}
69+
70+
MIOPEN_ENFORCE(miopenInitConvolutionDescriptor(
71+
conv_desc_,
72+
mode_,
73+
pad_t(),
74+
pad_l(),
75+
stride_h(),
76+
stride_w(),
77+
dilation_h(),
78+
dilation_w()));
79+
80+
MIOPEN_ENFORCE(miopenSetConvolutionGroupCount(
81+
conv_desc_, group_));
6982
}
7083

7184
~MIOPENConvOpBase() {
@@ -78,8 +91,6 @@ class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> {
7891
}
7992

8093
protected:
81-
vector<int64_t> mio_input_dims_;
82-
vector<int64_t> mio_weight_dims_;
8394
MIOPENWrapper miopen_wrapper_;
8495
miopenTensorDescriptor_t bottom_desc_;
8596
miopenTensorDescriptor_t bias_desc_;
@@ -246,59 +257,35 @@ bool MIOPENConvOp::DoRunWithType() {
246257
"If you set group, the number of output channels should be divisible "
247258
"by group.");
248259

249-
bool input_changed = (X.dims() != mio_input_dims_);
250-
bool weight_changed = (Weight.dims() != mio_weight_dims_);
251-
252-
if (input_changed || weight_changed) {
253-
VLOG(1) << "Changing MIOpen descriptor configurations.";
254-
if (input_changed) {
255-
mio_input_dims_ = X.dims();
256-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
257-
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
258-
}
260+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
261+
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
259262

260-
if (weight_changed) {
261-
mio_weight_dims_ = Weight.dims();
262-
MIOPEN_ENFORCE(miopenInitConvolutionDescriptor(
263-
conv_desc_,
264-
mode_,
265-
pad_t(),
266-
pad_l(),
267-
stride_h(),
268-
stride_w(),
269-
dilation_h(),
270-
dilation_w()));
271-
272-
MIOPEN_ENFORCE(miopenSetConvolutionGroupCount(
273-
conv_desc_, group_));
274-
275-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
276-
weight_desc_,
277-
miopenTypeWrapper<T_W>::type,
278-
M,
279-
C / group_,
280-
kernel_h(),
281-
kernel_w()));
282-
}
263+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
264+
weight_desc_,
265+
miopenTypeWrapper<T_W>::type,
266+
M,
267+
C / group_,
268+
kernel_h(),
269+
kernel_w()));
283270

284-
MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim(
285-
conv_desc_,
286-
bottom_desc_,
287-
weight_desc_,
288-
&N_out,
289-
&C_out,
290-
&H_out,
291-
&W_out));
271+
MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim(
272+
conv_desc_,
273+
bottom_desc_,
274+
weight_desc_,
275+
&N_out,
276+
&C_out,
277+
&H_out,
278+
&W_out));
292279

293-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
294-
top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
280+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
281+
top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
295282

296-
if (InputSize() == 3) {
283+
if (InputSize() == 3) {
297284
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
298285
bias_desc_, miopenTypeWrapper<T_B>::type, 1, M, 1, 1));
299-
}
286+
}
300287

301-
while (!bestAlgoFound_) {
288+
while (!bestAlgoFound_) {
302289
miopenConvAlgoPerf_t perf;
303290

304291
MIOPEN_ENFORCE(miopenConvolutionForwardGetWorkSpaceSize(
@@ -331,8 +318,8 @@ bool MIOPENConvOp::DoRunWithType() {
331318
});
332319
bestAlgoFound_ = true;
333320
fwdAlgo_ = perf.fwd_algo;
334-
}
335321
}
322+
336323
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
337324
MIOPEN_ENFORCE(miopenConvolutionForward(
338325
state->miopen_handle(),
@@ -437,59 +424,36 @@ bool MIOPENConvGradientOp::DoRunWithType() {
437424
"by group.");
438425

439426
bool doBwdDataComputation = (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2)));
440-
bool input_changed = (X.dims() != mio_input_dims_);
441-
bool weight_changed = (Weight.dims() != mio_weight_dims_);
442-
443-
if (input_changed || weight_changed) {
444-
VLOG(1) << "Changing MIOpen descriptor configurations.";
445-
if (input_changed) {
446-
mio_input_dims_ = X.dims();
447-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
448-
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
449-
}
450427

451-
if (weight_changed) {
452-
mio_weight_dims_ = Weight.dims();
453-
MIOPEN_ENFORCE(miopenInitConvolutionDescriptor(
454-
conv_desc_,
455-
mode_,
456-
pad_t(),
457-
pad_l(),
458-
stride_h(),
459-
stride_w(),
460-
dilation_h(),
461-
dilation_w()));
462-
463-
MIOPEN_ENFORCE(miopenSetConvolutionGroupCount(
464-
conv_desc_, group_));
428+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
429+
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
465430

466-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
467-
weight_desc_,
468-
miopenTypeWrapper<T_X>::type,
469-
M,
470-
C / group_,
471-
kernel_h(),
472-
kernel_w()));
473-
}
431+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
432+
weight_desc_,
433+
miopenTypeWrapper<T_X>::type,
434+
M,
435+
C / group_,
436+
kernel_h(),
437+
kernel_w()));
474438

475-
MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim(
476-
conv_desc_,
477-
bottom_desc_,
478-
weight_desc_,
479-
&N_out,
480-
&C_out,
481-
&H_out,
482-
&W_out));
439+
MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim(
440+
conv_desc_,
441+
bottom_desc_,
442+
weight_desc_,
443+
&N_out,
444+
&C_out,
445+
&H_out,
446+
&W_out));
483447

484-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
485-
top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
448+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
449+
top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
486450

487-
if (!no_bias_) {
488-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
489-
bias_desc_, miopenTypeWrapper<T_B>::type, 1, M, 1, 1));
490-
}
451+
if (!no_bias_) {
452+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
453+
bias_desc_, miopenTypeWrapper<T_B>::type, 1, M, 1, 1));
454+
}
491455

492-
while ((!bestDataAlgoFound_) && doBwdDataComputation) {
456+
while ((!bestDataAlgoFound_) && doBwdDataComputation) {
493457
miopenConvAlgoPerf_t perf;
494458

495459
MIOPEN_ENFORCE(miopenConvolutionBackwardDataGetWorkSpaceSize(
@@ -523,43 +487,43 @@ bool MIOPENConvGradientOp::DoRunWithType() {
523487

524488
bestDataAlgoFound_ = true;
525489
bwdDataAlgo_ = perf.bwd_data_algo;
526-
}
490+
}
527491

528-
while (!bestWeightAlgoFound_) {
529-
miopenConvAlgoPerf_t perf;
492+
while (!bestWeightAlgoFound_) {
493+
miopenConvAlgoPerf_t perf;
530494

531-
MIOPEN_ENFORCE(miopenConvolutionBackwardWeightsGetWorkSpaceSize(
532-
miopen_wrapper_.inline_miopen_handle(),
533-
top_desc_,
534-
bottom_desc_,
535-
conv_desc_,
536-
weight_desc_,
537-
&bwdWeightWsSize_));
538-
if ((bwdWeightWsSize_ > 0) && (bwdWeightWs_ == nullptr)) {
539-
HIP_CHECK(hipMalloc(&bwdWeightWs_, bwdWeightWsSize_));
540-
}
495+
MIOPEN_ENFORCE(miopenConvolutionBackwardWeightsGetWorkSpaceSize(
496+
miopen_wrapper_.inline_miopen_handle(),
497+
top_desc_,
498+
bottom_desc_,
499+
conv_desc_,
500+
weight_desc_,
501+
&bwdWeightWsSize_));
502+
if ((bwdWeightWsSize_ > 0) && (bwdWeightWs_ == nullptr)) {
503+
HIP_CHECK(hipMalloc(&bwdWeightWs_, bwdWeightWsSize_));
504+
}
541505

542-
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
543-
MIOPEN_ENFORCE(miopenFindConvolutionBackwardWeightsAlgorithm(
544-
state->miopen_handle(),
545-
top_desc_,
546-
dY.template data<T_DY>(),
547-
bottom_desc_,
548-
X.template data<T_X>(),
549-
conv_desc_,
550-
weight_desc_,
551-
dW->template mutable_data<T_DW>(),
552-
requestAlgoCount_,
553-
&returnedAlgoCount_,
554-
&perf,
555-
bwdWeightWs_,
556-
bwdWeightWsSize_,
557-
false));
558-
});
559-
bestWeightAlgoFound_ = true;
560-
bwdWeiAlgo_ = perf.bwd_weights_algo;
561-
}
506+
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
507+
MIOPEN_ENFORCE(miopenFindConvolutionBackwardWeightsAlgorithm(
508+
state->miopen_handle(),
509+
top_desc_,
510+
dY.template data<T_DY>(),
511+
bottom_desc_,
512+
X.template data<T_X>(),
513+
conv_desc_,
514+
weight_desc_,
515+
dW->template mutable_data<T_DW>(),
516+
requestAlgoCount_,
517+
&returnedAlgoCount_,
518+
&perf,
519+
bwdWeightWs_,
520+
bwdWeightWsSize_,
521+
false));
522+
});
523+
bestWeightAlgoFound_ = true;
524+
bwdWeiAlgo_ = perf.bwd_weights_algo;
562525
}
526+
563527
if (doBwdDataComputation) {
564528
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
565529
MIOPEN_ENFORCE(miopenConvolutionBackwardData(

0 commit comments

Comments
 (0)