@@ -66,6 +66,19 @@ class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> {
66
66
dilation_h () == 1 && dilation_w () == 1 ,
67
67
" MIOpen convolution does not support dilation for groups > 1." );
68
68
}
69
+
70
+ MIOPEN_ENFORCE (miopenInitConvolutionDescriptor (
71
+ conv_desc_,
72
+ mode_,
73
+ pad_t (),
74
+ pad_l (),
75
+ stride_h (),
76
+ stride_w (),
77
+ dilation_h (),
78
+ dilation_w ()));
79
+
80
+ MIOPEN_ENFORCE (miopenSetConvolutionGroupCount (
81
+ conv_desc_, group_));
69
82
}
70
83
71
84
~MIOPENConvOpBase () {
@@ -78,8 +91,6 @@ class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> {
78
91
}
79
92
80
93
protected:
81
- vector<int64_t > mio_input_dims_;
82
- vector<int64_t > mio_weight_dims_;
83
94
MIOPENWrapper miopen_wrapper_;
84
95
miopenTensorDescriptor_t bottom_desc_;
85
96
miopenTensorDescriptor_t bias_desc_;
@@ -246,59 +257,35 @@ bool MIOPENConvOp::DoRunWithType() {
246
257
" If you set group, the number of output channels should be divisible "
247
258
" by group." );
248
259
249
- bool input_changed = (X.dims () != mio_input_dims_);
250
- bool weight_changed = (Weight.dims () != mio_weight_dims_);
251
-
252
- if (input_changed || weight_changed) {
253
- VLOG (1 ) << " Changing MIOpen descriptor configurations." ;
254
- if (input_changed) {
255
- mio_input_dims_ = X.dims ();
256
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
257
- bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
258
- }
260
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
261
+ bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
259
262
260
- if (weight_changed) {
261
- mio_weight_dims_ = Weight.dims ();
262
- MIOPEN_ENFORCE (miopenInitConvolutionDescriptor (
263
- conv_desc_,
264
- mode_,
265
- pad_t (),
266
- pad_l (),
267
- stride_h (),
268
- stride_w (),
269
- dilation_h (),
270
- dilation_w ()));
271
-
272
- MIOPEN_ENFORCE (miopenSetConvolutionGroupCount (
273
- conv_desc_, group_));
274
-
275
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
276
- weight_desc_,
277
- miopenTypeWrapper<T_W>::type,
278
- M,
279
- C / group_,
280
- kernel_h (),
281
- kernel_w ()));
282
- }
263
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
264
+ weight_desc_,
265
+ miopenTypeWrapper<T_W>::type,
266
+ M,
267
+ C / group_,
268
+ kernel_h (),
269
+ kernel_w ()));
283
270
284
- MIOPEN_ENFORCE (miopenGetConvolutionForwardOutputDim (
285
- conv_desc_,
286
- bottom_desc_,
287
- weight_desc_,
288
- &N_out,
289
- &C_out,
290
- &H_out,
291
- &W_out));
271
+ MIOPEN_ENFORCE (miopenGetConvolutionForwardOutputDim (
272
+ conv_desc_,
273
+ bottom_desc_,
274
+ weight_desc_,
275
+ &N_out,
276
+ &C_out,
277
+ &H_out,
278
+ &W_out));
292
279
293
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
294
- top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
280
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
281
+ top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
295
282
296
- if (InputSize () == 3 ) {
283
+ if (InputSize () == 3 ) {
297
284
MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
298
285
bias_desc_, miopenTypeWrapper<T_B>::type, 1 , M, 1 , 1 ));
299
- }
286
+ }
300
287
301
- while (!bestAlgoFound_) {
288
+ while (!bestAlgoFound_) {
302
289
miopenConvAlgoPerf_t perf;
303
290
304
291
MIOPEN_ENFORCE (miopenConvolutionForwardGetWorkSpaceSize (
@@ -331,8 +318,8 @@ bool MIOPENConvOp::DoRunWithType() {
331
318
});
332
319
bestAlgoFound_ = true ;
333
320
fwdAlgo_ = perf.fwd_algo ;
334
- }
335
321
}
322
+
336
323
miopen_wrapper_.with_miopen_state (miopen_state_, [&](MIOPENState* state) {
337
324
MIOPEN_ENFORCE (miopenConvolutionForward (
338
325
state->miopen_handle (),
@@ -437,59 +424,36 @@ bool MIOPENConvGradientOp::DoRunWithType() {
437
424
" by group." );
438
425
439
426
bool doBwdDataComputation = (OutputSize () == 3 || (no_bias_ && (OutputSize () == 2 )));
440
- bool input_changed = (X.dims () != mio_input_dims_);
441
- bool weight_changed = (Weight.dims () != mio_weight_dims_);
442
-
443
- if (input_changed || weight_changed) {
444
- VLOG (1 ) << " Changing MIOpen descriptor configurations." ;
445
- if (input_changed) {
446
- mio_input_dims_ = X.dims ();
447
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
448
- bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
449
- }
450
427
451
- if (weight_changed) {
452
- mio_weight_dims_ = Weight.dims ();
453
- MIOPEN_ENFORCE (miopenInitConvolutionDescriptor (
454
- conv_desc_,
455
- mode_,
456
- pad_t (),
457
- pad_l (),
458
- stride_h (),
459
- stride_w (),
460
- dilation_h (),
461
- dilation_w ()));
462
-
463
- MIOPEN_ENFORCE (miopenSetConvolutionGroupCount (
464
- conv_desc_, group_));
428
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
429
+ bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
465
430
466
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
467
- weight_desc_,
468
- miopenTypeWrapper<T_X>::type,
469
- M,
470
- C / group_,
471
- kernel_h (),
472
- kernel_w ()));
473
- }
431
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
432
+ weight_desc_,
433
+ miopenTypeWrapper<T_X>::type,
434
+ M,
435
+ C / group_,
436
+ kernel_h (),
437
+ kernel_w ()));
474
438
475
- MIOPEN_ENFORCE (miopenGetConvolutionForwardOutputDim (
476
- conv_desc_,
477
- bottom_desc_,
478
- weight_desc_,
479
- &N_out,
480
- &C_out,
481
- &H_out,
482
- &W_out));
439
+ MIOPEN_ENFORCE (miopenGetConvolutionForwardOutputDim (
440
+ conv_desc_,
441
+ bottom_desc_,
442
+ weight_desc_,
443
+ &N_out,
444
+ &C_out,
445
+ &H_out,
446
+ &W_out));
483
447
484
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
485
- top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
448
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
449
+ top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
486
450
487
- if (!no_bias_) {
488
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
489
- bias_desc_, miopenTypeWrapper<T_B>::type, 1 , M, 1 , 1 ));
490
- }
451
+ if (!no_bias_) {
452
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
453
+ bias_desc_, miopenTypeWrapper<T_B>::type, 1 , M, 1 , 1 ));
454
+ }
491
455
492
- while ((!bestDataAlgoFound_) && doBwdDataComputation) {
456
+ while ((!bestDataAlgoFound_) && doBwdDataComputation) {
493
457
miopenConvAlgoPerf_t perf;
494
458
495
459
MIOPEN_ENFORCE (miopenConvolutionBackwardDataGetWorkSpaceSize (
@@ -523,43 +487,43 @@ bool MIOPENConvGradientOp::DoRunWithType() {
523
487
524
488
bestDataAlgoFound_ = true ;
525
489
bwdDataAlgo_ = perf.bwd_data_algo ;
526
- }
490
+ }
527
491
528
- while (!bestWeightAlgoFound_) {
529
- miopenConvAlgoPerf_t perf;
492
+ while (!bestWeightAlgoFound_) {
493
+ miopenConvAlgoPerf_t perf;
530
494
531
- MIOPEN_ENFORCE (miopenConvolutionBackwardWeightsGetWorkSpaceSize (
532
- miopen_wrapper_.inline_miopen_handle (),
533
- top_desc_,
534
- bottom_desc_,
535
- conv_desc_,
536
- weight_desc_,
537
- &bwdWeightWsSize_));
538
- if ((bwdWeightWsSize_ > 0 ) && (bwdWeightWs_ == nullptr )) {
539
- HIP_CHECK (hipMalloc (&bwdWeightWs_, bwdWeightWsSize_));
540
- }
495
+ MIOPEN_ENFORCE (miopenConvolutionBackwardWeightsGetWorkSpaceSize (
496
+ miopen_wrapper_.inline_miopen_handle (),
497
+ top_desc_,
498
+ bottom_desc_,
499
+ conv_desc_,
500
+ weight_desc_,
501
+ &bwdWeightWsSize_));
502
+ if ((bwdWeightWsSize_ > 0 ) && (bwdWeightWs_ == nullptr )) {
503
+ HIP_CHECK (hipMalloc (&bwdWeightWs_, bwdWeightWsSize_));
504
+ }
541
505
542
- miopen_wrapper_.with_miopen_state (miopen_state_, [&](MIOPENState* state) {
543
- MIOPEN_ENFORCE (miopenFindConvolutionBackwardWeightsAlgorithm (
544
- state->miopen_handle (),
545
- top_desc_,
546
- dY.template data <T_DY>(),
547
- bottom_desc_,
548
- X.template data <T_X>(),
549
- conv_desc_,
550
- weight_desc_,
551
- dW->template mutable_data <T_DW>(),
552
- requestAlgoCount_,
553
- &returnedAlgoCount_,
554
- &perf,
555
- bwdWeightWs_,
556
- bwdWeightWsSize_,
557
- false ));
558
- });
559
- bestWeightAlgoFound_ = true ;
560
- bwdWeiAlgo_ = perf.bwd_weights_algo ;
561
- }
506
+ miopen_wrapper_.with_miopen_state (miopen_state_, [&](MIOPENState* state) {
507
+ MIOPEN_ENFORCE (miopenFindConvolutionBackwardWeightsAlgorithm (
508
+ state->miopen_handle (),
509
+ top_desc_,
510
+ dY.template data <T_DY>(),
511
+ bottom_desc_,
512
+ X.template data <T_X>(),
513
+ conv_desc_,
514
+ weight_desc_,
515
+ dW->template mutable_data <T_DW>(),
516
+ requestAlgoCount_,
517
+ &returnedAlgoCount_,
518
+ &perf,
519
+ bwdWeightWs_,
520
+ bwdWeightWsSize_,
521
+ false ));
522
+ });
523
+ bestWeightAlgoFound_ = true ;
524
+ bwdWeiAlgo_ = perf.bwd_weights_algo ;
562
525
}
526
+
563
527
if (doBwdDataComputation) {
564
528
miopen_wrapper_.with_miopen_state (miopen_state_, [&](MIOPENState* state) {
565
529
MIOPEN_ENFORCE (miopenConvolutionBackwardData (
0 commit comments