|
1 | 1 | import copy
|
2 | 2 | from functools import partial
|
3 |
| -from math import isnan |
| 3 | +from math import isnan, sqrt |
4 | 4 |
|
5 | 5 | import dask.array as da
|
6 | 6 | import numba as nb
|
@@ -420,6 +420,178 @@ def apply(raster, kernel, func=_calc_mean, name='focal_apply'):
|
420 | 420 | return result
|
421 | 421 |
|
422 | 422 |
|
| 423 | +@cuda.jit |
| 424 | +def _focal_min_cuda(data, kernel, out): |
| 425 | + i, j = cuda.grid(2) |
| 426 | + |
| 427 | + delta_rows = kernel.shape[0] // 2 |
| 428 | + delta_cols = kernel.shape[1] // 2 |
| 429 | + |
| 430 | + data_rows, data_cols = data.shape |
| 431 | + |
| 432 | + if i < delta_rows or i >= data_rows - delta_rows or \ |
| 433 | + j < delta_cols or j >= data_cols - delta_cols: |
| 434 | + return |
| 435 | + |
| 436 | + s = data[i, j] |
| 437 | + for k in range(kernel.shape[0]): |
| 438 | + for h in range(kernel.shape[1]): |
| 439 | + i_k = i + k - delta_rows |
| 440 | + j_h = j + h - delta_cols |
| 441 | + if (i_k >= 0) and (i_k < data_rows) and (j_h >= 0) and (j_h < data_cols): |
| 442 | + if (kernel[k, h] != 0) and s > data[i_k, j_h]: |
| 443 | + s = data[i_k, j_h] |
| 444 | + out[i, j] = s |
| 445 | + |
| 446 | + |
| 447 | +@cuda.jit |
| 448 | +def _focal_max_cuda(data, kernel, out): |
| 449 | + i, j = cuda.grid(2) |
| 450 | + |
| 451 | + delta_rows = kernel.shape[0] // 2 |
| 452 | + delta_cols = kernel.shape[1] // 2 |
| 453 | + |
| 454 | + data_rows, data_cols = data.shape |
| 455 | + |
| 456 | + if i < delta_rows or i >= data_rows - delta_rows or \ |
| 457 | + j < delta_cols or j >= data_cols - delta_cols: |
| 458 | + return |
| 459 | + |
| 460 | + s = data[i, j] |
| 461 | + for k in range(kernel.shape[0]): |
| 462 | + for h in range(kernel.shape[1]): |
| 463 | + i_k = i + k - delta_rows |
| 464 | + j_h = j + h - delta_cols |
| 465 | + if (i_k >= 0) and (i_k < data_rows) and (j_h >= 0) and (j_h < data_cols): |
| 466 | + if (kernel[k, h] != 0) and s < data[i_k, j_h]: |
| 467 | + s = data[i_k, j_h] |
| 468 | + out[i, j] = s |
| 469 | + |
| 470 | + |
| 471 | +def _focal_range_cupy(data, kernel): |
| 472 | + focal_min = _focal_stats_func_cupy(data, kernel, _focal_min_cuda) |
| 473 | + focal_max = _focal_stats_func_cupy(data, kernel, _focal_max_cuda) |
| 474 | + out = focal_max - focal_min |
| 475 | + return out |
| 476 | + |
| 477 | + |
| 478 | +@cuda.jit |
| 479 | +def _focal_std_cuda(data, kernel, out): |
| 480 | + i, j = cuda.grid(2) |
| 481 | + |
| 482 | + delta_rows = kernel.shape[0] // 2 |
| 483 | + delta_cols = kernel.shape[1] // 2 |
| 484 | + |
| 485 | + data_rows, data_cols = data.shape |
| 486 | + |
| 487 | + if i < delta_rows or i >= data_rows - delta_rows or \ |
| 488 | + j < delta_cols or j >= data_cols - delta_cols: |
| 489 | + return |
| 490 | + |
| 491 | + sum_squares = 0 |
| 492 | + sum = 0 |
| 493 | + count = 0 |
| 494 | + for k in range(kernel.shape[0]): |
| 495 | + for h in range(kernel.shape[1]): |
| 496 | + i_k = i + k - delta_rows |
| 497 | + j_h = j + h - delta_cols |
| 498 | + if (i_k >= 0) and (i_k < data_rows) and (j_h >= 0) and (j_h < data_cols): |
| 499 | + sum_squares += (kernel[k, h]*data[i_k, j_h])**2 |
| 500 | + sum += kernel[k, h]*data[i_k, j_h] |
| 501 | + count += kernel[k, h] |
| 502 | + squared_sum = sum**2 |
| 503 | + out[i, j] = sqrt((sum_squares - squared_sum/count) / count) |
| 504 | + |
| 505 | + |
| 506 | +@cuda.jit |
| 507 | +def _focal_var_cuda(data, kernel, out): |
| 508 | + i, j = cuda.grid(2) |
| 509 | + |
| 510 | + delta_rows = kernel.shape[0] // 2 |
| 511 | + delta_cols = kernel.shape[1] // 2 |
| 512 | + |
| 513 | + data_rows, data_cols = data.shape |
| 514 | + |
| 515 | + if i < delta_rows or i >= data_rows - delta_rows or \ |
| 516 | + j < delta_cols or j >= data_cols - delta_cols: |
| 517 | + return |
| 518 | + |
| 519 | + sum_squares = 0 |
| 520 | + sum = 0 |
| 521 | + count = 0 |
| 522 | + for k in range(kernel.shape[0]): |
| 523 | + for h in range(kernel.shape[1]): |
| 524 | + i_k = i + k - delta_rows |
| 525 | + j_h = j + h - delta_cols |
| 526 | + if (i_k >= 0) and (i_k < data_rows) and (j_h >= 0) and (j_h < data_cols): |
| 527 | + sum_squares += (kernel[k, h]*data[i_k, j_h])**2 |
| 528 | + sum += kernel[k, h]*data[i_k, j_h] |
| 529 | + count += kernel[k, h] |
| 530 | + squared_sum = sum**2 |
| 531 | + out[i, j] = (sum_squares - squared_sum/count) / count |
| 532 | + |
| 533 | + |
| 534 | +def _focal_mean_cupy(data, kernel): |
| 535 | + out = convolve_2d(data, kernel / kernel.sum()) |
| 536 | + return out |
| 537 | + |
| 538 | + |
| 539 | +def _focal_sum_cupy(data, kernel): |
| 540 | + out = convolve_2d(data, kernel) |
| 541 | + return out |
| 542 | + |
| 543 | + |
| 544 | +def _focal_stats_func_cupy(data, kernel, func=_focal_max_cuda): |
| 545 | + out = cupy.empty(data.shape, dtype='f4') |
| 546 | + out[:, :] = cupy.nan |
| 547 | + griddim, blockdim = cuda_args(data.shape) |
| 548 | + func[griddim, blockdim](data, kernel, cupy.asarray(out)) |
| 549 | + return out |
| 550 | + |
| 551 | + |
| 552 | +def _focal_stats_cupy(agg, kernel, stats_funcs): |
| 553 | + _stats_cupy_mapper = dict( |
| 554 | + mean=_focal_mean_cupy, |
| 555 | + sum=_focal_sum_cupy, |
| 556 | + range=_focal_range_cupy, |
| 557 | + max=lambda *args: _focal_stats_func_cupy(*args, func=_focal_max_cuda), |
| 558 | + min=lambda *args: _focal_stats_func_cupy(*args, func=_focal_min_cuda), |
| 559 | + std=lambda *args: _focal_stats_func_cupy(*args, func=_focal_std_cuda), |
| 560 | + var=lambda *args: _focal_stats_func_cupy(*args, func=_focal_var_cuda), |
| 561 | + ) |
| 562 | + stats_aggs = [] |
| 563 | + for stats in stats_funcs: |
| 564 | + data = agg.data.astype(cupy.float32) |
| 565 | + stats_data = _stats_cupy_mapper[stats](data, kernel) |
| 566 | + stats_agg = xr.DataArray( |
| 567 | + stats_data, |
| 568 | + dims=agg.dims, |
| 569 | + coords=agg.coords, |
| 570 | + attrs=agg.attrs |
| 571 | + ) |
| 572 | + stats_aggs.append(stats_agg) |
| 573 | + stats = xr.concat(stats_aggs, pd.Index(stats_funcs, name='stats')) |
| 574 | + return stats |
| 575 | + |
| 576 | + |
| 577 | +def _focal_stats_cpu(agg, kernel, stats_funcs): |
| 578 | + _function_mapping = { |
| 579 | + 'mean': _calc_mean, |
| 580 | + 'max': _calc_max, |
| 581 | + 'min': _calc_min, |
| 582 | + 'range': _calc_range, |
| 583 | + 'std': _calc_std, |
| 584 | + 'var': _calc_var, |
| 585 | + 'sum': _calc_sum |
| 586 | + } |
| 587 | + stats_aggs = [] |
| 588 | + for stats in stats_funcs: |
| 589 | + stats_agg = apply(agg, kernel, func=_function_mapping[stats]) |
| 590 | + stats_aggs.append(stats_agg) |
| 591 | + stats = xr.concat(stats_aggs, pd.Index(stats_funcs, name='stats')) |
| 592 | + return stats |
| 593 | + |
| 594 | + |
423 | 595 | def focal_stats(agg,
|
424 | 596 | kernel,
|
425 | 597 | stats_funcs=[
|
@@ -480,24 +652,25 @@ def focal_stats(agg,
|
480 | 652 | * stats (stats) object 'min' 'sum'
|
481 | 653 | Dimensions without coordinates: dim_0, dim_1
|
482 | 654 | """
|
| 655 | + # validate raster |
| 656 | + if not isinstance(agg, DataArray): |
| 657 | + raise TypeError("`agg` must be instance of DataArray") |
483 | 658 |
|
484 |
| - _function_mapping = { |
485 |
| - 'mean': _calc_mean, |
486 |
| - 'max': _calc_max, |
487 |
| - 'min': _calc_min, |
488 |
| - 'range': _calc_range, |
489 |
| - 'std': _calc_std, |
490 |
| - 'var': _calc_var, |
491 |
| - 'sum': _calc_sum |
492 |
| - } |
| 659 | + if agg.ndim != 2: |
| 660 | + raise ValueError("`agg` must be 2D") |
493 | 661 |
|
494 |
| - stats_aggs = [] |
495 |
| - for stats in stats_funcs: |
496 |
| - stats_agg = apply(agg, kernel, func=_function_mapping[stats]) |
497 |
| - stats_aggs.append(stats_agg) |
| 662 | + # Validate the kernel |
| 663 | + kernel = custom_kernel(kernel) |
498 | 664 |
|
499 |
| - stats = xr.concat(stats_aggs, pd.Index(stats_funcs, name='stats')) |
500 |
| - return stats |
| 665 | + mapper = ArrayTypeFunctionMapping( |
| 666 | + numpy_func=_focal_stats_cpu, |
| 667 | + cupy_func=_focal_stats_cupy, |
| 668 | + dask_func=_focal_stats_cpu, |
| 669 | + dask_cupy_func=lambda *args: not_implemented_func( |
| 670 | + *args, messages='focal_stats() does not support dask with cupy backed DataArray.'), |
| 671 | + ) |
| 672 | + result = mapper(agg)(agg, kernel, stats_funcs) |
| 673 | + return result |
501 | 674 |
|
502 | 675 |
|
503 | 676 | @ngjit
|
|
0 commit comments