Skip to content

Make NAN handling in the SCAL kernels depend on the dummy2 parameter #4807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions kernel/arm/scal.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,22 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
if ( (n <= 0) || (inc_x <= 0))
return(0);

if (dummy2 == 0) {
while(j < n)
{

while(j < n)
{
if ( da == 0.0 )
x[i]=0.0;
else
x[i] = da * x[i] ;

i += inc_x ;
j++;
}
} else {

while(j < n)
{

if ( da == 0.0 )
if (!isnan(x[i]) && !isinf(x[i])) {
Expand All @@ -59,6 +72,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
i += inc_x ;
j++;

}
}
return 0;

Expand Down
11 changes: 8 additions & 3 deletions kernel/arm64/scal.S
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define X_COPY x5 /* X vector address */
#define INC_X x4 /* X stride */
#define I x1 /* loop variable */

#define FLAG x9
/*******************************************************************************
* Macro definitions
*******************************************************************************/
Expand Down Expand Up @@ -168,9 +168,14 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmp N, xzr
ble .Lscal_kernel_L999

//fcmp DA, #0.0
//beq .Lscal_kernel_zero
ldr FLAG, [sp]
cmp FLAG, #1
beq .Lscal_kernel_nansafe

fcmp DA, #0.0
beq .Lscal_kernel_zero

.Lscal_kernel_nansafe:
cmp INC_X, #1
bne .Lscal_kernel_S_BEGIN

Expand Down
34 changes: 31 additions & 3 deletions kernel/power/dscal.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ static void dscal_kernel_8_zero (BLASLONG n, FLOAT *x)

for( i=0; i<n; i+=8 )
{
x[0] = alpha;
x[1] = alpha;
x[2] = alpha;
x[3] = alpha;
x[4] = alpha;
x[5] = alpha;
x[6] = alpha;
x[7] = alpha;
#if 0
if(isfinite(x[0]))
x[0] = alpha;
else
Expand Down Expand Up @@ -106,7 +115,8 @@ static void dscal_kernel_8_zero (BLASLONG n, FLOAT *x)
else
x[7] = NAN;
x+=8;
}
#endif
}

}

