Skip to content

Commit 30f2ac8

Browse files
committed
Merge branch 'master' into fix_topk
2 parents 48575dc + d3fe294 commit 30f2ac8

File tree

11 files changed

+207
-37
lines changed

11 files changed

+207
-37
lines changed

aten/src/ATen/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ IF(USE_CUDA AND NOT USE_ROCM)
247247
ENDIF()
248248

249249
IF(USE_ROCM)
250-
### Link in the ROCm libraries BLAS / RNG.
250+
### Link in the ROCm libraries BLAS / RNG .
251251
FIND_LIBRARY(ROCBLAS_LIBRARY rocblas HINTS ${ROCBLAS_PATH}/lib)
252252
FIND_LIBRARY(HIPRAND_LIBRARY hiprand HINTS ${HIPRAND_PATH}/lib)
253253

aten/src/ATen/native/cuda/CuFFTPlanCache.h

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,11 @@ class CuFFTConfig {
149149
// TODO: Figure out why windows fails to compile
150150
// at::optional<std::vector<long long int>> inembed_opt = at::nullopt;
151151
// Then move the following to a helper function.
152+
#ifdef __HIP_PLATFORM_HCC__
153+
std::vector<int> inembed(signal_ndim);
154+
#else
152155
std::vector<long long int> inembed(signal_ndim);
156+
#endif
153157
if (!clone_input) {
154158
auto istrides = input.strides();
155159
auto last_istride = istrides[signal_ndim];
@@ -192,6 +196,37 @@ class CuFFTConfig {
192196
inembed.begin()); // begin of output
193197
}
194198

199+
#ifdef __HIP_PLATFORM_HCC__
200+
201+
hipfftType exec_type;
202+
if (input.type().scalarType() == ScalarType::Float) {
203+
if (complex_input && complex_output) {
204+
exec_type = HIPFFT_C2C;
205+
} else if (complex_input && !complex_output) {
206+
exec_type = HIPFFT_C2R;
207+
} else if (!complex_input && complex_output) {
208+
exec_type = HIPFFT_R2C;
209+
} else {
210+
throw std::runtime_error("hipFFT doesn't support r2r (float)");
211+
}
212+
} else if (input.type().scalarType() == ScalarType::Double) {
213+
if (complex_input && complex_output) {
214+
exec_type = HIPFFT_Z2Z;
215+
} else if (complex_input && !complex_output) {
216+
exec_type = HIPFFT_Z2D;
217+
} else if (!complex_input && complex_output) {
218+
exec_type = HIPFFT_D2Z;
219+
} else {
220+
throw std::runtime_error("hipFFT doesn't support r2r (double)");
221+
}
222+
} else {
223+
std::ostringstream ss;
224+
ss << "hipFFT doesn't support tensor of type: "
225+
<< at::toString(input.type().scalarType());
226+
throw std::runtime_error(ss.str());
227+
}
228+
229+
#else
195230
cudaDataType itype, otype, exec_type;
196231
if (input.type().scalarType() == ScalarType::Float) {
197232
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
@@ -211,6 +246,7 @@ class CuFFTConfig {
211246
<< at::toString(input.type().scalarType());
212247
throw std::runtime_error(ss.str());
213248
}
249+
#endif
214250

215251
// create plan
216252
auto raw_plan_ptr = new cufftHandle();
@@ -229,10 +265,18 @@ class CuFFTConfig {
229265
// by assuming base_istride = base_ostride = 1.
230266
//
231267
// See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
268+
#ifdef __HIP_PLATFORM_HCC__
269+
int sizes = *signal_sizes.data();
270+
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, &sizes,
271+
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1,
272+
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1,
273+
exec_type, batch, &ws_size_t));
274+
#else
232275
CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
233276
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
234277
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
235278
batch, &ws_size_t, exec_type));
279+
#endif
236280
} else {
237281
// set idist (stride at batch dim)
238282
// set base_istride (stride at innermost dim of signal)
@@ -254,6 +298,19 @@ class CuFFTConfig {
254298
}
255299

256300
// set odist, onembed, base_ostride
301+
#ifdef __HIP_PLATFORM_HCC__
302+
int odist = at::prod_intlist(output_sizes.slice(1, signal_ndim));
303+
std::vector<int> onembed(output_sizes.data() + 1, output_sizes.data() + signal_ndim + 1);
304+
int base_ostride = 1;
305+
306+
int sizes = *signal_sizes.data();
307+
int istride = base_istride;
308+
int iidist = idist;
309+
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, &sizes,
310+
inembed.data(), istride, iidist,
311+
onembed.data(), base_ostride, odist,
312+
exec_type, batch, &ws_size_t));
313+
#else
257314
long long int odist = at::prod_intlist(output_sizes.slice(1, signal_ndim));
258315
std::vector<long long int> onembed(output_sizes.data() + 1, output_sizes.data() + signal_ndim + 1);
259316
long long int base_ostride = 1;
@@ -262,11 +319,16 @@ class CuFFTConfig {
262319
inembed.data(), base_istride, idist, itype,
263320
onembed.data(), base_ostride, odist, otype,
264321
batch, &ws_size_t, exec_type));
265-
}
322+
#endif
323+
}
266324
ws_size = static_cast<int64_t>(ws_size_t);
267325
}
268326

