Skip to content

Commit 84a66b6

Browse files
authored
focal_stats(): gpu case (#709)
* focal stats gpu case * flake8 * add tests
1 parent dad5abb commit 84a66b6

File tree

2 files changed

+228
-46
lines changed

2 files changed

+228
-46
lines changed

xrspatial/focal.py

Lines changed: 189 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
from functools import partial
3-
from math import isnan
3+
from math import isnan, sqrt
44

55
import dask.array as da
66
import numba as nb
@@ -420,6 +420,178 @@ def apply(raster, kernel, func=_calc_mean, name='focal_apply'):
420420
return result
421421

422422

423+
@cuda.jit
424+
def _focal_min_cuda(data, kernel, out):
425+
i, j = cuda.grid(2)
426+
427+
delta_rows = kernel.shape[0] // 2
428+
delta_cols = kernel.shape[1] // 2
429+
430+
data_rows, data_cols = data.shape
431+
432+
if i < delta_rows or i >= data_rows - delta_rows or \
433+
j < delta_cols or j >= data_cols - delta_cols:
434+
return
435+
436+
s = data[i, j]
437+
for k in range(kernel.shape[0]):
438+
for h in range(kernel.shape[1]):
439+
i_k = i + k - delta_rows
440+
j_h = j + h - delta_cols
441+
if (i_k >= 0) and (i_k < data_rows) and (j_h >= 0) and (j_h < data_cols):
442+
if (kernel[k, h] != 0) and s > data[i_k, j_h]:
443+
s = data[i_k, j_h]
444+
out[i, j] = s
445+
446+
447+
@cuda.jit
448+
def _focal_max_cuda(data, kernel, out):
449+
i, j = cuda.grid(2)
450+
451+
delta_rows = kernel.shape[0] // 2
452+
delta_cols = kernel.shape[1] // 2
453+
454+
data_rows, data_cols = data.shape
455+
456+
if i < delta_rows or i >= data_rows - delta_rows or \
457+
j < delta_cols or j >= data_cols - delta_cols:
458+
return
459+
460+
s = data[i, j]
461+
for k in range(kernel.shape[0]):
462+
for h in range(kernel.shape[1]):
463+
i_k = i + k - delta_rows
464+
j_h = j + h - delta_cols
465+
if (i_k >= 0) and (i_k < data_rows) and (j_h >= 0) and (j_h < data_cols):
466+
if (kernel[k, h] != 0) and s < data[i_k, j_h]:
467+
s = data[i_k, j_h]
468+
out[i, j] = s
469+
470+
471+
def _focal_range_cupy(data, kernel):
472+
focal_min = _focal_stats_func_cupy(data, kernel, _focal_min_cuda)
473+
focal_max = _focal_stats_func_cupy(data, kernel, _focal_max_cuda)
474+
out = focal_max - focal_min
475+
return out
476+
477+
478+
@cuda.jit
479+
def _focal_std_cuda(data, kernel, out):
480+
i, j = cuda.grid(2)
481+
482+
delta_rows = kernel.shape[0] // 2
483+
delta_cols = kernel.shape[1] // 2
484+
485+
data_rows, data_cols = data.shape
486+
487+
if i < delta_rows or i >= data_rows - delta_rows or \
488+
j < delta_cols or j >= data_cols - delta_cols:
489+
return
490+
491+
sum_squares = 0
492+
sum = 0
493+
count = 0
494+
for k in range(kernel.shape[0]):
495+
for h in range(kernel.shape[1]):
496+
i_k = i + k - delta_rows
497+
j_h = j + h - delta_cols
498+
if (i_k >= 0) and (i_k < data_rows) and (j_h >= 0) and (j_h < data_cols):
499+
sum_squares += (kernel[k, h]*data[i_k, j_h])**2
500+
sum += kernel[k, h]*data[i_k, j_h]
501+
count += kernel[k, h]
502+
squared_sum = sum**2
503+
out[i, j] = sqrt((sum_squares - squared_sum/count) / count)
504+
505+
506+
@cuda.jit
507+
def _focal_var_cuda(data, kernel, out):
508+
i, j = cuda.grid(2)
509+
510+
delta_rows = kernel.shape[0] // 2
511+
delta_cols = kernel.shape[1] // 2
512+
513+
data_rows, data_cols = data.shape
514+
515+
if i < delta_rows or i >= data_rows - delta_rows or \
516+
j < delta_cols or j >= data_cols - delta_cols:
517+
return
518+
519+
sum_squares = 0
520+
sum = 0
521+
count = 0
522+
for k in range(kernel.shape[0]):
523+
for h in range(kernel.shape[1]):
524+
i_k = i + k - delta_rows
525+
j_h = j + h - delta_cols
526+
if (i_k >= 0) and (i_k < data_rows) and (j_h >= 0) and (j_h < data_cols):
527+
sum_squares += (kernel[k, h]*data[i_k, j_h])**2
528+
sum += kernel[k, h]*data[i_k, j_h]
529+
count += kernel[k, h]
530+
squared_sum = sum**2
531+
out[i, j] = (sum_squares - squared_sum/count) / count
532+
533+
534+
def _focal_mean_cupy(data, kernel):
535+
out = convolve_2d(data, kernel / kernel.sum())
536+
return out
537+
538+
539+
def _focal_sum_cupy(data, kernel):
540+
out = convolve_2d(data, kernel)
541+
return out
542+
543+
544+
def _focal_stats_func_cupy(data, kernel, func=_focal_max_cuda):
545+
out = cupy.empty(data.shape, dtype='f4')
546+
out[:, :] = cupy.nan
547+
griddim, blockdim = cuda_args(data.shape)
548+
func[griddim, blockdim](data, kernel, cupy.asarray(out))
549+
return out
550+
551+
552+
def _focal_stats_cupy(agg, kernel, stats_funcs):
553+
_stats_cupy_mapper = dict(
554+
mean=_focal_mean_cupy,
555+
sum=_focal_sum_cupy,
556+
range=_focal_range_cupy,
557+
max=lambda *args: _focal_stats_func_cupy(*args, func=_focal_max_cuda),
558+
min=lambda *args: _focal_stats_func_cupy(*args, func=_focal_min_cuda),
559+
std=lambda *args: _focal_stats_func_cupy(*args, func=_focal_std_cuda),
560+
var=lambda *args: _focal_stats_func_cupy(*args, func=_focal_var_cuda),
561+
)
562+
stats_aggs = []
563+
for stats in stats_funcs:
564+
data = agg.data.astype(cupy.float32)
565+
stats_data = _stats_cupy_mapper[stats](data, kernel)
566+
stats_agg = xr.DataArray(
567+
stats_data,
568+
dims=agg.dims,
569+
coords=agg.coords,
570+
attrs=agg.attrs
571+
)
572+
stats_aggs.append(stats_agg)
573+
stats = xr.concat(stats_aggs, pd.Index(stats_funcs, name='stats'))
574+
return stats
575+
576+
577+
def _focal_stats_cpu(agg, kernel, stats_funcs):
578+
_function_mapping = {
579+
'mean': _calc_mean,
580+
'max': _calc_max,
581+
'min': _calc_min,
582+
'range': _calc_range,
583+
'std': _calc_std,
584+
'var': _calc_var,
585+
'sum': _calc_sum
586+
}
587+
stats_aggs = []
588+
for stats in stats_funcs:
589+
stats_agg = apply(agg, kernel, func=_function_mapping[stats])
590+
stats_aggs.append(stats_agg)
591+
stats = xr.concat(stats_aggs, pd.Index(stats_funcs, name='stats'))
592+
return stats
593+
594+
423595
def focal_stats(agg,
424596
kernel,
425597
stats_funcs=[
@@ -480,24 +652,25 @@ def focal_stats(agg,
480652
* stats (stats) object 'min' 'sum'
481653
Dimensions without coordinates: dim_0, dim_1
482654
"""
655+
# validate raster
656+
if not isinstance(agg, DataArray):
657+
raise TypeError("`agg` must be instance of DataArray")
483658

484-
_function_mapping = {
485-
'mean': _calc_mean,
486-
'max': _calc_max,
487-
'min': _calc_min,
488-
'range': _calc_range,
489-
'std': _calc_std,
490-
'var': _calc_var,
491-
'sum': _calc_sum
492-
}
659+
if agg.ndim != 2:
660+
raise ValueError("`agg` must be 2D")
493661

494-
stats_aggs = []
495-
for stats in stats_funcs:
496-
stats_agg = apply(agg, kernel, func=_function_mapping[stats])
497-
stats_aggs.append(stats_agg)
662+
# Validate the kernel
663+
kernel = custom_kernel(kernel)
498664

499-
stats = xr.concat(stats_aggs, pd.Index(stats_funcs, name='stats'))
500-
return stats
665+
mapper = ArrayTypeFunctionMapping(
666+
numpy_func=_focal_stats_cpu,
667+
cupy_func=_focal_stats_cupy,
668+
dask_func=_focal_stats_cpu,
669+
dask_cupy_func=lambda *args: not_implemented_func(
670+
*args, messages='focal_stats() does not support dask with cupy backed DataArray.'),
671+
)
672+
result = mapper(agg)(agg, kernel, stats_funcs)
673+
return result
501674

502675

503676
@ngjit

xrspatial/tests/test_focal.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -322,44 +322,43 @@ def test_apply_dask_numpy(data_apply):
322322
@pytest.fixture
323323
def data_focal_stats():
324324
data = np.arange(16).reshape(4, 4)
325-
cellsize = (1, 1)
326-
kernel = circle_kernel(*cellsize, 1.5)
325+
kernel = custom_kernel(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 0]]))
327326
expected_result = np.asarray([
328327
# mean
329-
[[1.66666667, 2., 3., 4.],
330-
[4.25, 5., 6., 6.75],
331-
[8.25, 9., 10., 10.75],
332-
[11., 12., 13., 13.33333333]],
328+
[[0, 1, 2, 3.],
329+
[4, 2.5, 3.5, 4.5],
330+
[8, 6.5, 7.5, 8.5],
331+
[12, 10.5, 11.5, 12.5]],
333332
# max
334-
[[4., 5., 6., 7.],
335-
[8., 9., 10., 11.],
336-
[12., 13., 14., 15.],
337-
[13., 14., 15., 15.]],
333+
[[0, 1, 2, 3.],
334+
[4, 5, 6, 7.],
335+
[8, 9, 10, 11.],
336+
[12, 13, 14, 15.]],
338337
# min
339-
[[0., 0., 1., 2.],
340-
[0., 1., 2., 3.],
341-
[4., 5., 6., 7.],
342-
[8., 9., 10., 11.]],
338+
[[0, 1, 2, 3.],
339+
[4, 0, 1, 2.],
340+
[8, 4, 5, 6.],
341+
[12, 8, 9, 10.]],
343342
# range
344-
[[4., 5., 5., 5.],
345-
[8., 8., 8., 8.],
346-
[8., 8., 8., 8.],
347-
[5., 5., 5., 4.]],
343+
[[0, 0, 0, 0.],
344+
[0, 5, 5, 5.],
345+
[0, 5, 5, 5.],
346+
[0, 5, 5, 5.]],
348347
# std
349-
[[1.69967317, 1.87082869, 1.87082869, 2.1602469],
350-
[2.86138079, 2.60768096, 2.60768096, 2.86138079],
351-
[2.86138079, 2.60768096, 2.60768096, 2.86138079],
352-
[2.1602469, 1.87082869, 1.87082869, 1.69967317]],
348+
[[0, 0, 0, 0.],
349+
[0, 2.5, 2.5, 2.5],
350+
[0, 2.5, 2.5, 2.5],
351+
[0, 2.5, 2.5, 2.5]],
353352
# var
354-
[[2.88888889, 3.5, 3.5, 4.66666667],
355-
[8.1875, 6.8, 6.8, 8.1875],
356-
[8.1875, 6.8, 6.8, 8.1875],
357-
[4.66666667, 3.5, 3.5, 2.88888889]],
353+
[[0, 0, 0, 0.],
354+
[0, 6.25, 6.25, 6.25],
355+
[0, 6.25, 6.25, 6.25],
356+
[0, 6.25, 6.25, 6.25]],
358357
# sum
359-
[[5., 8., 12., 12.],
360-
[17., 25., 30., 27.],
361-
[33., 45., 50., 43.],
362-
[33., 48., 52., 40.]]
358+
[[0, 1, 2, 3.],
359+
[4, 5, 7, 9.],
360+
[8, 13, 15, 17.],
361+
[12, 21, 23, 25.]]
363362
])
364363
return data, kernel, expected_result
365364

@@ -383,6 +382,16 @@ def test_focal_stats_dask_numpy(data_focal_stats):
383382
)
384383

385384

385+
@cuda_and_cupy_available
386+
def test_focal_stats_gpu(data_focal_stats):
387+
data, kernel, expected_result = data_focal_stats
388+
cupy_agg = create_test_raster(data, backend='cupy')
389+
cupy_focalstats = focal_stats(cupy_agg, kernel)
390+
general_output_checks(
391+
cupy_agg, cupy_focalstats, verify_attrs=False, expected_results=expected_result
392+
)
393+
394+
386395
@pytest.fixture
387396
def data_hotspots():
388397
data = np.asarray([

0 commit comments

Comments
 (0)