diff --git a/ndarray-linalg/src/trace.rs b/ndarray-linalg/src/trace.rs index 3020a9a5..feb119f2 100644 --- a/ndarray-linalg/src/trace.rs +++ b/ndarray-linalg/src/trace.rs @@ -4,7 +4,6 @@ use ndarray::*; use std::iter::Sum; use super::error::*; -use super::layout::*; use super::types::*; pub trait Trace { @@ -20,7 +19,13 @@ where type Output = A; fn trace(&self) -> Result { - let (n, _) = self.square_layout()?.size(); + let n = match self.is_square() { + true => Ok(self.nrows()), + false => Err(LinalgError::NotSquare { + rows: self.nrows() as i32, + cols: self.ncols() as i32, + }), + }?; Ok((0..n as usize).map(|i| self[(i, i)]).sum()) } }