327+
#ifdef __HIP_PLATFORM_HCC__
328+
cufftHandle &plan() const { return *plan_ptr.get(); }
329+
#else
269330
const cufftHandle &plan() const { return *plan_ptr.get(); }
331+
#endif
270332

271333
bool should_clone_input() const { return clone_input; }
272334

aten/src/ATen/native/cuda/CuFFTUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ static inline std::string _cudaGetErrorEnum(cufftResult error)
4949
return "CUFFT_NO_WORKSPACE";
5050
case CUFFT_NOT_IMPLEMENTED:
5151
return "CUFFT_NOT_IMPLEMENTED";
52+
#ifndef __HIP_PLATFORM_HCC__
5253
case CUFFT_LICENSE_ERROR:
5354
return "CUFFT_LICENSE_ERROR";
55+
#endif
5456
case CUFFT_NOT_SUPPORTED:
5557
return "CUFFT_NOT_SUPPORTED";
5658
default:

aten/src/ATen/native/cuda/SpectralOps.cu

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,45 @@ static inline Tensor _run_cufft(
190190
CUFFT_CHECK(cufftSetWorkArea(plan, ws.data_ptr()));
191191

192192
// run
193+
#ifdef __HIP_PLATFORM_HCC__
194+
if (input.type().scalarType() == ScalarType::Float) {
195+
if (complex_input && complex_output) {
196+
CUFFT_CHECK(hipfftExecC2C(plan, static_cast<hipfftComplex*>(input.data_ptr()),
197+
static_cast<hipfftComplex*>(output.data_ptr()),
198+
inverse ? HIPFFT_BACKWARD : HIPFFT_FORWARD));
199+
} else if (complex_input && !complex_output) {
200+
CUFFT_CHECK(hipfftExecC2R(plan, static_cast<hipfftComplex*>(input.data_ptr()),
201+
static_cast<hipfftReal*>(output.data_ptr())));
202+
} else if (!complex_input && complex_output) {
203+
CUFFT_CHECK(hipfftExecR2C(plan, static_cast<hipfftReal*>(input.data_ptr()),
204+
static_cast<hipfftComplex*>(output.data_ptr())));
205+
} else {
206+
throw std::runtime_error("hipFFT doesn't support r2r (float)");
207+
}
208+
} else if (input.type().scalarType() == ScalarType::Double) {
209+
if (complex_input && complex_output) {
210+
CUFFT_CHECK(hipfftExecZ2Z(plan, static_cast<hipfftDoubleComplex*>(input.data_ptr()),
211+
static_cast<hipfftDoubleComplex*>(output.data_ptr()),
212+
inverse ? HIPFFT_BACKWARD : HIPFFT_FORWARD));
213+
} else if (complex_input && !complex_output) {
214+
CUFFT_CHECK(hipfftExecZ2D(plan, static_cast<hipfftDoubleComplex*>(input.data_ptr()),
215+
static_cast<hipfftDoubleReal*>(output.data_ptr())));
216+
} else if (!complex_input && complex_output) {
217+
CUFFT_CHECK(hipfftExecD2Z(plan, static_cast<hipfftDoubleReal*>(input.data_ptr()),
218+
static_cast<hipfftDoubleComplex*>(output.data_ptr())));
219+
} else {
220+
throw std::runtime_error("hipFFT doesn't support r2r (double)");
221+
}
222+
} else {
223+
std::ostringstream ss;
224+
ss << "hipFFT doesn't support tensor of type: "
225+
<< at::toString(input.type().scalarType());
226+
throw std::runtime_error(ss.str());
227+
}
228+
#else
193229
CUFFT_CHECK(cufftXtExec(plan, input.data_ptr(), output.data_ptr(),
194230
inverse ? CUFFT_INVERSE : CUFFT_FORWARD));
231+
#endif
195232

196233
// rescale if needed by normalized flag or inverse transform
197234
auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1];

cmake/Dependencies.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ endif()
562562
if(USE_ROCM)
563563
include_directories(SYSTEM ${HIP_PATH}/include)
564564
include_directories(SYSTEM ${ROCBLAS_PATH}/include)
565+
include_directories(SYSTEM ${ROCFFT_PATH}/include)
565566
include_directories(SYSTEM ${HIPSPARSE_PATH}/include)
566567
include_directories(SYSTEM ${HIPRAND_PATH}/include)
567568
include_directories(SYSTEM ${ROCRAND_PATH}/include)