Expand All @@ -130,6 +140,11 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
if ( n >= 16 )
{
BLASLONG align = ((32 - ((uintptr_t)x & (uintptr_t)0x1F)) >> 3) & 0x3;
if (dummy2 == 0)
for (j = 0; j < align; j++) {
x [j] = 0.0;
}
else
for (j = 0; j < align; j++) {
if (isfinite(x[j]))
x[j] = 0.0;
Expand All @@ -151,7 +166,13 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
j=n1;
}
#endif

if (dummy2 == 0)
while(j < n)
{
x[j]=0.0;
j++;
}
else
while(j < n)
{
if (!isfinite(x[j]))
Expand Down Expand Up @@ -202,7 +223,14 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS

if ( da == 0.0 )
{

if (dummy2 == 0)
while(j < n)
{
x[i]=0.0;
i += inc_x;
j++;
}
else
while(j < n)
{
if (!isfinite(x[i]))
Expand Down
13 changes: 10 additions & 3 deletions kernel/power/scal.S
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,23 @@
#ifndef __64BIT__
#define X r6
#define INCX r7
#define FLAG r11
#else
#define X r7
#define INCX r8
#define FLAG r12
#endif
#endif

#if defined(_AIX) || defined(__APPLE__)
#if !defined(__64BIT__) && defined(DOUBLE)
#define X r8
#define INCX r9
#define FLAG r13
#else
#define X r7
#define INCX r8
#define FLAG r12
#endif
#endif

Expand All @@ -84,9 +88,12 @@
cmpwi cr0, N, 0
blelr- cr0

// fcmpu cr0, FZERO, ALPHA
// bne- cr0, LL(A1I1)
b LL(A1I1)
fcmpu cr0, FZERO, ALPHA
bne- cr0, LL(A1I1)

ld FLAG, 48+64+8(SP)
cmpwi cr0, FLAG, 1
beq- cr0, LL(A1I1)

cmpwi cr0, INCX, SIZE
bne- cr0, LL(A0IN)
Expand Down
44 changes: 40 additions & 4 deletions kernel/power/sscal.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,24 @@ static void sscal_kernel_16_zero( BLASLONG n, FLOAT *x )

for( i=0; i<n; i+=8 )
{
if (isfinite(x[0]))
x[0] = alpha;
x[1] = alpha;
x[2] = alpha;
x[3] = alpha;
x[4] = alpha;
x[5] = alpha;
x[6] = alpha;
x[7] = alpha;
x[8] = alpha;
x[9] = alpha;
x[10] = alpha;
x[11] = alpha;
x[12] = alpha;
x[13] = alpha;
x[14] = alpha;
x[15] = alpha;
#if 0
if (isfinite(x[0]))
x[0] = alpha;
else
x[0] = NAN;
Expand Down Expand Up @@ -107,7 +124,8 @@ static void sscal_kernel_16_zero( BLASLONG n, FLOAT *x )
else
x[7] = NAN;
x+=8;
}
#endif
}

}

Expand All @@ -132,6 +150,11 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
if ( n >= 32 )
{
BLASLONG align = ((32 - ((uintptr_t)x & (uintptr_t)0x1F)) >> 2) & 0x7;
if (dummy2 == 0)
for (j = 0; j < align; j++){
x[j] = 0.0;
}
else
for (j = 0; j < align; j++) {
if (isfinite(x[j]))
x[j] = 0.0;
Expand All @@ -153,9 +176,15 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
j=n1;
}
#endif

if (dummy2 == 0)
while(j < n)
{
x[j] = 0.0;
j++;
}
else
while(j < n)
{
if (isfinite(x[j]))
x[j]=0.0;
else
Expand Down Expand Up @@ -204,7 +233,14 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS

if ( da == 0.0 )
{

if (dummy2 == 0)
while(j < n)
{
x[i]=0.0;
i += inc_x;
j++;
}
else
while(j < n)
{
if (isfinite(x[i]))
Expand Down
18 changes: 15 additions & 3 deletions kernel/riscv64/scal.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
if ( (n <= 0) || (inc_x <= 0))
return(0);


while(j < n)
{
if (dummy2 == 0) {
while(j < n)
{

if ( da == 0.0 )
if (isfinite(x[i]))
Expand All @@ -57,7 +57,19 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS

i += inc_x ;
j++;
}
} else {
while(j < n)
{

if ( da == 0.0 )
x[i]=0.0;
else
x[i] = da * x[i] ;

i += inc_x ;
j++;
}
}
return 0;

Expand Down
4 changes: 2 additions & 2 deletions kernel/riscv64/scal_rvv.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
FLOAT_V_T v0;

if(inc_x == 1) {
if(da == 0.0) {
if(dummy2 == 0 && da == 0.0) {
int gvl = VSETVL_MAX;
v0 = VFMVVF_FLOAT(0.0, gvl);
for (size_t vl; n > 0; n -= vl, x += vl) {
Expand All @@ -75,7 +75,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
} else {
BLASLONG stride_x = inc_x * sizeof(FLOAT);

if(da == 0.0) {
if(dummy2 == 0 && da == 0.0) {
int gvl = VSETVL_MAX;
v0 = VFMVVF_FLOAT(0.0, gvl);
for (size_t vl; n > 0; n -= vl, x += vl*inc_x) {
Expand Down
4 changes: 2 additions & 2 deletions kernel/riscv64/scal_vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
FLOAT_V_T v0, v1;
unsigned int gvl = 0;
if(inc_x == 1){
if (0){ //if(da == 0.0){
if(dummy2 == 0 && da == 0.0){
memset(&x[0], 0, n * sizeof(FLOAT));
}else{
gvl = VSETVL(n);
Expand All @@ -96,7 +96,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
}
}
}else{
if (0) { //if(da == 0.0){
if(dummy2 == 0 && da == 0.0){
BLASLONG stride_x = inc_x * sizeof(FLOAT);
BLASLONG ix = 0;
gvl = VSETVL(n);
Expand Down
9 changes: 7 additions & 2 deletions kernel/x86/scal.S
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,24 @@
#ifdef XDOUBLE
movl 44(%esp),%edi
movl 48(%esp),%esi
movl 64(%esp),%ecx
#elif defined(DOUBLE)
movl 36(%esp),%edi
movl 40(%esp),%esi
movl 56(%esp),%ecx
#else
movl 32(%esp),%edi
movl 36(%esp),%esi
movl 54(%esp),%ecx
#endif

ftst
fnstsw %ax
andb $68, %ah
// je .L300 # Alpha != ZERO
jmp .L300
je .L300 # Alpha != ZERO

cmpl $1,%ecx # dummy2 flag
je .L300

/* Alpha == ZERO */
cmpl $1,%esi
Expand Down
8 changes: 7 additions & 1 deletion kernel/x86_64/scal_atom.S
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@
#ifdef WINDOWS_ABI
movq 40(%rsp), X
movq 48(%rsp), INCX

movq 64(%rsp), %r9
movaps %xmm3, %xmm0
#else
movq 24(%rsp), %r9
#endif

SAVEREGISTERS
Expand All @@ -73,6 +75,10 @@
lea (, INCX, SIZE), INCX
comisd %xmm0, %xmm1
jne .L100
jp .L100

cmpq $1, %r9
je .L100

/* Alpha == ZERO */
cmpq $SIZE, INCX
Expand Down
6 changes: 5 additions & 1 deletion kernel/x86_64/scal_sse.S
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@
#ifdef WINDOWS_ABI
movq 40(%rsp), X
movq 48(%rsp), INCX

movq 64(%rsp), %r9
movaps %xmm3, %xmm0
#else
movq 24(%rsp), %r9
#endif

SAVEREGISTERS
Expand All @@ -76,6 +78,8 @@
shufps $0, %xmm0, %xmm0

jne .L100 # Alpha != ZERO

cmpq $1, %r9
je .L100
/* Alpha == ZERO */
cmpq $SIZE, INCX
Expand Down
Loading
Loading