diff --git a/source/module_base/math_sphbes.cpp b/source/module_base/math_sphbes.cpp index 5e7f41de54..00326f928e 100644 --- a/source/module_base/math_sphbes.cpp +++ b/source/module_base/math_sphbes.cpp @@ -808,7 +808,7 @@ void Sphbes::dsphbesj(const int n, } } -void Sphbes::sphbes_zeros(const int l, const int n, double* const zeros) +void Sphbes::sphbes_zeros(const int l, const int n, double* const zeros, const bool return_all) { assert( n > 0 ); assert( l >= 0 ); @@ -818,10 +818,22 @@ void Sphbes::sphbes_zeros(const int l, const int n, double* const zeros) // This property enables us to use bracketing method recursively // to find all zeros of j_l from the zeros of j_0. - // if l is odd , j_0 --> j_1 --> j_3 --> j_5 --> ... - // if l is even, j_0 --> j_2 --> j_4 --> j_6 --> ... - - int nz = n + (l+1)/2; // number of effective zeros in buffer + // If return_all is true, zeros of j_0, j_1, ..., j_l will all be returned + // such that zeros[l*n+i] is the i-th zero of j_l. As such, it is required + // that the array "zeros" has a size of (l+1)*n. + // + // If return_all is false, only the zeros of j_l will be returned + // and "zeros" is merely required to have a size of n. + // Note that in this case the bracketing method can be applied with a stride + // of 2 instead of 1: + // j_0 --> j_1 --> j_3 --> j_5 --> ... --> j_l (odd l) + // j_0 --> j_2 --> j_4 --> j_6 --> ... --> j_l (even l) + + // Every recursion step reduces the number of zeros by 1. + // If return_all is true, one needs to start with n+l zeros of j_0 + // to ensure n zeros of j_l; otherwise with a stride of 2 one only + // needs to start with n+(l+1)/2 zeros of j_0 + int nz = n + ( return_all ? l : (l+1)/2 ); double* buffer = new double[nz]; // zeros of j_0 = sin(x)/x is just n*pi @@ -831,27 +843,34 @@ void Sphbes::sphbes_zeros(const int l, const int n, double* const zeros) buffer[i] = (i+1) * PI; } - int ll = 1; + int ll; // active l auto jl = [&ll] (double x) { return sphbesj(ll, x); }; - - if (l % 2 == 1) + int stride; + std::function copy_if_needed; + int offset = 0; // keeps track of the position in zeros for next copy (used when return_all == true) + if (return_all) { - for (int i = 0; i < nz-1; i++) - { - buffer[i] = illinois(jl, buffer[i], buffer[i+1], 1e-15, 50); - } - --nz; + copy_if_needed = [&](){ std::copy(buffer, buffer + n, zeros + offset); offset += n; }; + stride = 1; + ll = 1; + } + else + { + copy_if_needed = [](){}; + stride = 2; + ll = 2 - l % 2; } - for (ll = 2 + l%2; ll <= l; ll += 2, --nz) + for (; ll <= l; ll += stride, --nz) { + copy_if_needed(); for (int i = 0; i < nz-1; i++) { buffer[i] = illinois(jl, buffer[i], buffer[i+1], 1e-15, 50); } } - std::copy(buffer, buffer + n, zeros); + std::copy(buffer, buffer + n, zeros + offset); delete[] buffer; } diff --git a/source/module_base/math_sphbes.h b/source/module_base/math_sphbes.h index c654847a5d..7aa9c78a48 100644 --- a/source/module_base/math_sphbes.h +++ b/source/module_base/math_sphbes.h @@ -126,13 +126,18 @@ class Sphbes * This function computes the first n positive zeros of the l-th order * spherical Bessel function of the first kind. * - * @param[in] l order of the spherical Bessel function - * @param[in] n number of zeros to be computed - * @param[out] zeros on exit, contains the first n positive zeros in ascending order + * @param[in] l (maximum) order of the spherical Bessel function + * @param[in] n number of zeros to be computed (for each j_l if return_all is true) + * @param[out] zeros on exit, contains the positive zeros. + * @param[in] return_all if true, return all zeros from j_0 to j_l such that zeros[l*n+i] + * is the i-th zero of j_l. If false, return only the first n zeros of j_l. + * + * @note The size of array "zeros" must be at least (l+1)*n if return_all is true, and n otherwise. */ static void sphbes_zeros(const int l, const int n, - double* const zeros + double* const zeros, + bool return_all = false ); private: diff --git a/source/module_base/test/math_sphbes_test.cpp b/source/module_base/test/math_sphbes_test.cpp index 521d4dc2f4..e72c6e289c 100644 --- a/source/module_base/test/math_sphbes_test.cpp +++ b/source/module_base/test/math_sphbes_test.cpp @@ -352,15 +352,27 @@ TEST_F(Sphbes, Zeros) int lmax = 20; int nzeros = 500; - double* zeros = new double[nzeros]; + double* zeros = new double[nzeros*(lmax+1)]; for (int l = 0; l <= lmax; ++l) { - ModuleBase::Sphbes::sphbes_zeros(l, nzeros, zeros); + ModuleBase::Sphbes::sphbes_zeros(l, nzeros, zeros, false); for (int i = 0; i < nzeros; ++i) { EXPECT_LT(std::abs(ModuleBase::Sphbes::sphbesj(l, zeros[i])), 1e-14); } } + + + ModuleBase::Sphbes::sphbes_zeros(lmax, nzeros, zeros, true); + for (int l = 0; l <= lmax; ++l) + { + for (int i = 0; i < nzeros; ++i) + { + EXPECT_LT(std::abs(ModuleBase::Sphbes::sphbesj(l, zeros[l*nzeros+i])), 1e-14); + } + } + + delete[] zeros; } TEST_F(Sphbes, ZerosOld)