cmake/public/LoadHIP.cmake

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ ELSE()
3838
SET(ROCBLAS_PATH $ENV{ROCBLAS_PATH})
3939
ENDIF()
4040

41+
# ROCFFT_PATH
42+
IF(NOT DEFINED ENV{ROCFFT_PATH})
43+
SET(ROCBLAS_PATH ${ROCM_PATH}/rocfft)
44+
ELSE()
45+
SET(ROCFFT_PATH $ENV{ROCFFT_PATH})
46+
ENDIF()
47+
4148
# HIPSPARSE_PATH
4249
IF(NOT DEFINED ENV{HIPSPARSE_PATH})
4350
SET(HIPSPARSE_PATH ${ROCM_PATH}/hcsparse)
@@ -106,11 +113,13 @@ IF(HIP_FOUND)
106113
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
107114
set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen)
108115
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
116+
set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft)
109117
set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse)
110118

111119
find_package(rocrand REQUIRED)
112120
find_package(hiprand REQUIRED)
113121
find_package(rocblas REQUIRED)
122+
find_package(rocfft REQUIRED)
114123
find_package(miopen REQUIRED)
115124
#find_package(hipsparse REQUIRED)
116125

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,7 @@ def run(self):
920920
rocm_include_path = '/opt/rocm/include'
921921
hcc_include_path = '/opt/rocm/hcc/include'
922922
rocblas_include_path = '/opt/rocm/rocblas/include'
923+
rocfft_include_path = '/opt/rocm/rocfft/include'
923924
hipsparse_include_path = '/opt/rocm/hcsparse/include'
924925
hiprand_include_path = '/opt/rocm/hiprand/include'
925926
rocrand_include_path = '/opt/rocm/rocrand/include'
@@ -928,6 +929,7 @@ def run(self):
928929
include_dirs.append(rocm_include_path)
929930
include_dirs.append(hcc_include_path)
930931
include_dirs.append(rocblas_include_path)
932+
include_dirs.append(rocfft_include_path)
931933
include_dirs.append(hipsparse_include_path)
932934
include_dirs.append(hiprand_include_path)
933935
include_dirs.append(rocrand_include_path)

test/test_cuda.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,19 +330,26 @@ def tmp(t):
330330
('kthvalue', small_3d_unique, lambda t: [3, -1], 'neg_dim'),
331331
('lerp', small_3d, lambda t: [small_3d(t), 0.3],'', types, False, "skipIfHalfTensor"),
332332
('max', small_3d_unique, lambda t: [],'', types, False, "skipIfHalfTensor"),
333-
('max', small_3d_unique, lambda t: [1], 'dim'),
334-
('max', small_3d_unique, lambda t: [-1], 'neg_dim'),
333+
('max', small_3d_unique, lambda t: [1], 'dim', types, False,
334+
"skipIfByteTensor;skipIfCharTensor;skipIfDoubleTensor;skipIfFloatTensor;skipIfHalfTensor;skipIfIntTensor;skipIfLongTensor;skipIfShortTensor"),
335+
('max', small_3d_unique, lambda t: [-1], 'neg_dim', types, False,
336+
"skipIfByteTensor;skipIfCharTensor;skipIfDoubleTensor;skipIfFloatTensor;skipIfHalfTensor;skipIfIntTensor;skipIfLongTensor;skipIfShortTensor"),
335337
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
336338
('min', small_3d_unique, lambda t: [],'', types, False, "skipIfHalfTensor"),
337-
('min', small_3d_unique, lambda t: [1], 'dim'),
338-
('min', small_3d_unique, lambda t: [-1], 'neg_dim'),
339+
('min', small_3d_unique, lambda t: [1], 'dim', types, False,
340+
"skipIfByteTensor;skipIfCharTensor;skipIfDoubleTensor;skipIfFloatTensor;skipIfHalfTensor;skipIfIntTensor;skipIfLongTensor;skipIfShortTensor"),
341+
('min', small_3d_unique, lambda t: [-1], 'neg_dim', types, False,
342+
"skipIfByteTensor;skipIfCharTensor;skipIfDoubleTensor;skipIfFloatTensor;skipIfHalfTensor;skipIfIntTensor;skipIfLongTensor;skipIfShortTensor"),
339343
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
340344
('mean', small_3d, lambda t: [], '', types, False, "skipIfHalfTensor"),
341345
('mean', small_3d, lambda t: [-1], 'neg_dim', types, False, "skipIfDoubleTensor;skipIfFloatTensor;skipIfHalfTensor"),
342346
('mean', small_3d, lambda t: [1], 'dim', types, False, "skipIfDoubleTensor;skipIfFloatTensor;skipIfHalfTensor"),
343-
('mode', small_3d, lambda t: [],),
344-
('mode', small_3d, lambda t: [1], 'dim'),
345-
('mode', small_3d, lambda t: [-1], 'neg_dim'),
347+
('mode', small_3d, lambda t: [],'', types, False,
348+
"skipIfByteTensor;skipIfCharTensor;skipIfDoubleTensor;skipIfFloatTensor;skipIfHalfTensor;skipIfIntTensor;skipIfLongTensor;skipIfShortTensor"),
349+
('mode', small_3d, lambda t: [1], 'dim', types, False,
350+
"skipIfByteTensor;skipIfCharTensor;skipIfDoubleTensor;skipIfFloatTensor;skipIfHalfTensor;skipIfIntTensor;skipIfLongTensor;skipIfShortTensor"),
351+
('mode', small_3d, lambda t: [-1], 'neg_dim', types, False,
352+
"skipIfByteTensor;skipIfCharTensor;skipIfDoubleTensor;skipIfFloatTensor;skipIfHalfTensor;skipIfIntTensor;skipIfLongTensor;skipIfShortTensor"),
346353
('mvlgamma', lambda t: tensor_clamp(small_2d(t), 0.1, 10), lambda t: [1], '2d_p=1', float_types_no_half, False, "skipIfDoubleTensor;skipIfFloatTensor"),
347354
('mvlgamma', lambda t: tensor_clamp(small_2d(t), 0.6, 10), lambda t: [2], '2d_p=2', float_types_no_half, False, "skipIfDoubleTensor;skipIfFloatTensor"),
348355
('remainder', small_3d, lambda t: [3], 'value', types, False, "skipIfHalfTensor"),
@@ -924,6 +931,7 @@ def test_broadcast_cpu(self):
924931
def test_broadcast_gpu(self):
925932
self._test_broadcast(torch.randn(5, 5).cuda())
926933

934+
@skipIfRocm
927935
def test_min_max_nan(self):
928936
tests = [(lambda x: x.min(), 'min'),
929937
(lambda x: x.max(), 'max'),
@@ -1656,6 +1664,7 @@ def test_btrisolve(self):
16561664
def test_dim_reduction(self):
16571665
TestTorch._test_dim_reduction(self, lambda t: t.cuda())
16581666

1667+
@skipIfRocm
16591668
def test_tensor_gather(self):
16601669
TestTorch._test_gather(self, lambda t: t.cuda(), False)
16611670

@@ -1669,6 +1678,7 @@ def test_tensor_scatterAdd(self):
16691678
def test_tensor_scatterFill(self):
16701679
TestTorch._test_scatter_base(self, lambda t: t.cuda(), 'scatter_', True, test_bounds=False)
16711680

1681+
@skipIfRocm
16721682
def test_min_max_inits(self):
16731683
# Testing if THC_reduceAll received the correct index initialization.
16741684
# This affects the result of THC_reduceAll operations at extreme values

tools/amd_build/disabled_features.yaml

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -97,27 +97,6 @@
9797
"struct mtgp32_kernel_params": "mtgp32_kernel_params"
9898
}
9999
},
100-
{
101-
"path": "aten/src/ATen/native/cuda/CuFFTUtils.h",
102-
"s_constants": {
103-
"#include <cufft.h>": "",
104-
"#include <cufftXt.h>": ""
105-
}
106-
},
107-
{
108-
"path": "aten/src/ATen/native/cuda/CuFFTPlanCache.h",
109-
"s_constants": {
110-
"#include <cufft.h>": "",
111-
"#include <cufftXt.h>": ""
112-
}
113-
},
114-
{
115-
"path": "aten/src/ATen/native/cuda/SpectralOps.cu",
116-
"s_constants": {
117-
"#include <cufft.h>": "",
118-
"#include <cufftXt.h>": ""
119-
}
120-
},
121100
{
122101
"path": "aten/src/ATen/native/cuda/RoiPooling.cu",
123102
"s_constants": {
@@ -142,9 +121,6 @@
142121
}
143122
],
144123
"disabled_modules": [
145-
"aten/src/ATen/native/cuda/CuFFTUtils.h",
146-
"aten/src/ATen/native/cuda/CuFFTPlanCache.h",
147-
"aten/src/ATen/native/cuda/SpectralOps.cu",
148124
],
149125
"disabled_functions": [
150126
{

tools/amd_build/pyHIPIFY/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
API_SPARSE = 40
5151
API_RAND = 41
5252
API_LAST = 42
53+
API_FFT = 43
5354

5455
HIP_UNSUPPORTED = 43
5556
API_PYTORCH = 1337
56-
API_CAFFE2 = 1338
57+
API_CAFFE2 = 1338

0 commit comments

Comments
 (0)