diff --git a/Cargo.lock b/Cargo.lock index 02efaf6ca2..687c2bcf2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1498,7 +1498,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1515,7 +1515,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1538,7 +1538,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "bitflags 2.9.0", "bytemuck", @@ -1561,7 +1561,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "bytemuck", "cubecl-common", @@ -1575,7 +1575,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "bytemuck", "cubecl-common", @@ -1592,7 +1592,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "bytemuck", "cubecl-common", @@ -1619,7 +1619,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -1637,7 +1637,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "bytemuck", "cubecl-common", @@ -1653,7 +1653,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "cubecl-common", "darling", @@ -1668,7 +1668,7 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "darling", "proc-macro2", @@ -1679,7 +1679,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1695,7 +1695,7 @@ dependencies = [ [[package]] name = "cubecl-random" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "cubecl-common", "cubecl-core", @@ -1710,7 +1710,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1723,7 +1723,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "async-channel", "bytemuck", @@ -1747,7 +1747,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "bitflags 2.9.0", "cubecl-common", @@ -1763,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-std" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1774,7 +1774,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.6.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=cfcbc082c8fbe4285e8d01d587132ff6720295fa#cfcbc082c8fbe4285e8d01d587132ff6720295fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=6ebfb95975cf10979f163ab04b1c86b75fbd9a4d#6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index b93a750d42..f602080961 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,11 @@ license = "MIT OR Apache-2.0" readme = "README.md" version = "0.18.0" +[workspace.lints.clippy] +suspicious = "deny" +perf = "deny" +# pedantic = "deny" + [workspace.dependencies] atomic_float = "1" bytemuck = "1.21.0" @@ -156,9 +161,9 @@ portable-atomic = { version = "1.11.0" } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "cfcbc082c8fbe4285e8d01d587132ff6720295fa" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "cfcbc082c8fbe4285e8d01d587132ff6720295fa" } -cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "cfcbc082c8fbe4285e8d01d587132ff6720295fa" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" } +cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6ebfb95975cf10979f163ab04b1c86b75fbd9a4d" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-autodiff/Cargo.toml b/crates/burn-autodiff/Cargo.toml index bf4d1386c7..92aec1808c 100644 --- a/crates/burn-autodiff/Cargo.toml +++ b/crates/burn-autodiff/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-autodiff" documentation = "https://docs.rs/burn-autodiff" version.workspace = true +[lints] +workspace = true + [features] default = ["std"] export_tests = ["burn-tensor-testgen"] diff --git a/crates/burn-candle/Cargo.toml b/crates/burn-candle/Cargo.toml index 6c9ecc4670..d959abd290 100644 --- a/crates/burn-candle/Cargo.toml +++ b/crates/burn-candle/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-candle" documentation = "https://docs.rs/burn-candle" version.workspace = true +[lints] +workspace = true + [features] default = ["std"] std = [] diff --git a/crates/burn-candle/src/backend.rs b/crates/burn-candle/src/backend.rs index b6a70c0332..22bf99d6ef 100644 --- a/crates/burn-candle/src/backend.rs +++ b/crates/burn-candle/src/backend.rs @@ -55,6 +55,7 @@ pub enum CandleDevice { impl CandleDevice { /// Create a Cuda device with the given index. /// The index is the index of the Cuda device in the list of all Cuda devices found on the system. + #[must_use] pub fn cuda(index: usize) -> Self { CandleDevice::Cuda(CudaDevice { device: candle_core::CudaDevice::new(index).unwrap(), @@ -64,6 +65,7 @@ impl CandleDevice { /// Create a Metal device with the given index. /// The index is the index of the Metal device in the list of all Metal devices found on the system. + #[must_use] pub fn metal(index: usize) -> Self { CandleDevice::Metal(MetalDevice { device: candle_core::MetalDevice::new(index).unwrap(), diff --git a/crates/burn-candle/src/ops/base.rs b/crates/burn-candle/src/ops/base.rs index 775be254ed..975a1c2999 100644 --- a/crates/burn-candle/src/ops/base.rs +++ b/crates/burn-candle/src/ops/base.rs @@ -90,7 +90,7 @@ pub fn slice(tensor: CandleTensor, ranges: &[std::ops::Range]) -> CandleT for (i, range) in ranges.iter().enumerate().take(ranges.len()) { narrow_tensor = narrow_tensor .narrow(i, range.start, range.end - range.start) - .unwrap() + .unwrap(); } CandleTensor::new(narrow_tensor) } diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index da9be55735..338a99175f 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -136,15 +136,15 @@ impl FloatTensorOps for Candle } fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { - let lhs_contiguous = if !lhs.tensor.is_contiguous() { - lhs.tensor.contiguous().unwrap() - } else { + let lhs_contiguous = if lhs.tensor.is_contiguous() { lhs.tensor - }; - let rhs_contiguous = if !rhs.tensor.is_contiguous() { - rhs.tensor.contiguous().unwrap() } else { + lhs.tensor.contiguous().unwrap() + }; + let rhs_contiguous = if rhs.tensor.is_contiguous() { rhs.tensor + } else { + rhs.tensor.contiguous().unwrap() }; CandleTensor::new(lhs_contiguous.broadcast_matmul(&rhs_contiguous).unwrap()) } diff --git a/crates/burn-candle/src/tensor.rs b/crates/burn-candle/src/tensor.rs index aabc88f2a7..b60960c4bd 100644 --- a/crates/burn-candle/src/tensor.rs +++ b/crates/burn-candle/src/tensor.rs @@ -31,6 +31,7 @@ impl TensorMetadata for CandleTensor { impl CandleTensor { /// Create a new tensor. + #[must_use] pub fn new(tensor: candle_core::Tensor) -> Self { Self { tensor } } @@ -45,6 +46,7 @@ impl CandleTensor { /// # Returns /// /// A new tensor. + #[must_use] pub fn from_data(data: TensorData, device: CandleDevice) -> Self { let candle_shape: candle_core::Shape = data.shape.clone().into(); let tensor = candle_core::Tensor::from_slice( diff --git a/crates/burn-common/Cargo.toml b/crates/burn-common/Cargo.toml index 5c3995b65b..ee41895c3d 100644 --- a/crates/burn-common/Cargo.toml +++ b/crates/burn-common/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-common" documentation = "https://docs.rs/burn-common" version.workspace = true +[lints] +workspace = true + [features] default = ["std", "cubecl-common/default"] std = ["cubecl-common/std"] diff --git a/crates/burn-common/src/id.rs b/crates/burn-common/src/id.rs index b16f5c3390..af90b18310 100644 --- a/crates/burn-common/src/id.rs +++ b/crates/burn-common/src/id.rs @@ -5,6 +5,7 @@ pub struct IdGenerator {} impl IdGenerator { /// Generates a new ID. + #[must_use] pub fn generate() -> u64 { // Generate a random u64 (18,446,744,073,709,551,615 combinations) let random_bytes: [u8; 8] = gen_random(); diff --git a/crates/burn-common/src/lib.rs b/crates/burn-common/src/lib.rs index a2f674c502..7be56c674b 100644 --- a/crates/burn-common/src/lib.rs +++ b/crates/burn-common/src/lib.rs @@ -34,6 +34,7 @@ pub mod tensor { /// of all dimensions greater than `k`. /// /// This means that strides increase as you move from the rightmost to the leftmost dimension. + #[must_use] pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { if shape.is_empty() { return true; @@ -52,6 +53,7 @@ pub mod tensor { /// /// In a contiguous row-major tensor, the stride for each dimension /// equals the product of all dimension sizes to its right. + #[must_use] pub fn contiguous_strides(shape: &[usize]) -> Vec { let mut strides = Vec::with_capacity(shape.len()); let mut current = 1; diff --git a/crates/burn-common/src/network.rs b/crates/burn-common/src/network.rs index a5dd149613..baeca48d1a 100644 --- a/crates/burn-common/src/network.rs +++ b/crates/burn-common/src/network.rs @@ -18,6 +18,7 @@ pub mod downloader { /// A vector of bytes containing the downloaded file data. #[cfg(feature = "std")] #[tokio::main(flavor = "current_thread")] + #[must_use] pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec { // Get file from web let mut response = Client::new().get(url).send().await.unwrap(); @@ -34,7 +35,7 @@ pub mod downloader { .with_key( "eta", |state: &ProgressState, w: &mut dyn std::fmt::Write| { - write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap(); }, ) .progress_chars("▬ "), diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index c167b81e78..51a863d190 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-core" version.workspace = true +[lints] +workspace = true + [features] dataset = ["burn-dataset"] default = [ diff --git a/crates/burn-cubecl-fusion/Cargo.toml b/crates/burn-cubecl-fusion/Cargo.toml index 9ce73675b1..4c29f78358 100644 --- a/crates/burn-cubecl-fusion/Cargo.toml +++ b/crates/burn-cubecl-fusion/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cubecl-fusion" version.workspace = true +[lints] +workspace = true + [features] default = ["autotune", "std", "cubecl/default"] autotune = [] diff --git a/crates/burn-cubecl-fusion/src/elemwise/builder.rs b/crates/burn-cubecl-fusion/src/elemwise/builder.rs index 6006cd4b8b..20d4d5e5a0 100644 --- a/crates/burn-cubecl-fusion/src/elemwise/builder.rs +++ b/crates/burn-cubecl-fusion/src/elemwise/builder.rs @@ -56,7 +56,7 @@ impl OptimizationBuilder> for ElementWiseBuilder } fn reset(&mut self) { - self.builder.reset() + self.builder.reset(); } fn status(&self) -> burn_fusion::OptimizationStatus { diff --git a/crates/burn-cubecl-fusion/src/elemwise/optimization.rs b/crates/burn-cubecl-fusion/src/elemwise/optimization.rs index 1c4757a799..64451bc3b7 100644 --- a/crates/burn-cubecl-fusion/src/elemwise/optimization.rs +++ b/crates/burn-cubecl-fusion/src/elemwise/optimization.rs @@ -123,6 +123,6 @@ fn elemwise_fuse( let length = ref_len(inputs, outputs, &locals, config); if pos < length { - fuse_on_write::(inputs, outputs, &mut locals, pos, values, args, config) + fuse_on_write::(inputs, outputs, &mut locals, pos, values, args, config); } } diff --git a/crates/burn-cubecl-fusion/src/matmul/optimization.rs b/crates/burn-cubecl-fusion/src/matmul/optimization.rs index 81f2af8a2f..baf2bc6fd3 100644 --- a/crates/burn-cubecl-fusion/src/matmul/optimization.rs +++ b/crates/burn-cubecl-fusion/src/matmul/optimization.rs @@ -319,13 +319,15 @@ impl FusedMatmul { lhs_shape[..lhs_shape.len() - 2].to_vec(), rhs_shape[..rhs_shape.len() - 2].to_vec(), ), - lhs_layout: match lhs_transposed { - true => components::MatrixLayout::ColMajor, - false => components::MatrixLayout::RowMajor, + lhs_layout: if lhs_transposed { + components::MatrixLayout::ColMajor + } else { + components::MatrixLayout::RowMajor }, - rhs_layout: match rhs_transposed { - true => components::MatrixLayout::ColMajor, - false => components::MatrixLayout::RowMajor, + rhs_layout: if rhs_transposed { + components::MatrixLayout::ColMajor + } else { + components::MatrixLayout::RowMajor }, lhs_line_size, rhs_line_size, @@ -353,7 +355,7 @@ impl FusedMatmul { problem, plane_size, ) { - Ok(_) => Ok(()), + Ok(()) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), } } @@ -365,7 +367,7 @@ impl FusedMatmul { problem, plane_size, ) { - Ok(_) => Ok(()), + Ok(()) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), } } diff --git a/crates/burn-cubecl-fusion/src/reduce/builder.rs b/crates/burn-cubecl-fusion/src/reduce/builder.rs index ba65366969..91b1330fa5 100644 --- a/crates/burn-cubecl-fusion/src/reduce/builder.rs +++ b/crates/burn-cubecl-fusion/src/reduce/builder.rs @@ -181,7 +181,7 @@ impl OptimizationBuilder> for ReduceBuilder { _ => { self.on_elemwise_read(operation); } - }; + } } else if let OperationIr::NumericInt(_, op) = operation { match op { NumericOperationIr::SumDim(op) => { @@ -211,7 +211,7 @@ impl OptimizationBuilder> for ReduceBuilder { _ => { self.on_elemwise_read(operation); } - }; + } } else { self.on_elemwise_read(operation); } @@ -261,7 +261,7 @@ impl OptimizationBuilder> for ReduceBuilder { properties.score += 1; } else { properties.ready = false; - }; + } properties } diff --git a/crates/burn-cubecl-fusion/src/reduce/optimization.rs b/crates/burn-cubecl-fusion/src/reduce/optimization.rs index d8606b8de4..4ad17b30f6 100644 --- a/crates/burn-cubecl-fusion/src/reduce/optimization.rs +++ b/crates/burn-cubecl-fusion/src/reduce/optimization.rs @@ -86,6 +86,7 @@ pub struct FusedReduce { } impl FusedReduce { + #[must_use] pub fn with_strategy(&self, strategy: ReduceStrategy) -> Self { Self { input: self.input.clone(), @@ -323,17 +324,18 @@ impl TraceRunner for FusedReduce { .map(|(i, s)| if i == self.axis { 1 } else { *s as u32 }) .product(); - let line_mode = match self.axis == config_read.rank as usize - 1 { - true => LineMode::Parallel, - false => LineMode::Perpendicular, + let line_mode = if self.axis == config_read.rank as usize - 1 { + LineMode::Parallel + } else { + LineMode::Perpendicular }; let config_reduce = ReduceConfig { cube_count: CubeCount::new_single(), cube_dim: CubeDim::new_single(), line_mode, - line_size_input: config_read.width as u32, - line_size_output: config_write.width as u32, + line_size_input: u32::from(config_read.width), + line_size_output: u32::from(config_write.width), bound_checks: false, bound_checks_inner: if strategy.use_planes { BoundChecksInner::Branch @@ -407,7 +409,7 @@ fn launch_reduce_mixed_precision( ReduceInstruction::Min => ReduceFnConfig::Min, ReduceInstruction::MaxAbs => ReduceFnConfig::MaxAbs, }; - launch_reduce::(kwargs, config, dtype_input, dtype_output, dtype_acc) + launch_reduce::(kwargs, config, dtype_input, dtype_output, dtype_acc); } fn launch_reduce( diff --git a/crates/burn-cubecl-fusion/src/shared/builder.rs b/crates/burn-cubecl-fusion/src/shared/builder.rs index af74bad9be..6921d06e0d 100644 --- a/crates/burn-cubecl-fusion/src/shared/builder.rs +++ b/crates/burn-cubecl-fusion/src/shared/builder.rs @@ -76,7 +76,7 @@ impl OptimizationBuilder for FuseOptimizationBuilder { self.status = OptimizationStatus::Closed; return; } - }; + } self.status = OptimizationStatus::Open; self.num_ops += 1; @@ -181,10 +181,10 @@ impl FuseOptimizationBuilder { input }); - let args = if !is_success { - return None; - } else { + let args = if is_success { args.map(|arg| arg.unwrap()) + } else { + return None; }; let current_output_shape = core::mem::take(&mut self.current_output_shape); diff --git a/crates/burn-cubecl-fusion/src/shared/io.rs b/crates/burn-cubecl-fusion/src/shared/io.rs index 932bf82bd2..6b780e81b4 100644 --- a/crates/burn-cubecl-fusion/src/shared/io.rs +++ b/crates/burn-cubecl-fusion/src/shared/io.rs @@ -1,4 +1,11 @@ -use super::{DYN_ELEM_ID, ir::*, tensor::GlobalTensor}; +use super::{ + DYN_ELEM_ID, + ir::{ + Arg, FuseBlockConfig, FusePrecision, GlobalArgs, LayoutInfo, LocalArgs, RefLayout, + VirtualLayout, + }, + tensor::GlobalTensor, +}; use cubecl::{ intrinsic, ir::{ExpandElement, Variable}, @@ -27,7 +34,7 @@ pub fn read( let global = inputs.tensors.index(pos); let line_size = global.tensor.line_size(); - if comptime![!global.broadcasted && line_size != config.width as u32] { + if comptime![!global.broadcasted && line_size != u32::from(config.width)] { read_input_aligned(inputs, locals, pos, ref_pos, layout, config, None) } else { read_input(inputs, locals, pos, ref_pos, layout, config, None) @@ -68,7 +75,7 @@ pub fn read( let global = inputs.tensors.index(pos); let line_size = global.tensor.line_size(); - if comptime![!broadcasted && line_size != config.width as u32] { + if comptime![!broadcasted && line_size != u32::from(config.width)] { read_input_aligned( inputs, locals, @@ -101,7 +108,7 @@ pub fn read( let global = inputs.tensors.index(pos); let line_size = global.tensor.line_size(); - if comptime![!broadcasted && line_size != config.width as u32] { + if comptime![!broadcasted && line_size != u32::from(config.width)] { read_input_aligned( inputs, locals, @@ -188,16 +195,16 @@ pub fn read_input_aligned( #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, ) -> Line { - let mut result: Line = Line::::empty(comptime![config.width as u32]); + let mut result: Line = Line::::empty(comptime![u32::from(config.width)]); let tensor = inputs.tensors.index(pos); match comptime![transform.clone()] { Some(Transform::Reshape(shape)) => { // Very brute force, not really efficient, but not easy to optimize and not a very // frequent workflow. - let ref_pos = ref_pos * comptime![config.width as u32]; + let ref_pos = ref_pos * comptime![u32::from(config.width)]; #[unroll] - for i in 0u32..comptime!(config.width as u32) { + for i in 0u32..comptime!(u32::from(config.width)) { let index = reshaped_index( inputs, locals, @@ -216,7 +223,7 @@ pub fn read_input_aligned( let stride = tensor.tensor.stride(comptime![i]); #[unroll] - for i in 0u32..comptime!(config.width as u32) { + for i in 0u32..comptime!(u32::from(config.width)) { let index = offset + i * stride; result[i] = C::cast_from(tensor.tensor[index][0]) } @@ -226,7 +233,7 @@ pub fn read_input_aligned( get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform); let stride = tensor.tensor.stride(comptime![config.rank - 1]); #[unroll] - for i in 0u32..comptime!(config.width as u32) { + for i in 0u32..comptime!(u32::from(config.width)) { let index = offset + i * stride; result[i] = C::cast_from(tensor.tensor[index][0]) } @@ -309,7 +316,7 @@ pub fn write( } Arg::Local(pos, precision) => match comptime![precision] { FusePrecision::F32 | FusePrecision::Flex32 => { - locals.l_f32.insert(pos, Line::cast_from(value)) + locals.l_f32.insert(pos, Line::cast_from(value)); } FusePrecision::F16 => locals.l_f16.insert(pos, Line::cast_from(value)), FusePrecision::BF16 => locals.l_bf16.insert(pos, Line::cast_from(value)), diff --git a/crates/burn-cubecl-fusion/src/shared/kernel.rs b/crates/burn-cubecl-fusion/src/shared/kernel.rs index 9f89907a17..e6e4e011b8 100644 --- a/crates/burn-cubecl-fusion/src/shared/kernel.rs +++ b/crates/burn-cubecl-fusion/src/shared/kernel.rs @@ -1,7 +1,13 @@ use crate::shared::DYN_ELEM_ID; -use super::io::*; -use super::ir::*; +use super::io::{ + global_offset, global_stride, read, read_input, read_scalar_shape, reverse_index, + swap_dims_transform, write, +}; +use super::ir::{ + Arg, BinaryFuseArgs, FuseBlockConfig, FuseOp, GlobalArgs, LayoutInfo, LocalArgs, RefLayout, + UnaryFuseArgs, VirtualLayout, +}; use cubecl::prelude::*; #[cube] @@ -49,11 +55,11 @@ pub fn fuse_on_read( let value = read::(inputs, outputs, locals, read_pos, arg, config); let value_line_size = value.line_size(); - let output_line_size = comptime!(config.width as u32); + let output_line_size = comptime!(u32::from(config.width)); // We currently don't support broadcasting __across__ blocks. if comptime!(value_line_size != output_line_size) { - let mut tmp = Line::::empty(comptime!(config.width as u32)); + let mut tmp = Line::::empty(comptime!(u32::from(config.width))); comptime!( assert_eq!(value_line_size, 1, "The input line_size must be 1 or the same as the config width."); ); @@ -61,7 +67,7 @@ pub fn fuse_on_read( let val = value[0]; #[unroll] - for i in 0..comptime!(config.width as u32) { + for i in 0..comptime!(u32::from(config.width)) { tmp[i] = val; } @@ -205,67 +211,67 @@ fn fuse( match op { FuseOp::Add(op) => { - add::>(inputs, outputs, locals, pos, op, config) + add::>(inputs, outputs, locals, pos, op, config); } FuseOp::Div(op) => { - div::>(inputs, outputs, locals, pos, op, config) + div::>(inputs, outputs, locals, pos, op, config); } FuseOp::Sub(op) => { - sub::>(inputs, outputs, locals, pos, op, config) + sub::>(inputs, outputs, locals, pos, op, config); } FuseOp::Mul(op) => { - mul::>(inputs, outputs, locals, pos, op, config) + mul::>(inputs, outputs, locals, pos, op, config); } FuseOp::Powf(op) => { - powf::>(inputs, outputs, locals, pos, op, config) + powf::>(inputs, outputs, locals, pos, op, config); } FuseOp::Erf(op) => { - erf::>(inputs, outputs, locals, pos, op, config) + erf::>(inputs, outputs, locals, pos, op, config); } FuseOp::Sqrt(op) => { - sqrt::>(inputs, outputs, locals, pos, op, config) + sqrt::>(inputs, outputs, locals, pos, op, config); } FuseOp::Abs(op) => { - abs::>(inputs, outputs, locals, pos, op, config) + abs::>(inputs, outputs, locals, pos, op, config); } FuseOp::Log(op) => { - log::>(inputs, outputs, locals, pos, op, config) + log::>(inputs, outputs, locals, pos, op, config); } FuseOp::Log1p(op) => { - log1p::>(inputs, outputs, locals, pos, op, config) + log1p::>(inputs, outputs, locals, pos, op, config); } FuseOp::Recip(op) => { - recip::>(inputs, outputs, locals, pos, op, config) + recip::>(inputs, outputs, locals, pos, op, config); } FuseOp::Assign(op) => { - assign::>(inputs, outputs, locals, pos, op, config) + assign::>(inputs, outputs, locals, pos, op, config); } FuseOp::Exp(op) => { - exp::>(inputs, outputs, locals, pos, op, config) + exp::>(inputs, outputs, locals, pos, op, config); } FuseOp::Cos(op) => { - cos::>(inputs, outputs, locals, pos, op, config) + cos::>(inputs, outputs, locals, pos, op, config); } FuseOp::Sin(op) => { - sin::>(inputs, outputs, locals, pos, op, config) + sin::>(inputs, outputs, locals, pos, op, config); } FuseOp::Tanh(op) => { - tanh::>(inputs, outputs, locals, pos, op, config) + tanh::>(inputs, outputs, locals, pos, op, config); } FuseOp::Equal(op) => { - equal::>(inputs, outputs, locals, pos, op, config) + equal::>(inputs, outputs, locals, pos, op, config); } FuseOp::Greater(op) => { - greater::>(inputs, outputs, locals, pos, op, config) + greater::>(inputs, outputs, locals, pos, op, config); } FuseOp::GreaterEqual(op) => greater_equal::>( inputs, outputs, locals, pos, op, config, ), FuseOp::Lower(op) => { - lower::>(inputs, outputs, locals, pos, op, config) + lower::>(inputs, outputs, locals, pos, op, config); } FuseOp::LowerEqual(op) => { - lower_equal::>(inputs, outputs, locals, pos, op, config) + lower_equal::>(inputs, outputs, locals, pos, op, config); } FuseOp::ConditionalAssign { cond, diff --git a/crates/burn-cubecl-fusion/src/shared/settings.rs b/crates/burn-cubecl-fusion/src/shared/settings.rs index e3ba867beb..92f5f4179c 100644 --- a/crates/burn-cubecl-fusion/src/shared/settings.rs +++ b/crates/burn-cubecl-fusion/src/shared/settings.rs @@ -21,9 +21,9 @@ pub struct FuseSettings { #[derive(Clone, Copy, Debug, Serialize, Deserialize)] /// How vectorization is handled during fusion. pub enum VectorizationSetting { - /// The biggest line_size possible will be used. + /// The biggest `line_size` possible will be used. Activated, - /// Equivalent to using line_size of one. + /// Equivalent to using `line_size` of one. Deactivated, /// This is a good setting when a block processes values calculated from a previous block. SmallerOrEqualThanPreviousBlock, diff --git a/crates/burn-cubecl-fusion/src/shared/tensor.rs b/crates/burn-cubecl-fusion/src/shared/tensor.rs index 7f63125cf8..91c20b8128 100644 --- a/crates/burn-cubecl-fusion/src/shared/tensor.rs +++ b/crates/burn-cubecl-fusion/src/shared/tensor.rs @@ -267,7 +267,7 @@ impl LaunchArg for GlobalTensor { impl ArgSettings for GlobalTensorArg<'_, R> { fn register(&self, launcher: &mut KernelLauncher) { - launcher.register_tensor(&self.tensor) + launcher.register_tensor(&self.tensor); } } diff --git a/crates/burn-cubecl-fusion/src/shared/trace/base.rs b/crates/burn-cubecl-fusion/src/shared/trace/base.rs index 42dc3dd2db..9caf6762ab 100644 --- a/crates/burn-cubecl-fusion/src/shared/trace/base.rs +++ b/crates/burn-cubecl-fusion/src/shared/trace/base.rs @@ -49,7 +49,7 @@ impl TuneOutput { TuneOutput::Checked { handles } => match other { TuneOutput::UnChecked(..) => {} TuneOutput::Checked { handles: o } => { - for (k, v) in o.into_iter() { + for (k, v) in o { handles.insert(k, v); } } @@ -73,7 +73,7 @@ impl cubecl::tune::AutotuneOutput for TuneOutput { ) = (self, &other) { let mut num_checked = 0; - for (id, (shape, handle)) in handles_ref.iter() { + for (id, (shape, handle)) in handles_ref { if let Some((shape_other, other)) = handles.get(id) { assert_eq!( handle.strides, other.strides, @@ -88,7 +88,7 @@ impl cubecl::tune::AutotuneOutput for TuneOutput { match handle.dtype { DType::F64 => { - data_ref.assert_approx_eq::(&data_other, Tolerance::permissive()) + data_ref.assert_approx_eq::(&data_other, Tolerance::permissive()); } DType::F32 => { data_ref.assert_approx_eq::(&data_other, Tolerance::permissive()) diff --git a/crates/burn-cubecl-fusion/src/shared/trace/block.rs b/crates/burn-cubecl-fusion/src/shared/trace/block.rs index 6e7d0a5e4f..04684d6a1a 100644 --- a/crates/burn-cubecl-fusion/src/shared/trace/block.rs +++ b/crates/burn-cubecl-fusion/src/shared/trace/block.rs @@ -61,16 +61,15 @@ impl FuseBlockBuilder { _ => precision, }; - let out = match self.locals.get(precision, tensor.id) { - Some(local) => local, - None => { - let out = self.locals.create(precision, tensor.id); + let out = if let Some(local) = self.locals.get(precision, tensor.id) { + local + } else { + let out = self.locals.create(precision, tensor.id); - self.outputs.insert(precision_output, tensor.clone()); - resources.outputs.insert(precision_output, tensor.clone()); + self.outputs.insert(precision_output, tensor.clone()); + resources.outputs.insert(precision_output, tensor.clone()); - out - } + out }; Some(out) @@ -90,35 +89,32 @@ impl FuseBlockBuilder { _ => precision, }; - let arg = match self.locals.get(precision, tensor.id) { - Some(local) => { - resources.inputs.update(tensor); - // An input can be an output of a previously fused operation. - // We need to flag the new status for the tensor. - resources.outputs.update(tensor); - self.outputs.update(tensor); + let arg = if let Some(local) = self.locals.get(precision, tensor.id) { + resources.inputs.update(tensor); + // An input can be an output of a previously fused operation. + // We need to flag the new status for the tensor. + resources.outputs.update(tensor); + self.outputs.update(tensor); - local - } - None => { - let new_input = resources.inputs.insert(precision_input, tensor.clone()); - let out = self.locals.create(precision, tensor.id); - let input = Arg::Input(new_input, precision_input, LayoutInfo::Unknown); - - let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) { - e.insert(Vec::with_capacity(1)); - self.reads.get_mut(&tensor.id).unwrap() - } else { - self.reads.get_mut(&tensor.id).unwrap() - }; - - reads.push(FuseOp::Assign(UnaryFuseArgs { - input, - out: out.clone(), - })); - - out - } + local + } else { + let new_input = resources.inputs.insert(precision_input, tensor.clone()); + let out = self.locals.create(precision, tensor.id); + let input = Arg::Input(new_input, precision_input, LayoutInfo::Unknown); + + let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) { + e.insert(Vec::with_capacity(1)); + self.reads.get_mut(&tensor.id).unwrap() + } else { + self.reads.get_mut(&tensor.id).unwrap() + }; + + reads.push(FuseOp::Assign(UnaryFuseArgs { + input, + out: out.clone(), + })); + + out }; Some(arg) @@ -499,17 +495,17 @@ impl FuseBlockBuilder { }; // For all operators, mark their local tensor id in the proper set. - for (_, ops) in self.reads.iter() { + for ops in self.reads.values() { for op in ops { mark_op(op); } } - for op in self.ops.iter() { + for op in &self.ops { mark_op(op); } - for arg in self.outputs_unhandled.iter() { + for arg in &self.outputs_unhandled { mark(arg, &mut local_tensor_ids_output); } @@ -554,7 +550,7 @@ impl LocalVariablePool { } fn get_any_precision(&self, tensor_id: TensorId) -> Option { - for (precision, indexes) in self.values.iter() { + for (precision, indexes) in &self.values { if let Some(index) = indexes.get(&tensor_id) { return Some(Arg::Local(*index, *precision)); } diff --git a/crates/burn-cubecl-fusion/src/shared/trace/builder.rs b/crates/burn-cubecl-fusion/src/shared/trace/builder.rs index 7bc5b8998f..7b06087f98 100644 --- a/crates/burn-cubecl-fusion/src/shared/trace/builder.rs +++ b/crates/burn-cubecl-fusion/src/shared/trace/builder.rs @@ -44,7 +44,7 @@ impl FuseTraceBuilder { pub fn num_ops_fused(&self) -> u32 { let mut num_ops_fused = 0; - for (block, _) in self.blocks_previous.iter() { + for (block, _) in &self.blocks_previous { num_ops_fused += block.ops.len(); } @@ -63,7 +63,7 @@ impl FuseTraceBuilder { // but should never return less. pub fn estimate_bindings(&self) -> u32 { let mut estimation = 1; // Metadata takes one. - for b in self.blocks_previous.iter() { + for b in &self.blocks_previous { estimation += b.0.estimate_num_outputs(&self.resources); } estimation += self.block_current.estimate_num_outputs(&self.resources); @@ -97,9 +97,10 @@ impl FuseTraceBuilder { /// /// It is therefore the responsibility of the operation to read the given tensor. pub fn input_unhandled(&mut self, tensor: &TensorIr) -> Arg { - if self.resources.indexed.contains_key(&tensor.id) { - panic!("Can't add a new input that is already used in an index operation"); - } + assert!( + !self.resources.indexed.contains_key(&tensor.id), + "Can't add a new input that is already used in an index operation" + ); let precision = tensor.dtype.into(); @@ -132,7 +133,7 @@ impl FuseTraceBuilder { pub fn input_indexed(&mut self, tensor: &TensorIr) -> Option { if let Some(val) = self.resources.indexed.get(&tensor.id) { return Some(val.clone()); - }; + } if self.resources.inputs.get(tensor.id).is_some() { return None; @@ -202,7 +203,7 @@ impl FuseTraceBuilder { let mut offset = 0; - for (block, shape_ref) in self.blocks_previous.iter() { + for (block, shape_ref) in &self.blocks_previous { offset += register_block(block, shape_ref, offset); } register_block(&self.block_current, &shape_ref, offset); diff --git a/crates/burn-cubecl-fusion/src/shared/trace/executor.rs b/crates/burn-cubecl-fusion/src/shared/trace/executor.rs index 434c0385b2..30c7056b1f 100644 --- a/crates/burn-cubecl-fusion/src/shared/trace/executor.rs +++ b/crates/burn-cubecl-fusion/src/shared/trace/executor.rs @@ -52,7 +52,7 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { scalars: &mut ScalarIds, ) -> Result, ExecutionError> { let mut num_writes = 0; - for b in plan.blocks.iter() { + for b in &plan.blocks { num_writes += b.writes.len(); } @@ -88,7 +88,7 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { let reference = match block_plan.reference { ReferenceSelection::Concrete { layout, .. } => RefLayout::Concrete(layout), ReferenceSelection::VirtualShape { original, .. } => { - RefLayout::Virtual(VirtualLayout::Shape(original, block_plan.width as u32)) + RefLayout::Virtual(VirtualLayout::Shape(original, u32::from(block_plan.width))) } ReferenceSelection::SwapDims { original, dims } => { RefLayout::Virtual(VirtualLayout::SwapDims(original, dims)) @@ -113,7 +113,7 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { } } - for op in block.ops.iter() { + for op in &block.ops { ops.push(op.clone()); } @@ -146,7 +146,7 @@ fn register_inputs<'h, R: Runtime>( handle_inputs: &'h [HandleInput], inputs: &mut GlobalArgsLaunch<'h, R>, ) { - for hi in handle_inputs.iter() { + for hi in handle_inputs { let arg = hi.handle.as_tensor_arg(&hi.global_shape, hi.vectorization); inputs.tensors.push(GlobalTensorArg::new( arg, @@ -161,7 +161,7 @@ fn register_outputs<'s, BT: CubeElement, R: Runtime>( outputs: &mut GlobalArgsLaunch<'s, R>, #[allow(unused_variables)] tune_output: &mut TuneOutput, ) { - for item in handle_outputs.iter() { + for item in handle_outputs { match item { HandleOutput::Alias { input_pos, @@ -315,7 +315,7 @@ fn register_scalars<'h, R: Runtime>( if let TensorView::Reshape { reshaped, .. } = relative { let global = context.tensors.get(reshaped).unwrap(); - for shape in global.shape.iter() { + for shape in &global.shape { inputs.reshapes.push(ScalarArg::new(*shape as u32)); } } diff --git a/crates/burn-cubecl-fusion/src/shared/trace/input.rs b/crates/burn-cubecl-fusion/src/shared/trace/input.rs index 84b23f1959..07b729bfe1 100644 --- a/crates/burn-cubecl-fusion/src/shared/trace/input.rs +++ b/crates/burn-cubecl-fusion/src/shared/trace/input.rs @@ -70,7 +70,7 @@ impl<'a, R: Runtime> InputPlanner<'a, R> { { let mut is_a_view = false; // For each view we try to see if it's not possible to set it as a reference input. - for view in self.resources.views.iter() { + for view in &self.resources.views { for (block_plan, block) in plan.blocks.iter_mut().zip(self.blocks) { is_a_view = is_a_view || Self::analyze_view(pos, tensor_relative, block, block_plan, view); @@ -194,7 +194,7 @@ impl<'a, R: Runtime> InputPlanner<'a, R> { return true; } } - }; + } false } diff --git a/crates/burn-cubecl-fusion/src/shared/trace/output.rs b/crates/burn-cubecl-fusion/src/shared/trace/output.rs index b077eab377..3b4a795d33 100644 --- a/crates/burn-cubecl-fusion/src/shared/trace/output.rs +++ b/crates/burn-cubecl-fusion/src/shared/trace/output.rs @@ -94,7 +94,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { let mut outputs = Vec::new(); core::mem::swap(&mut outputs, &mut self.outputs_sorted); - for output in outputs.into_iter() { + for output in outputs { let tensor_global = context .tensors .get(&output.tensor_relative.id) @@ -154,14 +154,14 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { } for (i, block) in plan.blocks.iter_mut().enumerate() { - if !block.reference.is_found() { + if block.reference.is_found() { + Self::add_layout_info_inputs(block, &plan.handle_inputs); + } else { Self::select_reference_from_inputs( self.blocks[i].settings.ref_layout, block, &plan.handle_inputs, ); - } else { - Self::add_layout_info_inputs(block, &plan.handle_inputs); } } } @@ -204,9 +204,9 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { RefLayoutSetting::Any => set_ref_as_concrete(block), RefLayoutSetting::OnlyContiguous => { if is_contiguous(&reference.global_shape, &reference.handle.strides) { - set_ref_as_concrete(block) + set_ref_as_concrete(block); } else { - set_ref_as_virtual(block) + set_ref_as_virtual(block); } } } @@ -227,14 +227,14 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { InputReference::Reshaped { reshape_pos } => { block.reference = ReferenceSelection::Reshaped { reshape_pos }; } - }; + } } else { block.reference = ReferenceSelection::NotFound; } } fn add_layout_info_inputs(block: &mut BlockPlan<'_>, handle_inputs: &[HandleInput]) { - for hi in handle_inputs.iter() { + for hi in handle_inputs { if let ReferenceSelection::Concrete { strides, shape, .. } | ReferenceSelection::VirtualShape { strides, shape, .. } = &block.reference { @@ -285,8 +285,9 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { && pi.strides == strides && block.reference.compatible_strides_for_inplace(strides) }) - .map(|(pos, _)| OutputKind::Inplace { input_pos: pos }) - .unwrap_or(OutputKind::Normal); + .map_or(OutputKind::Normal, |(pos, _)| OutputKind::Inplace { + input_pos: pos, + }); (kind, block_idx) } @@ -304,7 +305,12 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { let potential_inplace = block.potential_inplaces.remove(input_index); let handle_input = plan.handle_inputs.get(potential_inplace.input_pos).unwrap(); - if !block.reference.is_found() { + if block.reference.is_found() { + // Already validated, necessary for correctness. + if let Some(FuseOp::Assign(op)) = block.writes.get_mut(&output.tensor_relative.id) { + op.out.add_layout_info(LayoutInfo::SameAsRef); + } + } else { let index_input = self .resources .inputs @@ -321,18 +327,13 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { for op in ops.iter_mut() { if let FuseOp::Assign(op) = op { op.input.add_layout_info(LayoutInfo::IsRef); - }; + } } } if let Some(FuseOp::Assign(op)) = block.writes.get_mut(&output.tensor_relative.id) { op.out.add_layout_info(LayoutInfo::IsRef); - }; - } else { - // Already validated, necessary for correctness. - if let Some(FuseOp::Assign(op)) = block.writes.get_mut(&output.tensor_relative.id) { - op.out.add_layout_info(LayoutInfo::SameAsRef); - }; + } } context @@ -382,7 +383,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { // Sometimes outputs that are manually handled don't have any write registered. if let Some(FuseOp::Assign(op)) = block.writes.get_mut(&output.tensor_relative.id) { op.out.add_layout_info(LayoutInfo::IsRef); - }; + } } else if let ReferenceSelection::Concrete { shape: ref_shape, strides: ref_strides, @@ -394,7 +395,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { block.writes.get_mut(&output.tensor_relative.id).unwrap() { op.out.add_layout_info(LayoutInfo::SameAsRef); - }; + } } } diff --git a/crates/burn-cubecl-fusion/src/shared/trace/plan.rs b/crates/burn-cubecl-fusion/src/shared/trace/plan.rs index 641a8fcbb8..7be27ee401 100644 --- a/crates/burn-cubecl-fusion/src/shared/trace/plan.rs +++ b/crates/burn-cubecl-fusion/src/shared/trace/plan.rs @@ -114,7 +114,7 @@ impl LaunchPlan<'_, R> { let mut rank = 0; let mut blocks = Vec::with_capacity(fuse_blocks.len()); - for b in fuse_blocks.iter() { + for b in fuse_blocks { rank = usize::max(b.shape_ref.len(), rank); let block = BlockPlan { reference: ReferenceSelection::Searching, diff --git a/crates/burn-cubecl-fusion/src/shared/trace/runner.rs b/crates/burn-cubecl-fusion/src/shared/trace/runner.rs index b6b4d0fd4b..132b83da6b 100644 --- a/crates/burn-cubecl-fusion/src/shared/trace/runner.rs +++ b/crates/burn-cubecl-fusion/src/shared/trace/runner.rs @@ -54,7 +54,7 @@ pub trait Vectorization { ref_elem, max, axis, - ) + ); } } @@ -126,18 +126,18 @@ fn vectorization_default<'a, R: Runtime>( let shape_axis = original.shape[original.shape.len() - 1]; for s in R::line_size_elem(ref_elem) { - if !multi_reads { - // The last dimension should be a multiple of the vector size or broadcated. - if reshape_axis % s as usize == 0 && s <= max { - return Vect::Aligned(s); - } - } else { + if multi_reads { // Since the original tensor must share the same vectorization factor as the // reshaped tensor, they must have compatible shapes when both are access // independently. if reshape_axis % s as usize == 0 && shape_axis % s as usize == 0 && s <= max { return Vect::Aligned(s); } + } else { + // The last dimension should be a multiple of the vector size or broadcated. + if reshape_axis % s as usize == 0 && s <= max { + return Vect::Aligned(s); + } } } @@ -220,7 +220,7 @@ fn multi_reads_vectorization_update( original: TensorId, vect: Vect, ) { - if let Some(ori_vect) = vectorizations.get(&original).cloned() { + if let Some(ori_vect) = vectorizations.get(&original).copied() { match ori_vect { Vect::Broadcasted => { // keep the original as is. @@ -230,11 +230,11 @@ fn multi_reads_vectorization_update( vectorizations.insert(original, Vect::Aligned(1)); } Vect::Aligned(new) => { - let val = if new != ori { 1 } else { new }; + let val = if new == ori { new } else { 1 }; vectorizations.insert(original, Vect::Aligned(val)); } }, - }; + } } else { vectorizations.insert(original, vect); } diff --git a/crates/burn-cubecl-fusion/src/shared/trace/vectorization.rs b/crates/burn-cubecl-fusion/src/shared/trace/vectorization.rs index 226bbc0c1c..195ff0d1c6 100644 --- a/crates/burn-cubecl-fusion/src/shared/trace/vectorization.rs +++ b/crates/burn-cubecl-fusion/src/shared/trace/vectorization.rs @@ -37,8 +37,8 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { ) { let has_multiple_read = |tensor: &TensorId| { let mut read_count = 0; - for block in plan.blocks.iter() { - read_count += block.reads.get(tensor).map(|a| a.len()).unwrap_or(0); + for block in &plan.blocks { + read_count += block.reads.get(tensor).map_or(0, std::vec::Vec::len); } read_count > 1 }; @@ -69,7 +69,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { let mut ref_elem = (Elem::UInt(UIntKind::U64), 8); - for r in plan.global_inputs.iter() { + for r in &plan.global_inputs { let elem: Elem = r.dtype.into(); let elem_size = elem.size(); @@ -77,7 +77,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { ref_elem = (elem, elem_size); } } - for r in plan.global_outputs.iter() { + for r in &plan.global_outputs { let elem: Elem = r.dtype.into(); let elem_size = elem.size(); diff --git a/crates/burn-cubecl-fusion/src/tune.rs b/crates/burn-cubecl-fusion/src/tune.rs index aa5794ec45..9a66c368f6 100644 --- a/crates/burn-cubecl-fusion/src/tune.rs +++ b/crates/burn-cubecl-fusion/src/tune.rs @@ -30,7 +30,7 @@ pub struct TuneInput { /// The wrapper removes the context lifetime. /// /// For it to be correct, the context must not be used after the invocation of the -/// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are +/// [`cubecl::tune::LocalTuner::execute`] function. This is the case, since autotune functions are /// tuned using a cloned version of the input; therefore, a fork of the context will be used to find /// the best kernel to use, which can be async. enum UnsafeTuneContext { @@ -71,7 +71,7 @@ impl UnsafeTuneContext { // It is necessary for the lifetime. #[allow(clippy::unnecessary_cast)] - Self::Original(ptr as *mut Context<'static, _>) + Self::Original(ptr.cast::>()) } fn get(&self) -> TuneContext<'static, R> { diff --git a/crates/burn-cubecl/Cargo.toml b/crates/burn-cubecl/Cargo.toml index 53be661c82..9c9056d11a 100644 --- a/crates/burn-cubecl/Cargo.toml +++ b/crates/burn-cubecl/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cubecl" version.workspace = true +[lints] +workspace = true + [features] autotune = ["burn-cubecl-fusion?/autotune"] autotune-checks = ["burn-cubecl-fusion?/autotune-checks"] diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index 993af7d3ef..5ba3c3890a 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda" version.workspace = true +[lints] +workspace = true + [features] autotune = ["burn-cubecl/autotune"] autotune-checks = ["burn-cubecl/autotune-checks"] diff --git a/crates/burn-dataset/Cargo.toml b/crates/burn-dataset/Cargo.toml index c9fc1b43fd..2c328ee18f 100644 --- a/crates/burn-dataset/Cargo.toml +++ b/crates/burn-dataset/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-dataset" documentation = "https://docs.rs/burn-dataset" version.workspace = true +[lints] +workspace = true + [features] default = ["sqlite-bundled"] doc = ["default"] diff --git a/crates/burn-dataset/src/dataset/fake.rs b/crates/burn-dataset/src/dataset/fake.rs index c27f8cf0d3..4e13737ed0 100644 --- a/crates/burn-dataset/src/dataset/fake.rs +++ b/crates/burn-dataset/src/dataset/fake.rs @@ -8,6 +8,7 @@ pub struct FakeDataset { impl> FakeDataset { /// Create a new fake dataset with the given size. + #[must_use] pub fn new(size: usize) -> Self { let mut items = Vec::with_capacity(size); for _ in 0..size { diff --git a/crates/burn-dataset/src/dataset/in_memory.rs b/crates/burn-dataset/src/dataset/in_memory.rs index 13854cd0d1..b567832b50 100644 --- a/crates/burn-dataset/src/dataset/in_memory.rs +++ b/crates/burn-dataset/src/dataset/in_memory.rs @@ -15,6 +15,7 @@ pub struct InMemDataset { impl InMemDataset { /// Creates a new in memory dataset from the given items. + #[must_use] pub fn new(items: Vec) -> Self { InMemDataset { items } } diff --git a/crates/burn-dataset/src/dataset/sqlite.rs b/crates/burn-dataset/src/dataset/sqlite.rs index ad3230d619..69f61f1210 100644 --- a/crates/burn-dataset/src/dataset/sqlite.rs +++ b/crates/burn-dataset/src/dataset/sqlite.rs @@ -62,11 +62,11 @@ impl From<&'static str> for SqliteDatasetError { } } -/// This struct represents a dataset where all items are stored in an SQLite database. -/// Each instance of this struct corresponds to a specific table within the SQLite database, +/// This struct represents a dataset where all items are stored in an `SQLite` database. +/// Each instance of this struct corresponds to a specific table within the `SQLite` database, /// and allows for interaction with the data stored in the table in a structured and typed manner. /// -/// The SQLite database must contain a table with the same name as the `split` field. This table should +/// The `SQLite` database must contain a table with the same name as the `split` field. This table should /// have a primary key column named `row_id`, which is used to index the rows in the table. The `row_id` /// should start at 1, while the corresponding dataset `index` should start at 0, i.e., `row_id` = `index` + 1. /// @@ -82,7 +82,7 @@ impl From<&'static str> for SqliteDatasetError { /// /// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table /// should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields -/// that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using +/// that cannot be mapped to a `SQLite` type, such as nested structs, vectors, etc. The serialization is done using /// [MessagePack](https://msgpack.org/). /// /// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate @@ -100,7 +100,7 @@ pub struct SqliteDataset { } impl SqliteDataset { - /// Initializes a `SqliteDataset` from a SQLite database file and a split name. + /// Initializes a `SqliteDataset` from a `SQLite` database file and a split name. pub fn from_db_file>(db_file: P, split: &str) -> Result { // Create a connection pool let conn_pool = create_conn_pool(&db_file, false)?; @@ -130,7 +130,7 @@ impl SqliteDataset { }) } - /// Returns true if table has two columns: row_id (integer) and item (blob). + /// Returns true if table has two columns: `row_id` (integer) and item (blob). /// /// This is used to determine if the table is row serialized or not. fn check_if_row_serialized( @@ -170,23 +170,25 @@ impl SqliteDataset { columns.push(column?); } - if columns.len() != 2 { - Ok(false) - } else { + if columns.len() == 2 { // Check if the column names and types match the expected values Ok(columns[0].name == "row_id" && columns[0].ty == "integer" && columns[1].name == "item" && columns[1].ty == "blob") + } else { + Ok(false) } } /// Get the database file name. + #[must_use] pub fn db_file(&self) -> PathBuf { self.db_file.clone() } /// Get the split name. + #[must_use] pub fn split(&self) -> &str { self.split.as_str() } @@ -278,7 +280,7 @@ fn create_conn_pool>( Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool) } -/// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets. +/// The `SqliteDatasetStorage` struct represents a `SQLite` database for storing datasets. /// It consists of an optional name, a database file path, and a base directory for storage. #[derive(Clone, Debug)] pub struct SqliteDatasetStorage { @@ -293,6 +295,7 @@ impl SqliteDatasetStorage { /// # Arguments /// /// * `name` - A string slice that holds the name of the dataset. + #[must_use] pub fn from_name(name: &str) -> Self { SqliteDatasetStorage { name: Some(name.to_string()), @@ -329,6 +332,7 @@ impl SqliteDatasetStorage { /// # Returns /// /// * A boolean value indicating whether the file exists or not. + #[must_use] pub fn exists(&self) -> bool { self.db_file().exists() } @@ -338,13 +342,13 @@ impl SqliteDatasetStorage { /// # Returns /// /// * A `PathBuf` instance representing the file path. + #[must_use] pub fn db_file(&self) -> PathBuf { - match &self.db_file { - Some(db_file) => db_file.clone(), - None => { - let name = sanitize(self.name.as_ref().expect("Name is not set")); - Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db")) - } + if let Some(db_file) = &self.db_file { + db_file.clone() + } else { + let name = sanitize(self.name.as_ref().expect("Name is not set")); + Self::base_dir(self.base_dir.clone()).join(format!("{name}.db")) } } @@ -357,18 +361,18 @@ impl SqliteDatasetStorage { /// # Returns /// /// * A `PathBuf` instance representing the base directory. + #[must_use] pub fn base_dir(base_dir: Option) -> PathBuf { - match base_dir { - Some(base_dir) => base_dir, - None => { - let home_dir = dirs::home_dir().expect("Could not get home directory"); + if let Some(base_dir) = base_dir { + base_dir + } else { + let home_dir = dirs::home_dir().expect("Could not get home directory"); - home_dir.join(".cache").join("burn-dataset") - } + home_dir.join(".cache").join("burn-dataset") } } - /// Provides a writer instance for the SQLite dataset. + /// Provides a writer instance for the `SQLite` dataset. /// /// # Arguments /// @@ -384,7 +388,7 @@ impl SqliteDatasetStorage { SqliteDatasetWriter::new(self.db_file(), overwrite) } - /// Provides a reader instance for the SQLite dataset. + /// Provides a reader instance for the `SQLite` dataset. /// /// # Arguments /// @@ -397,15 +401,13 @@ impl SqliteDatasetStorage { where I: Clone + Send + Sync + Serialize + DeserializeOwned, { - if !self.exists() { - panic!("The database file does not exist"); - } + assert!(self.exists(), "The database file does not exist"); SqliteDataset::from_db_file(self.db_file(), split) } } -/// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets. +/// This `SqliteDatasetWriter` struct is a `SQLite` database writer dedicated to storing datasets. /// It retains the current writer's state and its database connection. /// /// Being thread-safe, this writer can be concurrently used across multiple threads. @@ -544,7 +546,7 @@ where pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?; // Insert the serialized item into the database - let insert_statement = format!("insert into {split} (item) values (?)", split = split); + let insert_statement = format!("insert into {split} (item) values (?)"); conn.execute(insert_statement.as_str(), [serialized_item])?; // Get the primary key of the last inserted row and convert to index (row_id-1) @@ -613,7 +615,7 @@ where /// Runs a pragma update and ignores the `ExecuteReturnedResults` error. /// -/// Sometimes ExecuteReturnedResults is returned when running a pragma update. This is not an error +/// Sometimes `ExecuteReturnedResults` is returned when running a pragma update. This is not an error /// and can be ignored. This function runs the pragma update and ignores the error if it is /// `ExecuteReturnedResults`. fn pragma_update_with_error_handling( @@ -684,7 +686,7 @@ mod tests { let mut match_count = 0; for (_index, result) in indices.iter().zip(results.iter()) { if let Some(_val) = result { - match_count += 1 + match_count += 1; } } @@ -822,7 +824,7 @@ mod tests { (0..record_count).into_par_iter().for_each(|index: i64| { let thread_id: std::thread::ThreadId = std::thread::current().id(); let sample = Complex { - column_str: format!("test_{:?}_{}", thread_id, index), + column_str: format!("test_{thread_id:?}_{index}"), column_bytes: vec![index as u8, 2, 3], column_int: index, column_bool: true, diff --git a/crates/burn-dataset/src/source/huggingface/downloader.rs b/crates/burn-dataset/src/source/huggingface/downloader.rs index 48e41c7a04..b8cf237fcb 100644 --- a/crates/burn-dataset/src/source/huggingface/downloader.rs +++ b/crates/burn-dataset/src/source/huggingface/downloader.rs @@ -70,6 +70,7 @@ pub struct HuggingfaceDatasetLoader { impl HuggingfaceDatasetLoader { /// Create a huggingface dataset loader. + #[must_use] pub fn new(name: &str) -> Self { Self { name: name.to_string(), @@ -87,6 +88,7 @@ impl HuggingfaceDatasetLoader { /// The subset name must be one of the subsets listed in the dataset page. /// /// If no subset names are listed, then do not use this method. + #[must_use] pub fn with_subset(mut self, subset: &str) -> Self { self.subset = Some(subset.to_string()); self @@ -95,6 +97,7 @@ impl HuggingfaceDatasetLoader { /// Specify a base directory to store the dataset. /// /// If not specified, the dataset will be stored in `~/.cache/burn-dataset`. + #[must_use] pub fn with_base_dir(mut self, base_dir: &str) -> Self { self.base_dir = Some(base_dir.into()); self @@ -103,6 +106,7 @@ impl HuggingfaceDatasetLoader { /// Specify a huggingface token to download datasets behind authentication. /// /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens) + #[must_use] pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self { self.huggingface_token = Some(huggingface_token.to_string()); self @@ -111,6 +115,7 @@ impl HuggingfaceDatasetLoader { /// Specify a huggingface cache directory to store the downloaded datasets. /// /// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`. + #[must_use] pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self { self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string()); self @@ -119,8 +124,9 @@ impl HuggingfaceDatasetLoader { /// Specify a relative path to a subset of a dataset. This is used in some datasets for the /// manual steps of dataset download process. /// - /// Unless you've encountered a ManualDownloadError + /// Unless you've encountered a `ManualDownloadError` /// when loading your dataset you probably don't have to worry about this setting. + #[must_use] pub fn with_huggingface_data_dir(mut self, huggingface_data_dir: &str) -> Self { self.huggingface_data_dir = Some(huggingface_data_dir.to_string()); self @@ -129,6 +135,7 @@ impl HuggingfaceDatasetLoader { /// Specify whether or not to trust remote code. /// /// If not specified, trust remote code is set to true. + #[must_use] pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self { self.trust_remote_code = trust_remote_code; self @@ -162,7 +169,7 @@ impl HuggingfaceDatasetLoader { let db_file_name = if let Some(subset) = self.subset.clone() { format!("{}-{}.db", name, sanitize(subset.as_str())) } else { - format!("{}.db", name) + format!("{name}.db") }; let db_file = base_dir.join(db_file_name); @@ -263,7 +270,7 @@ fn check_python_version_is_3(python: &str) -> bool { /// get python3 name `python` `python3` or `py` fn get_python_name() -> Result<&'static str, ImporterError> { let python_name_list = ["python3", "python", "py"]; - for python_name in python_name_list.iter() { + for python_name in &python_name_list { if check_python_version_is_3(python_name) { return Ok(python_name); } @@ -298,7 +305,7 @@ fn install_python_deps(base_dir: &Path) -> Result { let mut handle = command.spawn().unwrap(); handle.wait().map_err(|err| { - ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)) + ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")) })?; // Check if the venv environment can be used successfully." if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { @@ -321,9 +328,9 @@ fn install_python_deps(base_dir: &Path) -> Result { // Spawn the pip install process and wait for it to complete. let mut handle = command.spawn().unwrap(); - handle.wait().map_err(|err| { - ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err)) - })?; + handle + .wait() + .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")))?; Ok(venv_python_path) } diff --git a/crates/burn-dataset/src/transform/composed.rs b/crates/burn-dataset/src/transform/composed.rs index 8f26bd5976..5e3314107a 100644 --- a/crates/burn-dataset/src/transform/composed.rs +++ b/crates/burn-dataset/src/transform/composed.rs @@ -13,7 +13,7 @@ where { fn get(&self, index: usize) -> Option { let mut current_index = 0; - for dataset in self.datasets.iter() { + for dataset in &self.datasets { if index < dataset.len() + current_index { return dataset.get(index - current_index); } @@ -23,7 +23,7 @@ where } fn len(&self) -> usize { let mut total = 0; - for dataset in self.datasets.iter() { + for dataset in &self.datasets { total += dataset.len(); } total diff --git a/crates/burn-dataset/src/transform/partial.rs b/crates/burn-dataset/src/transform/partial.rs index c8bd53f08b..d60377c25b 100644 --- a/crates/burn-dataset/src/transform/partial.rs +++ b/crates/burn-dataset/src/transform/partial.rs @@ -72,9 +72,10 @@ mod tests { let mut items_original_2 = HashSet::new(); let mut items_partial = HashSet::new(); dataset_original.iter().enumerate().for_each(|(i, item)| { - match i >= 10 { - true => items_original_2.insert(item), - false => items_original_1.insert(item), + if i >= 10 { + items_original_2.insert(item) + } else { + items_original_1.insert(item) }; }); @@ -99,9 +100,10 @@ mod tests { let mut items_partial = HashSet::new(); dataset_original.iter().enumerate().for_each(|(i, item)| { - match !(10..20).contains(&i) { - true => items_original_2.insert(item), - false => items_original_1.insert(item), + if !(10..20).contains(&i) { + items_original_2.insert(item) + } else { + items_original_1.insert(item) }; }); diff --git a/crates/burn-dataset/src/transform/sampler.rs b/crates/burn-dataset/src/transform/sampler.rs index fb154894dd..76728c83ea 100644 --- a/crates/burn-dataset/src/transform/sampler.rs +++ b/crates/burn-dataset/src/transform/sampler.rs @@ -1,6 +1,6 @@ use crate::Dataset; use rand::{Rng, SeedableRng, distr::Uniform, rngs::StdRng, seq::IteratorRandom}; -use std::{marker::PhantomData, ops::DerefMut, sync::Mutex}; +use std::{marker::PhantomData, sync::Mutex}; /// Sample items from a dataset. /// @@ -62,7 +62,7 @@ where fn index(&self) -> usize { let mut state = self.state.lock().unwrap(); - match state.deref_mut() { + match &mut *state { SamplerState::WithReplacement(rng) => { rng.sample(Uniform::new(0, self.dataset.len()).unwrap()) } diff --git a/crates/burn-dataset/src/transform/window.rs b/crates/burn-dataset/src/transform/window.rs index e60009f7ab..b295cfb930 100644 --- a/crates/burn-dataset/src/transform/window.rs +++ b/crates/burn-dataset/src/transform/window.rs @@ -174,7 +174,7 @@ mod tests { let dataset = InMemDataset::new(items.clone()); let expected = items .windows(3) - .map(|x| x.to_vec()) + .map(<[i32]>::to_vec) .collect::>>(); let result = dataset.windows(3).collect::>>(); @@ -188,7 +188,7 @@ mod tests { let dataset = InMemDataset::new(items.clone()); let expected = items .windows(3) - .map(|x| x.to_vec()) + .map(<[i32]>::to_vec) .collect::>>(); let result = WindowsDataset::new(dataset, 3) diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index f6a8201600..1d6876c8dc 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -85,7 +85,7 @@ pub struct SegmentationMask { /// Object detection bounding box annotation. #[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] pub struct BoundingBox { - /// Coordinates in [x_min, y_min, width, height] format. + /// Coordinates in [`x_min`, `y_min`, width, height] format. pub coords: [f32; 4], /// Box class label. @@ -266,8 +266,7 @@ fn parse_coco_bbox_annotations( if bbox_coords.len() < BBOX_MIN_NUM_VALUES { return Err(ImageLoaderError::ParsingError(format!( - "not enough bounding box coordinates in annotation for image {}", - image_id + "not enough bounding box coordinates in annotation for image {image_id}" ))); } @@ -334,8 +333,8 @@ fn parse_coco_images>( .unwrap_or_else(|| AnnotationRaw::BoundingBoxes(Vec::new())); images.push(ImageDatasetItemRaw { - annotation, image_path, + annotation, }); } } @@ -534,7 +533,7 @@ impl ImageFolderDataset { items.push(ImageDatasetItemRaw::new( image_path, AnnotationRaw::Label(label), - )) + )); } // Sort class names @@ -639,7 +638,7 @@ impl ImageFolderDataset { /// # Arguments /// /// * `annotations_json` - Path to the JSON file containing annotations in COCO format (for - /// example instances_train2017.json). + /// example `instances_train2017.json`). /// /// * `images_path` - Path containing the images matching the annotations JSON. /// @@ -650,9 +649,9 @@ impl ImageFolderDataset { images_path: I, ) -> Result { let file = fs::File::open(annotations_json) - .map_err(|e| ImageLoaderError::IOError(format!("Failed to open annotations: {}", e)))?; + .map_err(|e| ImageLoaderError::IOError(format!("Failed to open annotations: {e}")))?; let json: Value = serde_json::from_reader(file).map_err(|e| { - ImageLoaderError::ParsingError(format!("Failed to parse annotations: {}", e)) + ImageLoaderError::ParsingError(format!("Failed to parse annotations: {e}")) })?; let classes = parse_coco_classes(&json)?; @@ -683,7 +682,10 @@ impl ImageFolderDataset { let dataset = InMemDataset::new(items); // Class names to index map - let classes = classes.iter().map(|c| c.as_ref()).collect::>(); + let classes = classes + .iter() + .map(std::convert::AsRef::as_ref) + .collect::>(); let classes_map: HashMap<_, _> = classes .into_iter() .enumerate() @@ -701,12 +703,12 @@ impl ImageFolderDataset { /// Check if extension is supported. fn check_extension>(extension: &S) -> Result { let extension = extension.as_ref(); - if !SUPPORTED_FILES.contains(&extension) { + if SUPPORTED_FILES.contains(&extension) { + Ok(extension.to_string()) + } else { Err(ImageLoaderError::InvalidFileExtensionError( extension.to_string(), )) - } else { - Ok(extension.to_string()) } } } @@ -826,7 +828,7 @@ mod tests { #[test] #[should_panic] pub fn pixel_depth_try_into_u8_invalid() { - let _: u8 = PixelDepth::U16(u8::MAX as u16 + 1).try_into().unwrap(); + let _: u8 = PixelDepth::U16(u16::from(u8::MAX) + 1).try_into().unwrap(); } #[test] @@ -839,7 +841,7 @@ mod tests { #[test] #[should_panic] pub fn pixel_depth_try_into_u16_invalid() { - let _: u16 = PixelDepth::F32(u16::MAX as f32).try_into().unwrap(); + let _: u16 = PixelDepth::F32(f32::from(u16::MAX)).try_into().unwrap(); } #[test] diff --git a/crates/burn-dataset/src/vision/mnist.rs b/crates/burn-dataset/src/vision/mnist.rs index a502c21499..21cfadfcc1 100644 --- a/crates/burn-dataset/src/vision/mnist.rs +++ b/crates/burn-dataset/src/vision/mnist.rs @@ -51,7 +51,7 @@ impl Mapper for BytesToImage { for (i, pixel) in item.image_bytes.iter().enumerate() { let x = i % WIDTH; let y = i / HEIGHT; - image_array[y][x] = *pixel as f32; + image_array[y][x] = f32::from(*pixel); } MnistItem { @@ -83,11 +83,13 @@ impl Dataset for MnistDataset { impl MnistDataset { /// Creates a new train dataset. + #[must_use] pub fn train() -> Self { Self::new("train") } /// Creates a new test dataset. + #[must_use] pub fn test() -> Self { Self::new("test") } @@ -139,8 +141,8 @@ impl MnistDataset { MnistDataset::download_file(TEST_IMAGES, &split_dir); MnistDataset::download_file(TEST_LABELS, &split_dir); } - _ => panic!("Invalid split specified {}", split), - }; + _ => panic!("Invalid split specified {split}"), + } split_dir } @@ -191,7 +193,7 @@ impl MnistDataset { buf_images .chunks(WIDTH * HEIGHT) - .map(|chunk| chunk.to_vec()) + .map(<[u8]>::to_vec) .collect() } diff --git a/crates/burn-derive/Cargo.toml b/crates/burn-derive/Cargo.toml index 2aa48bbb08..8b67e204e5 100644 --- a/crates/burn-derive/Cargo.toml +++ b/crates/burn-derive/Cargo.toml @@ -10,6 +10,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-derive" version.workspace = true +[lints] +workspace = true + [lib] proc-macro = true diff --git a/crates/burn-derive/src/config/analyzer_struct.rs b/crates/burn-derive/src/config/analyzer_struct.rs index 42bfd98b0a..806b5aa1b7 100644 --- a/crates/burn-derive/src/config/analyzer_struct.rs +++ b/crates/burn-derive/src/config/analyzer_struct.rs @@ -38,15 +38,15 @@ impl ConfigStructAnalyzer { fn names(&self) -> Vec { let mut names = Vec::new(); - for field in self.fields_required.iter() { + for field in &self.fields_required { names.push(field.clone()); } - for field in self.fields_option.iter() { + for field in &self.fields_option { names.push(field.clone()); } - for (field, _) in self.fields_default.iter() { + for (field, _) in &self.fields_default { names.push(field.clone()); } @@ -56,7 +56,7 @@ impl ConfigStructAnalyzer { fn name_types(&self, names: &[FieldTypeAnalyzer]) -> Vec { let mut name_types = Vec::new(); - for field in names.iter() { + for field in names { let name = field.ident(); let ty = &field.field.ty; @@ -152,7 +152,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer { let mut body = quote! {}; let mut names = Vec::new(); - for field in self.fields_required.iter() { + for field in &self.fields_required { let name = field.ident(); let ty = &field.field.ty; @@ -164,7 +164,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer { }); } - for field in self.fields_option.iter() { + for field in &self.fields_option { let name = field.ident(); body.extend(quote! { @@ -172,7 +172,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer { }); } - for (field, attribute) in self.fields_default.iter() { + for (field, attribute) in &self.fields_default { let name = field.ident(); let value = &attribute.value; match value { @@ -188,7 +188,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer { #name: #value, }); } - }; + } } let body = quote! { @@ -205,7 +205,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer { fn gen_builder_fns(&self) -> TokenStream { let mut body = quote! {}; - for (field, _) in self.fields_default.iter() { + for (field, _) in &self.fields_default { let name = field.ident(); let doc = field.doc().unwrap_or_else(|| { quote! { @@ -224,7 +224,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer { }); } - for field in self.fields_option.iter() { + for field in &self.fields_option { let name = field.ident(); let ty = &field.field.ty; let fn_name = Ident::new(&format!("with_{name}"), name.span()); diff --git a/crates/burn-derive/src/module/codegen.rs b/crates/burn-derive/src/module/codegen.rs index 911df3390c..441baabdaf 100644 --- a/crates/burn-derive/src/module/codegen.rs +++ b/crates/burn-derive/src/module/codegen.rs @@ -44,7 +44,7 @@ pub(crate) fn generate_module_standard( let clone_fn = codegen.gen_clone(); let record = codegen.record_codegen(); - let record_name = Ident::new(format!("{}Record", name).as_str(), name.span()); + let record_name = Ident::new(format!("{name}Record").as_str(), name.span()); let record_type = record.gen_record_type(&record_name, &generics.module); let (generics_module, generics_ty_module, generics_where_module) = @@ -121,10 +121,10 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream { let mut generics_module = ast.generics.clone(); let mut generics_module_autodiff = ast.generics.clone(); - for param in backend.params.into_iter() { + for param in backend.params { generics_module.params.push(param); } - for param in backend_ad.params.into_iter() { + for param in backend_ad.params { generics_module_autodiff.params.push(param); } let (generics_module, _, _) = generics_module.split_for_impl(); diff --git a/crates/burn-derive/src/module/codegen_enum.rs b/crates/burn-derive/src/module/codegen_enum.rs index a82876e4ef..5fe2533192 100644 --- a/crates/burn-derive/src/module/codegen_enum.rs +++ b/crates/burn-derive/src/module/codegen_enum.rs @@ -178,12 +178,12 @@ impl EnumModuleCodegen { { let mut match_arms = quote! {}; - for variant in self.variants.iter() { + for variant in &self.variants { let name = &variant.ident; let arm_pattern = quote! {Self::#name(module)}; let arm_code = func(name.clone()); - match_arms.extend(quote! {#arm_pattern => #arm_code,}) + match_arms.extend(quote! {#arm_pattern => #arm_code,}); } quote! { diff --git a/crates/burn-derive/src/module/codegen_struct.rs b/crates/burn-derive/src/module/codegen_struct.rs index ed146cadfd..7206313eca 100644 --- a/crates/burn-derive/src/module/codegen_struct.rs +++ b/crates/burn-derive/src/module/codegen_struct.rs @@ -206,7 +206,7 @@ impl StructModuleCodegen { let mut body = quote! {}; let mut names = Vec::new(); - for field in self.fields.iter() { + for field in &self.fields { let name = field.ident(); names.push(name.clone()); @@ -222,7 +222,7 @@ impl StructModuleCodegen { { let mut body = quote! {}; - for field in self.fields.iter() { + for field in &self.fields { body.extend(func(field.ident())); } diff --git a/crates/burn-derive/src/module/display.rs b/crates/burn-derive/src/module/display.rs index 4c799f7159..bc8a8d0ddc 100644 --- a/crates/burn-derive/src/module/display.rs +++ b/crates/burn-derive/src/module/display.rs @@ -56,7 +56,7 @@ pub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { } syn::Fields::Unnamed(unnamed_fields) => { let field_names = (0..unnamed_fields.unnamed.len()).map(|i| { - syn::Ident::new(&format!("_{}", i), proc_macro2::Span::call_site()) + syn::Ident::new(&format!("_{i}"), proc_macro2::Span::call_site()) }); let field_prints = field_names.clone().map(|field_name| { diff --git a/crates/burn-derive/src/module/record_enum.rs b/crates/burn-derive/src/module/record_enum.rs index dda1911735..40c891bde8 100644 --- a/crates/burn-derive/src/module/record_enum.rs +++ b/crates/burn-derive/src/module/record_enum.rs @@ -17,7 +17,7 @@ impl ModuleRecordCodegen for EnumModuleRecordCodegen { let vis = &self.vis; // Capture the Record enum variant types - for variant in self.variants.iter() { + for variant in &self.variants { let ty = &variant.ty; let name = &variant.ident; diff --git a/crates/burn-derive/src/module/record_struct.rs b/crates/burn-derive/src/module/record_struct.rs index 0f4af32059..4833bde99a 100644 --- a/crates/burn-derive/src/module/record_struct.rs +++ b/crates/burn-derive/src/module/record_struct.rs @@ -16,7 +16,7 @@ impl ModuleRecordCodegen for StructModuleRecordCodegen { let mut fields = quote! {}; let vis = &self.vis; - for field in self.fields.iter() { + for field in &self.fields { let ty = &field.field.ty; let name = &field.field.ident; diff --git a/crates/burn-derive/src/record/codegen.rs b/crates/burn-derive/src/record/codegen.rs index c9f251eda2..e6f316176f 100644 --- a/crates/burn-derive/src/record/codegen.rs +++ b/crates/burn-derive/src/record/codegen.rs @@ -29,7 +29,7 @@ impl RecordCodegen { let param: syn::Generics = parse_quote! { }; let mut generics = self.ty.generics.clone(); - for param in param.params.into_iter() { + for param in param.params { generics.params.push(param); } @@ -87,7 +87,7 @@ impl RecordCodegen { fn record_item_generics(&self) -> Generics { let param: syn::Generics = parse_quote! { }; let mut generics = self.ty.generics.clone(); - for param in param.params.into_iter() { + for param in param.params { generics.params.push(param); } @@ -122,7 +122,7 @@ struct RecordType { impl RecordType { fn from_ast(ast: &syn::DeriveInput) -> Self { let name = ast.ident.clone(); - let item = Ident::new(format!("{}Item", name).as_str(), name.span()); + let item = Ident::new(format!("{name}Item").as_str(), name.span()); let has_backend = ast .generics .type_params() diff --git a/crates/burn-derive/src/record/item/codegen.rs b/crates/burn-derive/src/record/item/codegen.rs index ef2b5bc200..cc91f1ec02 100644 --- a/crates/burn-derive/src/record/item/codegen.rs +++ b/crates/burn-derive/src/record/item/codegen.rs @@ -12,7 +12,7 @@ pub(crate) trait RecordItemCodegen { generics: &Generics, has_backend: bool, ) -> TokenStream; - /// Generate the into_item function. + /// Generate the `into_item` function. fn gen_into_item(&self, item_name: &Ident) -> TokenStream; /// Generate the from item function. fn gen_from_item(&self) -> TokenStream; diff --git a/crates/burn-derive/src/record/item/codegen_enum.rs b/crates/burn-derive/src/record/item/codegen_enum.rs index 213bbcfbc7..125ff41e9a 100644 --- a/crates/burn-derive/src/record/item/codegen_enum.rs +++ b/crates/burn-derive/src/record/item/codegen_enum.rs @@ -30,7 +30,7 @@ impl RecordItemCodegen for EnumRecordItemCodegen { let vis = &self.vis; // Capture the Record enum variant types and names to transpose them in RecordItem - for variant in self.variants.iter() { + for variant in &self.variants { let ty = &variant.ty; let name = &variant.ident; @@ -47,13 +47,13 @@ impl RecordItemCodegen for EnumRecordItemCodegen { let bound = bounds.to_string(); // Capture the type's generics and bounds in where clauses - let (generics, generics_where) = if !has_backend { - let mut generics = generics.clone(); - let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; - generics.params.push(syn::GenericParam::Type(param)); + let (generics, generics_where) = if has_backend { let (generics, _, generics_where) = generics.split_for_impl(); (quote! { #generics }, quote! { #generics_where }) } else { + let mut generics = generics.clone(); + let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; + generics.params.push(syn::GenericParam::Type(param)); let (generics, _, generics_where) = generics.split_for_impl(); (quote! { #generics }, quote! { #generics_where }) }; @@ -74,7 +74,7 @@ impl RecordItemCodegen for EnumRecordItemCodegen { fn gen_into_item(&self, _item_name: &Ident) -> TokenStream { let mut into_item_match_arms = quote! {}; - for variant in self.variants.iter() { + for variant in &self.variants { let name = &variant.ident; into_item_match_arms.extend(quote! { @@ -94,7 +94,7 @@ impl RecordItemCodegen for EnumRecordItemCodegen { fn gen_from_item(&self) -> TokenStream { let mut from_item_match_arms = quote! {}; - for variant in self.variants.iter() { + for variant in &self.variants { let name = &variant.ident; from_item_match_arms.extend(quote! { diff --git a/crates/burn-derive/src/record/item/codegen_struct.rs b/crates/burn-derive/src/record/item/codegen_struct.rs index 3d93a88449..ef3a9b1533 100644 --- a/crates/burn-derive/src/record/item/codegen_struct.rs +++ b/crates/burn-derive/src/record/item/codegen_struct.rs @@ -31,7 +31,7 @@ impl RecordItemCodegen for StructRecordItemCodegen { let mut bounds = quote! {}; let vis = &self.vis; - for field in self.fields.iter() { + for field in &self.fields { let ty = &field.field.ty; let name = &field.field.ident; @@ -46,13 +46,13 @@ impl RecordItemCodegen for StructRecordItemCodegen { } let bound = bounds.to_string(); - let (generics, generics_where) = if !has_backend { - let mut generics = generics.clone(); - let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; - generics.params.push(syn::GenericParam::Type(param)); + let (generics, generics_where) = if has_backend { let (generics, _, generics_where) = generics.split_for_impl(); (quote! { #generics }, quote! { #generics_where }) } else { + let mut generics = generics.clone(); + let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; + generics.params.push(syn::GenericParam::Type(param)); let (generics, _, generics_where) = generics.split_for_impl(); (quote! { #generics }, quote! { #generics_where }) }; @@ -72,7 +72,7 @@ impl RecordItemCodegen for StructRecordItemCodegen { fn gen_into_item(&self, item_name: &Ident) -> TokenStream { let mut body_into_item = quote! {}; - for field in self.fields.iter() { + for field in &self.fields { let name = &field.field.ident; body_into_item.extend(quote! { @@ -92,7 +92,7 @@ impl RecordItemCodegen for StructRecordItemCodegen { fn gen_from_item(&self) -> TokenStream { let mut body_from_item = quote! {}; - for field in self.fields.iter() { + for field in &self.fields { let name = &field.field.ident; body_from_item.extend(quote! { diff --git a/crates/burn-derive/src/shared/enum_variant.rs b/crates/burn-derive/src/shared/enum_variant.rs index 5c059bf0f1..7d34e992aa 100644 --- a/crates/burn-derive/src/shared/enum_variant.rs +++ b/crates/burn-derive/src/shared/enum_variant.rs @@ -59,11 +59,12 @@ pub(crate) fn parse_variants(ast: &syn::DeriveInput) -> Vec { let mut variants = Vec::new(); if let syn::Data::Enum(enum_data) = &ast.data { - for variant in enum_data.variants.iter() { - if variant.fields.len() != 1 { - // No support for unit variants or variants with multiple fields - panic!("Enums are only supported for one field type") - } + for variant in &enum_data.variants { + // No support for unit variants or variants with multiple fields + assert!( + (variant.fields.len() == 1), + "Enums are only supported for one field type" + ); let field = variant.fields.iter().next().unwrap(); diff --git a/crates/burn-derive/src/shared/field.rs b/crates/burn-derive/src/shared/field.rs index 274cabcc15..210eea84b7 100644 --- a/crates/burn-derive/src/shared/field.rs +++ b/crates/burn-derive/src/shared/field.rs @@ -94,12 +94,12 @@ pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec { match &ast.data { syn::Data::Struct(struct_data) => { - for field in struct_data.fields.iter() { + for field in &struct_data.fields { fields.push(field.clone()); } } syn::Data::Enum(_) => panic!("Only struct can be derived"), syn::Data::Union(_) => panic!("Only struct can be derived"), - }; + } fields } diff --git a/crates/burn-derive/src/shared/generics.rs b/crates/burn-derive/src/shared/generics.rs index abe4a15832..4cf8cd13cc 100644 --- a/crates/burn-derive/src/shared/generics.rs +++ b/crates/burn-derive/src/shared/generics.rs @@ -43,7 +43,7 @@ impl GenericsHelper { - The default backend trait is `burn::tensor::backend::Backend`. - Any backend trait is supported."; - for param in self.generics.params.iter() { + for param in &self.generics.params { if let syn::GenericParam::Type(ty) = ¶m { if ty.ident == "B" { let bound = ty diff --git a/crates/burn-fusion/Cargo.toml b/crates/burn-fusion/Cargo.toml index 3630c2e750..3b8c071a50 100644 --- a/crates/burn-fusion/Cargo.toml +++ b/crates/burn-fusion/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-fusion" documentation = "https://docs.rs/burn-fusion" version.workspace = true +[lints] +workspace = true + [features] default = ["std"] std = ["serde/std"] diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index a75e349862..37a7d309d7 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -90,7 +90,7 @@ pub struct OptimizationProperties { /// the speed and efficiency of the computational graph. It doesn't mean that all registered /// operations should be fused, but that another way of executing them is more efficient. /// -/// Also, it is important to return (OptimizationStatus::Closed) when no more registered operation can +/// Also, it is important to return (`OptimizationStatus::Closed`) when no more registered operation can /// improve the performance. pub trait OptimizationBuilder: Send { /// Register a new [tensor operation](OperationIr). diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index c999a78362..29058139e3 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -43,7 +43,7 @@ where { self.server .lock() - .register(streams, repr, Box::new(operation)) + .register(streams, repr, Box::new(operation)); } fn drain(&self) { diff --git a/crates/burn-fusion/src/fusion.rs b/crates/burn-fusion/src/fusion.rs index 4d488e5e30..6fff924b2b 100644 --- a/crates/burn-fusion/src/fusion.rs +++ b/crates/burn-fusion/src/fusion.rs @@ -3,7 +3,7 @@ use burn_tensor::backend::{DeviceId, DeviceOps}; use crate::{Client, FusionDevice, FusionRuntime, client::FusionClient}; -use std::{any::Any, collections::HashMap, ops::DerefMut}; +use std::{any::Any, collections::HashMap}; /// Type alias for [representation backend handle](burn_ir::BackendIr::Handle). pub type Handle = ::Handle; @@ -34,19 +34,18 @@ impl FusionClientLocator { Self::register_inner::(client_id, client, &mut clients); } - match clients.deref_mut() { - Some(clients) => match clients.get(&client_id) { - Some(client) => { + match &mut *clients { + Some(clients) => { + if let Some(client) = clients.get(&client_id) { let client: &Client = client.downcast_ref().unwrap(); client.clone() - } - None => { + } else { let client = Client::::new(device.clone()); let any = Box::new(client.clone()); clients.insert(client_id, any); client } - }, + } _ => unreachable!(), } } @@ -61,9 +60,10 @@ impl FusionClientLocator { } if let Some(clients) = clients { - if clients.contains_key(&key) { - panic!("Client already created for device {:?}", key); - } + assert!( + !clients.contains_key(&key), + "Client already created for device {key:?}" + ); clients.insert(key, Box::new(client)); } diff --git a/crates/burn-fusion/src/lib.rs b/crates/burn-fusion/src/lib.rs index c43f3be98a..2547623dd7 100644 --- a/crates/burn-fusion/src/lib.rs +++ b/crates/burn-fusion/src/lib.rs @@ -20,7 +20,7 @@ mod ops; mod server; mod tensor; -pub(crate) use server::*; +pub(crate) use server::FusionServer; pub use backend::*; pub use fusion::*; diff --git a/crates/burn-fusion/src/ops/binary.rs b/crates/burn-fusion/src/ops/binary.rs index 2d9bd6b2f5..71c7a36293 100644 --- a/crates/burn-fusion/src/ops/binary.rs +++ b/crates/burn-fusion/src/ops/binary.rs @@ -17,13 +17,13 @@ pub(crate) fn check_binary_op(desc: BinaryOpIr) -> Result Result<(), BinaryOpError> { - if lhs.dtype != rhs.dtype { + if lhs.dtype == rhs.dtype { + Ok(()) + } else { Err(BinaryOpError::DTypeMismatch { lhs: lhs.dtype, rhs: rhs.dtype, }) - } else { - Ok(()) } } diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index 695c6985ce..5e4caa7122 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -310,14 +310,17 @@ impl BoolTensorOps for Fusion { let streams = tensors.iter().map(|t| t.stream).collect::>(); shape[dim] = 0; - for tensor in tensors.iter() { + for tensor in &tensors { shape[dim] += tensor.shape[dim]; } let out = client.tensor_uninitialized(shape, B::BoolElem::dtype()); let desc = CatOpIr { - tensors: tensors.into_iter().map(|t| t.into_ir()).collect(), + tensors: tensors + .into_iter() + .map(super::super::tensor::FusionTensor::into_ir) + .collect(), dim, out: out.to_ir_out(), }; diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index ebe9bad7e3..42e8ccdd97 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -7,7 +7,13 @@ use crate::{ stream::{StreamId, execution::Operation}, unary_float_ops, }; -use burn_ir::*; +use burn_ir::{ + BaseOperationIr, BinaryOpIr, CatOpIr, ClampOpIr, ExpandOpIr, FlipOpIr, FloatOperationIr, + GatherOpIr, HandleContainer, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, NumericOperationIr, + OperationIr, PermuteOpIr, RandomOpIr, ReduceDimOpIr, ReduceDimWithIndicesOpIr, RepeatDimOpIr, + ScalarOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, SliceAssignOpIr, SliceOpIr, + SwapDimsOpIr, TensorIr, UnaryOpIr, +}; use burn_tensor::{ Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor, binary_ops_shape}, @@ -1632,14 +1638,17 @@ impl FloatTensorOps for Fusion { let streams = tensors.iter().map(|tensor| tensor.stream).collect(); let mut shape: Vec = tensor_first.shape.clone(); shape[dim] = 0; - for tensor in tensors.iter() { + for tensor in &tensors { shape[dim] += tensor.shape[dim]; } let out = client.tensor_uninitialized(shape, dtype); let desc = CatOpIr { - tensors: tensors.into_iter().map(|t| t.into_ir()).collect(), + tensors: tensors + .into_iter() + .map(super::super::tensor::FusionTensor::into_ir) + .collect(), dim, out: out.to_ir_out(), }; diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index f7640b746e..76fc99f4e7 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -5,7 +5,13 @@ use crate::{ stream::{StreamId, execution::Operation}, unary_int_ops, }; -use burn_ir::*; +use burn_ir::{ + BaseOperationIr, BinaryOpIr, CatOpIr, ClampOpIr, ExpandOpIr, FlipOpIr, GatherOpIr, + HandleContainer, InitOperationIr, IntOperationIr, MaskFillOpIr, MaskWhereOpIr, + NumericOperationIr, OperationIr, PermuteOpIr, RandomOpIr, ReduceDimOpIr, + ReduceDimWithIndicesOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, + SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, +}; use burn_tensor::{ Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps, binary_ops_shape}, @@ -525,14 +531,17 @@ impl IntTensorOps for Fusion { let streams = tensors.iter().map(|tensor| tensor.stream).collect(); let mut shape: Vec = tensor_first.shape.clone(); shape[dim] = 0; - for tensor in tensors.iter() { + for tensor in &tensors { shape[dim] += tensor.shape[dim]; } let out = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = CatOpIr { - tensors: tensors.into_iter().map(|t| t.into_ir()).collect(), + tensors: tensors + .into_iter() + .map(super::super::tensor::FusionTensor::into_ir) + .collect(), dim, out: out.to_ir_out(), }; diff --git a/crates/burn-fusion/src/ops/module.rs b/crates/burn-fusion/src/ops/module.rs index c536fcfab8..1a9014ba44 100644 --- a/crates/burn-fusion/src/ops/module.rs +++ b/crates/burn-fusion/src/ops/module.rs @@ -1,5 +1,13 @@ use crate::{Fusion, FusionBackend, client::FusionClient, stream::execution::Operation}; -use burn_ir::*; +use burn_ir::{ + AdaptiveAvgPool1dBackwardOpIr, AdaptiveAvgPool1dOpIr, AdaptiveAvgPool2dBackwardOpIr, + AdaptiveAvgPool2dOpIr, AvgPool1dBackwardOpIr, AvgPool1dOpIr, AvgPool2dBackwardOpIr, + AvgPool2dOpIr, Conv1dOpIr, Conv2dOpIr, Conv3dOpIr, ConvTranspose1dOpIr, ConvTranspose2dOpIr, + ConvTranspose3dOpIr, DeformConv2dBackwardOpIr, DeformConv2dOpIr, HandleContainer, + InterpolateBackwardOpIr, InterpolateOpIr, MaxPool1dOpIr, MaxPool1dWithIndicesBackwardOpIr, + MaxPool1dWithIndicesOpIr, MaxPool2dOpIr, MaxPool2dWithIndicesBackwardOpIr, + MaxPool2dWithIndicesOpIr, ModuleOperationIr, OperationIr, +}; use burn_tensor::{ Element, ops::{ @@ -69,7 +77,7 @@ impl ModuleOps> for Fusion { let description = Conv1dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::FusionTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -133,7 +141,7 @@ impl ModuleOps> for Fusion { let desc = Conv2dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::FusionTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -209,8 +217,8 @@ impl ModuleOps> for Fusion { x: x.into_ir(), offset: offset.into_ir(), weight: weight.into_ir(), - mask: mask.map(|mask| mask.into_ir()), - bias: bias.map(|bias| bias.into_ir()), + mask: mask.map(super::super::tensor::FusionTensor::into_ir), + bias: bias.map(super::super::tensor::FusionTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -315,15 +323,19 @@ impl ModuleOps> for Fusion { x: x.into_ir(), offset: offset.into_ir(), weight: weight.into_ir(), - mask: mask.map(|mask| mask.into_ir()), - bias: bias.map(|bias| bias.into_ir()), + mask: mask.map(super::super::tensor::FusionTensor::into_ir), + bias: bias.map(super::super::tensor::FusionTensor::into_ir), options: options.into(), out_grad: output_grad.into_ir(), input_grad: input_grad.to_ir_out(), offset_grad: offset_grad.to_ir_out(), weight_grad: weight_grad.to_ir_out(), - mask_grad: mask_grad.as_ref().map(|mask_grad| mask_grad.to_ir_out()), - bias_grad: bias_grad.as_ref().map(|bias_grad| bias_grad.to_ir_out()), + mask_grad: mask_grad + .as_ref() + .map(super::super::tensor::FusionTensor::to_ir_out), + bias_grad: bias_grad + .as_ref() + .map(super::super::tensor::FusionTensor::to_ir_out), }; let streams = match (stream_4, stream_5) { @@ -403,7 +415,7 @@ impl ModuleOps> for Fusion { let desc = Conv3dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::FusionTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -462,7 +474,7 @@ impl ModuleOps> for Fusion { let desc = ConvTranspose1dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::FusionTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -529,7 +541,7 @@ impl ModuleOps> for Fusion { let desc = ConvTranspose2dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::FusionTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -610,7 +622,7 @@ impl ModuleOps> for Fusion { let desc = ConvTranspose3dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::FusionTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 62f163fc47..c2e001e1ed 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -85,7 +85,10 @@ impl QTensorOps for Fusion { tensor: tensor.into_ir(), qparams: QuantizationParametersIr { scale: qparams.scale.clone().into_ir(), - offset: qparams.offset.clone().map(|x| x.into_ir()), + offset: qparams + .offset + .clone() + .map(super::super::tensor::FusionTensor::into_ir), }, scheme: *scheme, out: out.to_ir_out(), diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index 1e912bc782..fc59b0540d 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -29,11 +29,11 @@ where operation: Box>, ) { self.streams - .register(streams, repr, operation, &mut self.handles) + .register(streams, repr, operation, &mut self.handles); } pub fn drain_stream(&mut self, id: StreamId) { - self.streams.drain(&mut self.handles, id) + self.streams.drain(&mut self.handles, id); } pub fn create_empty_handle(&mut self) -> Arc { diff --git a/crates/burn-fusion/src/stream/base.rs b/crates/burn-fusion/src/stream/base.rs index 13a9bf1e1e..360f13b8de 100644 --- a/crates/burn-fusion/src/stream/base.rs +++ b/crates/burn-fusion/src/stream/base.rs @@ -30,6 +30,7 @@ impl Default for OperationQueue { impl OperationQueue { /// Create a new empty queue. + #[must_use] pub fn new() -> Self { Self { global: Vec::new(), @@ -56,11 +57,13 @@ impl OperationQueue { } /// The size of the queue. + #[must_use] pub fn len(&self) -> usize { self.global.len() } /// If the queue is empty. + #[must_use] pub fn is_empty(&self) -> bool { self.len() == 0 } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index e32cb2a170..dbd9442afb 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -1,4 +1,18 @@ -use burn_ir::*; +use burn_ir::{ + AdaptiveAvgPool1dBackwardOpIr, AdaptiveAvgPool1dOpIr, AdaptiveAvgPool2dBackwardOpIr, + AdaptiveAvgPool2dOpIr, AvgPool1dBackwardOpIr, AvgPool1dOpIr, AvgPool2dBackwardOpIr, + AvgPool2dOpIr, BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ClampOpIr, Conv1dOpIr, + Conv2dOpIr, Conv3dOpIr, ConvTranspose1dOpIr, ConvTranspose2dOpIr, ConvTranspose3dOpIr, + CustomOpIr, DeformConv2dBackwardOpIr, DeformConv2dOpIr, DequantizeOpIr, EmbeddingBackwardOpIr, + EmbeddingOpIr, ExpandOpIr, FlipOpIr, FloatOperationIr, GatherOpIr, HandleContainer, + InitOperationIr, IntOperationIr, InterpolateBackwardOpIr, InterpolateOpIr, MaskFillOpIr, + MaskWhereOpIr, MaxPool1dOpIr, MaxPool1dWithIndicesBackwardOpIr, MaxPool1dWithIndicesOpIr, + MaxPool2dOpIr, MaxPool2dWithIndicesBackwardOpIr, MaxPool2dWithIndicesOpIr, ModuleOperationIr, + NumericOperationIr, OperationIr, PermuteOpIr, QuantizationParametersIr, QuantizeOpIr, + RandomOpIr, ReduceDimOpIr, ReduceDimWithIndicesOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, + SelectAssignOpIr, SelectOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorId, TensorIr, + UnaryOpIr, +}; use burn_tensor::{DType, Element, ElementConversion}; use half::{bf16, f16}; use hashbrown::HashMap; @@ -121,6 +135,7 @@ impl ContextOwned { } /// Fork the context again. + #[must_use] pub fn fork(&self) -> ContextOwned { ContextOwned { tensors: self.tensors.clone(), @@ -142,6 +157,7 @@ impl ContextOwned { impl Context<'_, H> { /// Fork the context into an [owned context](ContextOwned). + #[must_use] pub fn fork(&self) -> ContextOwned { ContextOwned { tensors: self.tensors.clone(), @@ -1094,7 +1110,7 @@ impl RelativeOps for TensorIr { // We can create relative shapes by mapping each shape found to an ID, which is a `usize`. let mut relative_shape = Vec::with_capacity(self.shape.len()); - for dim in self.shape.iter() { + for dim in &self.shape { if let Some(dim_id) = converter.shapes_global2relative.get(dim) { // We already saw that dim value before, so we retrieve its ID. relative_shape.push(*dim_id); diff --git a/crates/burn-fusion/src/stream/execution/base.rs b/crates/burn-fusion/src/stream/execution/base.rs index b30420fa12..dc40da1908 100644 --- a/crates/burn-fusion/src/stream/execution/base.rs +++ b/crates/burn-fusion/src/stream/execution/base.rs @@ -31,10 +31,10 @@ impl OperationQueue { ) { match &mut store.get_mut_unchecked(id).strategy { ExecutionStrategy::Optimization(optimization) => { - self.execute_optimization(handles, optimization) + self.execute_optimization(handles, optimization); } ExecutionStrategy::Operations => self.execute_operations(handles), - }; + } } /// Execute the optimization (fused operations) and remove all the corresponding @@ -80,7 +80,7 @@ impl OperationQueue { self.relative.clear(); self.converter.clear(); - for node in self.global.iter() { + for node in &self.global { let relative = node.to_relative(&mut self.converter); self.relative.push(relative); } diff --git a/crates/burn-fusion/src/stream/execution/explorer.rs b/crates/burn-fusion/src/stream/execution/explorer.rs index 44d11a45d0..8bc370138b 100644 --- a/crates/burn-fusion/src/stream/execution/explorer.rs +++ b/crates/burn-fusion/src/stream/execution/explorer.rs @@ -69,7 +69,7 @@ impl Explorer { /// Reset the state of the explorer to the provided list of operations. pub(crate) fn reset(&mut self, operations: &[OperationIr]) { - for operation in self.builders.iter_mut() { + for operation in &mut self.builders { operation.reset(); } self.num_explored = 0; @@ -86,7 +86,7 @@ impl Explorer { let index = operations.len() - 1 - i; let relative = &operations[index]; - for builder in self.builders.iter_mut() { + for builder in &mut self.builders { builder.register(relative); } self.num_explored += 1; @@ -103,9 +103,9 @@ impl Explorer { fn still_optimizing(optimizations: &[Box>]) -> bool { let mut num_stopped = 0; - for optimization in optimizations.iter() { + for optimization in optimizations { if let OptimizationStatus::Closed = optimization.status() { - num_stopped += 1 + num_stopped += 1; } } diff --git a/crates/burn-fusion/src/stream/execution/policy.rs b/crates/burn-fusion/src/stream/execution/policy.rs index 7e74fd0b88..44ecf91fe1 100644 --- a/crates/burn-fusion/src/stream/execution/policy.rs +++ b/crates/burn-fusion/src/stream/execution/policy.rs @@ -75,11 +75,10 @@ impl Policy { operations: &[OperationIr], mode: ExecutionMode, ) -> Action { - if self.num_operations < operations.len() { - panic!( - "Internal Error: Can't retrieve the policy action on a list of operations bigger than what is analyzed." - ); - } + assert!( + (self.num_operations >= operations.len()), + "Internal Error: Can't retrieve the policy action on a list of operations bigger than what is analyzed." + ); if let Some((id, _length)) = self.found { return Action::Execute(id); @@ -124,7 +123,7 @@ impl Policy { fn check_candidates(&mut self, store: &ExecutionPlanStore) { let mut candidates_to_remove = Vec::new(); - for candidate in self.candidates.iter() { + for candidate in &self.candidates { match candidate.state { ValidatorState::Found { size } => { let item = store.get_unchecked(candidate.id); @@ -149,7 +148,7 @@ impl Policy { candidates_to_remove.push(candidate.id); } ValidatorState::Validating => {} - }; + } } let mut updated_candidates = Vec::new(); @@ -162,8 +161,8 @@ impl Policy { } fn check_availables(&mut self) { - for available in self.availables.iter() { - for trigger in available.triggers.iter() { + for available in &self.availables { + for trigger in &available.triggers { match trigger { TriggerValidator::OnOperations { matching, @@ -222,12 +221,12 @@ impl Policy { return Action::Defer; } - for available in self.availables.iter() { + for available in &self.availables { if available.size == operations.len() { return Action::Defer; } - for trigger in available.triggers.iter() { + for trigger in &available.triggers { if let TriggerValidator::OnOperations { matching, progress: _, @@ -244,13 +243,13 @@ impl Policy { } fn action_sync(&self, operations: &[OperationIr], store: &ExecutionPlanStore) -> Action { - for available in self.availables.iter() { + for available in &self.availables { if available.size == operations.len() { return Action::Execute(available.id); } } - for candidate in self.candidates.iter() { + for candidate in &self.candidates { let item = store.get_unchecked(candidate.id); if item.operations.len() == operations.len() { diff --git a/crates/burn-fusion/src/stream/execution/processor.rs b/crates/burn-fusion/src/stream/execution/processor.rs index c4bec71e83..2e82c85176 100644 --- a/crates/burn-fusion/src/stream/execution/processor.rs +++ b/crates/burn-fusion/src/stream/execution/processor.rs @@ -71,7 +71,7 @@ impl Processor { segment.execute(id, store); self.reset(store, segment.operations()); } - }; + } } } @@ -132,7 +132,7 @@ impl Processor { self.policy.reset(); // Reset the policy state with the remaining operations - for operation in operations.iter() { + for operation in operations { self.policy.update(store, operation); } } diff --git a/crates/burn-fusion/src/stream/execution/tests.rs b/crates/burn-fusion/src/stream/execution/tests.rs index fcbddb8ab1..ba1248d5cd 100644 --- a/crates/burn-fusion/src/stream/execution/tests.rs +++ b/crates/burn-fusion/src/stream/execution/tests.rs @@ -452,11 +452,10 @@ impl OptimizationBuilder for TestOptimizationBuilder { if self.actual.len() < self.expected_operations.len() { let operations = &self.expected_operations[0..self.actual.len()]; - return match self.actual == operations { - // Still optimizing. - true => OptimizationStatus::Open, - // Never gonna be possible on that stream. - false => OptimizationStatus::Closed, + return if self.actual == operations { + OptimizationStatus::Open + } else { + OptimizationStatus::Closed }; } @@ -512,7 +511,7 @@ impl StreamSegment for TestSegment<'_> { self.operations.drain(0..optimization.size); } ExecutionStrategy::Operations => self.operations.clear(), - }; + } self.executed.push(id); } diff --git a/crates/burn-fusion/src/stream/execution/validator.rs b/crates/burn-fusion/src/stream/execution/validator.rs index 88f9d30dcd..09be507804 100644 --- a/crates/burn-fusion/src/stream/execution/validator.rs +++ b/crates/burn-fusion/src/stream/execution/validator.rs @@ -54,15 +54,14 @@ impl OperationsValidator { ValidatorState::Found { size: _ } => return, ValidatorState::Invalidated => return, ValidatorState::Validating => {} - }; + } let item = store.get(self.id); - let operation_candidate = match item.get(added_position) { - Some(val) => val, - None => { - self.state = ValidatorState::Invalidated; - return; - } + let operation_candidate = if let Some(val) = item.get(added_position) { + val + } else { + self.state = ValidatorState::Invalidated; + return; }; if operation_candidate != added { diff --git a/crates/burn-fusion/src/stream/multi.rs b/crates/burn-fusion/src/stream/multi.rs index 92e877cc74..205a5f4277 100644 --- a/crates/burn-fusion/src/stream/multi.rs +++ b/crates/burn-fusion/src/stream/multi.rs @@ -34,15 +34,14 @@ impl MultiStream { ) { let id = self.resolve_streams(streams, handles, &repr); - let stream = match self.streams.get_mut(&id) { - Some(stream) => stream, - None => { - let stream = Stream::new(self.device.clone()); - self.streams.insert(id, stream); - self.streams - .get_mut(&id) - .expect("Just added, so should be included in the hashmap.") - } + let stream = if let Some(stream) = self.streams.get_mut(&id) { + stream + } else { + let stream = Stream::new(self.device.clone()); + self.streams.insert(id, stream); + self.streams + .get_mut(&id) + .expect("Just added, so should be included in the hashmap.") }; stream.queue.add(repr, operation); @@ -159,7 +158,7 @@ impl StreamSegment for Segment<'_, R> { } fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore) { - self.queue.execute(id, self.handles, store) + self.queue.execute(id, self.handles, store); } } diff --git a/crates/burn-fusion/src/stream/store/base.rs b/crates/burn-fusion/src/stream/store/base.rs index 8b9534b12b..56db9ad1b9 100644 --- a/crates/burn-fusion/src/stream/store/base.rs +++ b/crates/burn-fusion/src/stream/store/base.rs @@ -53,9 +53,10 @@ impl ExecutionPlanStore { } pub fn add(&mut self, exploration: ExecutionPlan) -> ExecutionPlanId { - if exploration.operations.is_empty() { - panic!("Can't add an empty optimization."); - } + assert!( + !exploration.operations.is_empty(), + "Can't add an empty optimization." + ); let id = self.plans.len(); log::trace!( diff --git a/crates/burn-fusion/src/stream/store/index.rs b/crates/burn-fusion/src/stream/store/index.rs index 3dce4f9010..00f5704308 100644 --- a/crates/burn-fusion/src/stream/store/index.rs +++ b/crates/burn-fusion/src/stream/store/index.rs @@ -49,7 +49,7 @@ impl ExecutionPlanIndex { match query { InsertQuery::NewPlan { operations, id } => { if let Some(operation) = operations.first() { - self.insert_new_operation(operation, id) + self.insert_new_operation(operation, id); } } } @@ -81,26 +81,24 @@ impl ExecutionPlanIndex { /// Update the index for an execution plan starting with operation `ops` fn insert_new_operation(&mut self, ops: &OperationIr, new_id: ExecutionPlanId) { let key = self.operation_key(ops); - let values = match self.mapping.get_mut(&key) { - Some(val) => val, - None => { - // New starter ops. - let index = self.starters.len(); - self.starters.push(vec![new_id]); - self.mapping.insert(key, vec![(ops.clone(), index)]); - - return; - } + let values = if let Some(val) = self.mapping.get_mut(&key) { + val + } else { + // New starter ops. + let index = self.starters.len(); + self.starters.push(vec![new_id]); + self.mapping.insert(key, vec![(ops.clone(), index)]); + + return; }; - let (_, index) = match values.iter_mut().find(|value| &value.0 == ops) { - Some(val) => val, - None => { - // New with hash collision. - let index = self.starters.len(); - self.starters.push(vec![new_id]); - values.push((ops.clone(), index)); - return; - } + let (_, index) = if let Some(val) = values.iter_mut().find(|value| &value.0 == ops) { + val + } else { + // New with hash collision. + let index = self.starters.len(); + self.starters.push(vec![new_id]); + values.push((ops.clone(), index)); + return; }; // New optimization for an existing starter. diff --git a/crates/burn-import/Cargo.toml b/crates/burn-import/Cargo.toml index c860a2082d..b037d8a37e 100644 --- a/crates/burn-import/Cargo.toml +++ b/crates/burn-import/Cargo.toml @@ -14,6 +14,9 @@ version.workspace = true default-run = "onnx2burn" +[lints] +workspace = true + [features] default = ["onnx", "pytorch", "safetensors"] onnx = ["burn-ndarray"] diff --git a/crates/burn-ir/Cargo.toml b/crates/burn-ir/Cargo.toml index 8bd6a2b385..8c5a2f0923 100644 --- a/crates/burn-ir/Cargo.toml +++ b/crates/burn-ir/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-ir" documentation = "https://docs.rs/burn-ir" version.workspace = true +[lints] +workspace = true + [features] default = ["std"] std = ["burn-tensor/std"] diff --git a/crates/burn-ir/src/handle.rs b/crates/burn-ir/src/handle.rs index 56be1ac95f..1f7b6a4e23 100644 --- a/crates/burn-ir/src/handle.rs +++ b/crates/burn-ir/src/handle.rs @@ -22,10 +22,11 @@ pub struct HandleContainer { impl HandleContainer { /// Fork the container, useful for autotune. + #[must_use] pub fn fork(&self) -> Self { let mut handles = HashMap::with_capacity(self.handles.len()); - for (id, handle) in self.handles.iter() { + for (id, handle) in &self.handles { handles.insert(*id, handle.clone()); } @@ -57,7 +58,8 @@ pub enum Handle { } impl HandleContainer { - /// Create a new HandleContainer + /// Create a new `HandleContainer` + #[must_use] pub fn new() -> Self { Self { handles: HashMap::new(), @@ -87,7 +89,7 @@ impl HandleContainer { let (id, handle) = self .handles .remove_entry(id) - .unwrap_or_else(|| panic!("Should have handle for tensor {:?}", id)); + .unwrap_or_else(|| panic!("Should have handle for tensor {id:?}")); match handle { Handle::Existing(handle) => match status { @@ -97,11 +99,10 @@ impl HandleContainer { } TensorStatus::ReadWrite => handle, TensorStatus::NotInit => panic!( - "Cannot get uninitialized tensor {:?}. Tensor exist but with wrong status", - id + "Cannot get uninitialized tensor {id:?}. Tensor exist but with wrong status" ), }, - Handle::NotInit => panic!("Cannot get uninitialized handle {:?}.", id), + Handle::NotInit => panic!("Cannot get uninitialized handle {id:?}."), } } diff --git a/crates/burn-ir/src/operation.rs b/crates/burn-ir/src/operation.rs index 8859496cdb..786a4f596a 100644 --- a/crates/burn-ir/src/operation.rs +++ b/crates/burn-ir/src/operation.rs @@ -29,6 +29,7 @@ pub struct CustomOpIr { impl CustomOpIr { /// Create a new custom operation intermediate representation. + #[must_use] pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self { Self { id: id.to_owned(), @@ -38,6 +39,7 @@ impl CustomOpIr { } /// Cast the intermediate representation, and get the in and output tensors. + #[must_use] pub fn as_fixed( &self, ) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) { @@ -94,7 +96,7 @@ pub enum FloatOperationIr { Log1p(UnaryOpIr), /// Operation corresponding to [erf](burn_tensor::ops::FloatTensorOps::float_erf). Erf(UnaryOpIr), - /// Operation corresponding to [powf_scalar](burn_tensor::ops::FloatTensorOps::float_powf_scalar). + /// Operation corresponding to [`powf_scalar`](burn_tensor::ops::FloatTensorOps::float_powf_scalar). PowfScalar(ScalarOpIr), /// Operation corresponding to [sqrt](burn_tensor::ops::FloatTensorOps::float_sqrt). Sqrt(UnaryOpIr), @@ -110,7 +112,7 @@ pub enum FloatOperationIr { Floor(UnaryOpIr), /// Operation corresponding to [ceil](burn_tensor::ops::FloatTensorOps::float_ceil). Ceil(UnaryOpIr), - /// Operation corresponding to [into_int](burn_tensor::ops::FloatTensorOps::float_into_int). + /// Operation corresponding to [`into_int`](burn_tensor::ops::FloatTensorOps::float_into_int). IntoInt(UnaryOpIr), /// Operation corresponding to [matmul](burn_tensor::ops::FloatTensorOps::float_matmul). Matmul(BinaryOpIr), @@ -129,7 +131,7 @@ pub enum FloatOperationIr { pub enum ModuleOperationIr { /// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding). Embedding(EmbeddingOpIr), - /// Operation corresponding to [embedding_backward](burn_tensor::ops::ModuleOps::embedding_backward). + /// Operation corresponding to [`embedding_backward`](burn_tensor::ops::ModuleOps::embedding_backward). EmbeddingBackward(EmbeddingBackwardOpIr), /// Operation corresponding to [conv1d](burn_tensor::ops::ModuleOps::conv1d). Conv1d(Conv1dOpIr), @@ -137,9 +139,9 @@ pub enum ModuleOperationIr { Conv2d(Conv2dOpIr), /// Operation corresponding to [conv3d](burn_tensor::ops::ModuleOps::conv3d). Conv3d(Conv3dOpIr), - /// Operation corresponding to [deform_conv2d](burn_tensor::ops::ModuleOps::deform_conv2d) + /// Operation corresponding to [`deform_conv2d`](burn_tensor::ops::ModuleOps::deform_conv2d) DeformableConv2d(Box), - /// Operation corresponding to [deform_conv2d_backward](burn_tensor::ops::ModuleOps::deform_conv2d_backward) + /// Operation corresponding to [`deform_conv2d_backward`](burn_tensor::ops::ModuleOps::deform_conv2d_backward) DeformableConv2dBackward(Box), /// Operation corresponding to [conv transpose 1d](burn_tensor::ops::ModuleOps::conv_transpose1d). ConvTranspose1d(ConvTranspose1dOpIr), @@ -211,9 +213,9 @@ pub enum BaseOperationIr { /// Operation corresponding to: /// - /// Float => [swap_dims](burn_tensor::ops::FloatTensorOps::float_swap_dims). - /// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims). - /// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims). + /// Float => [`swap_dims`](burn_tensor::ops::FloatTensorOps::float_swap_dims). + /// Int => [`swap_dims`](burn_tensor::ops::IntTensorOps::int_swap_dims). + /// Bool => [`swap_dims`](burn_tensor::ops::BoolTensorOps::bool_swap_dims). SwapDims(SwapDimsOpIr), /// Operation corresponding to: @@ -500,13 +502,13 @@ pub enum NumericOperationIr { MinDim(ReduceDimOpIr), /// Operation corresponding to: /// - /// Float => [max_abs](burn_tensor::ops::FloatTensorOps::float_max_abs). - /// Int => [max_abs](burn_tensor::ops::IntTensorOps::int_max_abs). + /// Float => [`max_abs`](burn_tensor::ops::FloatTensorOps::float_max_abs). + /// Int => [`max_abs`](burn_tensor::ops::IntTensorOps::int_max_abs). MaxAbs(UnaryOpIr), /// Operation corresponding to: /// - /// Float => [max_abs dim](burn_tensor::ops::FloatTensorOps::float_max_abs_dim). - /// Int => [max_abs dim](burn_tensor::ops::IntTensorOps::int_max_abs_dim). + /// Float => [`max_abs` dim](burn_tensor::ops::FloatTensorOps::float_max_abs_dim). + /// Int => [`max_abs` dim](burn_tensor::ops::IntTensorOps::int_max_abs_dim). MaxAbsDim(ReduceDimOpIr), /// Operation corresponding to: /// @@ -1372,6 +1374,7 @@ pub struct InterpolateBackwardOpIr { impl OperationIr { /// Cleanup the remaining tensor handles that have not been used. + #[must_use] pub fn nodes(&self) -> Vec<&TensorIr> { match self { OperationIr::BaseFloat(repr) => repr.nodes(), diff --git a/crates/burn-ir/src/tensor.rs b/crates/burn-ir/src/tensor.rs index 966d7928ed..1b67d22bb1 100644 --- a/crates/burn-ir/src/tensor.rs +++ b/crates/burn-ir/src/tensor.rs @@ -27,10 +27,10 @@ pub enum TensorStatus { /// /// A tensor that is used multiple times has its status updated for each operation. /// -/// 1. Status::NotInit -/// 2. Status::ReadOnly -/// 3. Status::ReadOnly -/// 4. Status::ReadWrite +/// 1. `Status::NotInit` +/// 2. `Status::ReadOnly` +/// 3. `Status::ReadOnly` +/// 4. `Status::ReadWrite` #[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct TensorIr { /// The [tensor id](TensorId). @@ -45,6 +45,7 @@ pub struct TensorIr { impl TensorId { /// Create a new tensor id. + #[must_use] pub fn new(value: u64) -> Self { Self { value } } diff --git a/crates/burn-ndarray/Cargo.toml b/crates/burn-ndarray/Cargo.toml index 3b93e30aee..eceff8a210 100644 --- a/crates/burn-ndarray/Cargo.toml +++ b/crates/burn-ndarray/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-ndarray" version.workspace = true +[lints] +workspace = true + [features] default = ["std", "simd"] doc = ["default"] diff --git a/crates/burn-ndarray/src/lib.rs b/crates/burn-ndarray/src/lib.rs index a0caac2e8c..ac536288b4 100644 --- a/crates/burn-ndarray/src/lib.rs +++ b/crates/burn-ndarray/src/lib.rs @@ -22,7 +22,7 @@ mod tensor; pub use backend::*; pub use element::*; -pub(crate) use sharing::*; +pub(crate) use sharing::UnsafeSharedRef; pub use tensor::*; extern crate alloc; diff --git a/crates/burn-ndarray/src/ops/adaptive_avgpool.rs b/crates/burn-ndarray/src/ops/adaptive_avgpool.rs index 09beed5348..6a86ec2b95 100644 --- a/crates/burn-ndarray/src/ops/adaptive_avgpool.rs +++ b/crates/burn-ndarray/src/ops/adaptive_avgpool.rs @@ -45,7 +45,7 @@ pub(crate) fn adaptive_avg_pool2d( output[[b, c, h, w]] = sum_val / count.elem(); } } - }) + }); }); NdArrayTensor::new(output.into_dyn().into_shared()) @@ -86,7 +86,7 @@ pub(crate) fn adaptive_avg_pool2d_backward( } } } - }) + }); }); NdArrayTensor::new(output_grad.into_dyn().into_shared()) diff --git a/crates/burn-ndarray/src/ops/avgpool.rs b/crates/burn-ndarray/src/ops/avgpool.rs index 1b10012b05..12ab287b3e 100644 --- a/crates/burn-ndarray/src/ops/avgpool.rs +++ b/crates/burn-ndarray/src/ops/avgpool.rs @@ -64,7 +64,7 @@ pub(crate) fn avg_pool2d( output[[b, c, oh, ow]] = sum_val / count; } } - }) + }); }); NdArrayTensor::new(output.into_dyn().into_shared()) @@ -110,9 +110,10 @@ pub(crate) fn avg_pool2d_backward( let ih_end = usize::min(ih_end, x_height + padding_height); let iw_end = usize::min(iw_end, x_width + padding_width); - let count = match count_include_pad { - true => kernel_width * kernel_height, - false => (ih_end - ih_start) * (iw_end - iw_start), + let count = if count_include_pad { + kernel_width * kernel_height + } else { + (ih_end - ih_start) * (iw_end - iw_start) }; for ih in ih_start..ih_end { @@ -126,7 +127,7 @@ pub(crate) fn avg_pool2d_backward( } } } - }) + }); }); NdArrayTensor::new(output_grad.into_dyn().into_shared()) diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index 6988ed8a57..5fd2f9d6cd 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -573,13 +573,13 @@ where ); let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices); - if shape_value != shape_indices { - panic!( - "Invalid dimension: the shape of the index tensor should be the same as the value \ + assert!( + !(shape_value != shape_indices), + "Invalid dimension: the shape of the index tensor should be the same as the value \ tensor: Index {:?} value {:?}", - shape_indices.dims, shape_value.dims - ); - } + shape_indices.dims, + shape_value.dims + ); let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array; let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array; @@ -641,13 +641,13 @@ where let mut batch_size = 1; for i in 0..ndims - 1 { - if shape_tensor.dims[i] != shape_indices.dims[i] { - panic!( - "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \ + assert!( + (shape_tensor.dims[i] == shape_indices.dims[i]), + "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \ {:?}", - shape_tensor.dims, shape_indices.dims - ); - } + shape_tensor.dims, + shape_indices.dims + ); batch_size *= shape_indices.dims[i]; } @@ -714,10 +714,7 @@ where f64 ); - tensor.array.mapv_inplace(|x| match x < min { - true => min, - false => x, - }); + tensor.array.mapv_inplace(|x| if x < min { min } else { x }); tensor } @@ -740,10 +737,7 @@ where f64 ); - tensor.array.mapv_inplace(|x| match x > max { - true => max, - false => x, - }); + tensor.array.mapv_inplace(|x| if x > max { max } else { x }); tensor } @@ -766,12 +760,14 @@ where f64 ); - tensor.array.mapv_inplace(|x| match x < min { - true => min, - false => match x > max { - true => max, - false => x, - }, + tensor.array.mapv_inplace(|x| { + if x < min { + min + } else if x > max { + max + } else { + x + } }); tensor @@ -823,7 +819,10 @@ where pub(crate) fn abs(tensor: NdArrayTensor) -> NdArrayTensor { let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64); - let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared(); + let array = tensor + .array + .mapv_into(super::super::element::ExpElement::abs_elem) + .into_shared(); NdArrayTensor::new(array) } diff --git a/crates/burn-ndarray/src/ops/bool_tensor.rs b/crates/burn-ndarray/src/ops/bool_tensor.rs index 397f54486a..cd7f6600af 100644 --- a/crates/burn-ndarray/src/ops/bool_tensor.rs +++ b/crates/burn-ndarray/src/ops/bool_tensor.rs @@ -91,7 +91,7 @@ impl BoolTensorOp fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor { new_tensor_float!(NdArrayTensor { - array: tensor.array.mapv(|a| (a as i32).elem()).into_shared(), + array: tensor.array.mapv(|a| i32::from(a).elem()).into_shared(), }) } diff --git a/crates/burn-ndarray/src/ops/deform_conv.rs b/crates/burn-ndarray/src/ops/deform_conv.rs index 9bb5ef7abe..674da6ab76 100644 --- a/crates/burn-ndarray/src/ops/deform_conv.rs +++ b/crates/burn-ndarray/src/ops/deform_conv.rs @@ -36,9 +36,7 @@ fn deform_im2col_kernel( for kernel_y in 0..kernel_h { for kernel_x in 0..kernel_w { - let mask_value = mask - .map(|it| it[[kernel_y, kernel_x]]) - .unwrap_or_else(|| F::from_elem(1.0)); + let mask_value = mask.map_or_else(|| F::from_elem(1.0), |it| it[[kernel_y, kernel_x]]); let offset = offset.slice(s![kernel_y, kernel_x, ..]); let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) @@ -158,7 +156,7 @@ pub(crate) fn deform_conv2d( let columns = deform_im2col( input.view(), offset.view(), - mask.as_ref().map(|it| it.view()), + mask.as_ref().map(ndarray::ArrayBase::view), args, out_dims, (kernel_h, kernel_w), @@ -260,7 +258,11 @@ pub mod backward { use atomic_float::AtomicF32; use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4}; - use super::*; + use super::{ + Array4, ArrayView2, ArrayView3, Axis, DeformConvOptions, FloatNdArrayElement, + NdArrayTensor, TensorMetadata, Zip, bilinear_interpolate, deform_im2col, iter_par, matmul, + run_par, s, + }; pub(crate) type DeformConv2dBackward = ( NdArrayTensor, @@ -321,7 +323,7 @@ pub mod backward { input.view(), weight, offset.view(), - mask.as_ref().map(|it| it.view()), + mask.as_ref().map(ndarray::ArrayBase::view), out_grad.view(), &args, (kernel_h, kernel_w), @@ -330,7 +332,7 @@ pub mod backward { let weight_grad = compute_weight_grad( input.view(), offset.view(), - mask.as_ref().map(|it| it.view()), + mask.as_ref().map(ndarray::ArrayBase::view), out_grad.view(), args, (kernel_h, kernel_w), @@ -615,7 +617,7 @@ pub mod backward { #[cfg(feature = "std")] run_par!(|| { iter_par!(Zip::indexed(columns)) - .for_each(|(args0, args1)| compute_for_each(args0, args1)) + .for_each(|(args0, args1)| compute_for_each(args0, args1)); }); #[cfg(not(feature = "std"))] diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index dcc53df83e..a5e0bfa572 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -268,7 +268,10 @@ impl IntTensorOps fn int_into_float(tensor: NdArrayTensor) -> FloatTensor { new_tensor_float!(NdArrayTensor { - array: tensor.array.mapv(|a| a.elem()).into_shared() + array: tensor + .array + .mapv(burn_tensor::ElementConversion::elem) + .into_shared() }) } diff --git a/crates/burn-ndarray/src/ops/interpolate.rs b/crates/burn-ndarray/src/ops/interpolate.rs index 03c835c7e4..f52f91baab 100644 --- a/crates/burn-ndarray/src/ops/interpolate.rs +++ b/crates/burn-ndarray/src/ops/interpolate.rs @@ -75,10 +75,10 @@ pub(crate) fn nearest_interpolate_backward( let ih = start_index(oh, output_height, input_height); let iw = start_index(ow, output_width, input_width); - output_grad[[b, c, ih, iw]] += grad.array[[b, c, oh, ow]] + output_grad[[b, c, ih, iw]] += grad.array[[b, c, oh, ow]]; } } - }) + }); }); NdArrayTensor::new(output_grad.into_dyn().into_shared()) diff --git a/crates/burn-ndarray/src/ops/matmul.rs b/crates/burn-ndarray/src/ops/matmul.rs index 951ce61c27..ddbb4bf739 100644 --- a/crates/burn-ndarray/src/ops/matmul.rs +++ b/crates/burn-ndarray/src/ops/matmul.rs @@ -60,7 +60,7 @@ where &rhs_slice, beta, &mut out_slice, - ) + ); } }); @@ -82,7 +82,7 @@ impl Strides { fn unflatten(&self, linear_index: usize) -> Vec { let mut coord = Vec::with_capacity(self.strides.len()); let mut rem = linear_index; - for stride in self.strides.iter() { + for stride in &self.strides { coord.push(rem / stride); rem %= stride; } @@ -113,18 +113,20 @@ impl Strides { /// one dim is equal to 1 is broadcast.) fn output_shape(lsh: &Shape, rsh: &Shape) -> (Shape, Strides, Strides, Strides) { let ndims = lsh.num_dims(); - if ndims < 2 { - panic!("Matrix multiplication requires an array with at least 2 dimensions."); - } + assert!( + (ndims >= 2), + "Matrix multiplication requires an array with at least 2 dimensions." + ); // Fetch matrix dimensions and check compatibility. let l_rows = lsh.dims[ndims - 2]; let l_cols = lsh.dims[ndims - 1]; let r_rows = rsh.dims[ndims - 2]; let r_cols = rsh.dims[ndims - 1]; - if l_cols != r_rows { - panic!("Dimensions are incompatible for matrix multiplication."); - } + assert!( + (l_cols == r_rows), + "Dimensions are incompatible for matrix multiplication." + ); // Set matrix dimensions of the output shape. let mut osh = vec![0; ndims]; osh[ndims - 2] = l_rows; @@ -253,7 +255,7 @@ mod tests { Strides::new(vec![3, 1, 0]), Strides::new(vec![12, 4, 1]) ) - ) + ); } #[test] diff --git a/crates/burn-ndarray/src/ops/maxpool.rs b/crates/burn-ndarray/src/ops/maxpool.rs index b7f8e776e3..9fef795af4 100644 --- a/crates/burn-ndarray/src/ops/maxpool.rs +++ b/crates/burn-ndarray/src/ops/maxpool.rs @@ -63,7 +63,7 @@ pub(crate) fn max_pool2d( output[[b, c, oh, ow]] = max_val; } } - }) + }); }); NdArrayTensor::new(output.into_dyn().into_shared()) @@ -133,7 +133,7 @@ pub(crate) fn max_pool2d_with_indices ModuleOps()) }; let s~N = Op::apply_vec(s~N, rhs); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible - unsafe { vstore_unaligned(&mut elem[N] as *mut _ as *mut Out, s~N) }; + unsafe { vstore_unaligned((&raw mut elem[N]).cast::(), s~N) }; }); } @@ -409,11 +409,11 @@ unsafe fn binary_scalar_slice_inplace< // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. - let s0 = unsafe { vload(elem as *const _ as *const T) }; + let s0 = unsafe { vload(std::ptr::from_ref(elem).cast::()) }; let s0 = Op::apply_vec(s0, rhs); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible - unsafe { vstore(elem as *mut _ as *mut Out, s0) }; + unsafe { vstore(std::ptr::from_mut(elem).cast::(), s0) }; } } diff --git a/crates/burn-ndarray/src/ops/simd/cmp.rs b/crates/burn-ndarray/src/ops/simd/cmp.rs index 2502441642..a6b524a999 100644 --- a/crates/burn-ndarray/src/ops/simd/cmp.rs +++ b/crates/burn-ndarray/src/ops/simd/cmp.rs @@ -212,7 +212,7 @@ fn cmp<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( .zip(chunks_rhs.remainder()) .zip(chunks_out.into_remainder()) { - *out = Op::apply(*lhs, *rhs) + *out = Op::apply(*lhs, *rhs); } } @@ -229,7 +229,10 @@ mod elemwise { use bytemuck::cast; use macerator::vload; - use super::*; + use super::{ + ArrayD, NdArrayElement, NdArrayTensor, PhantomData, Scalar, Simd, SimdCmpOp, Vector, + is_accelerated, seq, should_use_simd, vload_unaligned, + }; pub fn try_cmp_scalar_simd>( input: NdArrayTensor, @@ -324,7 +327,7 @@ mod elemwise { .iter() .zip(chunks_out.into_remainder()) { - *out = Op::apply(*input, rhs) + *out = Op::apply(*input, rhs); } } @@ -351,11 +354,11 @@ mod elemwise { // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. - let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; + let s~N = unsafe { vload((&raw const elem[N]).cast::()) }; let s~N = Op::apply_vec(s~N, rhs); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible - unsafe { T::mask_store_as_bool(&mut elem[N] as *mut _ as *mut bool, s~N) }; + unsafe { T::mask_store_as_bool((&raw mut elem[N]).cast::(), s~N) }; }); } @@ -363,12 +366,12 @@ mod elemwise { // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. - let s0 = unsafe { vload(elem as *const _ as *const T) }; + let s0 = unsafe { vload(std::ptr::from_ref(elem).cast::()) }; let s0 = Op::apply_vec(s0, rhs); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible - unsafe { T::mask_store_as_bool(elem as *mut _ as *mut bool, s0) }; + unsafe { T::mask_store_as_bool(std::ptr::from_mut(elem).cast::(), s0) }; } } } diff --git a/crates/burn-ndarray/src/ops/simd/conv.rs b/crates/burn-ndarray/src/ops/simd/conv.rs index 42430d9b70..6e6fd66377 100644 --- a/crates/burn-ndarray/src/ops/simd/conv.rs +++ b/crates/burn-ndarray/src/ops/simd/conv.rs @@ -107,28 +107,28 @@ fn conv2d( match (padded, strided, grouped) { (true, true, true) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) + conv2d_launch::(x, w, &bias, &mut out, &options, ob); } (true, false, true) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) + conv2d_launch::(x, w, &bias, &mut out, &options, ob); } (false, true, true) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) + conv2d_launch::(x, w, &bias, &mut out, &options, ob); } (false, false, true) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) + conv2d_launch::(x, w, &bias, &mut out, &options, ob); } (true, true, false) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) + conv2d_launch::(x, w, &bias, &mut out, &options, ob); } (true, false, false) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) + conv2d_launch::(x, w, &bias, &mut out, &options, ob); } (false, true, false) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) + conv2d_launch::(x, w, &bias, &mut out, &options, ob); } (false, false, false) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) + conv2d_launch::(x, w, &bias, &mut out, &options, ob); } } }); diff --git a/crates/burn-ndarray/src/ops/simd/maxpool.rs b/crates/burn-ndarray/src/ops/simd/maxpool.rs index 7490ebeab5..fe43a695bb 100644 --- a/crates/burn-ndarray/src/ops/simd/maxpool.rs +++ b/crates/burn-ndarray/src/ops/simd/maxpool.rs @@ -69,7 +69,10 @@ mod nhwc { use crate::ops::simd::lanes; - use super::*; + use super::{ + Array4, Element, MinMax, NdArrayTensor, TensorMetadata, UnsafeSharedRef, VOrd, + iter_range_par, run_par, s, + }; // Until you can use associated constants as array size, we need to hardcode this. // The most common config (x86-v3) has 16 registers, so use half of them for accumulators. diff --git a/crates/burn-ndarray/src/ops/simd/unary.rs b/crates/burn-ndarray/src/ops/simd/unary.rs index 18d41fcb1c..3f13c185e7 100644 --- a/crates/burn-ndarray/src/ops/simd/unary.rs +++ b/crates/burn-ndarray/src/ops/simd/unary.rs @@ -182,7 +182,7 @@ fn unary_slice< .iter() .zip(chunks_out.into_remainder()) { - *out = Op::apply(*input) + *out = Op::apply(*input); } } @@ -212,11 +212,11 @@ unsafe fn unary_slice_inplace< // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. - let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; + let s~N = unsafe { vload((&raw const elem[N]).cast::()) }; let s~N = Op::apply_vec::(s~N); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible - unsafe { vstore(&mut elem[N] as *mut _ as *mut Out, s~N) }; + unsafe { vstore((&raw mut elem[N]).cast::(), s~N) }; }); } @@ -224,11 +224,11 @@ unsafe fn unary_slice_inplace< // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. - let s0 = unsafe { vload(elem as *const _ as *const T) }; + let s0 = unsafe { vload(std::ptr::from_ref(elem).cast::()) }; let s0 = Op::apply_vec::(s0); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible - unsafe { vstore(elem as *mut _ as *mut Out, s0) }; + unsafe { vstore(std::ptr::from_mut(elem).cast::(), s0) }; } } diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index 4c649424c2..41e42a4972 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -314,7 +314,10 @@ impl FloatTensorO fn float_exp(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor| { - let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared(); + let array = tensor + .array + .mapv_into(super::super::element::ExpElement::exp_elem) + .into_shared(); NdArrayTensor::new(array) }) @@ -322,7 +325,10 @@ impl FloatTensorO fn float_log(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor| { - let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared(); + let array = tensor + .array + .mapv_into(super::super::element::ExpElement::log_elem) + .into_shared(); NdArrayTensor::new(array) }) @@ -338,7 +344,10 @@ impl FloatTensorO fn float_log1p(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor| { - let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared(); + let array = tensor + .array + .mapv_into(super::super::element::ExpElement::log1p_elem) + .into_shared(); NdArrayTensor::new(array) }) @@ -366,7 +375,10 @@ impl FloatTensorO fn float_sqrt(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor| { - let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared(); + let array = tensor + .array + .mapv_into(super::super::element::ExpElement::sqrt_elem) + .into_shared(); NdArrayTensor::new(array) }) @@ -374,7 +386,10 @@ impl FloatTensorO fn float_abs(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor| { - let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared(); + let array = tensor + .array + .mapv_into(super::super::element::ExpElement::abs_elem) + .into_shared(); NdArrayTensor::new(array) }) @@ -512,7 +527,7 @@ impl FloatTensorO fn float_into_int(tensor: FloatTensor) -> NdArrayTensor { execute_with_float_dtype!(tensor, E => |tensor: NdArrayTensor| { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); + let array = tensor.array.mapv(burn_tensor::ElementConversion::elem).into_shared(); NdArrayTensor { array } }) } @@ -545,7 +560,10 @@ impl FloatTensorO fn cast( tensor: &NdArrayTensor, ) -> NdArrayTensor { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); + let array = tensor + .array + .mapv(burn_tensor::ElementConversion::elem) + .into_shared(); NdArrayTensor { array } } diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index 3156bcd205..71e3098ec7 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -192,7 +192,7 @@ macro_rules! execute_with_float_dtype { mod utils { use burn_common::tensor::is_contiguous; - use super::*; + use super::{Element, NdArrayTensor, TensorData, TensorMetadata, Vec}; impl NdArrayTensor where @@ -302,6 +302,7 @@ where E: Element, { /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData). + #[must_use] pub fn from_data(mut data: TensorData) -> NdArrayTensor { let shape = mem::take(&mut data.shape); @@ -328,6 +329,7 @@ pub struct NdArrayQTensor { impl NdArrayQTensor { /// Returns the quantization strategy, including quantization parameters, for the given tensor. + #[must_use] pub fn strategy(&self) -> QuantizationStrategy { match self.scheme { QuantScheme { diff --git a/crates/burn-no-std-tests/Cargo.toml b/crates/burn-no-std-tests/Cargo.toml index a59f1d45a3..1ba04d3028 100644 --- a/crates/burn-no-std-tests/Cargo.toml +++ b/crates/burn-no-std-tests/Cargo.toml @@ -10,6 +10,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-no-std-tests" version.workspace = true +[lints] +workspace = true + [dependencies] # ** Please make sure all dependencies support no_std ** diff --git a/crates/burn-no-std-tests/src/mlp.rs b/crates/burn-no-std-tests/src/mlp.rs index ca016b86a2..e6f808b70e 100644 --- a/crates/burn-no-std-tests/src/mlp.rs +++ b/crates/burn-no-std-tests/src/mlp.rs @@ -56,7 +56,7 @@ impl Mlp { pub fn forward(&self, input: Tensor) -> Tensor { let mut x = input; - for linear in self.linears.iter() { + for linear in &self.linears { x = linear.forward(x); x = self.dropout.forward(x); x = self.activation.forward(x); diff --git a/crates/burn-remote/Cargo.toml b/crates/burn-remote/Cargo.toml index 5fafbc7841..b4685608cb 100644 --- a/crates/burn-remote/Cargo.toml +++ b/crates/burn-remote/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router-rem documentation = "https://docs.rs/burn-router-remote" version.workspace = true +[lints] +workspace = true + [features] default = [] doc = [] diff --git a/crates/burn-remote/src/server/base.rs b/crates/burn-remote/src/server/base.rs index d9f48cbf29..eabb41fed6 100644 --- a/crates/burn-remote/src/server/base.rs +++ b/crates/burn-remote/src/server/base.rs @@ -120,12 +120,11 @@ impl WsServer { loop { let packet = socket.recv().await; - let msg = match packet { - Some(msg) => msg, - None => { - log::info!("Still no message"); - continue; - } + let msg = if let Some(msg) = packet { + msg + } else { + log::info!("Still no message"); + continue; }; if let Ok(ws::Message::Binary(bytes)) = msg { @@ -142,13 +141,13 @@ impl WsServer { break; } - let (stream, connection_id, task) = match self.state.stream(&mut session_id, task) { - Some(val) => val, - None => { + let (stream, connection_id, task) = + if let Some(val) = self.state.stream(&mut session_id, task) { + val + } else { log::info!("Ops session activated {session_id:?}"); continue; - } - }; + }; match task { ComputeTask::RegisterOperation(op) => { @@ -170,7 +169,7 @@ impl WsServer { } else { log::info!("Not a binary message, closing, received {msg:?}"); break; - }; + } } log::info!("Closing connection"); diff --git a/crates/burn-remote/src/server/processor.rs b/crates/burn-remote/src/server/processor.rs index c988b56026..2e93e19060 100644 --- a/crates/burn-remote/src/server/processor.rs +++ b/crates/burn-remote/src/server/processor.rs @@ -27,7 +27,7 @@ impl Processor { let (sender, rec) = std::sync::mpsc::sync_channel(1); std::thread::spawn(move || { - for item in rec.iter() { + for item in &rec { match item { ProcessorTask::RegisterOperation(op) => { runner.register(*op); diff --git a/crates/burn-remote/src/server/session.rs b/crates/burn-remote/src/server/session.rs index 6c8c88178f..3a46846cf0 100644 --- a/crates/burn-remote/src/server/session.rs +++ b/crates/burn-remote/src/server/session.rs @@ -120,13 +120,12 @@ impl Session { /// Select the current [stream](Stream) based on the given task. fn select(&mut self, stream_id: StreamId) -> Stream { // We return the stream. - match self.streams.get(&stream_id) { - Some(stream) => stream.clone(), - None => { - let stream = Stream::::new(self.runner.clone(), self.sender.clone()); - self.streams.insert(stream_id, stream.clone()); - stream - } + if let Some(stream) = self.streams.get(&stream_id) { + stream.clone() + } else { + let stream = Stream::::new(self.runner.clone(), self.sender.clone()); + self.streams.insert(stream_id, stream.clone()); + stream } } diff --git a/crates/burn-remote/src/server/stream.rs b/crates/burn-remote/src/server/stream.rs index 7aea85db00..2d3af88a13 100644 --- a/crates/burn-remote/src/server/stream.rs +++ b/crates/burn-remote/src/server/stream.rs @@ -37,13 +37,13 @@ impl Stream { pub fn register_tensor(&self, tensor_id: TensorId, data: TensorData) { self.compute_sender .send(ProcessorTask::RegisterTensor(tensor_id, data)) - .unwrap() + .unwrap(); } pub fn register_orphan(&self, tensor_id: TensorId) { self.compute_sender .send(ProcessorTask::RegisterOrphan(tensor_id)) - .unwrap() + .unwrap(); } pub fn read_tensor(&self, id: ConnectionId, desc: TensorIr) { diff --git a/crates/burn-rocm/Cargo.toml b/crates/burn-rocm/Cargo.toml index bf8a846b8c..ddd1a7b76d 100644 --- a/crates/burn-rocm/Cargo.toml +++ b/crates/burn-rocm/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-rocm" documentation = "https://docs.rs/burn-rocm" version.workspace = true +[lints] +workspace = true + [features] default = ["fusion", "burn-cubecl/default", "cubecl/default"] fusion = ["burn-fusion", "burn-cubecl/fusion"] diff --git a/crates/burn-router/Cargo.toml b/crates/burn-router/Cargo.toml index 0a1df0e976..63b1d30dab 100644 --- a/crates/burn-router/Cargo.toml +++ b/crates/burn-router/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router" documentation = "https://docs.rs/burn-router" version.workspace = true +[lints] +workspace = true + [features] default = ["std"] std = ["burn-tensor/std", "burn-common/std", "burn-ir/std"] diff --git a/crates/burn-router/src/backend.rs b/crates/burn-router/src/backend.rs index 473f608107..4430501faf 100644 --- a/crates/burn-router/src/backend.rs +++ b/crates/burn-router/src/backend.rs @@ -62,7 +62,7 @@ impl Backend for BackendRouter { } fn seed(seed: u64) { - set_seed(seed) + set_seed(seed); } fn sync(device: &Self::Device) { diff --git a/crates/burn-router/src/client/base.rs b/crates/burn-router/src/client/base.rs index 7ea3179063..9b72b1bafe 100644 --- a/crates/burn-router/src/client/base.rs +++ b/crates/burn-router/src/client/base.rs @@ -1,9 +1,6 @@ use alloc::{boxed::Box, vec::Vec}; use burn_common::future::DynFut; -use core::{ - ops::DerefMut, - sync::atomic::{AtomicBool, AtomicU64, Ordering}, -}; +use core::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use hashbrown::HashMap; use spin::Mutex; @@ -34,11 +31,11 @@ pub trait RunnerClient: Clone + Send + Sync + Sized { fn read_tensor(&self, tensor: TensorIr) -> DynFut; /// Sync the runner, ensure that all computations are finished. fn sync(&self); - /// Create a new [RouterTensor] from the tensor data. + /// Create a new [`RouterTensor`] from the tensor data. fn register_tensor_data(&self, data: TensorData) -> RouterTensor; - /// Create a new [RouterTensor] with no resources associated. + /// Create a new [`RouterTensor`] with no resources associated. fn register_empty_tensor(&self, shape: Vec, dtype: DType) -> RouterTensor; - /// Create a new float [RouterTensor] with no resources associated. + /// Create a new float [`RouterTensor`] with no resources associated. fn register_float_tensor(&self, shape: Vec, dtype: FloatDType) -> RouterTensor; /// Get the current device used by all operations handled by this client. fn device(&self) -> Self::Device; @@ -76,7 +73,7 @@ fn get_seed() -> Option { fn new_client(device: &R::Device) -> Client { let client = R::init_client(device); if let Some(seed) = get_seed() { - client.seed(seed) + client.seed(seed); } client } @@ -102,19 +99,18 @@ impl RunnerClientLocator { Self::register_inner::(client_id, client, &mut clients); } - match clients.deref_mut() { - Some(clients) => match clients.get(&client_id) { - Some(client) => { + match &mut *clients { + Some(clients) => { + if let Some(client) = clients.get(&client_id) { let client: &Client = client.downcast_ref().unwrap(); client.clone() - } - None => { + } else { let client = new_client::(device); let any = Box::new(client.clone()); clients.insert(client_id, any); client } - }, + } _ => unreachable!(), } } @@ -129,9 +125,10 @@ impl RunnerClientLocator { } if let Some(clients) = clients { - if clients.contains_key(&key) { - panic!("Client already created for device {:?}", key); - } + assert!( + !clients.contains_key(&key), + "Client already created for device {key:?}" + ); clients.insert(key, Box::new(client)); } diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index 865e145013..2e7db9170d 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -270,13 +270,16 @@ impl BoolTensorOps for BackendRouter { // Calculate the output shape let mut shape = tensor_first.shape.clone(); shape[dim] = 0; - for tensor in tensors.iter() { + for tensor in &tensors { shape[dim] += tensor.shape[dim]; } let out = client.register_empty_tensor(shape, dtype); let desc = CatOpIr { - tensors: tensors.into_iter().map(|t| t.into_ir()).collect(), + tensors: tensors + .into_iter() + .map(super::super::tensor::RouterTensor::into_ir) + .collect(), dim, out: out.to_ir_out(), }; diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index 6db8e0ec34..0fa3696d17 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -1150,13 +1150,16 @@ impl FloatTensorOps for BackendRouter { // Calculate the output shape let mut shape = tensor_first.shape.clone(); shape[dim] = 0; - for tensor in tensors.iter() { + for tensor in &tensors { shape[dim] += tensor.shape[dim]; } let out = client.register_empty_tensor(shape, tensor_first.dtype); let desc = CatOpIr { - tensors: tensors.into_iter().map(|tensor| tensor.into_ir()).collect(), + tensors: tensors + .into_iter() + .map(super::super::tensor::RouterTensor::into_ir) + .collect(), dim, out: out.to_ir_out(), }; diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index 76e6e87428..4eb530f6b4 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -278,13 +278,16 @@ impl IntTensorOps for BackendRouter { // Calculate the output shape let mut shape = tensor_first.shape.clone(); shape[dim] = 0; - for tensor in tensors.iter() { + for tensor in &tensors { shape[dim] += tensor.shape[dim]; } let out = client.register_empty_tensor(shape, dtype); let desc = CatOpIr { - tensors: tensors.into_iter().map(|t| t.into_ir()).collect(), + tensors: tensors + .into_iter() + .map(super::super::tensor::RouterTensor::into_ir) + .collect(), dim, out: out.to_ir_out(), }; @@ -689,7 +692,7 @@ impl IntTensorOps for BackendRouter { // Get the runtime client on which to register the operation for execution. let client = get_client::(device); let dtype = IntElem::::dtype(); - let out = client.register_empty_tensor(shape.dims.to_vec(), dtype); + let out = client.register_empty_tensor(shape.dims.clone(), dtype); client.register(OperationIr::NumericInt( dtype, diff --git a/crates/burn-router/src/ops/op_module.rs b/crates/burn-router/src/ops/op_module.rs index 332dcbd602..caae0a936c 100644 --- a/crates/burn-router/src/ops/op_module.rs +++ b/crates/burn-router/src/ops/op_module.rs @@ -46,7 +46,7 @@ impl ModuleOps for BackendRouter { let desc = Conv1dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::RouterTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -84,7 +84,7 @@ impl ModuleOps for BackendRouter { let desc = Conv2dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::RouterTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -129,7 +129,7 @@ impl ModuleOps for BackendRouter { let desc = Conv3dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::RouterTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -161,7 +161,7 @@ impl ModuleOps for BackendRouter { let desc = ConvTranspose1dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::RouterTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -203,7 +203,7 @@ impl ModuleOps for BackendRouter { let desc = ConvTranspose2dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::RouterTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -259,7 +259,7 @@ impl ModuleOps for BackendRouter { let desc = ConvTranspose3dOpIr { x: x.into_ir(), weight: weight.into_ir(), - bias: bias.map(|bias| bias.into_ir()), + bias: bias.map(super::super::tensor::RouterTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -741,8 +741,8 @@ impl ModuleOps for BackendRouter { x: x.into_ir(), offset: offset.into_ir(), weight: weight.into_ir(), - mask: mask.map(|mask| mask.into_ir()), - bias: bias.map(|bias| bias.into_ir()), + mask: mask.map(super::super::tensor::RouterTensor::into_ir), + bias: bias.map(super::super::tensor::RouterTensor::into_ir), options: options.into(), out: out.to_ir_out(), }; @@ -779,15 +779,19 @@ impl ModuleOps for BackendRouter { x: x.into_ir(), offset: offset.into_ir(), weight: weight.into_ir(), - mask: mask.map(|mask| mask.into_ir()), - bias: bias.map(|bias| bias.into_ir()), + mask: mask.map(super::super::tensor::RouterTensor::into_ir), + bias: bias.map(super::super::tensor::RouterTensor::into_ir), options: options.into(), out_grad: output_grad.into_ir(), input_grad: input_grad.to_ir_out(), offset_grad: offset_grad.to_ir_out(), weight_grad: weight_grad.to_ir_out(), - mask_grad: mask_grad.as_ref().map(|mask_grad| mask_grad.to_ir_out()), - bias_grad: bias_grad.as_ref().map(|bias_grad| bias_grad.to_ir_out()), + mask_grad: mask_grad + .as_ref() + .map(super::super::tensor::RouterTensor::to_ir_out), + bias_grad: bias_grad + .as_ref() + .map(super::super::tensor::RouterTensor::to_ir_out), }; client.register(OperationIr::Module( diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index ae6e59f4cd..e330e10686 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -28,7 +28,7 @@ impl RunnerContext { fn free_orphans(&mut self) { // Passing an empty "remaining" tensor identifiers will remove the orphan handles from the container - self.handles.free_orphans(&[]) + self.handles.free_orphans(&[]); } /// Set a tensor handle to be removed. @@ -86,13 +86,13 @@ impl Runner { if dtype.is_float() { let tensor = B::float_from_data(data, &self.device); - ctx.handles.register_float_tensor::(&id, tensor) + ctx.handles.register_float_tensor::(&id, tensor); } else if dtype.is_int() { let tensor = B::int_from_data(data, &self.device); - ctx.handles.register_int_tensor::(&id, tensor) + ctx.handles.register_int_tensor::(&id, tensor); } else if dtype.is_bool() { let tensor = B::bool_from_data(data, &self.device); - ctx.handles.register_bool_tensor::(&id, tensor) + ctx.handles.register_bool_tensor::(&id, tensor); } else if let DType::QFloat(_) = dtype { todo!(); } @@ -109,13 +109,13 @@ impl Runner { if dtype.is_float() { let tensor = B::float_from_data(data, &self.device); - ctx.handles.register_float_tensor::(&id, tensor) + ctx.handles.register_float_tensor::(&id, tensor); } else if dtype.is_int() { let tensor = B::int_from_data(data, &self.device); - ctx.handles.register_int_tensor::(&id, tensor) + ctx.handles.register_int_tensor::(&id, tensor); } else if dtype.is_bool() { let tensor = B::bool_from_data(data, &self.device); - ctx.handles.register_bool_tensor::(&id, tensor) + ctx.handles.register_bool_tensor::(&id, tensor); } else if let DType::QFloat(_) = dtype { todo!(); } @@ -211,7 +211,7 @@ impl RunnerClient for Runner { handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Equal(desc) => { - binary_float_cmp_ops!(handles, desc, B::float_equal) + binary_float_cmp_ops!(handles, desc, B::float_equal); } BaseOperationIr::RepeatDim(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); @@ -286,7 +286,7 @@ impl RunnerClient for Runner { handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Equal(desc) => { - binary_int_cmp_ops!(handles, desc, B::int_equal) + binary_int_cmp_ops!(handles, desc, B::int_equal); } BaseOperationIr::RepeatDim(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); @@ -388,37 +388,37 @@ impl RunnerClient for Runner { }, OperationIr::NumericFloat(_dtype, op) => match op { NumericOperationIr::Add(desc) => { - binary_float_ops!(handles, desc, B::float_add) + binary_float_ops!(handles, desc, B::float_add); } NumericOperationIr::AddScalar(desc) => { - scalar_float_ops!(handles, desc, B::float_add_scalar) + scalar_float_ops!(handles, desc, B::float_add_scalar); } NumericOperationIr::Sub(desc) => { - binary_float_ops!(handles, desc, B::float_sub) + binary_float_ops!(handles, desc, B::float_sub); } NumericOperationIr::SubScalar(desc) => { - scalar_float_ops!(handles, desc, B::float_sub_scalar) + scalar_float_ops!(handles, desc, B::float_sub_scalar); } NumericOperationIr::Div(desc) => { - binary_float_ops!(handles, desc, B::float_div) + binary_float_ops!(handles, desc, B::float_div); } NumericOperationIr::DivScalar(desc) => { - scalar_float_ops!(handles, desc, B::float_div_scalar) + scalar_float_ops!(handles, desc, B::float_div_scalar); } NumericOperationIr::Rem(desc) => { - binary_float_ops!(handles, desc, B::float_remainder) + binary_float_ops!(handles, desc, B::float_remainder); } NumericOperationIr::RemScalar(desc) => { - scalar_float_ops!(handles, desc, B::float_remainder_scalar) + scalar_float_ops!(handles, desc, B::float_remainder_scalar); } NumericOperationIr::Mul(desc) => { - binary_float_ops!(handles, desc, B::float_mul) + binary_float_ops!(handles, desc, B::float_mul); } NumericOperationIr::MulScalar(desc) => { - scalar_float_ops!(handles, desc, B::float_mul_scalar) + scalar_float_ops!(handles, desc, B::float_mul_scalar); } NumericOperationIr::Abs(desc) => { - unary_float_ops!(handles, desc, B::float_abs) + unary_float_ops!(handles, desc, B::float_abs); } NumericOperationIr::Ones(desc) => { let shape = Shape::from(desc.shape.clone()); @@ -481,58 +481,58 @@ impl RunnerClient for Runner { handles.register_float_tensor::(&desc.out.id, output); } NumericOperationIr::MeanDim(desc) => { - reduce_float_dim_ops!(handles, desc, B::float_mean_dim) + reduce_float_dim_ops!(handles, desc, B::float_mean_dim); } NumericOperationIr::Mean(desc) => { - unary_float_ops!(handles, desc, B::float_mean) + unary_float_ops!(handles, desc, B::float_mean); } NumericOperationIr::Sum(desc) => { - unary_float_ops!(handles, desc, B::float_sum) + unary_float_ops!(handles, desc, B::float_sum); } NumericOperationIr::SumDim(desc) => { - reduce_float_dim_ops!(handles, desc, B::float_sum_dim) + reduce_float_dim_ops!(handles, desc, B::float_sum_dim); } NumericOperationIr::Prod(desc) => { - unary_float_ops!(handles, desc, B::float_prod) + unary_float_ops!(handles, desc, B::float_prod); } NumericOperationIr::ProdDim(desc) => { - reduce_float_dim_ops!(handles, desc, B::float_prod_dim) + reduce_float_dim_ops!(handles, desc, B::float_prod_dim); } NumericOperationIr::EqualElem(desc) => { - scalar_float_cmp_ops!(handles, desc, B::float_equal_elem) + scalar_float_cmp_ops!(handles, desc, B::float_equal_elem); } NumericOperationIr::Greater(desc) => { - binary_float_cmp_ops!(handles, desc, B::float_greater) + binary_float_cmp_ops!(handles, desc, B::float_greater); } NumericOperationIr::GreaterElem(desc) => { - scalar_float_cmp_ops!(handles, desc, B::float_greater_elem) + scalar_float_cmp_ops!(handles, desc, B::float_greater_elem); } NumericOperationIr::GreaterEqual(desc) => { - binary_float_cmp_ops!(handles, desc, B::float_greater_equal) + binary_float_cmp_ops!(handles, desc, B::float_greater_equal); } NumericOperationIr::GreaterEqualElem(desc) => { - scalar_float_cmp_ops!(handles, desc, B::float_greater_equal_elem) + scalar_float_cmp_ops!(handles, desc, B::float_greater_equal_elem); } NumericOperationIr::Lower(desc) => { - binary_float_cmp_ops!(handles, desc, B::float_lower) + binary_float_cmp_ops!(handles, desc, B::float_lower); } NumericOperationIr::LowerElem(desc) => { - scalar_float_cmp_ops!(handles, desc, B::float_lower_elem) + scalar_float_cmp_ops!(handles, desc, B::float_lower_elem); } NumericOperationIr::LowerEqual(desc) => { - binary_float_cmp_ops!(handles, desc, B::float_lower_equal) + binary_float_cmp_ops!(handles, desc, B::float_lower_equal); } NumericOperationIr::LowerEqualElem(desc) => { - scalar_float_cmp_ops!(handles, desc, B::float_lower_equal_elem) + scalar_float_cmp_ops!(handles, desc, B::float_lower_equal_elem); } NumericOperationIr::ArgMax(desc) => { - reduce_float2int_dim_ops!(handles, desc, B::float_argmax) + reduce_float2int_dim_ops!(handles, desc, B::float_argmax); } NumericOperationIr::ArgMin(desc) => { - reduce_float2int_dim_ops!(handles, desc, B::float_argmin) + reduce_float2int_dim_ops!(handles, desc, B::float_argmin); } NumericOperationIr::Max(desc) => { - unary_float_ops!(handles, desc, B::float_max) + unary_float_ops!(handles, desc, B::float_max); } NumericOperationIr::MaxDimWithIndices(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); @@ -549,19 +549,19 @@ impl RunnerClient for Runner { handles.register_int_tensor::(&desc.out_indices.id, output_idx); } NumericOperationIr::Min(desc) => { - unary_float_ops!(handles, desc, B::float_min) + unary_float_ops!(handles, desc, B::float_min); } NumericOperationIr::MaxDim(desc) => { - reduce_float_dim_ops!(handles, desc, B::float_max_dim) + reduce_float_dim_ops!(handles, desc, B::float_max_dim); } NumericOperationIr::MinDim(desc) => { - reduce_float_dim_ops!(handles, desc, B::float_min_dim) + reduce_float_dim_ops!(handles, desc, B::float_min_dim); } NumericOperationIr::MaxAbs(desc) => { - unary_float_ops!(handles, desc, B::float_max_abs) + unary_float_ops!(handles, desc, B::float_max_abs); } NumericOperationIr::MaxAbsDim(desc) => { - reduce_float_dim_ops!(handles, desc, B::float_max_abs_dim) + reduce_float_dim_ops!(handles, desc, B::float_max_abs_dim); } NumericOperationIr::Clamp(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); @@ -571,42 +571,42 @@ impl RunnerClient for Runner { } NumericOperationIr::IntRandom(_) => unreachable!(), NumericOperationIr::Powf(desc) => { - binary_float_ops!(handles, desc, B::float_powf) + binary_float_ops!(handles, desc, B::float_powf); } }, OperationIr::NumericInt(_dtype, op) => match op { NumericOperationIr::Add(desc) => { - binary_int_ops!(handles, desc, B::int_add) + binary_int_ops!(handles, desc, B::int_add); } NumericOperationIr::AddScalar(desc) => { - scalar_int_ops!(handles, desc, B::int_add_scalar) + scalar_int_ops!(handles, desc, B::int_add_scalar); } NumericOperationIr::Sub(desc) => { - binary_int_ops!(handles, desc, B::int_sub) + binary_int_ops!(handles, desc, B::int_sub); } NumericOperationIr::SubScalar(desc) => { - scalar_int_ops!(handles, desc, B::int_sub_scalar) + scalar_int_ops!(handles, desc, B::int_sub_scalar); } NumericOperationIr::Div(desc) => { - binary_int_ops!(handles, desc, B::int_div) + binary_int_ops!(handles, desc, B::int_div); } NumericOperationIr::DivScalar(desc) => { - scalar_int_ops!(handles, desc, B::int_div_scalar) + scalar_int_ops!(handles, desc, B::int_div_scalar); } NumericOperationIr::Rem(desc) => { - binary_int_ops!(handles, desc, B::int_remainder) + binary_int_ops!(handles, desc, B::int_remainder); } NumericOperationIr::RemScalar(desc) => { - scalar_int_ops!(handles, desc, B::int_remainder_scalar) + scalar_int_ops!(handles, desc, B::int_remainder_scalar); } NumericOperationIr::Mul(desc) => { - binary_int_ops!(handles, desc, B::int_mul) + binary_int_ops!(handles, desc, B::int_mul); } NumericOperationIr::MulScalar(desc) => { - scalar_int_ops!(handles, desc, B::int_mul_scalar) + scalar_int_ops!(handles, desc, B::int_mul_scalar); } NumericOperationIr::Abs(desc) => { - unary_int_ops!(handles, desc, B::int_abs) + unary_int_ops!(handles, desc, B::int_abs); } NumericOperationIr::Ones(desc) => { let shape = Shape::from(desc.shape.clone()); @@ -669,58 +669,58 @@ impl RunnerClient for Runner { handles.register_int_tensor::(&desc.out.id, output); } NumericOperationIr::MeanDim(desc) => { - reduce_int_dim_ops!(handles, desc, B::int_mean_dim) + reduce_int_dim_ops!(handles, desc, B::int_mean_dim); } NumericOperationIr::Mean(desc) => { - unary_int_ops!(handles, desc, B::int_mean) + unary_int_ops!(handles, desc, B::int_mean); } NumericOperationIr::Sum(desc) => { - unary_int_ops!(handles, desc, B::int_sum) + unary_int_ops!(handles, desc, B::int_sum); } NumericOperationIr::SumDim(desc) => { - reduce_int_dim_ops!(handles, desc, B::int_sum_dim) + reduce_int_dim_ops!(handles, desc, B::int_sum_dim); } NumericOperationIr::Prod(desc) => { - unary_int_ops!(handles, desc, B::int_prod) + unary_int_ops!(handles, desc, B::int_prod); } NumericOperationIr::ProdDim(desc) => { - reduce_int_dim_ops!(handles, desc, B::int_prod_dim) + reduce_int_dim_ops!(handles, desc, B::int_prod_dim); } NumericOperationIr::EqualElem(desc) => { - scalar_int_cmp_ops!(handles, desc, B::int_equal_elem) + scalar_int_cmp_ops!(handles, desc, B::int_equal_elem); } NumericOperationIr::Greater(desc) => { - binary_int_cmp_ops!(handles, desc, B::int_greater) + binary_int_cmp_ops!(handles, desc, B::int_greater); } NumericOperationIr::GreaterElem(desc) => { - scalar_int_cmp_ops!(handles, desc, B::int_greater_elem) + scalar_int_cmp_ops!(handles, desc, B::int_greater_elem); } NumericOperationIr::GreaterEqual(desc) => { - binary_int_cmp_ops!(handles, desc, B::int_greater_equal) + binary_int_cmp_ops!(handles, desc, B::int_greater_equal); } NumericOperationIr::GreaterEqualElem(desc) => { - scalar_int_cmp_ops!(handles, desc, B::int_greater_equal_elem) + scalar_int_cmp_ops!(handles, desc, B::int_greater_equal_elem); } NumericOperationIr::Lower(desc) => { - binary_int_cmp_ops!(handles, desc, B::int_lower) + binary_int_cmp_ops!(handles, desc, B::int_lower); } NumericOperationIr::LowerElem(desc) => { - scalar_int_cmp_ops!(handles, desc, B::int_lower_elem) + scalar_int_cmp_ops!(handles, desc, B::int_lower_elem); } NumericOperationIr::LowerEqual(desc) => { - binary_int_cmp_ops!(handles, desc, B::int_lower_equal) + binary_int_cmp_ops!(handles, desc, B::int_lower_equal); } NumericOperationIr::LowerEqualElem(desc) => { - scalar_int_cmp_ops!(handles, desc, B::int_lower_equal_elem) + scalar_int_cmp_ops!(handles, desc, B::int_lower_equal_elem); } NumericOperationIr::ArgMax(desc) => { - reduce_int_dim_ops!(handles, desc, B::int_argmax) + reduce_int_dim_ops!(handles, desc, B::int_argmax); } NumericOperationIr::ArgMin(desc) => { - reduce_int_dim_ops!(handles, desc, B::int_argmin) + reduce_int_dim_ops!(handles, desc, B::int_argmin); } NumericOperationIr::Max(desc) => { - unary_int_ops!(handles, desc, B::int_max) + unary_int_ops!(handles, desc, B::int_max); } NumericOperationIr::MaxDimWithIndices(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); @@ -737,19 +737,19 @@ impl RunnerClient for Runner { handles.register_int_tensor::(&desc.out_indices.id, output_idx); } NumericOperationIr::Min(desc) => { - unary_int_ops!(handles, desc, B::int_min) + unary_int_ops!(handles, desc, B::int_min); } NumericOperationIr::MaxDim(desc) => { - reduce_int_dim_ops!(handles, desc, B::int_max_dim) + reduce_int_dim_ops!(handles, desc, B::int_max_dim); } NumericOperationIr::MinDim(desc) => { - reduce_int_dim_ops!(handles, desc, B::int_min_dim) + reduce_int_dim_ops!(handles, desc, B::int_min_dim); } NumericOperationIr::MaxAbs(desc) => { - unary_int_ops!(handles, desc, B::int_max_abs) + unary_int_ops!(handles, desc, B::int_max_abs); } NumericOperationIr::MaxAbsDim(desc) => { - reduce_int_dim_ops!(handles, desc, B::int_max_abs_dim) + reduce_int_dim_ops!(handles, desc, B::int_max_abs_dim); } NumericOperationIr::Clamp(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); @@ -791,10 +791,10 @@ impl RunnerClient for Runner { handles.register_bool_tensor::(&desc.out.id, output); } BoolOperationIr::And(desc) => { - binary_bool_ops!(handles, desc, B::bool_and) + binary_bool_ops!(handles, desc, B::bool_and); } BoolOperationIr::Or(desc) => { - binary_bool_ops!(handles, desc, B::bool_or) + binary_bool_ops!(handles, desc, B::bool_or); } }, OperationIr::Int(op) => match op { @@ -805,75 +805,75 @@ impl RunnerClient for Runner { handles.register_float_tensor::(&desc.out.id, output); } IntOperationIr::BitwiseAnd(desc) => { - binary_int_ops!(handles, desc, B::bitwise_and) + binary_int_ops!(handles, desc, B::bitwise_and); } IntOperationIr::BitwiseAndScalar(desc) => { - scalar_int_ops!(handles, desc, B::bitwise_and_scalar) + scalar_int_ops!(handles, desc, B::bitwise_and_scalar); } IntOperationIr::BitwiseOr(desc) => { - binary_int_ops!(handles, desc, B::bitwise_or) + binary_int_ops!(handles, desc, B::bitwise_or); } IntOperationIr::BitwiseOrScalar(desc) => { - scalar_int_ops!(handles, desc, B::bitwise_or_scalar) + scalar_int_ops!(handles, desc, B::bitwise_or_scalar); } IntOperationIr::BitwiseXor(desc) => { - binary_int_ops!(handles, desc, B::bitwise_xor) + binary_int_ops!(handles, desc, B::bitwise_xor); } IntOperationIr::BitwiseXorScalar(desc) => { - scalar_int_ops!(handles, desc, B::bitwise_xor_scalar) + scalar_int_ops!(handles, desc, B::bitwise_xor_scalar); } IntOperationIr::BitwiseNot(desc) => { - unary_int_ops!(handles, desc, B::bitwise_not) + unary_int_ops!(handles, desc, B::bitwise_not); } IntOperationIr::BitwiseLeftShift(desc) => { - binary_int_ops!(handles, desc, B::bitwise_left_shift) + binary_int_ops!(handles, desc, B::bitwise_left_shift); } IntOperationIr::BitwiseRightShift(desc) => { - binary_int_ops!(handles, desc, B::bitwise_right_shift) + binary_int_ops!(handles, desc, B::bitwise_right_shift); } IntOperationIr::BitwiseLeftShiftScalar(desc) => { - scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar) + scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar); } IntOperationIr::BitwiseRightShiftScalar(desc) => { - scalar_int_ops!(handles, desc, B::bitwise_right_shift_scalar) + scalar_int_ops!(handles, desc, B::bitwise_right_shift_scalar); } }, OperationIr::Float(_dtype, op) => match op { FloatOperationIr::Exp(desc) => { - unary_float_ops!(handles, desc, B::float_exp) + unary_float_ops!(handles, desc, B::float_exp); } FloatOperationIr::Log(desc) => { - unary_float_ops!(handles, desc, B::float_log) + unary_float_ops!(handles, desc, B::float_log); } FloatOperationIr::Log1p(desc) => { - unary_float_ops!(handles, desc, B::float_log1p) + unary_float_ops!(handles, desc, B::float_log1p); } FloatOperationIr::Erf(desc) => { - unary_float_ops!(handles, desc, B::float_erf) + unary_float_ops!(handles, desc, B::float_erf); } FloatOperationIr::PowfScalar(desc) => { - scalar_float_ops!(handles, desc, B::float_powf_scalar) + scalar_float_ops!(handles, desc, B::float_powf_scalar); } FloatOperationIr::Sqrt(desc) => { - unary_float_ops!(handles, desc, B::float_sqrt) + unary_float_ops!(handles, desc, B::float_sqrt); } FloatOperationIr::Cos(desc) => { - unary_float_ops!(handles, desc, B::float_cos) + unary_float_ops!(handles, desc, B::float_cos); } FloatOperationIr::Sin(desc) => { - unary_float_ops!(handles, desc, B::float_sin) + unary_float_ops!(handles, desc, B::float_sin); } FloatOperationIr::Tanh(desc) => { - unary_float_ops!(handles, desc, B::float_tanh) + unary_float_ops!(handles, desc, B::float_tanh); } FloatOperationIr::Round(desc) => { - unary_float_ops!(handles, desc, B::float_round) + unary_float_ops!(handles, desc, B::float_round); } FloatOperationIr::Floor(desc) => { - unary_float_ops!(handles, desc, B::float_floor) + unary_float_ops!(handles, desc, B::float_floor); } FloatOperationIr::Ceil(desc) => { - unary_float_ops!(handles, desc, B::float_ceil) + unary_float_ops!(handles, desc, B::float_ceil); } FloatOperationIr::IntoInt(desc) => { let tensor = handles.get_float_tensor::(&desc.input); @@ -882,7 +882,7 @@ impl RunnerClient for Runner { handles.register_int_tensor::(&desc.out.id, output); } FloatOperationIr::Matmul(desc) => { - binary_float_ops!(handles, desc, B::float_matmul) + binary_float_ops!(handles, desc, B::float_matmul); } FloatOperationIr::Random(desc) => { let shape = Shape::from(desc.out.shape.clone()); @@ -891,7 +891,7 @@ impl RunnerClient for Runner { handles.register_float_tensor::(&desc.out.id, output); } FloatOperationIr::Recip(desc) => { - unary_float_ops!(handles, desc, B::float_recip) + unary_float_ops!(handles, desc, B::float_recip); } FloatOperationIr::Quantize(_) => todo!(), FloatOperationIr::Dequantize(_) => todo!(), @@ -1276,7 +1276,7 @@ impl RunnerClient for Runner { } fn register_orphan(&self, id: &TensorId) { - self.context.lock().unwrap().drop_tensor_handle(*id) + self.context.lock().unwrap().drop_tensor_handle(*id); } fn sync(&self) { @@ -1285,6 +1285,6 @@ impl RunnerClient for Runner { } fn seed(&self, seed: u64) { - B::seed(seed) + B::seed(seed); } } diff --git a/crates/burn-tch/Cargo.toml b/crates/burn-tch/Cargo.toml index c3e140879e..25bfc50164 100644 --- a/crates/burn-tch/Cargo.toml +++ b/crates/burn-tch/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tch" version.workspace = true +[lints] +workspace = true + [features] default = ["std"] std = [] diff --git a/crates/burn-tch/build.rs b/crates/burn-tch/build.rs index b7293bad26..b6829af9ff 100644 --- a/crates/burn-tch/build.rs +++ b/crates/burn-tch/build.rs @@ -80,10 +80,10 @@ impl SystemInfo { _ => {} } if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") { - libtorch_include_dirs.push(PathBuf::from(path)) + libtorch_include_dirs.push(PathBuf::from(path)); } if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") { - libtorch_lib_dir = Some(PathBuf::from(path)) + libtorch_lib_dir = Some(PathBuf::from(path)); } } match cxx11_abi { @@ -92,12 +92,10 @@ impl SystemInfo { } } else { let libtorch = Self::prepare_libtorch_dir(os)?; - let includes = env_var_rerun("LIBTORCH_INCLUDE") - .map(PathBuf::from) - .unwrap_or_else(|_| libtorch.clone()); - let lib = env_var_rerun("LIBTORCH_LIB") - .map(PathBuf::from) - .unwrap_or_else(|_| libtorch.clone()); + let includes = + env_var_rerun("LIBTORCH_INCLUDE").map_or_else(|_| libtorch.clone(), PathBuf::from); + let lib = + env_var_rerun("LIBTORCH_LIB").map_or_else(|_| libtorch.clone(), PathBuf::from); libtorch_include_dirs.push(includes.join("include")); libtorch_include_dirs.push(includes.join("include/torch/csrc/api/include")); if lib.ends_with("lib") { @@ -144,7 +142,7 @@ impl SystemInfo { } else { "src/cuda_hack/fake_cuda_dependency.cpp" }; - println!("cargo:rerun-if-changed={}", cuda_dependency); + println!("cargo:rerun-if-changed={cuda_dependency}"); match self.os { Os::Linux | Os::Macos => { @@ -169,12 +167,12 @@ impl SystemInfo { .files(&[cuda_dependency]) .compile("burn-tch"); } - }; + } } fn make_cpu() { let cuda_dependency = "src/cuda_hack/fake_cuda_dependency.cpp"; - println!("cargo:rerun-if-changed={}", cuda_dependency); + println!("cargo:rerun-if-changed={cuda_dependency}"); let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); @@ -197,7 +195,7 @@ impl SystemInfo { .files(&[cuda_dependency]) .compile("tch"); } - }; + } } } @@ -229,14 +227,14 @@ fn main() { if gpu_found { fs::write(check_file, "#[allow(clippy::no_effect)]\n()").unwrap(); } else { - let message = if !found_dir { - r#"Could not find libtorch dir. + let message = if found_dir { + "No libtorch_cuda or libtorch_hip found. Download the GPU version of libtorch to use a GPU device" + } else { + r"Could not find libtorch dir. If you are trying to use the automatically downloaded version, the path is not directly available on Windows. Instead, try setting the `LIBTORCH` environment variable for the manual download instructions. - If the library has already been downloaded in the torch-sys OUT_DIR, you can point the variable to this path (or move the downloaded lib and point to it)."# - } else { - "No libtorch_cuda or libtorch_hip found. Download the GPU version of libtorch to use a GPU device" + If the library has already been downloaded in the torch-sys OUT_DIR, you can point the variable to this path (or move the downloaded lib and point to it)." }; fs::write(check_file, format!("panic!(\"{message}\")")).unwrap(); } diff --git a/crates/burn-tch/src/ops/activation.rs b/crates/burn-tch/src/ops/activation.rs index 4293882831..b3b486f3fa 100644 --- a/crates/burn-tch/src/ops/activation.rs +++ b/crates/burn-tch/src/ops/activation.rs @@ -3,7 +3,7 @@ use burn_tensor::ops::ActivationOps; impl ActivationOps for LibTorch { fn relu(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu()) + tensor.unary_ops(|mut tensor| tensor.relu_(), tch::Tensor::relu) } fn gelu(tensor: TchTensor) -> TchTensor { @@ -21,7 +21,7 @@ impl ActivationOps for LibTorch { } fn sigmoid(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid()) + tensor.unary_ops(|mut tensor| tensor.sigmoid_(), tch::Tensor::sigmoid) } fn log_sigmoid(tensor: TchTensor) -> TchTensor { diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index 82a7605e94..4bdcbb68d8 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -51,7 +51,7 @@ impl TchOps { let tch_shape = TchShape::from(tensor.shape()); // Copy the input tensor if we can't mutate it. - let tensor_original: TchTensor = tensor.unary_ops(|tensor| tensor, |tensor| tensor.copy()); + let tensor_original: TchTensor = tensor.unary_ops(|tensor| tensor, tch::Tensor::copy); let tensor_original = tensor_original.tensor; let mut tensor = tensor_original.view_(tch_shape.dims); @@ -124,7 +124,7 @@ impl TchOps { rhs, |lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.eq_tensor(rhs), + tch::Tensor::eq_tensor, ) } @@ -141,7 +141,7 @@ impl TchOps { rhs, |lhs, rhs| lhs.greater_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.less_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.greater_tensor(rhs), + tch::Tensor::greater_tensor, ) } @@ -158,7 +158,7 @@ impl TchOps { rhs, |lhs, rhs| lhs.greater_equal_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.less_equal_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.greater_equal_tensor(rhs), + tch::Tensor::greater_equal_tensor, ) } @@ -179,7 +179,7 @@ impl TchOps { rhs, |lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.less_tensor(rhs), + tch::Tensor::less_tensor, ) } @@ -196,7 +196,7 @@ impl TchOps { rhs, |lhs, rhs| lhs.less_equal_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.greater_equal_tensor_(lhs).to_kind(tch::Kind::Bool), - |lhs, rhs| lhs.less_equal_tensor(rhs), + tch::Tensor::less_equal_tensor, ) } @@ -410,7 +410,7 @@ impl TchOps { } pub fn sign(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign()) + tensor.unary_ops(|mut tensor| tensor.sign_(), tch::Tensor::sign) } pub fn expand(tensor: TchTensor, shape: Shape) -> TchTensor { diff --git a/crates/burn-tch/src/ops/bool_tensor.rs b/crates/burn-tch/src/ops/bool_tensor.rs index 0182f65253..3a20a24cb3 100644 --- a/crates/burn-tch/src/ops/bool_tensor.rs +++ b/crates/burn-tch/src/ops/bool_tensor.rs @@ -74,9 +74,9 @@ impl BoolTensorOps for LibTorch { TchTensor::binary_ops_tensor( lhs, rhs, - |lhs, rhs| lhs.logical_and_(rhs), + tch::Tensor::logical_and_, |lhs, rhs| rhs.logical_and_(lhs), - |lhs, rhs| lhs.logical_and(rhs), + tch::Tensor::logical_and, ) } @@ -84,9 +84,9 @@ impl BoolTensorOps for LibTorch { TchTensor::binary_ops_tensor( lhs, rhs, - |lhs, rhs| lhs.logical_or_(rhs), + tch::Tensor::logical_or_, |lhs, rhs| rhs.logical_or_(lhs), - |lhs, rhs| lhs.logical_or(rhs), + tch::Tensor::logical_or, ) } diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 303fbcceb1..591d3b0524 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -328,7 +328,7 @@ impl IntTensorOps for LibTorch { } fn int_abs(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) + tensor.unary_ops(|mut tensor| tensor.abs_(), tch::Tensor::abs) } fn int_into_float(tensor: TchTensor) -> TchTensor { diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index 0ab85c94ef..61a9f8e227 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -332,58 +332,58 @@ impl FloatTensorOps for LibTorch { } fn float_exp(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp()) + tensor.unary_ops(|mut tensor| tensor.exp_(), tch::Tensor::exp) } fn float_log(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log()) + tensor.unary_ops(|mut tensor| tensor.log_(), tch::Tensor::log) } fn float_log1p(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p()) + tensor.unary_ops(|mut tensor| tensor.log1p_(), tch::Tensor::log1p) } fn float_powf_scalar(tensor: TchTensor, value: f32) -> TchTensor { tensor.unary_ops( - |mut tensor| tensor.f_pow_(value as f64).unwrap(), - |tensor| tensor.pow_tensor_scalar(value as f64), + |mut tensor| tensor.f_pow_(f64::from(value)).unwrap(), + |tensor| tensor.pow_tensor_scalar(f64::from(value)), ) } fn float_sqrt(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt()) + tensor.unary_ops(|mut tensor| tensor.sqrt_(), tch::Tensor::sqrt) } fn float_abs(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) + tensor.unary_ops(|mut tensor| tensor.abs_(), tch::Tensor::abs) } fn float_cos(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos()) + tensor.unary_ops(|mut tensor| tensor.cos_(), tch::Tensor::cos) } fn float_sin(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin()) + tensor.unary_ops(|mut tensor| tensor.sin_(), tch::Tensor::sin) } fn float_tanh(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh()) + tensor.unary_ops(|mut tensor| tensor.tanh_(), tch::Tensor::tanh) } fn float_round(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round()) + tensor.unary_ops(|mut tensor| tensor.round_(), tch::Tensor::round) } fn float_floor(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor()) + tensor.unary_ops(|mut tensor| tensor.floor_(), tch::Tensor::floor) } fn float_ceil(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil()) + tensor.unary_ops(|mut tensor| tensor.ceil_(), tch::Tensor::ceil) } fn float_erf(tensor: TchTensor) -> TchTensor { - tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf()) + tensor.unary_ops(|mut tensor| tensor.erf_(), tch::Tensor::erf) } fn float_cat(tensors: Vec, dim: usize) -> TchTensor { diff --git a/crates/burn-tch/src/tensor.rs b/crates/burn-tch/src/tensor.rs index 7cfa359d1e..6cfe3f6f71 100644 --- a/crates/burn-tch/src/tensor.rs +++ b/crates/burn-tch/src/tensor.rs @@ -34,6 +34,7 @@ pub enum Storage { impl Storage { /// Check if the storage can be used inplace. + #[must_use] pub fn can_mut(&self) -> bool { match self { Storage::View { @@ -47,6 +48,7 @@ impl Storage { } /// Get the whole buffer reference. + #[must_use] pub fn buffer_ref(&self) -> &StorageRef { match self { Storage::View { @@ -98,8 +100,9 @@ impl TchTensor { /// Create a new tensor. /// /// Note that if the tensor was created from an operation that may reuse the same tensor - /// storage as the parent, you should use [from_existing](TchTensor::from_existing) + /// storage as the parent, you should use [`from_existing`](TchTensor::from_existing) /// instead. + #[must_use] pub fn new(tensor: tch::Tensor) -> Self { #[allow(clippy::arc_with_non_send_sync)] let storage = Storage::Owned { @@ -113,6 +116,7 @@ impl TchTensor { /// /// If the child tensor shared the same storage as its parent, it will be cloned, effectively /// tracking how much tensors point to the same memory space. + #[must_use] pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self { let storage_child = tensor.data_ptr(); let mut is_a_new_tensor = true; @@ -133,20 +137,22 @@ impl TchTensor { is_a_new_tensor = false; } } - }; + } - let storage = match is_a_new_tensor { - true => Storage::Owned { + let storage = if is_a_new_tensor { + Storage::Owned { #[allow(clippy::arc_with_non_send_sync)] buffer_ref: Arc::new(storage_child), - }, - false => storage_parent.clone(), + } + } else { + storage_parent.clone() }; Self { tensor, storage } } /// Create a tensor that uses a part of its parent tensor such as slice and narrow. + #[must_use] pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self { let storage = Storage::View { buffer_ref: storage_parent.buffer_ref().clone(), @@ -168,6 +174,7 @@ impl TchTensor { /// /// Returns `true` if the tensor's stride does not contain zero (no broadcasting) /// and the storage can be mutated. + #[must_use] pub fn can_mut(&self) -> bool { let stride_contains_zero = self.tensor.stride().contains(&0); @@ -256,7 +263,7 @@ impl Clone for TchTensor { } } -/// A shape that can be used by LibTorch. +/// A shape that can be used by `LibTorch`. #[derive(Debug)] pub struct TchShape { /// The shape's dimensions. @@ -290,6 +297,7 @@ impl TchTensor { /// # Returns /// /// A new tensor. + #[must_use] pub fn from_data(data: TensorData, device: tch::Device) -> Self { let shape_tch = TchShape::from(data.shape.as_slice()); let tensor = tch::Tensor::from_slice(data.as_slice::().unwrap()).to(device); @@ -310,6 +318,7 @@ impl TchTensor { /// # Returns /// /// A new empty tensor. + #[must_use] pub fn empty(shape: Shape, device: LibTorchDevice) -> Self { let shape_tch = TchShape::from(shape); let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into())); @@ -329,6 +338,7 @@ pub struct TchQTensor { impl TchQTensor { /// Returns the quantization strategy, including quantization parameters, for the given tensor. + #[must_use] pub fn strategy(&self) -> QuantizationStrategy { match &self.scheme { QuantScheme { diff --git a/crates/burn-tensor-testgen/Cargo.toml b/crates/burn-tensor-testgen/Cargo.toml index 189bab58fd..f784ecfb0c 100644 --- a/crates/burn-tensor-testgen/Cargo.toml +++ b/crates/burn-tensor-testgen/Cargo.toml @@ -8,6 +8,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor-testgen" version.workspace = true +[lints] +workspace = true + [lib] proc-macro = true diff --git a/crates/burn-tensor-testgen/src/lib.rs b/crates/burn-tensor-testgen/src/lib.rs index 2d8fd4c17a..e4534fefe3 100644 --- a/crates/burn-tensor-testgen/src/lib.rs +++ b/crates/burn-tensor-testgen/src/lib.rs @@ -49,7 +49,7 @@ pub fn might_panic(args: TokenStream, input: TokenStream) -> TokenStream { // Extract the expected panic reason let mut expected_reason = None; - for arg in args.args.iter() { + for arg in &args.args { if let Meta::NameValue(MetaNameValue { path, value, .. }) = arg { if path.is_ident("reason") { if let Expr::Lit(lit) = value { diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index bca7aba1d7..04f5327282 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor" version.workspace = true +[lints] +workspace = true + [features] cubecl = ["dep:cubecl"] cubecl-cuda = ["cubecl", "cubecl/cuda"] diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index d46c0091ea..175912aeee 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-train" documentation = "https://docs.rs/burn-train" version.workspace = true +[lints] +workspace = true + [features] default = ["sys-metrics", "tui"] doc = ["default"] diff --git a/crates/burn-train/src/checkpoint/async_checkpoint.rs b/crates/burn-train/src/checkpoint/async_checkpoint.rs index dbd17a29eb..6fb1f2a133 100644 --- a/crates/burn-train/src/checkpoint/async_checkpoint.rs +++ b/crates/burn-train/src/checkpoint/async_checkpoint.rs @@ -26,7 +26,7 @@ where B: Backend, { fn run(self) { - for item in self.receiver.iter() { + for item in &self.receiver { match item { Message::Restore(epoch, device, callback) => { let record = self.checkpointer.restore(epoch, &device); @@ -45,7 +45,7 @@ where Message::End => { return; } - }; + } } } } @@ -104,7 +104,7 @@ where if let Ok(record) = receiver.recv() { return record; - }; + } Err(CheckpointerError::Unknown("Channel error.".to_string())) } diff --git a/crates/burn-train/src/checkpoint/file.rs b/crates/burn-train/src/checkpoint/file.rs index be3caa9e21..acd732727d 100644 --- a/crates/burn-train/src/checkpoint/file.rs +++ b/crates/burn-train/src/checkpoint/file.rs @@ -77,7 +77,7 @@ where ); if std::path::Path::new(&file_to_remove).exists() { - log::info!("Removing checkpoint {}", file_to_remove); + log::info!("Removing checkpoint {file_to_remove}"); std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?; } diff --git a/crates/burn-train/src/checkpoint/strategy/composed.rs b/crates/burn-train/src/checkpoint/strategy/composed.rs index 8029c9ed78..e1fe323dfa 100644 --- a/crates/burn-train/src/checkpoint/strategy/composed.rs +++ b/crates/burn-train/src/checkpoint/strategy/composed.rs @@ -28,6 +28,7 @@ impl ComposedCheckpointingStrategyBuilder { } /// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy). + #[must_use] pub fn build(self) -> ComposedCheckpointingStrategy { ComposedCheckpointingStrategy::new(self.strategies) } @@ -42,6 +43,7 @@ impl ComposedCheckpointingStrategy { } /// Create a new builder which help compose multiple /// [checkpointing strategies](CheckpointingStrategy). + #[must_use] pub fn builder() -> ComposedCheckpointingStrategyBuilder { ComposedCheckpointingStrategyBuilder::default() } @@ -86,7 +88,7 @@ impl CheckpointingStrategy for ComposedCheckpointingStrategy { actions.push(CheckpointingAction::Save); } - for epoch in epochs_to_check.into_iter() { + for epoch in epochs_to_check { let mut num_true = 0; for i in 0..self.strategies.len() { if self diff --git a/crates/burn-train/src/learner/application_logger.rs b/crates/burn-train/src/learner/application_logger.rs index b2772998d7..6d13dc1e5d 100644 --- a/crates/burn-train/src/learner/application_logger.rs +++ b/crates/burn-train/src/learner/application_logger.rs @@ -52,10 +52,10 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller { } let hook = std::panic::take_hook(); - let file_path = self.path.to_owned(); + let file_path = self.path.clone(); std::panic::set_hook(Box::new(move |info| { - log::error!("PANIC => {}", info); + log::error!("PANIC => {info}"); eprintln!( "=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \ '{}'\n=============", diff --git a/crates/burn-train/src/learner/base.rs b/crates/burn-train/src/learner/base.rs index ee9b8ce58a..9b74520133 100644 --- a/crates/burn-train/src/learner/base.rs +++ b/crates/burn-train/src/learner/base.rs @@ -115,6 +115,7 @@ pub struct TrainingInterrupter { impl TrainingInterrupter { /// Create a new instance. + #[must_use] pub fn new() -> Self { Self::default() } @@ -124,7 +125,8 @@ impl TrainingInterrupter { self.state.store(true, Ordering::Relaxed); } - /// True if .stop() has been called. + /// True if .`stop()` has been called. + #[must_use] pub fn should_stop(&self) -> bool { self.state.load(Ordering::Relaxed) } diff --git a/crates/burn-train/src/learner/builder.rs b/crates/burn-train/src/learner/builder.rs index eca758d68f..32205e8a1d 100644 --- a/crates/burn-train/src/learner/builder.rs +++ b/crates/burn-train/src/learner/builder.rs @@ -127,7 +127,7 @@ where self } - /// Update the checkpointing_strategy. + /// Update the `checkpointing_strategy`. pub fn with_checkpointing_strategy(mut self, strategy: CS) -> Self where CS: CheckpointingStrategy + 'static, @@ -177,6 +177,7 @@ where /// /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation` /// amount. + #[must_use] pub fn grads_accumulation(mut self, accumulation: usize) -> Self { self.grad_accumulation = Some(accumulation); self @@ -207,24 +208,28 @@ where } /// The number of epochs the training should last. + #[must_use] pub fn num_epochs(mut self, num_epochs: usize) -> Self { self.num_epochs = num_epochs; self } /// Run the training loop on multiple devices. + #[must_use] pub fn devices(mut self, devices: Vec) -> Self { self.devices = devices; self } /// The epoch from which the training must resume. + #[must_use] pub fn checkpoint(mut self, checkpoint: usize) -> Self { self.checkpoint = Some(checkpoint); self } /// Provides a handle that can be used to interrupt training. + #[must_use] pub fn interrupter(&self) -> TrainingInterrupter { self.interrupter.clone() } @@ -242,6 +247,7 @@ where /// By default, Rust logs are captured and written into /// `experiment.log`. If disabled, standard Rust log handling /// will apply. + #[must_use] pub fn with_application_logger( mut self, logger: Option>, @@ -279,6 +285,7 @@ where /// Enable the training summary report. /// /// The summary will be displayed at the end of `.fit()`. + #[must_use] pub fn summary(mut self) -> Self { self.summary = true; self @@ -314,7 +321,7 @@ where { if self.tracing_logger.is_some() { if let Err(e) = self.tracing_logger.as_ref().unwrap().install() { - log::warn!("Failed to install the experiment logger: {}", e); + log::warn!("Failed to install the experiment logger: {e}"); } } let renderer = self diff --git a/crates/burn-train/src/learner/early_stopping.rs b/crates/burn-train/src/learner/early_stopping.rs index 6562206a53..e7280d0960 100644 --- a/crates/burn-train/src/learner/early_stopping.rs +++ b/crates/burn-train/src/learner/early_stopping.rs @@ -32,14 +32,14 @@ pub struct MetricEarlyStoppingStrategy { impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy { fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool { - let current_value = - match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) { - Some(value) => value, - None => { - log::warn!("Can't find metric for early stopping."); - return false; - } - }; + let current_value = if let Some(value) = + store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) + { + value + } else { + log::warn!("Can't find metric for early stopping."); + return false; + }; let is_best = match self.direction { Direction::Lowest => current_value < self.best_value, @@ -205,7 +205,7 @@ mod tests { let mut epoch = 1; for (points, should_start, comment) in data { - for point in points.iter() { + for point in *points { process_train(&mut processor, *point, epoch); } end_epoch(&mut processor, epoch); diff --git a/crates/burn-train/src/learner/epoch.rs b/crates/burn-train/src/learner/epoch.rs index 831e5fc75d..2e8ca21e95 100644 --- a/crates/burn-train/src/learner/epoch.rs +++ b/crates/burn-train/src/learner/epoch.rs @@ -111,7 +111,7 @@ impl TrainEpoch { while let Some(item) = iterator.next() { iteration += 1; let lr = scheduler.step(); - log::info!("Iteration {}", iteration); + log::info!("Iteration {iteration}"); let progress = iterator.progress(); let item = model.step(item); diff --git a/crates/burn-train/src/learner/step/train.rs b/crates/burn-train/src/learner/step/train.rs index c6f1ba09bc..a22d5248e4 100644 --- a/crates/burn-train/src/learner/step/train.rs +++ b/crates/burn-train/src/learner/step/train.rs @@ -56,7 +56,7 @@ where sender_output.send(output).unwrap(); } Err(_err) => { - log::info!("Closing thread on device {:?}", device); + log::info!("Closing thread on device {device:?}"); break; } } @@ -80,7 +80,7 @@ where /// /// # Returns /// - /// MultiDevicesTrainStep instance. + /// `MultiDevicesTrainStep` instance. pub fn new(devices: &[B::Device]) -> Self where TI: Send + 'static, diff --git a/crates/burn-train/src/learner/summary.rs b/crates/burn-train/src/learner/summary.rs index 723daa632a..cdc8042a4b 100644 --- a/crates/burn-train/src/learner/summary.rs +++ b/crates/burn-train/src/learner/summary.rs @@ -141,10 +141,10 @@ impl Display for LearnerSummary { let split_valid = "Valid"; let max_split_len = "Split".len().max(split_train.len()).max(split_valid.len()); let mut max_metric_len = "Metric".len(); - for metric in self.metrics.train.iter() { + for metric in &self.metrics.train { max_metric_len = max_metric_len.max(metric.name.len()); } - for metric in self.metrics.valid.iter() { + for metric in &self.metrics.valid { max_metric_len = max_metric_len.max(metric.name.len()); } @@ -187,15 +187,15 @@ impl Display for LearnerSummary { fn fmt_val(val: f64) -> String { if val < 1e-2 { // Use scientific notation for small values which would otherwise be truncated - format!("{:<9.3e}", val) + format!("{val:<9.3e}") } else { - format!("{:<9.3}", val) + format!("{val:<9.3}") } } let mut write_metrics_summary = |metrics: &[MetricSummary], split: &str| -> std::fmt::Result { - for metric in metrics.iter() { + for metric in metrics { if metric.entries.is_empty() { continue; // skip metrics with no recorded values } diff --git a/crates/burn-train/src/learner/train_val.rs b/crates/burn-train/src/learner/train_val.rs index d29a2690b3..9f98e19dae 100644 --- a/crates/burn-train/src/learner/train_val.rs +++ b/crates/burn-train/src/learner/train_val.rs @@ -50,7 +50,7 @@ impl TrainOutput { /// # Notes /// /// To be used with the [Learner](Learner) struct, the struct which implements this trait must -/// also implement the [AutodiffModule] trait, which is done automatically with the +/// also implement the [`AutodiffModule`] trait, which is done automatically with the /// [Module](burn_core::module::Module) derive. pub trait TrainStep { /// Runs the training step, which executes the forward and backward passes. @@ -160,7 +160,7 @@ impl Learner { self.grad_accumulation, ); - for epoch in starting_epoch..self.num_epochs + 1 { + for epoch in starting_epoch..=self.num_epochs { if self.devices.len() > 1 { (self.model, self.optim) = epoch_train.run_multi_device::( self.model, @@ -169,7 +169,7 @@ impl Learner { &mut self.event_processor, self.devices.clone(), &self.interrupter, - ) + ); } else { (self.model, self.optim) = epoch_train.run::( self.model, @@ -216,7 +216,7 @@ impl Learner { if let Some(summary) = self.summary { match summary.init() { Ok(summary) => { - println!("{}", summary.with_model(self.model.to_string())) + println!("{}", summary.with_model(self.model.to_string())); } Err(err) => log::error!("Could not retrieve learner summary:\n{err}"), } diff --git a/crates/burn-train/src/logger/async_logger.rs b/crates/burn-train/src/logger/async_logger.rs index c659098b1e..fa6128b447 100644 --- a/crates/burn-train/src/logger/async_logger.rs +++ b/crates/burn-train/src/logger/async_logger.rs @@ -23,7 +23,7 @@ where L: Logger, { fn run(mut self) { - for item in self.receiver.iter() { + for item in &self.receiver { match item { Message::Log(item) => { self.logger.log(item); diff --git a/crates/burn-train/src/logger/metric.rs b/crates/burn-train/src/logger/metric.rs index 9424e90c16..59ea4edd4f 100644 --- a/crates/burn-train/src/logger/metric.rs +++ b/crates/burn-train/src/logger/metric.rs @@ -81,7 +81,7 @@ impl FileMetricLogger { } fn epoch_directory(&self, epoch: usize) -> PathBuf { - let name = format!("{}{}", EPOCH_PREFIX, epoch); + let name = format!("{EPOCH_PREFIX}{epoch}"); self.directory.join(name) } @@ -103,20 +103,19 @@ impl MetricLogger for FileMetricLogger { let key = &item.name; let value = &item.serialize; - let logger = match self.loggers.get_mut(key) { - Some(val) => val, - None => { - self.create_directory(self.epoch); + let logger = if let Some(val) = self.loggers.get_mut(key) { + val + } else { + self.create_directory(self.epoch); - let file_path = self.file_path(key, self.epoch); - let logger = FileLogger::new(file_path); - let logger = AsyncLogger::new(logger); + let file_path = self.file_path(key, self.epoch); + let logger = FileLogger::new(file_path); + let logger = AsyncLogger::new(logger); - self.loggers.insert(key.clone(), logger); - self.loggers - .get_mut(key) - .expect("Can get the previously saved logger.") - } + self.loggers.insert(key.clone(), logger); + self.loggers + .get_mut(key) + .expect("Can get the previously saved logger.") }; logger.log(value.clone()); @@ -129,7 +128,7 @@ impl MetricLogger for FileMetricLogger { fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { if let Some(value) = self.loggers.get(name) { - value.sync() + value.sync(); } let file_path = self.file_path(name, epoch); @@ -171,6 +170,7 @@ pub struct InMemoryMetricLogger { impl InMemoryMetricLogger { /// Create a new in-memory metric logger. + #[must_use] pub fn new() -> Self { Self::default() } @@ -188,7 +188,7 @@ impl MetricLogger for InMemoryMetricLogger { } fn end_epoch(&mut self, _epoch: usize) { - for (_, values) in self.values.iter_mut() { + for values in self.values.values_mut() { values.push(InMemoryLogger::default()); } } diff --git a/crates/burn-train/src/metric/acc.rs b/crates/burn-train/src/metric/acc.rs index 229a2b0351..81fa8e4d28 100644 --- a/crates/burn-train/src/metric/acc.rs +++ b/crates/burn-train/src/metric/acc.rs @@ -23,11 +23,13 @@ pub struct AccuracyInput { impl AccuracyMetric { /// Creates the metric. + #[must_use] pub fn new() -> Self { Self::default() } /// Sets the pad token. + #[must_use] pub fn with_pad_token(mut self, index: usize) -> Self { self.pad_token = Some(index); self @@ -74,7 +76,7 @@ impl Metric for AccuracyMetric { } fn clear(&mut self) { - self.state.reset() + self.state.reset(); } fn name(&self) -> String { diff --git a/crates/burn-train/src/metric/auroc.rs b/crates/burn-train/src/metric/auroc.rs index c4d9c91a51..341dbf60a3 100644 --- a/crates/burn-train/src/metric/auroc.rs +++ b/crates/burn-train/src/metric/auroc.rs @@ -23,6 +23,7 @@ pub struct AurocInput { impl AurocMetric { /// Creates the metric. + #[must_use] pub fn new() -> Self { Self::default() } @@ -35,9 +36,9 @@ impl AurocMetric { // Early return if we don't have both positive and negative samples if n_pos == 0 || n_pos == n { if n_pos == 0 { - log::warn!("Metric cannot be computed because all target values are negative.") + log::warn!("Metric cannot be computed because all target values are negative."); } else { - log::warn!("Metric cannot be computed because all target values are positive.") + log::warn!("Metric cannot be computed because all target values are positive."); } return 0.0; } @@ -95,7 +96,7 @@ impl Metric for AurocMetric { } fn clear(&mut self) { - self.state.reset() + self.state.reset(); } fn name(&self) -> String { diff --git a/crates/burn-train/src/metric/base.rs b/crates/burn-train/src/metric/base.rs index 1ba82159c4..8f4c5d9075 100644 --- a/crates/burn-train/src/metric/base.rs +++ b/crates/burn-train/src/metric/base.rs @@ -21,6 +21,7 @@ pub struct MetricMetadata { impl MetricMetadata { /// Fake metric metadata #[cfg(test)] + #[must_use] pub fn fake() -> Self { Self { progress: Progress { @@ -132,11 +133,13 @@ impl NumericEntry { } /// Format a float with the given precision. Will use scientific notation if necessary. +#[must_use] pub fn format_float(float: f64, precision: usize) -> String { let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0); - match scientific_notation_threshold >= float { - true => format!("{float:.precision$e}"), - false => format!("{float:.precision$}"), + if scientific_notation_threshold >= float { + format!("{float:.precision$e}") + } else { + format!("{float:.precision$}") } } diff --git a/crates/burn-train/src/metric/cpu_temp.rs b/crates/burn-train/src/metric/cpu_temp.rs index e390b2e4f4..5191afb440 100644 --- a/crates/burn-train/src/metric/cpu_temp.rs +++ b/crates/burn-train/src/metric/cpu_temp.rs @@ -11,6 +11,7 @@ pub struct CpuTemperature { impl CpuTemperature { /// Creates a new CPU temp metric + #[must_use] pub fn new() -> Self { Self { temp_celsius: 0., @@ -34,9 +35,10 @@ impl Metric for CpuTemperature { Err(_) => self.temp_celsius = f32::NAN, } - let formatted = match self.temp_celsius.is_nan() { - true => format!("{}: NaN °C", self.name()), - false => format!("{}: {:.2} °C", self.name(), self.temp_celsius), + let formatted = if self.temp_celsius.is_nan() { + format!("{}: NaN °C", self.name()) + } else { + format!("{}: {:.2} °C", self.name(), self.temp_celsius) }; let raw = format!("{:.2}", self.temp_celsius); @@ -52,6 +54,6 @@ impl Metric for CpuTemperature { impl Numeric for CpuTemperature { fn value(&self) -> f64 { - self.temp_celsius as f64 + f64::from(self.temp_celsius) } } diff --git a/crates/burn-train/src/metric/cpu_use.rs b/crates/burn-train/src/metric/cpu_use.rs index 088143fd83..8cb8852006 100644 --- a/crates/burn-train/src/metric/cpu_use.rs +++ b/crates/burn-train/src/metric/cpu_use.rs @@ -13,6 +13,7 @@ pub struct CpuUse { impl CpuUse { /// Creates a new CPU metric + #[must_use] pub fn new() -> Self { let mut sys = System::new(); let current = Self::refresh(&mut sys); @@ -32,7 +33,7 @@ impl CpuUse { let cpus = sys.cpus(); let num_cpus = cpus.len(); - let use_percentage = cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage()) as f64; + let use_percentage = f64::from(cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage())); use_percentage / num_cpus as f64 } diff --git a/crates/burn-train/src/metric/fbetascore.rs b/crates/burn-train/src/metric/fbetascore.rs index b9869a4220..8a3fb4e262 100644 --- a/crates/burn-train/src/metric/fbetascore.rs +++ b/crates/burn-train/src/metric/fbetascore.rs @@ -31,6 +31,7 @@ impl FBetaScoreMetric { /// * `beta` - Positive real factor to weight recall's importance. /// * `threshold` - The threshold to transform a probability into a binary prediction. #[allow(dead_code)] + #[must_use] pub fn binary(beta: f64, threshold: f64) -> Self { Self { config: ClassificationMetricConfig { @@ -51,6 +52,7 @@ impl FBetaScoreMetric { /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] + #[must_use] pub fn multiclass(beta: f64, top_k: usize, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { @@ -72,6 +74,7 @@ impl FBetaScoreMetric { /// * `threshold` - The threshold to transform a probability into a binary prediction. /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] + #[must_use] pub fn multilabel(beta: f64, threshold: f64, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { @@ -97,7 +100,7 @@ impl FBetaScoreMetric { let nan_mask = aggregated_metric.is_nan(); aggregated_metric = aggregated_metric .clone() - .select(0, nan_mask.bool_not().argwhere().squeeze(1)) + .select(0, nan_mask.bool_not().argwhere().squeeze(1)); } aggregated_metric.mean() } @@ -129,7 +132,7 @@ impl Metric for FBetaScoreMetric { } fn clear(&mut self) { - self.state.reset() + self.state.reset(); } fn name(&self) -> String { @@ -169,7 +172,7 @@ mod tests { let mut metric = FBetaScoreMetric::binary(beta, threshold); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) - .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); } #[rstest] @@ -191,7 +194,7 @@ mod tests { let mut metric = FBetaScoreMetric::multiclass(beta, top_k, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) - .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); } #[rstest] @@ -209,7 +212,7 @@ mod tests { let mut metric = FBetaScoreMetric::multilabel(beta, threshold, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) - .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); } #[test] diff --git a/crates/burn-train/src/metric/hamming.rs b/crates/burn-train/src/metric/hamming.rs index 79b38a2faf..2a17269eee 100644 --- a/crates/burn-train/src/metric/hamming.rs +++ b/crates/burn-train/src/metric/hamming.rs @@ -22,17 +22,20 @@ pub struct HammingScoreInput { impl HammingScore { /// Creates the metric. + #[must_use] pub fn new() -> Self { Self::default() } /// Sets the threshold. + #[must_use] pub fn with_threshold(mut self, threshold: f32) -> Self { self.threshold = threshold; self } /// Sets the sigmoid activation function usage. + #[must_use] pub fn with_sigmoid(mut self, sigmoid: bool) -> Self { self.sigmoid = sigmoid; self @@ -81,7 +84,7 @@ impl Metric for HammingScore { } fn clear(&mut self) { - self.state.reset() + self.state.reset(); } fn name(&self) -> String { diff --git a/crates/burn-train/src/metric/iteration.rs b/crates/burn-train/src/metric/iteration.rs index a07328b646..4bdb2c6f36 100644 --- a/crates/burn-train/src/metric/iteration.rs +++ b/crates/burn-train/src/metric/iteration.rs @@ -13,6 +13,7 @@ pub struct IterationSpeedMetric { impl IterationSpeedMetric { /// Create the metric. + #[must_use] pub fn new() -> Self { Self::default() } @@ -21,13 +22,12 @@ impl IterationSpeedMetric { impl Metric for IterationSpeedMetric { type Input = (); - fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> MetricEntry { - let raw = match self.instant { - Some(val) => metadata.iteration as f64 / val.elapsed().as_secs_f64(), - None => { - self.instant = Some(std::time::Instant::now()); - 0.0 - } + fn update(&mut self, (): &Self::Input, metadata: &MetricMetadata) -> MetricEntry { + let raw = if let Some(val) = self.instant { + metadata.iteration as f64 / val.elapsed().as_secs_f64() + } else { + self.instant = Some(std::time::Instant::now()); + 0.0 }; self.state.update( diff --git a/crates/burn-train/src/metric/learning_rate.rs b/crates/burn-train/src/metric/learning_rate.rs index 4cb8c8647b..49bd84ee1d 100644 --- a/crates/burn-train/src/metric/learning_rate.rs +++ b/crates/burn-train/src/metric/learning_rate.rs @@ -11,6 +11,7 @@ pub struct LearningRateMetric { impl LearningRateMetric { /// Creates a new learning rate metric. + #[must_use] pub fn new() -> Self { Self { state: NumericMetricState::new(), @@ -35,7 +36,7 @@ impl Metric for LearningRateMetric { } fn clear(&mut self) { - self.state.reset() + self.state.reset(); } fn name(&self) -> String { diff --git a/crates/burn-train/src/metric/loss.rs b/crates/burn-train/src/metric/loss.rs index b0d04fd96d..e3eb8099a7 100644 --- a/crates/burn-train/src/metric/loss.rs +++ b/crates/burn-train/src/metric/loss.rs @@ -21,6 +21,7 @@ pub struct LossInput { impl LossMetric { /// Create the metric. + #[must_use] pub fn new() -> Self { Self::default() } @@ -48,7 +49,7 @@ impl Metric for LossMetric { } fn clear(&mut self) { - self.state.reset() + self.state.reset(); } fn name(&self) -> String { diff --git a/crates/burn-train/src/metric/memory_use.rs b/crates/burn-train/src/metric/memory_use.rs index 217e69c28f..6b33f47655 100644 --- a/crates/burn-train/src/metric/memory_use.rs +++ b/crates/burn-train/src/metric/memory_use.rs @@ -15,6 +15,7 @@ pub struct CpuMemory { impl CpuMemory { /// Creates a new memory metric + #[must_use] pub fn new() -> Self { let mut metric = Self { last_refresh: Instant::now(), diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index 0aea5bab1a..1041ec2c86 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -26,6 +26,7 @@ impl PrecisionMetric { /// /// * `threshold` - The threshold to transform a probability into a binary prediction. #[allow(dead_code)] + #[must_use] pub fn binary(threshold: f64) -> Self { Self { config: ClassificationMetricConfig { @@ -44,6 +45,7 @@ impl PrecisionMetric { /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] + #[must_use] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { @@ -63,6 +65,7 @@ impl PrecisionMetric { /// * `threshold` - The threshold to transform a probability into a binary value. /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] + #[must_use] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { @@ -87,7 +90,7 @@ impl PrecisionMetric { let nan_mask = aggregated_metric.is_nan(); aggregated_metric = aggregated_metric .clone() - .select(0, nan_mask.bool_not().argwhere().squeeze(1)) + .select(0, nan_mask.bool_not().argwhere().squeeze(1)); } aggregated_metric.mean() } @@ -114,7 +117,7 @@ impl Metric for PrecisionMetric { } fn clear(&mut self) { - self.state.reset() + self.state.reset(); } fn name(&self) -> String { @@ -153,7 +156,7 @@ mod tests { let mut metric = PrecisionMetric::binary(threshold); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) - .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); } #[rstest] @@ -170,7 +173,7 @@ mod tests { let mut metric = PrecisionMetric::multiclass(top_k, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) - .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); } #[rstest] @@ -185,7 +188,7 @@ mod tests { let mut metric = PrecisionMetric::multilabel(threshold, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) - .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); } #[test] diff --git a/crates/burn-train/src/metric/processor/full.rs b/crates/burn-train/src/metric/processor/full.rs index 4ab4798275..34c8b1929a 100644 --- a/crates/burn-train/src/metric/processor/full.rs +++ b/crates/burn-train/src/metric/processor/full.rs @@ -52,7 +52,7 @@ impl EventProcessor for FullEventProcessor { .into_iter() .for_each(|(entry, value)| { self.renderer - .update_train(MetricState::Numeric(entry, value)) + .update_train(MetricState::Numeric(entry, value)); }); self.renderer.render_train(progress); @@ -90,7 +90,7 @@ impl EventProcessor for FullEventProcessor { .into_iter() .for_each(|(entry, value)| { self.renderer - .update_valid(MetricState::Numeric(entry, value)) + .update_valid(MetricState::Numeric(entry, value)); }); self.renderer.render_valid(progress); diff --git a/crates/burn-train/src/metric/processor/metrics.rs b/crates/burn-train/src/metric/processor/metrics.rs index 68b1c43b46..2bf84dd9af 100644 --- a/crates/burn-train/src/metric/processor/metrics.rs +++ b/crates/burn-train/src/metric/processor/metrics.rs @@ -29,7 +29,7 @@ impl Metrics { T::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); - self.train.push(Box::new(metric)) + self.train.push(Box::new(metric)); } /// Register a validation metric. @@ -38,7 +38,7 @@ impl Metrics { V::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); - self.valid.push(Box::new(metric)) + self.valid.push(Box::new(metric)); } /// Register a numeric training metric. @@ -49,7 +49,7 @@ impl Metrics { T::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); - self.train_numeric.push(Box::new(metric)) + self.train_numeric.push(Box::new(metric)); } /// Register a numeric validation metric. @@ -60,7 +60,7 @@ impl Metrics { V::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); - self.valid_numeric.push(Box::new(metric)) + self.valid_numeric.push(Box::new(metric)); } /// Update the training information from the training item. @@ -72,12 +72,12 @@ impl Metrics { let mut entries = Vec::with_capacity(self.train.len()); let mut entries_numeric = Vec::with_capacity(self.train_numeric.len()); - for metric in self.train.iter_mut() { + for metric in &mut self.train { let state = metric.update(item, metadata); entries.push(state); } - for metric in self.train_numeric.iter_mut() { + for metric in &mut self.train_numeric { let (state, value) = metric.update(item, metadata); entries_numeric.push((state, value)); } @@ -94,12 +94,12 @@ impl Metrics { let mut entries = Vec::with_capacity(self.valid.len()); let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len()); - for metric in self.valid.iter_mut() { + for metric in &mut self.valid { let state = metric.update(item, metadata); entries.push(state); } - for metric in self.valid_numeric.iter_mut() { + for metric in &mut self.valid_numeric { let (state, value) = metric.update(item, metadata); entries_numeric.push((state, value)); } @@ -109,20 +109,20 @@ impl Metrics { /// Signal the end of a training epoch. pub(crate) fn end_epoch_train(&mut self) { - for metric in self.train.iter_mut() { + for metric in &mut self.train { metric.clear(); } - for metric in self.train_numeric.iter_mut() { + for metric in &mut self.train_numeric { metric.clear(); } } /// Signal the end of a validation epoch. pub(crate) fn end_epoch_valid(&mut self) { - for metric in self.valid.iter_mut() { + for metric in &mut self.valid { metric.clear(); } - for metric in self.valid_numeric.iter_mut() { + for metric in &mut self.valid_numeric { metric.clear(); } } @@ -180,7 +180,7 @@ where } fn clear(&mut self) { - self.metric.clear() + self.metric.clear(); } } @@ -195,6 +195,6 @@ where } fn clear(&mut self) { - self.metric.clear() + self.metric.clear(); } } diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs index b698806b92..c21fa3379e 100644 --- a/crates/burn-train/src/metric/recall.rs +++ b/crates/burn-train/src/metric/recall.rs @@ -26,6 +26,7 @@ impl RecallMetric { /// /// * `threshold` - The threshold to transform a probability into a binary prediction. #[allow(dead_code)] + #[must_use] pub fn binary(threshold: f64) -> Self { Self { config: ClassificationMetricConfig { @@ -44,6 +45,7 @@ impl RecallMetric { /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] + #[must_use] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { @@ -63,6 +65,7 @@ impl RecallMetric { /// * `threshold` - The threshold to transform a probability into a binary prediction. /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] + #[must_use] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { @@ -87,7 +90,7 @@ impl RecallMetric { let nan_mask = aggregated_metric.is_nan(); aggregated_metric = aggregated_metric .clone() - .select(0, nan_mask.bool_not().argwhere().squeeze(1)) + .select(0, nan_mask.bool_not().argwhere().squeeze(1)); } aggregated_metric.mean() } @@ -113,7 +116,7 @@ impl Metric for RecallMetric { } fn clear(&mut self) { - self.state.reset() + self.state.reset(); } fn name(&self) -> String { @@ -151,7 +154,7 @@ mod tests { let mut metric = RecallMetric::binary(threshold); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) - .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); } #[rstest] @@ -168,7 +171,7 @@ mod tests { let mut metric = RecallMetric::multiclass(top_k, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) - .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); } #[rstest] @@ -183,7 +186,7 @@ mod tests { let mut metric = RecallMetric::multilabel(threshold, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) - .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); } #[test] diff --git a/crates/burn-train/src/metric/state.rs b/crates/burn-train/src/metric/state.rs index ae4ed9a6c1..92c1b4a711 100644 --- a/crates/burn-train/src/metric/state.rs +++ b/crates/burn-train/src/metric/state.rs @@ -30,12 +30,14 @@ impl FormatOptions { } /// Specify the metric unit. + #[must_use] pub fn unit(mut self, unit: &str) -> Self { self.unit = Some(unit.to_string()); self } /// Specify the floating point precision. + #[must_use] pub fn precision(mut self, precision: usize) -> Self { self.precision = Some(precision); self @@ -44,6 +46,7 @@ impl FormatOptions { impl NumericMetricState { /// Create a new [numeric metric state](NumericMetricState). + #[must_use] pub fn new() -> Self { Self { sum: 0.0, diff --git a/crates/burn-train/src/metric/store/aggregate.rs b/crates/burn-train/src/metric/store/aggregate.rs index c96fa2c144..b90c9aed7e 100644 --- a/crates/burn-train/src/metric/store/aggregate.rs +++ b/crates/burn-train/src/metric/store/aggregate.rs @@ -36,7 +36,7 @@ impl NumericMetricsAggregate { match logger.read_numeric(name, epoch) { Ok(points) => return Ok(points), Err(err) => errors.push(err), - }; + } } Err(errors.join(" ")) diff --git a/crates/burn-train/src/metric/store/client.rs b/crates/burn-train/src/metric/store/client.rs index 74ba83ab74..c84a9edd9a 100644 --- a/crates/burn-train/src/metric/store/client.rs +++ b/crates/burn-train/src/metric/store/client.rs @@ -40,6 +40,7 @@ impl EventStoreClient { } /// Find the epoch following the given criteria from the collected data. + #[must_use] pub fn find_epoch( &self, name: &str, @@ -60,11 +61,12 @@ impl EventStoreClient { match receiver.recv() { Ok(value) => value, - Err(err) => panic!("Event store thread crashed: {:?}", err), + Err(err) => panic!("Event store thread crashed: {err:?}"), } } /// Find the metric value for the current epoch following the given criteria. + #[must_use] pub fn find_metric( &self, name: &str, @@ -85,7 +87,7 @@ impl EventStoreClient { match receiver.recv() { Ok(value) => value, - Err(err) => panic!("Event store thread crashed: {:?}", err), + Err(err) => panic!("Event store thread crashed: {err:?}"), } } } @@ -101,7 +103,7 @@ where C: EventStore, { fn run(mut self) { - for item in self.receiver.iter() { + for item in &self.receiver { match item { Message::End => { return; diff --git a/crates/burn-train/src/metric/top_k_acc.rs b/crates/burn-train/src/metric/top_k_acc.rs index edaf294545..272b588f21 100644 --- a/crates/burn-train/src/metric/top_k_acc.rs +++ b/crates/burn-train/src/metric/top_k_acc.rs @@ -22,14 +22,15 @@ pub struct TopKAccuracyMetric { /// The [top-k accuracy metric](TopKAccuracyMetric) input type. #[derive(new)] pub struct TopKAccuracyInput { - /// The outputs (batch_size, num_classes) + /// The outputs (`batch_size`, `num_classes`) outputs: Tensor, - /// The labels (batch_size) + /// The labels (`batch_size`) targets: Tensor, } impl TopKAccuracyMetric { /// Creates the metric. + #[must_use] pub fn new(k: usize) -> Self { Self { k, @@ -38,6 +39,7 @@ impl TopKAccuracyMetric { } /// Sets the pad token. + #[must_use] pub fn with_pad_token(mut self, index: usize) -> Self { self.pad_token = Some(index); self @@ -88,7 +90,7 @@ impl Metric for TopKAccuracyMetric { } fn clear(&mut self) { - self.state.reset() + self.state.reset(); } fn name(&self) -> String { diff --git a/crates/burn-train/src/renderer/base.rs b/crates/burn-train/src/renderer/base.rs index 02270bc08e..c133b0f813 100644 --- a/crates/burn-train/src/renderer/base.rs +++ b/crates/burn-train/src/renderer/base.rs @@ -71,6 +71,7 @@ pub struct TrainingProgress { impl TrainingProgress { /// Creates a new empty training progress. + #[must_use] pub fn none() -> Self { Self { progress: Progress { diff --git a/crates/burn-train/src/renderer/cli.rs b/crates/burn-train/src/renderer/cli.rs index 18cf34c35a..efaaafeb70 100644 --- a/crates/burn-train/src/renderer/cli.rs +++ b/crates/burn-train/src/renderer/cli.rs @@ -17,10 +17,10 @@ impl MetricsRenderer for CliMetricsRenderer { fn update_valid(&mut self, _state: MetricState) {} fn render_train(&mut self, item: TrainingProgress) { - println!("{:?}", item); + println!("{item:?}"); } fn render_valid(&mut self, item: TrainingProgress) { - println!("{:?}", item); + println!("{item:?}"); } } diff --git a/crates/burn-train/src/renderer/tui/full_history.rs b/crates/burn-train/src/renderer/tui/full_history.rs index 3c2e4e90e7..a74465db2a 100644 --- a/crates/burn-train/src/renderer/tui/full_history.rs +++ b/crates/burn-train/src/renderer/tui/full_history.rs @@ -120,7 +120,7 @@ impl FullHistoryPoints { self.max_y = y; } if y < self.min_y { - self.min_y = y + self.min_y = y; } self.points.push((x, y)); @@ -192,10 +192,10 @@ mod tests { chart.update_max_sample_valid(0.6); for i in 0..100 { - chart.push_train(i as f64); + chart.push_train(f64::from(i)); } for i in 0..60 { - chart.push_valid(i as f64); + chart.push_valid(f64::from(i)); } let expected_train = vec![ diff --git a/crates/burn-train/src/renderer/tui/metric_numeric.rs b/crates/burn-train/src/renderer/tui/metric_numeric.rs index d0392c51e5..911aa4236c 100644 --- a/crates/burn-train/src/renderer/tui/metric_numeric.rs +++ b/crates/burn-train/src/renderer/tui/metric_numeric.rs @@ -87,7 +87,7 @@ impl NumericMetricsState { } if let Some(num_sample_train) = self.num_samples_train { - for (_, (_recent, full)) in self.data.iter_mut() { + for (_recent, full) in self.data.values_mut() { let ratio = progress.progress.items_total as f64 / num_sample_train as f64; full.update_max_sample_valid(ratio); } @@ -98,9 +98,10 @@ impl NumericMetricsState { /// Create a view to display the numeric metrics. pub(crate) fn view(&self) -> NumericMetricView<'_> { - match self.names.is_empty() { - true => NumericMetricView::None, - false => NumericMetricView::Plots(&self.names, self.selected, self.chart(), self.kind), + if self.names.is_empty() { + NumericMetricView::None + } else { + NumericMetricView::Plots(&self.names, self.selected, self.chart(), self.kind) } } @@ -164,13 +165,23 @@ impl NumericMetricsState { Axis::default() .style(Style::default().fg(Color::DarkGray)) .title("Iteration") - .labels(axes.labels_x.clone().into_iter().map(|s| s.bold())) + .labels( + axes.labels_x + .clone() + .into_iter() + .map(ratatui::prelude::Stylize::bold), + ) .bounds(axes.bounds_x), ) .y_axis( Axis::default() .style(Style::default().fg(Color::DarkGray)) - .labels(axes.labels_y.clone().into_iter().map(|s| s.bold())) + .labels( + axes.labels_y + .clone() + .into_iter() + .map(ratatui::prelude::Stylize::bold), + ) .bounds(axes.bounds_y), ) } @@ -230,6 +241,6 @@ impl NumericMetricView<'_> { frame.render_widget(chart, chunks[2]); } Self::None => {} - }; + } } } diff --git a/crates/burn-train/src/renderer/tui/popup.rs b/crates/burn-train/src/renderer/tui/popup.rs index 83b5dc84be..1f54bf0b0b 100644 --- a/crates/burn-train/src/renderer/tui/popup.rs +++ b/crates/burn-train/src/renderer/tui/popup.rs @@ -67,7 +67,7 @@ impl PopupState { } } } - }; + } if reset { *self = Self::Empty; diff --git a/crates/burn-train/src/renderer/tui/recent_history.rs b/crates/burn-train/src/renderer/tui/recent_history.rs index ac91d60888..1e592db525 100644 --- a/crates/burn-train/src/renderer/tui/recent_history.rs +++ b/crates/burn-train/src/renderer/tui/recent_history.rs @@ -121,7 +121,7 @@ impl RecentHistoryPoints { self.max_y = y; } if y < self.min_y { - self.min_y = y + self.min_y = y; } self.points.push((x, y)); } @@ -141,7 +141,7 @@ impl RecentHistoryPoints { } if *y == self.max_y { - update_y_max = true + update_y_max = true; } if *y == self.min_y { update_y_min = true; diff --git a/crates/burn-train/src/renderer/tui/renderer.rs b/crates/burn-train/src/renderer/tui/renderer.rs index 31602e988c..49c9c7d5eb 100644 --- a/crates/burn-train/src/renderer/tui/renderer.rs +++ b/crates/burn-train/src/renderer/tui/renderer.rs @@ -56,7 +56,7 @@ impl MetricsRenderer for TuiMetricsRenderer { self.metrics_numeric.push_train(entry.name.clone(), value); self.metrics_text.update_train(entry); } - }; + } } fn update_valid(&mut self, state: MetricState) { @@ -68,7 +68,7 @@ impl MetricsRenderer for TuiMetricsRenderer { self.metrics_numeric.push_valid(entry.name.clone(), value); self.metrics_text.update_valid(entry); } - }; + } } fn render_train(&mut self, item: TrainingProgress) { @@ -92,6 +92,7 @@ impl MetricsRenderer for TuiMetricsRenderer { impl TuiMetricsRenderer { /// Create a new terminal UI renderer. + #[must_use] pub fn new(interuptor: TrainingInterrupter, checkpoint: Option) -> Self { let mut stdout = io::stdout(); execute!(stdout, EnterAlternateScreen).unwrap(); @@ -125,6 +126,7 @@ impl TuiMetricsRenderer { } /// Set the renderer to persistent mode. + #[must_use] pub fn persistent(mut self) -> Self { self.persistent = true; self @@ -148,20 +150,19 @@ impl TuiMetricsRenderer { self.terminal.draw(|frame| { let size = frame.area(); - match self.popup.view() { - Some(view) => view.render(frame, size), - None => { - let view = MetricsView::new( - self.metrics_numeric.view(), - self.metrics_text.view(), - self.progress.view(), - ControlsView, - self.status.view(), - ); - - view.render(frame, size); - } - }; + if let Some(view) = self.popup.view() { + view.render(frame, size) + } else { + let view = MetricsView::new( + self.metrics_numeric.view(), + self.metrics_text.view(), + self.progress.view(), + ControlsView, + self.status.view(), + ); + + view.render(frame, size); + } })?; Ok(()) @@ -245,7 +246,7 @@ impl TuiMetricsRenderer { self.draw().ok(); } Err(err) => { - eprintln!("Error reading event: {}", err); + eprintln!("Error reading event: {err}"); break; } _ => continue, @@ -261,7 +262,7 @@ impl TuiMetricsRenderer { if self.previous_panic_hook.is_some() { if self.persistent { if let Err(err) = self.handle_post_training() { - eprintln!("Error in post-training handling: {}", err); + eprintln!("Error in post-training handling: {err}"); } } diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index 49a9982203..a6e8f2f705 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -14,6 +14,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-vision" version.workspace = true +[lints] +workspace = true + [features] candle = ["burn-candle"] diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index 2ee105aaca..f0101e8d02 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-wgpu" version.workspace = true +[lints] +workspace = true + [features] autotune = ["burn-cubecl/autotune"] autotune-checks = ["burn-cubecl/autotune-checks"] diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 5ca0f12354..0cbde7ce48 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -55,7 +55,7 @@ pub use cubecl::wgpu::vulkan::VkSpirvCompiler; /// /// # Notes /// -/// This version of the wgpu backend uses [burn_fusion] to compile and optimize streams of tensor +/// This version of the wgpu backend uses [`burn_fusion`] to compile and optimize streams of tensor /// operations for improved performance. /// /// You can disable the `fusion` feature flag to remove that functionality, which might be diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 13cf42a182..d27ddca6bc 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -12,6 +12,9 @@ repository = "https://github.com/tracel-ai/burn" rust-version = "1.85" version.workspace = true +[lints] +workspace = true + [features] default = [ "burn-core/default", diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index ad7e41ec37..e8d9ded151 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -26,7 +26,7 @@ //! with an ergonomic dashboard, and run inference everywhere from embedded devices to large GPU clusters. //! //! Burn was built from the ground up with training and inference in mind. It's also worth noting how Burn, -//! in comparison to frameworks like PyTorch, simplifies the transition from training to deployment, +//! in comparison to frameworks like `PyTorch`, simplifies the transition from training to deployment, //! eliminating the need for code changes. //! //! ## Backends @@ -42,8 +42,8 @@ //! //! - WGPU (WebGPU): Cross-Platform GPU Backend //! - Candle: Backend using the Candle bindings -//! - LibTorch: Backend using the LibTorch bindings -//! - NdArray: Backend using the NdArray primitive as data structure +//! - `LibTorch`: Backend using the `LibTorch` bindings +//! - `NdArray`: Backend using the `NdArray` primitive as data structure //! - Autodiff: Backend decorator that brings backpropagation to any backend //! - Fusion: Backend decorator that brings kernel fusion to backends that support it //! @@ -70,19 +70,19 @@ //! - `metrics`: Includes system info metrics (CPU/GPU usage, etc.) //! - Dataset //! - `dataset`: Includes a datasets library -//! - `audio`: Enables audio datasets (SpeechCommandsDataset) -//! - `sqlite`: Stores datasets in SQLite database -//! - `sqlite_bundled`: Use bundled version of SQLite -//! - `vision`: Enables vision datasets (MnistDataset) +//! - `audio`: Enables audio datasets (`SpeechCommandsDataset`) +//! - `sqlite`: Stores datasets in `SQLite` database +//! - `sqlite_bundled`: Use bundled version of `SQLite` +//! - `vision`: Enables vision datasets (`MnistDataset`) //! - Backends //! - `wgpu`: Makes available the WGPU backend //! - `webgpu`: Makes available the `wgpu` backend with the WebGPU Shading Language (WGSL) compiler //! - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler //! - `cuda`: Makes available the CUDA backend -//! - `rocm`: Makes available the ROCm backend +//! - `rocm`: Makes available the `ROCm` backend //! - `candle`: Makes available the Candle backend -//! - `tch`: Makes available the LibTorch backend -//! - `ndarray`: Makes available the NdArray backend +//! - `tch`: Makes available the `LibTorch` backend +//! - `ndarray`: Makes available the `NdArray` backend //! - Backend specifications //! - `accelerate`: If supported, Accelerate will be used //! - `blas-netlib`: If supported, Blas Netlib will be use @@ -93,7 +93,7 @@ //! - Backend decorators //! - `autodiff`: Makes available the Autodiff backend //! - Others: -//! - `std`: Activates the standard library (deactivate for no_std) +//! - `std`: Activates the standard library (deactivate for `no_std`) //! - `server`: Enables the remote server. //! - `network`: Enables network utilities (currently, only a file downloader with progress bar) //! - `experimental-named-tensor`: Enables named tensors (experimental) diff --git a/crates/onnx-ir/Cargo.toml b/crates/onnx-ir/Cargo.toml index bc720ebda0..1c3fec0443 100644 --- a/crates/onnx-ir/Cargo.toml +++ b/crates/onnx-ir/Cargo.toml @@ -12,6 +12,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/onnx-ir" documentation = "https://docs.rs/onnx-ir" version.workspace = true +[lints] +workspace = true + [dependencies] bytemuck = { workspace = true } diff --git a/crates/onnx-ir/src/coalesce.rs b/crates/onnx-ir/src/coalesce.rs index 6dcaea47b5..2738fed6b5 100644 --- a/crates/onnx-ir/src/coalesce.rs +++ b/crates/onnx-ir/src/coalesce.rs @@ -26,11 +26,9 @@ pub fn coalesce( /// This function converts a Gemm node into a Linear node /// -/// PyTorch and other frameworks use Gemm node to represent Linear layer. +/// `PyTorch` and other frameworks use Gemm node to represent Linear layer. pub(crate) fn convert_gemm_to_linear(node: &mut Node) { - if node.outputs.len() != 1 { - panic!("Gemm node must have 1 output"); - } + assert!((node.outputs.len() == 1), "Gemm node must have 1 output"); let straight_linear = match ( node.attrs.get("alpha"), node.attrs.get("beta"), @@ -121,20 +119,18 @@ fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec transposed } -/// This function converts a MatMul node into a Linear node if possible. +/// This function converts a `MatMul` node into a Linear node if possible. /// -/// PyTorch and other frameworks use MatMul node to represent Linear layer. +/// `PyTorch` and other frameworks use `MatMul` node to represent Linear layer. /// /// This function also converts the following Add node into a Linear node if possible. -/// Add node is used to represent bias in PyTorch. +/// Add node is used to represent bias in `PyTorch`. pub(crate) fn convert_matmul_to_linear( node: &mut Node, iter_mut: &mut Peekable>, graph_data: &GraphData, ) { - if node.inputs.len() != 2 { - panic!("MatMul node must have 2 inputs"); - } + assert!((node.inputs.len() == 2), "MatMul node must have 2 inputs"); // if the second input does not have a value, it is not a weight, then proceed to the next node if node.inputs[1].value.is_none() { diff --git a/crates/onnx-ir/src/from_onnx.rs b/crates/onnx-ir/src/from_onnx.rs index 0a7e828eb2..9e90654cec 100644 --- a/crates/onnx-ir/src/from_onnx.rs +++ b/crates/onnx-ir/src/from_onnx.rs @@ -120,10 +120,7 @@ impl GraphData { if let Some(init_arg) = self.initializers.get(proto_str) { init_arg.clone() } else { - log::warn!( - "Input {} not found, should only happen when peeking", - proto_str - ); + log::warn!("Input {proto_str} not found, should only happen when peeking"); Argument::new(proto_str.to_string()) } } @@ -132,7 +129,7 @@ impl GraphData { } } - /// Mark the graph_inputs to a node as passed, unless they are also initializers + /// Mark the `graph_inputs` to a node as passed, unless they are also initializers fn mark_input_passed(&mut self, node: &Node) { // we have to double map the inputs because the input might be replaced by an initializer node.inputs.iter().for_each(|node_input| { @@ -157,7 +154,7 @@ impl GraphData { log::debug!("adding node {:?}", &node.name); self.mark_input_passed(&node); let mut out_count = 1; - for output in node.outputs.iter_mut() { + for output in &mut node.outputs { self.input_name_map.insert( output.name.clone(), IOEntry::Node(self.processed_nodes.len(), out_count - 1), @@ -280,7 +277,7 @@ impl OnnxGraphBuilder { } else if self.constants_types.contains(&node.node_type) { log::debug!("checking node {} for constants", &node.name); for input in node.inputs.iter_mut().skip(1) { - log::debug!("checking input {:?} for const", input); + log::debug!("checking input {input:?} for const"); if let Some(const_idx) = self.constants_map.get(&input.name) { let constant = &graph_data.processed_nodes[*const_idx]; log::debug!( @@ -354,8 +351,9 @@ impl OnnxGraphBuilder { /// /// * If the file cannot be opened or read /// * If the ONNX model cannot be parsed -/// * If the model uses an unsupported opset version (must be >= MIN_OPSET_VERSION) +/// * If the model uses an unsupported opset version (must be >= `MIN_OPSET_VERSION`) /// * If the nodes in the graph are not topologically sorted +#[must_use] pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { log::info!("Parsing ONNX file: {}", onnx_path.display()); @@ -366,14 +364,12 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); // Check opset versions - must be >= MIN_OPSET_VERSION - if !verify_opsets(&onnx_model.opset_import, MIN_OPSET_VERSION) { - panic!( - "Unsupported ONNX opset version. This implementation requires opset {} or higher. \ + assert!( + verify_opsets(&onnx_model.opset_import, MIN_OPSET_VERSION), + "Unsupported ONNX opset version. This implementation requires opset {MIN_OPSET_VERSION} or higher. \ Please upgrade your model using the ONNX shape inference tool. \ - See documentation (https://burn.dev/burn-book/import/onnx-model.html) for details.", - MIN_OPSET_VERSION - ); - } + See documentation (https://burn.dev/burn-book/import/onnx-model.html) for details." + ); // ONNX nodes must be topologically sorted per spec: // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs @@ -489,6 +485,7 @@ impl TopologicalSortable for Vec { } /// Get the value of a constant node from its attributes +#[must_use] pub fn convert_constant_value(node: &Node) -> Argument { // A value can be stored in any of these attributes let keys = [ diff --git a/crates/onnx-ir/src/ir.rs b/crates/onnx-ir/src/ir.rs index 4b474ae596..03ec92ab33 100644 --- a/crates/onnx-ir/src/ir.rs +++ b/crates/onnx-ir/src/ir.rs @@ -121,14 +121,17 @@ impl Default for ArgType { } impl ArgType { + #[must_use] pub fn is_scalar(&self) -> bool { matches!(self, Self::Scalar(_)) } + #[must_use] pub fn is_tensor(&self) -> bool { matches!(self, Self::Tensor(_)) } /// returns the rank (dimension) of the Arg + #[must_use] pub fn rank(&self) -> usize { match self { ArgType::Scalar(_) => 0, @@ -138,6 +141,7 @@ impl ArgType { } /// returns the element type of the Arg + #[must_use] pub fn elem_type(&self) -> &ElementType { match self { ArgType::Scalar(s) => s, @@ -148,6 +152,7 @@ impl ArgType { } impl Argument { + #[must_use] pub fn new(name: String) -> Self { Self { name, @@ -174,6 +179,7 @@ pub struct TensorData { impl TensorData { /// The element type of the tensor inferred from the data. + #[must_use] pub fn elem_type(&self) -> ElementType { match &self.data { Data::Bool(_) | Data::Bools(_) => ElementType::Bool, @@ -490,7 +496,7 @@ fn trunc(v: &[T]) -> String { if i > BEGIN_INDEX { s.push_str(", "); } - s.push_str(&format!("{}", item)); + s.push_str(&format!("{item}")); if i > MAX_LEN { s.push_str(", ..."); break; @@ -511,18 +517,19 @@ impl fmt::Debug for Data { Data::Int64s(v) => write!(f, "Int64s({})", trunc(v)), Data::Strings(v) => write!(f, "Strings({})", trunc(v)), Data::Bools(v) => write!(f, "Bools({})", trunc(v)), - Data::Float16(v) => write!(f, "Float16({})", v), - Data::Float32(v) => write!(f, "Float32({})", v), - Data::Float64(v) => write!(f, "Float64({})", v), - Data::Int32(v) => write!(f, "Int32({})", v), - Data::Int64(v) => write!(f, "Int64({})", v), - Data::String(v) => write!(f, "String({})", v), - Data::Bool(v) => write!(f, "Bool({})", v), + Data::Float16(v) => write!(f, "Float16({v})"), + Data::Float32(v) => write!(f, "Float32({v})"), + Data::Float64(v) => write!(f, "Float64({v})"), + Data::Int32(v) => write!(f, "Int32({v})"), + Data::Int64(v) => write!(f, "Int64({v})"), + Data::String(v) => write!(f, "String({v})"), + Data::Bool(v) => write!(f, "Bool({v})"), } } } impl Data { + #[must_use] pub fn into_scalar(self) -> Self { match self { Data::Float16s(data) => { @@ -556,15 +563,17 @@ impl Data { _ => self, } } + #[must_use] pub fn into_f16(self) -> f16 { match self { Data::Float16(elem) => elem, Data::Float32(elem) => f16::from_f32(elem), Data::Float64(elem) => f16::from_f64(elem), - _ => panic!("Cannot convert {:?} to f16", self), + _ => panic!("Cannot convert {self:?} to f16"), } } + #[must_use] pub fn into_f32(self) -> f32 { match self { Data::Float16(elem) => elem.to_f32(), @@ -573,22 +582,24 @@ impl Data { Data::Int32(elem) => elem as f32, Data::Int64(elem) => elem as f32, Data::Float32s(elem) if elem.len() == 1 => elem[0], - _ => panic!("Cannot convert {:?} to f32", self), + _ => panic!("Cannot convert {self:?} to f32"), } } + #[must_use] pub fn into_f64(self) -> f64 { match self { Data::Float16(elem) => elem.to_f64(), - Data::Float32(elem) => elem as f64, + Data::Float32(elem) => f64::from(elem), Data::Float64(elem) => elem, - Data::Int32(elem) => elem as f64, + Data::Int32(elem) => f64::from(elem), Data::Int64(elem) => elem as f64, Data::Float64s(elem) if elem.len() == 1 => elem[0], - _ => panic!("Cannot convert {:?} to f64", self), + _ => panic!("Cannot convert {self:?} to f64"), } } + #[must_use] pub fn into_i32(self) -> i32 { match self { Data::Int32(elem) => elem, @@ -597,34 +608,37 @@ impl Data { Data::Float64(elem) => elem as i32, Data::Float32s(elem) if elem.len() == 1 => elem[0] as i32, Data::Int32s(elem) if elem.len() == 1 => elem[0], - _ => panic!("Cannot convert {:?} to i32", self), + _ => panic!("Cannot convert {self:?} to i32"), } } + #[must_use] pub fn into_i64(self) -> i64 { match self { - Data::Int32(elem) => elem as i64, + Data::Int32(elem) => i64::from(elem), Data::Int64(elem) => elem, Data::Float32(elem) => elem as i64, Data::Float64(elem) => elem as i64, Data::Int64s(elem) if elem.len() == 1 => elem[0], - _ => panic!("Cannot convert {:?} to i64", self), + _ => panic!("Cannot convert {self:?} to i64"), } } + #[must_use] pub fn into_bool(self) -> bool { if let Data::Bool(elem) = self { elem } else { - panic!("Expected Bool, got {:?}", self); + panic!("Expected Bool, got {self:?}"); } } + #[must_use] pub fn into_string(self) -> String { if let Data::String(elem) = self { elem } else { - panic!("Expected String, got {:?}", self); + panic!("Expected String, got {self:?}"); } } @@ -633,159 +647,175 @@ impl Data { Data::Float16s(elem) => elem, Data::Float32s(elem) => elem.into_iter().map(f16::from_f32).collect(), Data::Float64s(elem) => elem.into_iter().map(f16::from_f64).collect(), - _ => panic!("Cannot convert {:?} to Vec", self), + _ => panic!("Cannot convert {self:?} to Vec"), } } + #[must_use] pub fn into_f32s(self) -> Vec { match self { - Data::Float16s(elem) => elem.into_iter().map(|x| x.to_f32()).collect(), + Data::Float16s(elem) => elem.into_iter().map(half::f16::to_f32).collect(), Data::Float32s(elem) => elem, Data::Float64s(elem) => elem.into_iter().map(|x| x as f32).collect(), Data::Int32s(elem) => elem.into_iter().map(|x| x as f32).collect(), Data::Int64s(elem) => elem.into_iter().map(|x| x as f32).collect(), - _ => panic!("Cannot convert {:?} to Vec", self), + _ => panic!("Cannot convert {self:?} to Vec"), } } + #[must_use] pub fn into_f64s(self) -> Vec { match self { - Data::Float16s(elem) => elem.into_iter().map(|x| x.to_f64()).collect(), - Data::Float32s(elem) => elem.into_iter().map(|x| x as f64).collect(), + Data::Float16s(elem) => elem.into_iter().map(half::f16::to_f64).collect(), + Data::Float32s(elem) => elem.into_iter().map(f64::from).collect(), Data::Float64s(elem) => elem, - Data::Int32s(elem) => elem.into_iter().map(|x| x as f64).collect(), + Data::Int32s(elem) => elem.into_iter().map(f64::from).collect(), Data::Int64s(elem) => elem.into_iter().map(|x| x as f64).collect(), - _ => panic!("Cannot convert {:?} to Vec", self), + _ => panic!("Cannot convert {self:?} to Vec"), } } + #[must_use] pub fn into_i32s(self) -> Vec { match self { Data::Int32s(elem) => elem, Data::Int64s(elem) => elem.into_iter().map(|x| x as i32).collect(), Data::Float32s(elem) => elem.into_iter().map(|x| x as i32).collect(), Data::Float64s(elem) => elem.into_iter().map(|x| x as i32).collect(), - _ => panic!("Cannot convert {:?} to Vec", self), + _ => panic!("Cannot convert {self:?} to Vec"), } } + #[must_use] pub fn into_i64s(self) -> Vec { match self { - Data::Int32s(elem) => elem.into_iter().map(|x| x as i64).collect(), + Data::Int32s(elem) => elem.into_iter().map(i64::from).collect(), Data::Int64s(elem) => elem, Data::Float32s(elem) => elem.into_iter().map(|x| x as i64).collect(), Data::Float64s(elem) => elem.into_iter().map(|x| x as i64).collect(), - _ => panic!("Cannot convert {:?} to Vec", self), + _ => panic!("Cannot convert {self:?} to Vec"), } } + #[must_use] pub fn into_usizes(self) -> Vec { match self { Data::Int32s(elem) => elem.into_iter().map(|x| x as usize).collect(), Data::Int64s(elem) => elem.into_iter().map(|x| x as usize).collect(), Data::Float32s(elem) => elem.into_iter().map(|x| x as usize).collect(), Data::Float64s(elem) => elem.into_iter().map(|x| x as usize).collect(), - _ => panic!("Cannot convert {:?} to Vec", self), + _ => panic!("Cannot convert {self:?} to Vec"), } } + #[must_use] pub fn into_bools(self) -> Vec { if let Data::Bools(elem) = self { elem } else { - panic!("Expected Bools, got {:?}", self); + panic!("Expected Bools, got {self:?}"); } } + #[must_use] pub fn into_strings(self) -> Vec { if let Data::Strings(elem) = self { elem } else { - panic!("Expected Strings, got {:?}", self); + panic!("Expected Strings, got {self:?}"); } } } impl AttributeValue { + #[must_use] pub fn into_f32(self) -> f32 { if let AttributeValue::Float32(elem) = self { elem } else { - panic!("Expected Float32, got {:?}", self); + panic!("Expected Float32, got {self:?}"); } } + #[must_use] pub fn into_i32(self) -> i32 { if let AttributeValue::Int64(elem) = self { elem as i32 } else { - panic!("Expected Int32, got {:?}", self); + panic!("Expected Int32, got {self:?}"); } } + #[must_use] pub fn into_i64(self) -> i64 { if let AttributeValue::Int64(elem) = self { elem } else { - panic!("Expected Int64, got {:?}", self); + panic!("Expected Int64, got {self:?}"); } } + #[must_use] pub fn into_string(self) -> String { if let AttributeValue::String(elem) = self { elem } else { - panic!("Expected String, got {:?}", self); + panic!("Expected String, got {self:?}"); } } + #[must_use] pub fn into_tensor(self) -> TensorData { if let AttributeValue::Tensor(elem) = self { elem } else { - panic!("Expected Tensor, got {:?}", self); + panic!("Expected Tensor, got {self:?}"); } } + #[must_use] pub fn into_f32s(self) -> Vec { if let AttributeValue::Float32s(elem) = self { elem } else { - panic!("Expected Float32s, got {:?}", self); + panic!("Expected Float32s, got {self:?}"); } } + #[must_use] pub fn into_i64s(self) -> Vec { if let AttributeValue::Int64s(elem) = self { elem } else { - panic!("Expected Int64s, got {:?}", self); + panic!("Expected Int64s, got {self:?}"); } } + #[must_use] pub fn into_strings(self) -> Vec { if let AttributeValue::Strings(elem) = self { elem } else { - panic!("Expected Strings, got {:?}", self); + panic!("Expected Strings, got {self:?}"); } } + #[must_use] pub fn into_tensors(self) -> Vec { if let AttributeValue::Tensors(elem) = self { elem } else { - panic!("Expected Tensors, got {:?}", self); + panic!("Expected Tensors, got {self:?}"); } } } -/// Convert AttributeValue to an Argument +/// Convert `AttributeValue` to an Argument impl From for Argument { fn from(attr: AttributeValue) -> Argument { // "" is used as a placeholder for the name // TODO dt review this empty string placeholder; it came up a few times in the issues - let name = "".to_string(); + let name = String::new(); match attr { AttributeValue::Float32(value) => Argument { @@ -889,6 +919,7 @@ impl From for Argument { } impl Argument { + #[must_use] pub fn into_tensor(self) -> Option { if let ArgType::Tensor(_) = self.ty { self.value diff --git a/crates/onnx-ir/src/node/argmax.rs b/crates/onnx-ir/src/node/argmax.rs index 153eeac7b8..80790a04e7 100644 --- a/crates/onnx-ir/src/node/argmax.rs +++ b/crates/onnx-ir/src/node/argmax.rs @@ -1,16 +1,16 @@ use crate::ir::{ArgType, ElementType, Node, TensorType}; /// Create argmax config from the attributes of the node +#[must_use] pub fn argmax_config(node: &Node) -> usize { let mut axis: i64 = 0; // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Argmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } + assert!( + (node.inputs.len() == 1), + "Argmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); // extract the shape of the input tensor let tensor = match node.inputs.first().unwrap().clone().ty { @@ -19,26 +19,23 @@ pub fn argmax_config(node: &Node) -> usize { }; // extract the attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "axis" => axis = value.clone().into_i64(), "select_last_index" => { // not all params are supported in burn if value.clone().into_i64() != 0 { log::warn!( - "only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {:?})", - value + "only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {value:?})" ); } } "keepdims" => { // not all params are supported in burn - if value.clone().into_i64() != 1 { - panic!( - "Only keepdims=1 is supported for argmax in burn (got {:?})", - value - ); - } + assert!( + (value.clone().into_i64() == 1), + "Only keepdims=1 is supported for argmax in burn (got {value:?})" + ); } _ => {} } @@ -52,13 +49,14 @@ pub fn argmax_config(node: &Node) -> usize { axis as usize } -/// Update output rank for ArgMax (same as input rank). +/// Update output rank for `ArgMax` (same as input rank). pub fn argmax_update_outputs(node: &mut Node) { log::debug!("ArgMax rank inference for node {}", node.name); - if node.inputs.len() != 1 { - panic!("ArgMax: multiple inputs are not supported"); - } + assert!( + (node.inputs.len() == 1), + "ArgMax: multiple inputs are not supported" + ); let tensor = match &node.inputs[0].ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), diff --git a/crates/onnx-ir/src/node/avg_pool1d.rs b/crates/onnx-ir/src/node/avg_pool1d.rs index 73d6d70291..8146c043c1 100644 --- a/crates/onnx-ir/src/node/avg_pool1d.rs +++ b/crates/onnx-ir/src/node/avg_pool1d.rs @@ -2,7 +2,7 @@ use crate::{ir::Node, node::padding::padding_config_1d}; use super::padding::PaddingConfig1d; -/// Configuration for AvgPool1d operations extracted from ONNX nodes +/// Configuration for `AvgPool1d` operations extracted from ONNX nodes #[derive(Debug, Clone)] pub struct AvgPool1dConfig { /// Kernel size @@ -16,7 +16,8 @@ pub struct AvgPool1dConfig { } impl AvgPool1dConfig { - /// Create a new AvgPool1dConfig + /// Create a new `AvgPool1dConfig` + #[must_use] pub fn new( kernel_size: usize, stride: usize, @@ -32,7 +33,8 @@ impl AvgPool1dConfig { } } -/// Create an AvgPool1dConfig from the attributes of the node +/// Create an `AvgPool1dConfig` from the attributes of the node +#[must_use] pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { let mut kernel_shape = Vec::new(); let mut strides = vec![1]; @@ -40,7 +42,7 @@ pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { let mut count_include_pad: i64 = 0; let mut ceil_mode: i64 = 0; - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => strides = value.clone().into_i64s(), @@ -60,9 +62,7 @@ pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { ); assert_eq!(strides.len(), 1, "AvgPool1d: stride must have length 1"); - if ceil_mode == 1 { - panic!("ceil_mode is not supported"); - } + assert!((ceil_mode != 1), "ceil_mode is not supported"); let padding = padding_config_1d(&pads); diff --git a/crates/onnx-ir/src/node/avg_pool2d.rs b/crates/onnx-ir/src/node/avg_pool2d.rs index bc6e5a0f73..8c75949209 100644 --- a/crates/onnx-ir/src/node/avg_pool2d.rs +++ b/crates/onnx-ir/src/node/avg_pool2d.rs @@ -1,7 +1,7 @@ use crate::ir::Node; use crate::node::padding::{PaddingConfig2d, padding_config_2d}; -/// Configuration for AvgPool2d operations +/// Configuration for `AvgPool2d` operations #[derive(Debug, Clone)] pub struct AvgPool2dConfig { /// Kernel size [height, width] @@ -15,7 +15,8 @@ pub struct AvgPool2dConfig { } impl AvgPool2dConfig { - /// Create a new AvgPool2dConfig + /// Create a new `AvgPool2dConfig` + #[must_use] pub fn new( kernel_size: [usize; 2], strides: [usize; 2], @@ -31,7 +32,8 @@ impl AvgPool2dConfig { } } -/// Create a AvgPool2dConfig from the attributes of the node +/// Create a `AvgPool2dConfig` from the attributes of the node +#[must_use] pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { let mut kernel_shape = Vec::new(); let mut strides = vec![1, 1]; @@ -39,7 +41,7 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { let mut count_include_pad: i64 = 0; let mut ceil_mode: i64 = 0; - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => strides = value.clone().into_i64s(), @@ -52,9 +54,7 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { } } - if ceil_mode == 1 { - panic!("ceil_mode is not supported"); - } + assert!((ceil_mode != 1), "ceil_mode is not supported"); let padding = padding_config_2d(&pads); diff --git a/crates/onnx-ir/src/node/batch_norm.rs b/crates/onnx-ir/src/node/batch_norm.rs index aaa0165c8f..66da8540cb 100644 --- a/crates/onnx-ir/src/node/batch_norm.rs +++ b/crates/onnx-ir/src/node/batch_norm.rs @@ -1,6 +1,6 @@ use crate::ir::Node; -/// Configuration for BatchNorm operations +/// Configuration for `BatchNorm` operations #[derive(Debug, Clone)] pub struct BatchNormConfig { /// Number of features (channels) @@ -12,7 +12,8 @@ pub struct BatchNormConfig { } impl BatchNormConfig { - /// Create a new BatchNormConfig + /// Create a new `BatchNormConfig` + #[must_use] pub fn new(num_features: usize, epsilon: f64, momentum: f64) -> Self { Self { num_features, @@ -22,7 +23,8 @@ impl BatchNormConfig { } } -/// Create a BatchNormConfig from the attributes of the node +/// Create a `BatchNormConfig` from the attributes of the node +#[must_use] pub fn batch_norm_config(node: &Node) -> BatchNormConfig { let weight_shape = node.inputs[1] .value @@ -36,7 +38,7 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig { let mut epsilon = 0f32; let mut momentum = 0f32; - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "momentum" => momentum = value.clone().into_f32(), "epsilon" => epsilon = value.clone().into_f32(), @@ -44,7 +46,7 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig { } } - BatchNormConfig::new(num_features, epsilon as f64, momentum as f64) + BatchNormConfig::new(num_features, f64::from(epsilon), f64::from(momentum)) } #[cfg(test)] diff --git a/crates/onnx-ir/src/node/cast.rs b/crates/onnx-ir/src/node/cast.rs index b3dd289de1..492d926fcd 100644 --- a/crates/onnx-ir/src/node/cast.rs +++ b/crates/onnx-ir/src/node/cast.rs @@ -4,9 +4,10 @@ use protobuf::Enum; /// Update output type for Cast operations, preserving rank. pub fn cast_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Cast: multiple inputs are not supported"); - } + assert!( + (node.inputs.len() == 1), + "Cast: multiple inputs are not supported" + ); let input = &mut node.inputs[0]; let output = &mut node.outputs[0]; @@ -70,7 +71,7 @@ mod tests { #[test] fn test_cast_float_to_int64() { - let mut node = create_test_node(2, DataType::INT64.value() as i64); + let mut node = create_test_node(2, i64::from(DataType::INT64.value())); cast_update_outputs(&mut node); match &node.outputs[0].ty { @@ -84,7 +85,7 @@ mod tests { #[test] fn test_cast_scalar_handling() { - let mut node = create_test_node(0, DataType::BOOL.value() as i64); + let mut node = create_test_node(0, i64::from(DataType::BOOL.value())); cast_update_outputs(&mut node); match &node.outputs[0].ty { @@ -105,7 +106,7 @@ mod tests { #[test] #[should_panic(expected = "Cast: multiple inputs are not supported")] fn test_cast_multiple_inputs() { - let mut node = create_test_node(2, DataType::INT64.value() as i64); + let mut node = create_test_node(2, i64::from(DataType::INT64.value())); node.inputs.push(Argument { name: "extra".to_string(), ty: ArgType::Tensor(TensorType { @@ -121,7 +122,7 @@ mod tests { #[test] fn test_cast_scalar_to_bool() { - let mut node = create_scalar_test_node(DataType::BOOL.value() as i64); + let mut node = create_scalar_test_node(i64::from(DataType::BOOL.value())); cast_update_outputs(&mut node); match &node.outputs[0].ty { diff --git a/crates/onnx-ir/src/node/clip.rs b/crates/onnx-ir/src/node/clip.rs index e33bf381f0..6f86b50493 100644 --- a/crates/onnx-ir/src/node/clip.rs +++ b/crates/onnx-ir/src/node/clip.rs @@ -1,19 +1,20 @@ use crate::ir::{Data, Node}; +#[must_use] pub fn clip_config(node: &Node) -> (Option, Option) { let mut min_result: Option = None; let mut max_result: Option = None; // For Clip Opset 6+ , the min and max values are attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "min" => { - let min = value.clone().into_f32() as f64; + let min = f64::from(value.clone().into_f32()); min_result = Some(min); } "max" => { let max = value.clone().into_f32(); - max_result = Some(max as f64); + max_result = Some(f64::from(max)); } _ => {} } @@ -28,8 +29,8 @@ pub fn clip_config(node: &Node) -> (Option, Option) { if min_result.is_none() && min.is_some() { let min = min.unwrap().data.into_scalar(); min_result = match min { - Data::Float16(min) => Some(f32::from(min) as f64), - Data::Float32(min) => Some(min as f64), + Data::Float16(min) => Some(f64::from(f32::from(min))), + Data::Float32(min) => Some(f64::from(min)), Data::Float64(min) => Some(min), _ => panic!("Clip: only float min is supported"), }; @@ -38,17 +39,18 @@ pub fn clip_config(node: &Node) -> (Option, Option) { if max_result.is_none() && max.is_some() { let max = max.unwrap().data.into_scalar(); max_result = match max { - Data::Float16(max) => Some(f32::from(max) as f64), - Data::Float32(max) => Some(max as f64), + Data::Float16(max) => Some(f64::from(f32::from(max))), + Data::Float32(max) => Some(f64::from(max)), Data::Float64(max) => Some(max), _ => panic!("Clip: only float max is supported"), }; } } - if min_result.is_none() && max_result.is_none() { - panic!("Clip: min and max values must be either attributes or inputs"); - } + assert!( + !(min_result.is_none() && max_result.is_none()), + "Clip: min and max values must be either attributes or inputs" + ); (min_result, max_result) } diff --git a/crates/onnx-ir/src/node/concat.rs b/crates/onnx-ir/src/node/concat.rs index 71a4ee328c..0ce06431bf 100644 --- a/crates/onnx-ir/src/node/concat.rs +++ b/crates/onnx-ir/src/node/concat.rs @@ -25,6 +25,7 @@ pub fn concat_update_outputs(node: &mut Node) { } /// Create concat config from the attributes of the node +#[must_use] pub fn concat_config(node: &Node) -> usize { // the axis is the last dimension (Default: 1 per ONNX spec) let mut axis: i64 = 1; @@ -36,9 +37,9 @@ pub fn concat_config(node: &Node) -> usize { }; // extract the attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { if key.as_str() == "axis" { - axis = value.clone().into_i64() + axis = value.clone().into_i64(); } } diff --git a/crates/onnx-ir/src/node/constant.rs b/crates/onnx-ir/src/node/constant.rs index 2af9aabb04..bbbb33986f 100644 --- a/crates/onnx-ir/src/node/constant.rs +++ b/crates/onnx-ir/src/node/constant.rs @@ -60,7 +60,7 @@ pub fn constant_update_outputs(node: &mut Node) { static_shape: None, }) } - ty => panic!("Constant value of {:?} is not supported", ty), + ty => panic!("Constant value of {ty:?} is not supported"), }, None => panic!("Constant node must have a value attribute"), }; diff --git a/crates/onnx-ir/src/node/constant_of_shape.rs b/crates/onnx-ir/src/node/constant_of_shape.rs index 723a788562..e13e1fb05a 100644 --- a/crates/onnx-ir/src/node/constant_of_shape.rs +++ b/crates/onnx-ir/src/node/constant_of_shape.rs @@ -1,14 +1,12 @@ use crate::ir::{ArgType, ElementType, Node, TensorType}; -/// Updates the output rank for a ConstantOfShape node based on the rank of its input. +/// Updates the output rank for a `ConstantOfShape` node based on the rank of its input. pub fn constant_of_shape_update_output(node: &mut Node) { log::debug!("ConstantOfShape rank inference for node {}", node.name); - let value_type = node - .attrs - .get("value") - .map(|v| v.clone().into_tensor().elem_type()) - .unwrap_or(ElementType::Float32); // If not given, defaults to 0 as float32 + let value_type = node.attrs.get("value").map_or(ElementType::Float32, |v| { + v.clone().into_tensor().elem_type() + }); // If not given, defaults to 0 as float32 log::debug!( "ConstantOfShape value type for {}: {:?}", node.name, diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs index 3ab2dcf1ef..35981bf428 100644 --- a/crates/onnx-ir/src/node/conv1d.rs +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -24,8 +24,9 @@ pub struct Conv1dConfig { } impl Conv1dConfig { - /// Create a new Conv1dConfig + /// Create a new `Conv1dConfig` #[allow(clippy::too_many_arguments)] + #[must_use] pub fn new( channels_in: usize, channels_out: usize, @@ -41,15 +42,16 @@ impl Conv1dConfig { channels_out, kernel_size, stride, - padding, dilation, groups, bias, + padding, } } } -/// Create a Conv1dConfig from the attributes of the node +/// Create a `Conv1dConfig` from the attributes of the node +#[must_use] pub fn conv1d_config(curr: &Node) -> Conv1dConfig { let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec let mut strides = vec![1]; @@ -67,7 +69,7 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig { // check if the bias is present let bias = curr.inputs.len() == 3; - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => strides = value.clone().into_i64s(), diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs index 6715069e26..12a563e6f0 100644 --- a/crates/onnx-ir/src/node/conv2d.rs +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -21,7 +21,8 @@ pub struct Conv2dConfig { } impl Conv2dConfig { - /// Create a new Conv2dConfig + /// Create a new `Conv2dConfig` + #[must_use] pub fn new( channels: [usize; 2], kernel_size: [usize; 2], @@ -43,7 +44,8 @@ impl Conv2dConfig { } } -/// Create a Conv2dConfig from the attributes of the node +/// Create a `Conv2dConfig` from the attributes of the node +#[must_use] pub fn conv2d_config(curr: &Node) -> Conv2dConfig { let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec let mut strides = vec![1, 1]; @@ -61,7 +63,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig { // check if the bias is present let bias = curr.inputs.len() == 3; - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => strides = value.clone().into_i64s(), diff --git a/crates/onnx-ir/src/node/conv3d.rs b/crates/onnx-ir/src/node/conv3d.rs index 043ec6a5d6..5373ff96bb 100644 --- a/crates/onnx-ir/src/node/conv3d.rs +++ b/crates/onnx-ir/src/node/conv3d.rs @@ -22,6 +22,7 @@ pub struct Conv3dConfig { impl Conv3dConfig { /// Create a new configuration for a Conv3d. + #[must_use] pub fn new( channels: [usize; 2], kernel_size: [usize; 3], @@ -43,7 +44,8 @@ impl Conv3dConfig { } } -/// Create a Conv3dConfig from the attributes of the node +/// Create a `Conv3dConfig` from the attributes of the node +#[must_use] pub fn conv3d_config(curr: &Node) -> Conv3dConfig { let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec let mut strides = vec![1, 1, 1]; @@ -61,7 +63,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig { // check if the bias is present let bias = curr.inputs.len() == 3; - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => strides = value.clone().into_i64s(), diff --git a/crates/onnx-ir/src/node/conv_transpose1d.rs b/crates/onnx-ir/src/node/conv_transpose1d.rs index 961de340cc..8005ae2229 100644 --- a/crates/onnx-ir/src/node/conv_transpose1d.rs +++ b/crates/onnx-ir/src/node/conv_transpose1d.rs @@ -1,6 +1,6 @@ use crate::ir::Node; -/// Configuration for ConvTranspose1d operations extracted from ONNX nodes +/// Configuration for `ConvTranspose1d` operations extracted from ONNX nodes #[derive(Debug, Clone)] pub struct ConvTranspose1dConfig { /// Input channels @@ -24,8 +24,9 @@ pub struct ConvTranspose1dConfig { } impl ConvTranspose1dConfig { - /// Create a new ConvTranspose1dConfig + /// Create a new `ConvTranspose1dConfig` #[allow(clippy::too_many_arguments)] + #[must_use] pub fn new( channels_in: usize, channels_out: usize, @@ -42,16 +43,17 @@ impl ConvTranspose1dConfig { channels_out, kernel_size, stride, - padding, dilation, groups, bias, + padding, padding_out, } } } -/// Create a ConvTranspose1dConfig from the attributes of the node +/// Create a `ConvTranspose1dConfig` from the attributes of the node +#[must_use] pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { let mut kernel_shape = Vec::new(); // Default to empty vector let mut stride = vec![1]; // Default stride to 1 @@ -61,7 +63,7 @@ pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { let mut output_padding = vec![0]; // Default output padding to 0 // Extract attributes - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => stride = value.clone().into_i64s(), @@ -74,12 +76,10 @@ pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { } // Check the pads are symmetric - if pads.len() != 2 || pads[0] != pads[1] { - panic!( - "Asymmetric padding is not supported for ConvTranspose1d: {:?}", - pads - ); - } + assert!( + !(pads.len() != 2 || pads[0] != pads[1]), + "Asymmetric padding is not supported for ConvTranspose1d: {pads:?}" + ); let weight_shape = curr.inputs[1] .value diff --git a/crates/onnx-ir/src/node/conv_transpose2d.rs b/crates/onnx-ir/src/node/conv_transpose2d.rs index f813ee9136..772f599c73 100644 --- a/crates/onnx-ir/src/node/conv_transpose2d.rs +++ b/crates/onnx-ir/src/node/conv_transpose2d.rs @@ -1,6 +1,6 @@ use crate::ir::Node; -/// Configuration for ConvTranspose2d operations. +/// Configuration for `ConvTranspose2d` operations. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ConvTranspose2dConfig { /// Input and output channels [in, out]. @@ -22,8 +22,9 @@ pub struct ConvTranspose2dConfig { } impl ConvTranspose2dConfig { - /// Create a new configuration for a ConvTranspose2d. + /// Create a new configuration for a `ConvTranspose2d`. #[allow(clippy::too_many_arguments)] + #[must_use] pub fn new( channels: [usize; 2], kernel_size: [usize; 2], @@ -47,7 +48,8 @@ impl ConvTranspose2dConfig { } } -/// Create a ConvTranspose2dConfig from the attributes of the node +/// Create a `ConvTranspose2dConfig` from the attributes of the node +#[must_use] pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { let mut kernel_shape = Vec::new(); // Default to empty vector let mut stride = vec![1, 1]; // Default stride to 1 @@ -57,7 +59,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { let mut output_padding = vec![0, 0]; // Default output padding to 0 // Extract attributes - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => stride = value.clone().into_i64s(), diff --git a/crates/onnx-ir/src/node/conv_transpose3d.rs b/crates/onnx-ir/src/node/conv_transpose3d.rs index 288800776e..6d4df0b36f 100644 --- a/crates/onnx-ir/src/node/conv_transpose3d.rs +++ b/crates/onnx-ir/src/node/conv_transpose3d.rs @@ -1,6 +1,6 @@ use crate::ir::Node; -/// Configuration for ConvTranspose3d operations. +/// Configuration for `ConvTranspose3d` operations. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ConvTranspose3dConfig { /// Input and output channels [in, out]. @@ -22,8 +22,9 @@ pub struct ConvTranspose3dConfig { } impl ConvTranspose3dConfig { - /// Create a new configuration for a ConvTranspose3d. + /// Create a new configuration for a `ConvTranspose3d`. #[allow(clippy::too_many_arguments)] + #[must_use] pub fn new( channels: [usize; 2], kernel_size: [usize; 3], @@ -47,7 +48,8 @@ impl ConvTranspose3dConfig { } } -/// Create a ConvTranspose3dConfig from the attributes of the node +/// Create a `ConvTranspose3dConfig` from the attributes of the node +#[must_use] pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { let mut kernel_shape = Vec::new(); // Default to empty vector let mut stride = vec![1, 1, 1]; // Default stride to 1 @@ -57,7 +59,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { let mut output_padding = vec![0, 0, 0]; // Default output padding to 0 // Extract attributes - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => stride = value.clone().into_i64s(), diff --git a/crates/onnx-ir/src/node/dropout.rs b/crates/onnx-ir/src/node/dropout.rs index 05b3edb2d2..d99a0264d5 100644 --- a/crates/onnx-ir/src/node/dropout.rs +++ b/crates/onnx-ir/src/node/dropout.rs @@ -8,23 +8,26 @@ pub struct DropoutConfig { } impl DropoutConfig { - /// Create a new DropoutConfig + /// Create a new `DropoutConfig` + #[must_use] pub fn new(prob: f64) -> Self { Self { prob } } } -/// Create a DropoutConfig from an attribute and state of the node +/// Create a `DropoutConfig` from an attribute and state of the node +#[must_use] pub fn dropout_config(node: &Node) -> DropoutConfig { // Opset 7 and older store probability as an attribute if node.attrs.contains_key("ratio") { let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); - return DropoutConfig::new(prob as f64); + return DropoutConfig::new(f64::from(prob)); } - if node.inputs.len() < 2 { - panic!("Dropout configuration must have at least 2 inputs"); - } + assert!( + (node.inputs.len() >= 2), + "Dropout configuration must have at least 2 inputs" + ); let ratio = node.inputs[1] .value @@ -35,7 +38,7 @@ pub fn dropout_config(node: &Node) -> DropoutConfig { let prob = match ratio { Data::Float16(ratio) => f64::from(f32::from(ratio)), - Data::Float32(ratio) => ratio as f64, + Data::Float32(ratio) => f64::from(ratio), Data::Float64(ratio) => ratio, _ => panic!("Dropout ratio must be a float"), }; diff --git a/crates/onnx-ir/src/node/expand.rs b/crates/onnx-ir/src/node/expand.rs index a78377de4d..bb57851336 100644 --- a/crates/onnx-ir/src/node/expand.rs +++ b/crates/onnx-ir/src/node/expand.rs @@ -62,10 +62,11 @@ pub enum ExpandShape { Runtime(Argument), } -/// Creates an ExpandShape configuration from the given Node. +/// Creates an `ExpandShape` configuration from the given Node. /// /// Extracts shape information from the node's second input to determine /// whether to use static or runtime shape expansion. +#[must_use] pub fn expand_config(node: &Node) -> ExpandShape { match &node.inputs[1].ty { ArgType::Tensor(tensor) => { diff --git a/crates/onnx-ir/src/node/flatten.rs b/crates/onnx-ir/src/node/flatten.rs index 394f01dc17..72bc1bcb66 100644 --- a/crates/onnx-ir/src/node/flatten.rs +++ b/crates/onnx-ir/src/node/flatten.rs @@ -2,9 +2,10 @@ use crate::ir::{ArgType, Node, TensorType}; /// Update output type for Flatten operation (rank 2). pub fn flatten_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Flatten: multiple inputs are not supported"); - } + assert!( + (node.inputs.len() == 1), + "Flatten: multiple inputs are not supported" + ); let tensor = node .inputs .iter() @@ -21,18 +22,18 @@ pub fn flatten_update_outputs(node: &mut Node) { }); } -/// Create a FlattenConfig from the attributes of the node +/// Create a `FlattenConfig` from the attributes of the node +#[must_use] pub fn flatten_config(curr: &Node) -> usize { // the begin dimension is the first dimension (Default: 1 per ONNX spec) let mut axis: i64 = 1; // check if the node has only one input - if curr.inputs.len() != 1 { - panic!( - "Flatten: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } + assert!( + (curr.inputs.len() == 1), + "Flatten: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); // extract the shape of the input tensor let tensor = match curr.inputs.first().unwrap().clone().ty { @@ -41,17 +42,16 @@ pub fn flatten_config(curr: &Node) -> usize { }; // check if the input tensor has at least 2 dimensions - if tensor.rank < 2 { - panic!( - "Flatten: input tensor must have at least 2 dimensions (got {:?})", - tensor.rank - ); - } + assert!( + (tensor.rank >= 2), + "Flatten: input tensor must have at least 2 dimensions (got {:?})", + tensor.rank + ); // extract the attributes - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { if key.as_str() == "axis" { - axis = value.clone().into_i64() + axis = value.clone().into_i64(); } } diff --git a/crates/onnx-ir/src/node/gather.rs b/crates/onnx-ir/src/node/gather.rs index d785f35652..e53ea38180 100644 --- a/crates/onnx-ir/src/node/gather.rs +++ b/crates/onnx-ir/src/node/gather.rs @@ -4,9 +4,10 @@ use crate::ir::{ArgType, ElementType, Node, TensorType}; pub fn gather_update_outputs(node: &mut Node) { log::debug!("Gather rank inference for node {}", node.name); - if node.inputs.len() != 2 { - panic!("Gather requires two inputs: data and indices"); - } + assert!( + (node.inputs.len() == 2), + "Gather requires two inputs: data and indices" + ); let indices_rank = match &node.inputs[1].ty { ArgType::Tensor(tensor) => tensor.rank, @@ -69,31 +70,33 @@ pub fn gather_update_outputs(node: &mut Node) { ); } } - ty => panic!("Only tensor/shape input is valid but received: {:?}", ty), + ty => panic!("Only tensor/shape input is valid but received: {ty:?}"), } } -/// Create a GatherConfig from the attributes of the node +/// Create a `GatherConfig` from the attributes of the node +#[must_use] pub fn gather_config(curr: &Node) -> usize { // Default: 0 per ONNX spec let mut dim: i64 = 0; // check if the node has only one input - if curr.inputs.len() != 2 { - panic!("Gather: index tensor must be present"); - } + assert!( + (curr.inputs.len() == 2), + "Gather: index tensor must be present" + ); // extract the shape of the input tensor let input_dim = match curr.inputs.first().unwrap().clone().ty { ArgType::Tensor(tensor) => tensor.rank as i64, ArgType::Shape(_shape) => 1, // Shape is always 1-D - other => panic!("Only tensor or shape input is valid, got {:?}", other), + other => panic!("Only tensor or shape input is valid, got {other:?}"), }; // extract the attributes - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { if key.as_str() == "axis" { - dim = value.clone().into_i64() + dim = value.clone().into_i64(); } } diff --git a/crates/onnx-ir/src/node/gemm.rs b/crates/onnx-ir/src/node/gemm.rs index b81d5ce436..56a9f57cf9 100644 --- a/crates/onnx-ir/src/node/gemm.rs +++ b/crates/onnx-ir/src/node/gemm.rs @@ -34,13 +34,14 @@ pub fn gemm_output_shape(node: &mut Node) { }); } +#[must_use] pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { let mut alpha: f32 = 1.0; let mut beta: f32 = 1.0; let mut trans_a: i64 = 0; let mut trans_b: i64 = 0; - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "alpha" => alpha = value.clone().into_f32(), "beta" => beta = value.clone().into_f32(), diff --git a/crates/onnx-ir/src/node/hard_sigmoid.rs b/crates/onnx-ir/src/node/hard_sigmoid.rs index 7b1a517f00..eb77bbca6e 100644 --- a/crates/onnx-ir/src/node/hard_sigmoid.rs +++ b/crates/onnx-ir/src/node/hard_sigmoid.rs @@ -1,14 +1,15 @@ use crate::ir::Node; -/// Create a HardSigmoidConfig from the alpha and beta attributes of the node +/// Create a `HardSigmoidConfig` from the alpha and beta attributes of the node +#[must_use] pub fn hard_sigmoid_config(node: &Node) -> (f64, f64) { let mut alpha = 0.2; let mut beta = 0.5; - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { - "alpha" => alpha = value.clone().into_f32() as f64, - "beta" => beta = value.clone().into_f32() as f64, + "alpha" => alpha = f64::from(value.clone().into_f32()), + "beta" => beta = f64::from(value.clone().into_f32()), _ => {} } } diff --git a/crates/onnx-ir/src/node/instance_norm.rs b/crates/onnx-ir/src/node/instance_norm.rs index a7862c6879..8a80049f08 100644 --- a/crates/onnx-ir/src/node/instance_norm.rs +++ b/crates/onnx-ir/src/node/instance_norm.rs @@ -1,6 +1,6 @@ use crate::ir::Node; -/// Configuration for InstanceNorm operations +/// Configuration for `InstanceNorm` operations #[derive(Debug, Clone)] pub struct InstanceNormConfig { /// Number of features (channels) @@ -10,7 +10,8 @@ pub struct InstanceNormConfig { } impl InstanceNormConfig { - /// Create a new InstanceNormConfig + /// Create a new `InstanceNormConfig` + #[must_use] pub fn new(num_features: usize, epsilon: f64) -> Self { Self { num_features, @@ -19,7 +20,8 @@ impl InstanceNormConfig { } } -/// Create a InstanceNormConfig from the attributes of the node +/// Create a `InstanceNormConfig` from the attributes of the node +#[must_use] pub fn instance_norm_config(node: &Node) -> InstanceNormConfig { log::debug!("... => '{:?}'", &node.inputs[1]); let weight_shape = node.inputs[1] @@ -32,14 +34,14 @@ pub fn instance_norm_config(node: &Node) -> InstanceNormConfig { let num_features = weight_shape[0]; let mut epsilon = 1e-5; - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "epsilon" => epsilon = value.clone().into_f32(), _ => panic!("Unexpected attribute for InstanceNorm: {key}"), } } - InstanceNormConfig::new(num_features, epsilon as f64) + InstanceNormConfig::new(num_features, f64::from(epsilon)) } #[cfg(test)] diff --git a/crates/onnx-ir/src/node/layer_norm.rs b/crates/onnx-ir/src/node/layer_norm.rs index e97e0d085a..324cbf8eed 100644 --- a/crates/onnx-ir/src/node/layer_norm.rs +++ b/crates/onnx-ir/src/node/layer_norm.rs @@ -1,6 +1,6 @@ use crate::ir::Node; -/// Configuration for LayerNorm operations +/// Configuration for `LayerNorm` operations #[derive(Debug, Clone)] pub struct LayerNormConfig { /// Number of features/model dimension @@ -10,7 +10,8 @@ pub struct LayerNormConfig { } impl LayerNormConfig { - /// Create a new LayerNormConfig + /// Create a new `LayerNormConfig` + #[must_use] pub fn new(d_model: usize) -> Self { Self { d_model, @@ -19,13 +20,15 @@ impl LayerNormConfig { } /// Set the epsilon value + #[must_use] pub fn with_epsilon(mut self, epsilon: f64) -> Self { self.epsilon = epsilon; self } } -/// Create a LayerNormConfig from the attributes of the node +/// Create a `LayerNormConfig` from the attributes of the node +#[must_use] pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { let weight_shape = node.inputs[1] .value @@ -42,7 +45,7 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { let mut axis = -1; let mut epsilon = 1e-5; - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "axis" => axis = value.clone().into_i64(), "epsilon" => epsilon = value.clone().into_f32(), @@ -51,12 +54,13 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { } } - if axis != -1 && axis != weight_shape.len() as i64 - 1 { - panic!("LayerNorm: normalization is only supported on the last axis right now") - } + assert!( + !(axis != -1 && axis != weight_shape.len() as i64 - 1), + "LayerNorm: normalization is only supported on the last axis right now" + ); ( - LayerNormConfig::new(num_features).with_epsilon(epsilon as f64), + LayerNormConfig::new(num_features).with_epsilon(f64::from(epsilon)), stash_type == 1, ) } diff --git a/crates/onnx-ir/src/node/leaky_relu.rs b/crates/onnx-ir/src/node/leaky_relu.rs index ebeff2464a..7bcc3178f8 100644 --- a/crates/onnx-ir/src/node/leaky_relu.rs +++ b/crates/onnx-ir/src/node/leaky_relu.rs @@ -1,12 +1,13 @@ use crate::ir::Node; -/// Create a LeakyReluConfig from the alpha attribute of the node +/// Create a `LeakyReluConfig` from the alpha attribute of the node +#[must_use] pub fn leaky_relu_config(node: &Node) -> f64 { let mut alpha = 0.01; - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { if key.as_str() == "alpha" { - alpha = value.clone().into_f32() as f64 + alpha = f64::from(value.clone().into_f32()); } } diff --git a/crates/onnx-ir/src/node/linear.rs b/crates/onnx-ir/src/node/linear.rs index a9afa761ac..2829e8cccb 100644 --- a/crates/onnx-ir/src/node/linear.rs +++ b/crates/onnx-ir/src/node/linear.rs @@ -12,7 +12,8 @@ pub struct LinearConfig { } impl LinearConfig { - /// Create a new LinearConfig + /// Create a new `LinearConfig` + #[must_use] pub fn new(d_input: usize, d_output: usize) -> Self { Self { d_input, @@ -22,6 +23,7 @@ impl LinearConfig { } /// Set whether bias is used + #[must_use] pub fn with_bias(mut self, bias: bool) -> Self { self.bias = bias; self @@ -47,11 +49,10 @@ pub fn linear_update_outputs(node: &mut Node) { } } -/// Create a LinearConfig from the attributes of the node +/// Create a `LinearConfig` from the attributes of the node +#[must_use] pub fn linear_config(node: &Node) -> LinearConfig { - if node.inputs.len() < 2 { - panic!("Linear: missing weight tensor"); - } + assert!((node.inputs.len() >= 2), "Linear: missing weight tensor"); let weight_shape = node.inputs[1] .value @@ -61,12 +62,11 @@ pub fn linear_config(node: &Node) -> LinearConfig { .clone(); // check if the weight tensor has at least 2 dimensions - if weight_shape.len() < 2 { - panic!( - "Linear: weight tensor must have at least 2 dimensions (got {:?})", - weight_shape.len() - ); - } + assert!( + (weight_shape.len() >= 2), + "Linear: weight tensor must have at least 2 dimensions (got {:?})", + weight_shape.len() + ); let (in_size, out_size) = (weight_shape[0], weight_shape[1]); diff --git a/crates/onnx-ir/src/node/log_softmax.rs b/crates/onnx-ir/src/node/log_softmax.rs index c06e4b74e1..a4590b60bf 100644 --- a/crates/onnx-ir/src/node/log_softmax.rs +++ b/crates/onnx-ir/src/node/log_softmax.rs @@ -1,17 +1,17 @@ use crate::ir::{ArgType, Node}; -/// Create log_softmax config from the attributes of the node +/// Create `log_softmax` config from the attributes of the node +#[must_use] pub fn log_softmax_config(node: &Node) -> usize { // the axis is the last dimension (Default: 1 per ONNX spec) let mut axis: i64 = -1; // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "LogSoftmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } + assert!( + (node.inputs.len() == 1), + "LogSoftmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); // extract the shape of the input tensor let tensor = match node.inputs.first().unwrap().clone().ty { @@ -20,9 +20,9 @@ pub fn log_softmax_config(node: &Node) -> usize { }; // extract the attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { if key.as_str() == "axis" { - axis = value.clone().into_i64() + axis = value.clone().into_i64(); } } diff --git a/crates/onnx-ir/src/node/matmul.rs b/crates/onnx-ir/src/node/matmul.rs index 403f5d4942..bc2a975894 100644 --- a/crates/onnx-ir/src/node/matmul.rs +++ b/crates/onnx-ir/src/node/matmul.rs @@ -1,7 +1,7 @@ use crate::ir::{ArgType, Node, TensorType}; use core::cmp::max; -/// Update output rank for MatMul based on input ranks. +/// Update output rank for `MatMul` based on input ranks. pub fn matmul_update_outputs(node: &mut Node) { log::debug!("MatMul rank inference for node {}", node.name); diff --git a/crates/onnx-ir/src/node/max_pool1d.rs b/crates/onnx-ir/src/node/max_pool1d.rs index 6112e7fba1..c46d1bb2e4 100644 --- a/crates/onnx-ir/src/node/max_pool1d.rs +++ b/crates/onnx-ir/src/node/max_pool1d.rs @@ -2,7 +2,7 @@ use crate::{ir::Node, node::padding::padding_config_1d}; use super::padding::PaddingConfig1d; -/// Configuration for MaxPool1d operations extracted from ONNX nodes +/// Configuration for `MaxPool1d` operations extracted from ONNX nodes #[derive(Debug, Clone)] pub struct MaxPool1dConfig { /// Kernel size @@ -16,7 +16,8 @@ pub struct MaxPool1dConfig { } impl MaxPool1dConfig { - /// Create a new MaxPool1dConfig + /// Create a new `MaxPool1dConfig` + #[must_use] pub fn new(kernel_size: usize) -> Self { Self { kernel_size, @@ -27,32 +28,36 @@ impl MaxPool1dConfig { } /// Set the stride + #[must_use] pub fn with_stride(mut self, stride: usize) -> Self { self.stride = stride; self } /// Set the padding configuration + #[must_use] pub fn with_padding(mut self, padding: PaddingConfig1d) -> Self { self.padding = padding; self } /// Set the dilation + #[must_use] pub fn with_dilation(mut self, dilation: usize) -> Self { self.dilation = dilation; self } } -/// Create a MaxPool1dConfig from the attributes of the node +/// Create a `MaxPool1dConfig` from the attributes of the node +#[must_use] pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { let mut kernel_shape = Vec::new(); let mut stride = vec![1]; let mut pads = vec![0, 0]; let mut dilation = vec![1]; - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => stride = value.clone().into_i64s(), diff --git a/crates/onnx-ir/src/node/max_pool2d.rs b/crates/onnx-ir/src/node/max_pool2d.rs index 9883f86b6a..997b1bd0a6 100644 --- a/crates/onnx-ir/src/node/max_pool2d.rs +++ b/crates/onnx-ir/src/node/max_pool2d.rs @@ -1,7 +1,7 @@ use crate::ir::Node; use crate::node::padding::{PaddingConfig2d, padding_config_2d}; -/// Configuration for MaxPool2d operations +/// Configuration for `MaxPool2d` operations #[derive(Debug, Clone)] pub struct MaxPool2dConfig { /// Kernel size [height, width] @@ -15,7 +15,8 @@ pub struct MaxPool2dConfig { } impl MaxPool2dConfig { - /// Create a new MaxPool2dConfig + /// Create a new `MaxPool2dConfig` + #[must_use] pub fn new(kernel_size: [usize; 2]) -> Self { Self { kernel_size, @@ -26,32 +27,36 @@ impl MaxPool2dConfig { } /// Set the strides + #[must_use] pub fn with_strides(mut self, strides: [usize; 2]) -> Self { self.strides = strides; self } /// Set the padding configuration + #[must_use] pub fn with_padding(mut self, padding: PaddingConfig2d) -> Self { self.padding = padding; self } /// Set the dilation + #[must_use] pub fn with_dilation(mut self, dilation: [usize; 2]) -> Self { self.dilation = dilation; self } } -/// Create a MaxPool2dConfig from the attributes of the node +/// Create a `MaxPool2dConfig` from the attributes of the node +#[must_use] pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { let mut kernel_shape = Vec::new(); let mut strides = vec![1, 1]; let mut pads = vec![0, 0, 0, 0]; let mut dilations = vec![1, 1]; - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => strides = value.clone().into_i64s(), diff --git a/crates/onnx-ir/src/node/one_hot.rs b/crates/onnx-ir/src/node/one_hot.rs index 1ccff14d45..8dc3d52b06 100644 --- a/crates/onnx-ir/src/node/one_hot.rs +++ b/crates/onnx-ir/src/node/one_hot.rs @@ -1,5 +1,6 @@ use crate::ir::{ArgType, Node, TensorType}; +#[must_use] pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { let depth = curr.inputs[1] .value @@ -18,13 +19,12 @@ pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { let axis = curr .attrs .get("axis") - .map(|val| val.clone().into_i64()) - .unwrap_or(-1); + .map_or(-1, |val| val.clone().into_i64()); (depth as usize, values.try_into().unwrap(), axis) } -/// Update output rank for OneHot (input rank + 1). +/// Update output rank for `OneHot` (input rank + 1). pub fn one_hot_output_shape(node: &mut Node) { log::debug!("OneHot rank inference for node {}", node.name); diff --git a/crates/onnx-ir/src/node/pad.rs b/crates/onnx-ir/src/node/pad.rs index c1eceda76f..fa3d4325a1 100644 --- a/crates/onnx-ir/src/node/pad.rs +++ b/crates/onnx-ir/src/node/pad.rs @@ -10,6 +10,7 @@ pub struct PadConfig { } impl PadConfig { + #[must_use] pub fn new(pads: Vec, constant_value: f32) -> Self { PadConfig { pads, @@ -18,7 +19,8 @@ impl PadConfig { } } -/// Creates a PadConfig from the node attributes and inputs. +/// Creates a `PadConfig` from the node attributes and inputs. +#[must_use] pub fn pad_config(node: &Node) -> PadConfig { fn get_pads_input(node: &Node) -> Vec { if node.inputs.len() <= 1 { @@ -31,12 +33,8 @@ pub fn pad_config(node: &Node) -> PadConfig { } } fn get_pads(node: &Node) -> Vec { - if node.inputs.is_empty() { - panic!("Pad: must provide data as input") - } - if node.inputs.len() >= 4 { - panic!("Pad: axes input is not supported") - } + assert!(!node.inputs.is_empty(), "Pad: must provide data as input"); + assert!((node.inputs.len() < 4), "Pad: axes input is not supported"); let input_dim = match &node.inputs.first().unwrap().ty { ArgType::Tensor(tensor) => tensor.rank, @@ -49,7 +47,7 @@ pub fn pad_config(node: &Node) -> PadConfig { .map(|x| x as usize) .collect(); - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "pads" => { pads = value @@ -57,35 +55,37 @@ pub fn pad_config(node: &Node) -> PadConfig { .into_i64s() .iter() .map(|&x| { - if x < 0 { - panic!("Pad: Negative pad is not supported"); - } + assert!((x >= 0), "Pad: Negative pad is not supported"); x as usize }) - .collect() + .collect(); } "mode" => { let mode = value.clone().into_string(); - if mode != "constant" { - panic!("only constant mode is supported, given mode is {}", mode); - } + assert!( + (mode == "constant"), + "only constant mode is supported, given mode is {mode}" + ); } _ => {} } } - if pads.is_empty() { - panic!("Pad: pads should be given as attribute or as input"); - } + assert!( + !pads.is_empty(), + "Pad: pads should be given as attribute or as input" + ); - if pads.len() != input_dim * 2 { - panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); - } + assert!( + (pads.len() == input_dim * 2), + "Pad: pads should be a 1D tensor of shape [2 * num_axes]" + ); // TODO: Burn's pad should support 1D tensor - if input_dim < 2 { - panic!("Pad: input tensor should be rank 2 or higher"); - } + assert!( + (input_dim >= 2), + "Pad: input tensor should be rank 2 or higher" + ); let left_index = input_dim - 1; let top_index = input_dim - 2; @@ -94,11 +94,10 @@ pub fn pad_config(node: &Node) -> PadConfig { let index_list = [left_index, top_index, right_index, bottom_index]; for (index, &item) in pads.iter().enumerate() { - if !index_list.contains(&index) && item != 0 { - panic!( - "Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions" - ); - } + assert!( + index_list.contains(&index) || item == 0, + "Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions" + ); } let left = pads[left_index]; diff --git a/crates/onnx-ir/src/node/padding.rs b/crates/onnx-ir/src/node/padding.rs index 3fba910027..279211e7aa 100644 --- a/crates/onnx-ir/src/node/padding.rs +++ b/crates/onnx-ir/src/node/padding.rs @@ -13,7 +13,7 @@ impl fmt::Display for PaddingConfig1d { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { PaddingConfig1d::Valid => write!(f, "Valid"), - PaddingConfig1d::Explicit(size) => write!(f, "Explicit({})", size), + PaddingConfig1d::Explicit(size) => write!(f, "Explicit({size})"), } } } @@ -36,7 +36,8 @@ impl fmt::Display for PaddingConfig1d { /// # Remarks /// /// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +/// and not used when the padding is specified as a string, e.g. "`SAME_UPPER`". +#[must_use] pub fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { let [left, right] = [pads[0], pads[1]]; @@ -52,7 +53,7 @@ pub fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { PaddingConfig1d::Explicit(left as usize) } else { // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); + panic!("Padding configuration ({pads:?}) not supported"); } } @@ -70,7 +71,7 @@ impl fmt::Display for PaddingConfig2d { match self { PaddingConfig2d::Valid => write!(f, "Valid"), PaddingConfig2d::Explicit(width, height) => { - write!(f, "Explicit({}, {})", width, height) + write!(f, "Explicit({width}, {height})") } } } @@ -90,7 +91,7 @@ impl fmt::Display for PaddingConfig3d { match self { PaddingConfig3d::Valid => write!(f, "Valid"), PaddingConfig3d::Explicit(width, height, depth) => { - write!(f, "Explicit({}, {}, {})", width, height, depth) + write!(f, "Explicit({width}, {height}, {depth})") } } } @@ -114,7 +115,8 @@ impl fmt::Display for PaddingConfig3d { /// # Remarks /// /// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +/// and not used when the padding is specified as a string, e.g. "`SAME_UPPER`". +#[must_use] pub fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { let [top, left, bottom, right] = [pads[0], pads[1], pads[2], pads[3]]; @@ -128,7 +130,7 @@ pub fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { PaddingConfig2d::Explicit(top as usize, left as usize) } else { // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); + panic!("Padding configuration ({pads:?}) not supported"); } } @@ -150,7 +152,8 @@ pub fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { /// # Remarks /// /// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +/// and not used when the padding is specified as a string, e.g. "`SAME_UPPER`". +#[must_use] pub fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { let [front, top, left, back, bottom, right] = [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; @@ -165,7 +168,7 @@ pub fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { PaddingConfig3d::Explicit(front as usize, top as usize, left as usize) } else { // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); + panic!("Padding configuration ({pads:?}) not supported"); } } diff --git a/crates/onnx-ir/src/node/random.rs b/crates/onnx-ir/src/node/random.rs index c441cea478..0771f5f16e 100644 --- a/crates/onnx-ir/src/node/random.rs +++ b/crates/onnx-ir/src/node/random.rs @@ -6,11 +6,9 @@ use protobuf::Enum; pub fn random_update_output(node: &mut Node) { log::debug!("Random rank inference for node {}", node.name); - let dtype = node - .attrs - .get("dtype") - .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) - .unwrap_or(DataType::FLOAT); + let dtype = node.attrs.get("dtype").map_or(DataType::FLOAT, |val| { + DataType::from_i32(val.clone().into_i32()).unwrap() + }); log::debug!("Random dtype for {}: {:?}", node.name, dtype); let shape = node @@ -47,7 +45,7 @@ mod tests { fn create_test_node(dtype: i32, shape: Vec) -> Node { NodeBuilder::new(NodeType::RandomNormal, "test_random") .output_tensor_f32("output", 0, None) // Rank 0 will be updated - .attr_int("dtype", dtype as i64) + .attr_int("dtype", i64::from(dtype)) .attr_ints("shape", shape) .build() } diff --git a/crates/onnx-ir/src/node/random_like.rs b/crates/onnx-ir/src/node/random_like.rs index 25a09b8d70..c77db725cf 100644 --- a/crates/onnx-ir/src/node/random_like.rs +++ b/crates/onnx-ir/src/node/random_like.rs @@ -2,15 +2,13 @@ use crate::ir::{ArgType, ElementType, Node, TensorType}; use crate::protos::tensor_proto::DataType; use protobuf::Enum; -/// Update output rank for RandomLike operations based on input rank. +/// Update output rank for `RandomLike` operations based on input rank. pub fn random_like_update_output(node: &mut Node) { log::debug!("RandomLike rank inference for node {}", node.name); - let dtype = node - .attrs - .get("dtype") - .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) - .unwrap_or(DataType::FLOAT); + let dtype = node.attrs.get("dtype").map_or(DataType::FLOAT, |val| { + DataType::from_i32(val.clone().into_i32()).unwrap() + }); log::debug!("RandomLike dtype for {}: {:?}", node.name, dtype); let elem_type = match dtype { @@ -46,7 +44,7 @@ mod tests { NodeBuilder::new(NodeType::RandomNormalLike, "test_random_like") .input_tensor_f32("input", input_rank, static_shape) .output_tensor_f32("output", 0, None) // Rank 0 will be updated - .attr_int("dtype", dtype as i64) + .attr_int("dtype", i64::from(dtype)) .build() } diff --git a/crates/onnx-ir/src/node/range.rs b/crates/onnx-ir/src/node/range.rs index 8a21ca5086..0db1567b59 100644 --- a/crates/onnx-ir/src/node/range.rs +++ b/crates/onnx-ir/src/node/range.rs @@ -4,9 +4,11 @@ use crate::ir::{ArgType, ElementType, Node, TensorType}; pub fn range_update_outputs(node: &mut Node) { log::debug!("Range rank inference for node {}", node.name); - if node.inputs.len() != 3 { - panic!("Range: expected 3 inputs, found {}", node.inputs.len()); - } + assert!( + (node.inputs.len() == 3), + "Range: expected 3 inputs, found {}", + node.inputs.len() + ); log::debug!( "Range operation always produces rank 1 tensor for {}", node.name diff --git a/crates/onnx-ir/src/node/reduce_max.rs b/crates/onnx-ir/src/node/reduce_max.rs index 2cae7dfa99..912bae820d 100644 --- a/crates/onnx-ir/src/node/reduce_max.rs +++ b/crates/onnx-ir/src/node/reduce_max.rs @@ -1,6 +1,7 @@ use crate::ir::{ArgType, AttributeValue, Node, TensorType}; -/// Create a ReduceMaxConfig from the attributes of the node +/// Create a `ReduceMaxConfig` from the attributes of the node +#[must_use] pub fn reduce_max_config(node: &Node) -> Option { let mut axes = Vec::new(); let mut keepdims = 1; @@ -11,7 +12,7 @@ pub fn reduce_max_config(node: &Node) -> Option { }; // Extract the attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "axes" => axes = value.clone().into_i64s(), "keepdims" => keepdims = value.clone().into_i64(), @@ -19,18 +20,21 @@ pub fn reduce_max_config(node: &Node) -> Option { } } - if axes.len() > 1 { - panic!("ReduceMax: reducing on multiple dimensions is not supported") - } + assert!( + (axes.len() <= 1), + "ReduceMax: reducing on multiple dimensions is not supported" + ); - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMax: axes must be provided with keepdims") - } + assert!( + !(axes.is_empty() && keepdims == 1), + "ReduceMax: axes must be provided with keepdims" + ); - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMax: the reduce operation must preserve the reduced dimension") - } + // Not supported in Burn + assert!( + axes.is_empty() || keepdims != 0, + "ReduceMax: the reduce operation must preserve the reduced dimension" + ); if axes.is_empty() { None @@ -45,13 +49,14 @@ pub fn reduce_max_config(node: &Node) -> Option { } } -/// Update output rank for ReduceMax based on axes. +/// Update output rank for `ReduceMax` based on axes. pub fn reduce_max_update_outputs(node: &mut Node) { log::debug!("ReduceMax rank inference for node {}", node.name); - if node.inputs.len() != 1 { - panic!("ReduceMax: multiple inputs are not supported"); - } + assert!( + (node.inputs.len() == 1), + "ReduceMax: multiple inputs are not supported" + ); let tensor = match &node.inputs[0].ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), diff --git a/crates/onnx-ir/src/node/reduce_mean.rs b/crates/onnx-ir/src/node/reduce_mean.rs index 81849d2504..80383f8877 100644 --- a/crates/onnx-ir/src/node/reduce_mean.rs +++ b/crates/onnx-ir/src/node/reduce_mean.rs @@ -1,6 +1,7 @@ use crate::ir::{ArgType, AttributeValue, Node, TensorType}; -/// Create a ReduceMeanConfig from the attributes of the node +/// Create a `ReduceMeanConfig` from the attributes of the node +#[must_use] pub fn reduce_mean_config(node: &Node) -> Option { let mut axes = Vec::new(); let mut keepdims = 1; @@ -11,7 +12,7 @@ pub fn reduce_mean_config(node: &Node) -> Option { }; // Extract the attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "axes" => axes = value.clone().into_i64s(), "keepdims" => keepdims = value.clone().into_i64(), @@ -19,18 +20,21 @@ pub fn reduce_mean_config(node: &Node) -> Option { } } - if axes.len() > 1 { - panic!("ReduceMean: reducing on multiple dimensions is not supported") - } + assert!( + (axes.len() <= 1), + "ReduceMean: reducing on multiple dimensions is not supported" + ); - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMean: axes must be provided with keepdims") - } + assert!( + !(axes.is_empty() && keepdims == 1), + "ReduceMean: axes must be provided with keepdims" + ); - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMean: the reduce operation must preserve the reduced dimension") - } + // Not supported in Burn + assert!( + axes.is_empty() || keepdims != 0, + "ReduceMean: the reduce operation must preserve the reduced dimension" + ); if axes.is_empty() { None @@ -45,13 +49,14 @@ pub fn reduce_mean_config(node: &Node) -> Option { } } -/// Update output rank for ReduceMean based on axes. +/// Update output rank for `ReduceMean` based on axes. pub fn reduce_mean_update_outputs(node: &mut Node) { log::debug!("ReduceMean rank inference for node {}", node.name); - if node.inputs.len() != 1 { - panic!("ReduceMean: multiple inputs are not supported"); - } + assert!( + (node.inputs.len() == 1), + "ReduceMean: multiple inputs are not supported" + ); let tensor = match &node.inputs[0].ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), diff --git a/crates/onnx-ir/src/node/reduce_min.rs b/crates/onnx-ir/src/node/reduce_min.rs index 494079454f..1254cf54a4 100644 --- a/crates/onnx-ir/src/node/reduce_min.rs +++ b/crates/onnx-ir/src/node/reduce_min.rs @@ -1,6 +1,7 @@ use crate::ir::{ArgType, AttributeValue, Node, TensorType}; -/// Create a ReduceMinConfig from the attributes of the node +/// Create a `ReduceMinConfig` from the attributes of the node +#[must_use] pub fn reduce_min_config(node: &Node) -> Option { let mut axes = Vec::new(); let mut keepdims = 1; @@ -11,7 +12,7 @@ pub fn reduce_min_config(node: &Node) -> Option { }; // Extract the attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "axes" => axes = value.clone().into_i64s(), "keepdims" => keepdims = value.clone().into_i64(), @@ -19,17 +20,20 @@ pub fn reduce_min_config(node: &Node) -> Option { } } - if axes.len() > 1 { - panic!("ReduceMin: reducing on multiple dimensions is not supported") - } + assert!( + (axes.len() <= 1), + "ReduceMin: reducing on multiple dimensions is not supported" + ); - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMin: axes must be provided with keepdims") - } + assert!( + !(axes.is_empty() && keepdims == 1), + "ReduceMin: axes must be provided with keepdims" + ); - if !axes.is_empty() && keepdims == 0 { - panic!("ReduceMin: the reduce operation must preserve the reduced dimension") - } + assert!( + axes.is_empty() || keepdims != 0, + "ReduceMin: the reduce operation must preserve the reduced dimension" + ); if axes.is_empty() { None @@ -43,13 +47,14 @@ pub fn reduce_min_config(node: &Node) -> Option { } } -/// Update output rank for ReduceMin based on axes. +/// Update output rank for `ReduceMin` based on axes. pub fn reduce_min_update_outputs(node: &mut Node) { log::debug!("ReduceMin rank inference for node {}", node.name); - if node.inputs.len() != 1 { - panic!("ReduceMin: multiple inputs are not supported"); - } + assert!( + (node.inputs.len() == 1), + "ReduceMin: multiple inputs are not supported" + ); let tensor = match &node.inputs[0].ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), diff --git a/crates/onnx-ir/src/node/reduce_prod.rs b/crates/onnx-ir/src/node/reduce_prod.rs index 62196a7b4b..8698a0e62d 100644 --- a/crates/onnx-ir/src/node/reduce_prod.rs +++ b/crates/onnx-ir/src/node/reduce_prod.rs @@ -1,6 +1,7 @@ use crate::ir::{ArgType, AttributeValue, Node, TensorType}; -/// Create a ReduceProdConfig from the attributes of the node +/// Create a `ReduceProdConfig` from the attributes of the node +#[must_use] pub fn reduce_prod_config(node: &Node) -> Option { let mut axes = Vec::new(); let mut keepdims = 1; @@ -11,7 +12,7 @@ pub fn reduce_prod_config(node: &Node) -> Option { }; // Extract the attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "axes" => axes = value.clone().into_i64s(), "keepdims" => keepdims = value.clone().into_i64(), @@ -20,18 +21,21 @@ pub fn reduce_prod_config(node: &Node) -> Option { } } - if axes.len() > 1 { - panic!("ReduceProd: reducing on multiple dimensions is not supported") - } + assert!( + (axes.len() <= 1), + "ReduceProd: reducing on multiple dimensions is not supported" + ); - if axes.is_empty() && keepdims == 1 { - panic!("ReduceProd: axes must be provided with keepdims") - } + assert!( + !(axes.is_empty() && keepdims == 1), + "ReduceProd: axes must be provided with keepdims" + ); - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceProd: the reduce operation must preserve the reduced dimension") - } + // Not supported in Burn + assert!( + axes.is_empty() || keepdims != 0, + "ReduceProd: the reduce operation must preserve the reduced dimension" + ); if axes.is_empty() { None @@ -46,13 +50,14 @@ pub fn reduce_prod_config(node: &Node) -> Option { } } -/// Update output rank for ReduceProd based on axes. +/// Update output rank for `ReduceProd` based on axes. pub fn reduce_prod_update_outputs(node: &mut Node) { log::debug!("ReduceProd rank inference for node {}", node.name); - if node.inputs.len() != 1 { - panic!("ReduceProd: multiple inputs are not supported"); - } + assert!( + (node.inputs.len() == 1), + "ReduceProd: multiple inputs are not supported" + ); let tensor = match &node.inputs[0].ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), diff --git a/crates/onnx-ir/src/node/reduce_sum.rs b/crates/onnx-ir/src/node/reduce_sum.rs index 6b2fae1283..3b8ae5262f 100644 --- a/crates/onnx-ir/src/node/reduce_sum.rs +++ b/crates/onnx-ir/src/node/reduce_sum.rs @@ -1,6 +1,7 @@ use crate::ir::{ArgType, AttributeValue, Data, Node, TensorType}; -/// Create a ReduceSumConfig from the attributes of the node +/// Create a `ReduceSumConfig` from the attributes of the node +#[must_use] pub fn reduce_sum_config(node: &Node) -> Option { let mut axes = Vec::new(); let mut keepdims = 1; @@ -11,7 +12,7 @@ pub fn reduce_sum_config(node: &Node) -> Option { }; // Extract the attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "keepdims" => keepdims = value.clone().into_i64(), "axes" => axes = value.clone().into_i64s(), @@ -29,18 +30,21 @@ pub fn reduce_sum_config(node: &Node) -> Option { axes = value.clone().data.into_i64s(); } - if axes.len() > 1 { - panic!("ReduceSum: reducing on multiple dimensions is not supported") - } + assert!( + (axes.len() <= 1), + "ReduceSum: reducing on multiple dimensions is not supported" + ); - if axes.is_empty() && keepdims == 1 { - panic!("ReduceSum: axes must be provided with keepdims") - } + assert!( + !(axes.is_empty() && keepdims == 1), + "ReduceSum: axes must be provided with keepdims" + ); - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceSum: the reduce operation must preserve the reduced dimension") - } + // Not supported in Burn + assert!( + axes.is_empty() || keepdims != 0, + "ReduceSum: the reduce operation must preserve the reduced dimension" + ); if axes.is_empty() { None @@ -55,7 +59,7 @@ pub fn reduce_sum_config(node: &Node) -> Option { } } -/// Update output rank for ReduceSum based on axes. +/// Update output rank for `ReduceSum` based on axes. pub fn reduce_sum_update_outputs(node: &mut Node) { log::debug!("ReduceSum rank inference for node {}", node.name); diff --git a/crates/onnx-ir/src/node/reshape.rs b/crates/onnx-ir/src/node/reshape.rs index b48b4176b4..6f1dbce14d 100644 --- a/crates/onnx-ir/src/node/reshape.rs +++ b/crates/onnx-ir/src/node/reshape.rs @@ -6,21 +6,20 @@ pub fn reshape_update_outputs(node: &mut Node) { let shape = if node.inputs.len() == 2 { log::debug!("Reshape node {} has shape as second input", node.name); - match &node.inputs[1].value { - Some(value) => match &value.data { + if let Some(value) = &node.inputs[1].value { + match &value.data { Data::Int64s(shape) => { log::debug!("Reshape node {} has constant shape: {:?}", node.name, shape); Some(shape.clone()) } _ => panic!("Reshape: invalid input types"), - }, - None => { - log::debug!( - "Reshape node {} has dynamic shape as second input", - node.name - ); - None } + } else { + log::debug!( + "Reshape node {} has dynamic shape as second input", + node.name + ); + None } } else { log::debug!("Reshape node {} using shape from attributes", node.name); @@ -50,10 +49,11 @@ pub fn reshape_update_outputs(node: &mut Node) { }); } +#[must_use] pub fn reshape_config(node: &Node) -> Vec { let mut allowzero = 0; - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "allowzero" => allowzero = value.clone().into_i64(), "shape" => {} // This can be used when shape is not provided as input - handled elsewhere @@ -63,14 +63,13 @@ pub fn reshape_config(node: &Node) -> Vec { // Burn does not support zero size shape (0 means false in ONNX) // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) - if allowzero != 0 { - panic!("Zero shape size is not supported"); - } + assert!((allowzero == 0), "Zero shape size is not supported"); // TODO: check "shape" attribute - if node.inputs.len() != 2 || node.inputs[1].value.is_none() { - panic!("Reshape: shape tensor must be present for {:?}", node); - } + assert!( + !(node.inputs.len() != 2 || node.inputs[1].value.is_none()), + "Reshape: shape tensor must be present for {node:?}" + ); match &node.inputs[1].value { Some(TensorData { data, shape, .. }) => { diff --git a/crates/onnx-ir/src/node/resize.rs b/crates/onnx-ir/src/node/resize.rs index 4d43e58824..6e857a6c90 100644 --- a/crates/onnx-ir/src/node/resize.rs +++ b/crates/onnx-ir/src/node/resize.rs @@ -1,7 +1,8 @@ use crate::ir::{ArgType, Node, TensorData}; +#[must_use] pub fn resize_config(node: &Node) -> (String, Vec, Vec) { - let mut mode: String = "".to_string(); + let mut mode: String = String::new(); let mut scales: Vec; let mut sizes: Vec; @@ -22,7 +23,7 @@ pub fn resize_config(node: &Node) -> (String, Vec, Vec) { // However, some attributes are important to be checked and we are checking // against the default values of the attributes. // TODO revisit this when we have more Resize operators in the model - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "antialias" => assert_eq!( value.clone().into_i32(), @@ -31,7 +32,7 @@ pub fn resize_config(node: &Node) -> (String, Vec, Vec) { ), "axes" => panic!("Resize: custom axes attribute is not supported"), "coordinate_transformation_mode" => { - log::warn!("Resize: coordinate_transformation_mode is ignored") + log::warn!("Resize: coordinate_transformation_mode is ignored"); } "cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"), @@ -50,7 +51,7 @@ pub fn resize_config(node: &Node) -> (String, Vec, Vec) { value.clone().into_string().to_lowercase(), "stretch", "Resize: keep_aspect_ratio_policy other than 'stretch' is not supported" - ) + ); } "mode" => mode = value.clone().into_string().to_lowercase(), "nearest_mode" => log::warn!("Resize: nearest_mode is ignored"), @@ -99,30 +100,27 @@ pub fn resize_config(node: &Node) -> (String, Vec, Vec) { }) .unwrap_or_default(); - if mode.is_empty() { - panic!("Resize: mode attribute is required") - } + assert!(!mode.is_empty(), "Resize: mode attribute is required"); - if !roi.is_empty() { - panic!("Resize: roi input is not supported") - } + assert!(roi.is_empty(), "Resize: roi input is not supported"); - if scales.is_empty() && sizes.is_empty() { - panic!("Resize: either scales or sizes input is required") - } + assert!( + !(scales.is_empty() && sizes.is_empty()), + "Resize: either scales or sizes input is required" + ); if !scales.is_empty() { assert!(scales.len() == input.rank); // ignore the fist two items from scales // because they are the batch and channel dimensions - scales = scales.iter().skip(2).cloned().collect(); + scales = scales.iter().skip(2).copied().collect(); } if !sizes.is_empty() { assert!(sizes.len() == input.rank); // ignore the fist two items from sizes // because they are the batch and channel dimensions - sizes = sizes.iter().skip(2).cloned().collect(); + sizes = sizes.iter().skip(2).copied().collect(); } (mode, scales, sizes) diff --git a/crates/onnx-ir/src/node/shape.rs b/crates/onnx-ir/src/node/shape.rs index 1d4a9770e1..2e7adee196 100644 --- a/crates/onnx-ir/src/node/shape.rs +++ b/crates/onnx-ir/src/node/shape.rs @@ -1,12 +1,12 @@ use crate::ir::{ArgType, Node}; +#[must_use] pub fn shape_config(curr: &Node) -> (usize, usize) { - if curr.inputs.len() != 1 { - panic!( - "Shape: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } + assert!( + (curr.inputs.len() == 1), + "Shape: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); // Extract the shape of the input tensor let tensor = match curr.inputs.first().unwrap().clone().ty { @@ -19,7 +19,7 @@ pub fn shape_config(curr: &Node) -> (usize, usize) { let mut end_dim: i64 = tensor.rank as i64; // Extract the attributes - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "start" => start_dim = value.clone().into_i64(), "end" => end_dim = value.clone().into_i64(), @@ -40,9 +40,10 @@ pub fn shape_config(curr: &Node) -> (usize, usize) { /// Update output type for Shape operation (rank 1). pub fn shape_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Shape: multiple inputs are not supported: {:?}", node); - } + assert!( + (node.inputs.len() == 1), + "Shape: multiple inputs are not supported: {node:?}" + ); let (start, end) = shape_config(node); let dim = end - start; log::debug!( diff --git a/crates/onnx-ir/src/node/slice.rs b/crates/onnx-ir/src/node/slice.rs index f1441db485..af22c9997d 100644 --- a/crates/onnx-ir/src/node/slice.rs +++ b/crates/onnx-ir/src/node/slice.rs @@ -5,6 +5,7 @@ use crate::ir::{ArgType, Data, Node, TensorData}; /// /// Note: we leave the negative indices as is, but we need to handle them properly when slicing /// during the actual slicing operation using the dynamic shape information. +#[must_use] pub fn slice_config(node: &Node) -> Vec> { /// Extracts int64 values from a node's input at the specified index. /// Returns an empty vector if the input is not provided. @@ -30,7 +31,7 @@ pub fn slice_config(node: &Node) -> Vec> { // Reference: https://burn.dev/docs/burn/prelude/struct.Tensor.html#method.slice // TODO: Default missing axes ranges to the full range of the corresponding axis - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "starts" => starts = value.clone().into_i64s(), "ends" => ends = value.clone().into_i64s(), @@ -40,9 +41,10 @@ pub fn slice_config(node: &Node) -> Vec> { } } - if !steps.is_empty() && steps.iter().any(|&x| x != 1) { - panic!("Slice: steps other than 1 are not supported"); - } + assert!( + steps.is_empty() || steps.iter().all(|&x| x == 1), + "Slice: steps other than 1 are not supported" + ); // Extract the rank of the input tensor let input_rank = match node.inputs.first().unwrap().clone().ty { @@ -57,9 +59,10 @@ pub fn slice_config(node: &Node) -> Vec> { } // Validate input dimensions - if starts.len() != ends.len() || starts.len() != axes.len() { - panic!("Slice: starts, ends, and axes must have the same length"); - } + assert!( + !(starts.len() != ends.len() || starts.len() != axes.len()), + "Slice: starts, ends, and axes must have the same length" + ); // Convert negative axes indices to positive (counting from the end) for axis in &mut axes { @@ -146,7 +149,15 @@ mod tests { .input_tensor_f32("data", 3, None) .output_default("output"); - if !use_attrs { + if use_attrs { + // Add attributes + builder = builder.attr_ints("starts", starts); + builder = builder.attr_ints("ends", ends); + + if let Some(axes_vec) = axes { + builder = builder.attr_ints("axes", axes_vec); + } + } else { // Add inputs as tensors builder = builder.input_tensor_i64_data("starts", starts.clone(), vec![starts.len()]); builder = builder.input_tensor_i64_data("ends", ends.clone(), vec![ends.len()]); @@ -155,14 +166,6 @@ mod tests { builder = builder.input_tensor_i64_data("axes", axes_vec.clone(), vec![axes_vec.len()]); } - } else { - // Add attributes - builder = builder.attr_ints("starts", starts); - builder = builder.attr_ints("ends", ends); - - if let Some(axes_vec) = axes { - builder = builder.attr_ints("axes", axes_vec); - } } builder.build() diff --git a/crates/onnx-ir/src/node/softmax.rs b/crates/onnx-ir/src/node/softmax.rs index 60f018a2ae..7e17c2981b 100644 --- a/crates/onnx-ir/src/node/softmax.rs +++ b/crates/onnx-ir/src/node/softmax.rs @@ -1,17 +1,17 @@ use crate::ir::{ArgType, Node}; /// Create softmax config from the attributes of the node +#[must_use] pub fn softmax_config(node: &Node) -> usize { // the axis is the last dimension (Default: 1 per ONNX spec) let mut axis: i64 = -1; // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Softmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } + assert!( + (node.inputs.len() == 1), + "Softmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); // extract the shape of the input tensor let tensor = match node.inputs.first().unwrap().clone().ty { @@ -20,9 +20,9 @@ pub fn softmax_config(node: &Node) -> usize { }; // extract the attributes - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { if key.as_str() == "axis" { - axis = value.clone().into_i64() + axis = value.clone().into_i64(); } } diff --git a/crates/onnx-ir/src/node/split.rs b/crates/onnx-ir/src/node/split.rs index 8f5f70af9a..4bcf640320 100644 --- a/crates/onnx-ir/src/node/split.rs +++ b/crates/onnx-ir/src/node/split.rs @@ -37,6 +37,7 @@ pub struct SplitConfig { } impl SplitConfig { + #[must_use] pub fn new(axis: usize, split_size: Option, split_sizes: Option>) -> Self { SplitConfig { axis, @@ -46,7 +47,8 @@ impl SplitConfig { } } -/// Creates a SplitConfig from the node attributes and inputs. +/// Creates a `SplitConfig` from the node attributes and inputs. +#[must_use] pub fn split_config(node: &Node) -> SplitConfig { // Initialize the axis to split along (default is 0 as per ONNX specification) let mut axis: i64 = 0; @@ -65,7 +67,7 @@ pub fn split_config(node: &Node) -> SplitConfig { let mut num_outputs: Option = None; // Iterate through node attributes to extract relevant values - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { match key.as_str() { "axis" => axis = value.clone().into_i64(), "num_outputs" => num_outputs = Some(value.clone().into_i64() as usize), @@ -75,9 +77,10 @@ pub fn split_config(node: &Node) -> SplitConfig { // Handle the case when num_outputs is provided to calculate uniform split size if let Some(num_outputs) = num_outputs { - if num_outputs == 0 { - panic!("Split: 'num_outputs' must be a positive value greater than zero"); - } + assert!( + (num_outputs != 0), + "Split: 'num_outputs' must be a positive value greater than zero" + ); let dim_size = tensor .static_shape @@ -86,13 +89,12 @@ pub fn split_config(node: &Node) -> SplitConfig { // Calculate the split size considering any remainder for non-evenly divisible dimensions let calculated_split_size = - dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); + dim_size / (num_outputs - usize::from(dim_size % num_outputs != 0)); - if calculated_split_size == 0 { - panic!( - "Split: Calculated split size is zero. Please ensure 'num_outputs' is valid for the dimension size" - ); - } + assert!( + (calculated_split_size != 0), + "Split: Calculated split size is zero. Please ensure 'num_outputs' is valid for the dimension size" + ); // Assign the calculated split size split_size = Some(calculated_split_size); @@ -119,11 +121,10 @@ pub fn split_config(node: &Node) -> SplitConfig { } // Ensure that only one of 'split_sizes' or 'num_outputs' is specified - if split_sizes.is_some() && split_size.is_some() { - panic!( - "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" - ); - } + assert!( + !(split_sizes.is_some() && split_size.is_some()), + "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" + ); // Infer split_size if neither custom split_sizes nor split_size is provided if split_sizes.is_none() && split_size.is_none() { @@ -135,13 +136,12 @@ pub fn split_config(node: &Node) -> SplitConfig { // Calculate inferred split size based on number of outputs let calculated_split_size = - dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); + dim_size / (num_outputs - usize::from(dim_size % num_outputs != 0)); - if calculated_split_size == 0 { - panic!( - "Split: Inferred split size is zero. Please ensure the number of outputs is valid for the dimension size" - ); - } + assert!( + (calculated_split_size != 0), + "Split: Inferred split size is zero. Please ensure the number of outputs is valid for the dimension size" + ); split_size = Some(calculated_split_size); } @@ -183,7 +183,7 @@ mod tests { // Add output tensors for i in 0..num_outputs { builder = builder.output_tensor_f32( - &format!("output_{}", i), + &format!("output_{i}"), 0, // Will be updated None, ); diff --git a/crates/onnx-ir/src/node/squeeze.rs b/crates/onnx-ir/src/node/squeeze.rs index cdb4c9c8c4..abce6739fa 100644 --- a/crates/onnx-ir/src/node/squeeze.rs +++ b/crates/onnx-ir/src/node/squeeze.rs @@ -35,7 +35,10 @@ pub fn squeeze_update_output(node: &mut Node) { None => None, } } else { - node.attrs.get("axes").cloned().map(|v| v.into_i64s()) + node.attrs + .get("axes") + .cloned() + .map(super::super::ir::AttributeValue::into_i64s) }; let axes = axes.unwrap_or_else(|| panic!("Squeeze must specify an axis")); @@ -43,7 +46,7 @@ pub fn squeeze_update_output(node: &mut Node) { let input_rank = match &node.inputs[0].ty { ArgType::Tensor(tensor) => tensor.rank, - ty => panic!("Squeeze: invalid input type: {:?}", ty), + ty => panic!("Squeeze: invalid input type: {ty:?}"), }; log::debug!("Squeeze input rank for {}: {}", node.name, input_rank); @@ -65,7 +68,7 @@ mod tests { use crate::node::test_utils::NodeBuilder; fn create_test_node(axes: Option>, rank: usize) -> Node { - let output_rank = rank - (axes.as_ref().map_or(0, |a| a.len())); + let output_rank = rank - (axes.as_ref().map_or(0, std::vec::Vec::len)); let mut builder = NodeBuilder::new(NodeType::Squeeze, "test_squeeze") .input_tensor_f32("data", rank, None) diff --git a/crates/onnx-ir/src/node/test_utils.rs b/crates/onnx-ir/src/node/test_utils.rs index d6df87f50f..41371aef74 100644 --- a/crates/onnx-ir/src/node/test_utils.rs +++ b/crates/onnx-ir/src/node/test_utils.rs @@ -14,6 +14,7 @@ pub struct NodeBuilder { impl NodeBuilder { /// Create a new builder with the specified node type and name + #[must_use] pub fn new(node_type: NodeType, name: &str) -> Self { Self { node_type, @@ -29,6 +30,7 @@ impl NodeBuilder { /// Note: Prefer using the specialized methods like `input_tensor_f32`, /// `input_scalar_f32`, etc. for better readability and type safety. #[doc(hidden)] + #[must_use] pub fn add_input(mut self, name: &str, ty: ArgType) -> Self { self.inputs.push(Argument { name: name.to_string(), @@ -40,6 +42,7 @@ impl NodeBuilder { } /// Add a float32 tensor input with the given name and rank + #[must_use] pub fn input_tensor_f32( self, name: &str, @@ -57,6 +60,7 @@ impl NodeBuilder { } /// Add a float64 tensor input with the given name and rank + #[must_use] pub fn input_tensor_f64( self, name: &str, @@ -74,6 +78,7 @@ impl NodeBuilder { } /// Add an int32 tensor input with the given name and rank + #[must_use] pub fn input_tensor_i32( self, name: &str, @@ -91,6 +96,7 @@ impl NodeBuilder { } /// Add an int64 tensor input with the given name and rank + #[must_use] pub fn input_tensor_i64( self, name: &str, @@ -108,6 +114,7 @@ impl NodeBuilder { } /// Add a bool tensor input with the given name and rank + #[must_use] pub fn input_tensor_bool( self, name: &str, @@ -125,6 +132,7 @@ impl NodeBuilder { } /// Add a float16 tensor input with the given name and rank + #[must_use] pub fn input_tensor_f16( self, name: &str, @@ -142,6 +150,7 @@ impl NodeBuilder { } /// Add a string tensor input with the given name and rank + #[must_use] pub fn input_tensor_string( self, name: &str, @@ -159,26 +168,31 @@ impl NodeBuilder { } /// Add a scalar input with the given name and element type + #[must_use] pub fn input_scalar(self, name: &str, elem_type: ElementType) -> Self { self.add_input(name, ArgType::Scalar(elem_type)) } /// Add a float32 scalar input with the given name + #[must_use] pub fn input_scalar_f32(self, name: &str) -> Self { self.input_scalar(name, ElementType::Float32) } /// Add an int64 scalar input with the given name + #[must_use] pub fn input_scalar_i64(self, name: &str) -> Self { self.input_scalar(name, ElementType::Int64) } /// Add a shape input with the given name and rank + #[must_use] pub fn input_shape(self, name: &str, rank: usize) -> Self { self.add_input(name, ArgType::Shape(rank)) } /// Add a tensor input with data value + #[must_use] pub fn input_tensor_with_data( mut self, name: &str, @@ -202,6 +216,7 @@ impl NodeBuilder { } /// Add a float32 tensor input with data values + #[must_use] pub fn input_tensor_f32_data(self, name: &str, data: Vec, shape: Vec) -> Self { self.input_tensor_with_data( name, @@ -213,6 +228,7 @@ impl NodeBuilder { } /// Add an int64 tensor input with data values + #[must_use] pub fn input_tensor_i64_data(self, name: &str, data: Vec, shape: Vec) -> Self { self.input_tensor_with_data( name, @@ -224,6 +240,7 @@ impl NodeBuilder { } /// Add a float32 scalar tensor input (rank 0) + #[must_use] pub fn input_scalar_tensor_f32(mut self, name: &str, value: Option) -> Self { let arg = Argument { name: name.to_string(), @@ -243,6 +260,7 @@ impl NodeBuilder { } /// Add an int64 scalar tensor input (rank 0) + #[must_use] pub fn input_scalar_tensor_i64(mut self, name: &str, value: i64) -> Self { let arg = Argument { name: name.to_string(), @@ -262,6 +280,7 @@ impl NodeBuilder { } /// Add multiple tensor inputs with the same type but different names + #[must_use] pub fn input_tensors_f32( mut self, name_prefix: &str, @@ -270,11 +289,7 @@ impl NodeBuilder { static_shape: Option>, ) -> Self { for i in 0..count { - self = self.input_tensor_f32( - &format!("{}_{}", name_prefix, i), - rank, - static_shape.clone(), - ); + self = self.input_tensor_f32(&format!("{name_prefix}_{i}"), rank, static_shape.clone()); } self } @@ -284,6 +299,7 @@ impl NodeBuilder { /// Note: Prefer using the specialized methods like `output_tensor_f32`, /// `output_scalar_f32`, etc. for better readability and type safety. #[doc(hidden)] + #[must_use] pub fn add_output(mut self, name: &str, ty: ArgType) -> Self { self.outputs.push(Argument { name: name.to_string(), @@ -295,6 +311,7 @@ impl NodeBuilder { } /// Add a float32 tensor output with the given name and rank + #[must_use] pub fn output_tensor_f32( self, name: &str, @@ -312,6 +329,7 @@ impl NodeBuilder { } /// Add a float64 tensor output with the given name and rank + #[must_use] pub fn output_tensor_f64( self, name: &str, @@ -329,6 +347,7 @@ impl NodeBuilder { } /// Add an int32 tensor output with the given name and rank + #[must_use] pub fn output_tensor_i32( self, name: &str, @@ -346,6 +365,7 @@ impl NodeBuilder { } /// Add an int64 tensor output with the given name and rank + #[must_use] pub fn output_tensor_i64( self, name: &str, @@ -363,6 +383,7 @@ impl NodeBuilder { } /// Add a bool tensor output with the given name and rank + #[must_use] pub fn output_tensor_bool( self, name: &str, @@ -380,6 +401,7 @@ impl NodeBuilder { } /// Add a float16 tensor output with the given name and rank + #[must_use] pub fn output_tensor_f16( self, name: &str, @@ -397,6 +419,7 @@ impl NodeBuilder { } /// Add a string tensor output with the given name and rank + #[must_use] pub fn output_tensor_string( self, name: &str, @@ -414,26 +437,31 @@ impl NodeBuilder { } /// Add a scalar output with the given name and element type + #[must_use] pub fn output_scalar(self, name: &str, elem_type: ElementType) -> Self { self.add_output(name, ArgType::Scalar(elem_type)) } /// Add a float32 scalar output with the given name + #[must_use] pub fn output_scalar_f32(self, name: &str) -> Self { self.output_scalar(name, ElementType::Float32) } /// Add an int64 scalar output with the given name + #[must_use] pub fn output_scalar_i64(self, name: &str) -> Self { self.output_scalar(name, ElementType::Int64) } /// Add a shape output with the given name and rank + #[must_use] pub fn output_shape(self, name: &str, rank: usize) -> Self { self.add_output(name, ArgType::Shape(rank)) } /// Add an integer attribute + #[must_use] pub fn attr_int(mut self, name: &str, value: i64) -> Self { self.attrs .insert(name.to_string(), AttributeValue::Int64(value)); @@ -441,6 +469,7 @@ impl NodeBuilder { } /// Add a float attribute + #[must_use] pub fn attr_float(mut self, name: &str, value: f32) -> Self { self.attrs .insert(name.to_string(), AttributeValue::Float32(value)); @@ -448,6 +477,7 @@ impl NodeBuilder { } /// Add a string attribute + #[must_use] pub fn attr_string(mut self, name: &str, value: &str) -> Self { self.attrs .insert(name.to_string(), AttributeValue::String(value.to_string())); @@ -455,6 +485,7 @@ impl NodeBuilder { } /// Add an integer array attribute + #[must_use] pub fn attr_ints(mut self, name: &str, values: Vec) -> Self { self.attrs .insert(name.to_string(), AttributeValue::Int64s(values)); @@ -462,6 +493,7 @@ impl NodeBuilder { } /// Add a float array attribute + #[must_use] pub fn attr_floats(mut self, name: &str, values: Vec) -> Self { self.attrs .insert(name.to_string(), AttributeValue::Float32s(values)); @@ -469,6 +501,7 @@ impl NodeBuilder { } /// Add a string array attribute + #[must_use] pub fn attr_strings(mut self, name: &str, values: Vec) -> Self { self.attrs .insert(name.to_string(), AttributeValue::Strings(values)); @@ -476,6 +509,7 @@ impl NodeBuilder { } /// Add a tensor attribute + #[must_use] pub fn attr_tensor(mut self, name: &str, tensor: TensorData) -> Self { self.attrs .insert(name.to_string(), AttributeValue::Tensor(tensor)); @@ -483,6 +517,7 @@ impl NodeBuilder { } /// Add a default output with the given name + #[must_use] pub fn output_default(mut self, name: &str) -> Self { self.outputs.push(Argument { name: name.to_string(), @@ -494,6 +529,7 @@ impl NodeBuilder { } /// Build the node + #[must_use] pub fn build(self) -> Node { Node { node_type: self.node_type, diff --git a/crates/onnx-ir/src/node/tile.rs b/crates/onnx-ir/src/node/tile.rs index 9538fd635c..a5122b6067 100644 --- a/crates/onnx-ir/src/node/tile.rs +++ b/crates/onnx-ir/src/node/tile.rs @@ -8,12 +8,14 @@ pub struct TileConfig { } impl TileConfig { + #[must_use] pub fn new(repeats: Vec) -> Self { TileConfig { repeats } } } -/// Creates a TileConfig from the node attributes and inputs. +/// Creates a `TileConfig` from the node attributes and inputs. +#[must_use] pub fn tile_config(node: &Node) -> TileConfig { let repeat = node .inputs diff --git a/crates/onnx-ir/src/node/topk.rs b/crates/onnx-ir/src/node/topk.rs index c4a7219156..8c65c36e1d 100644 --- a/crates/onnx-ir/src/node/topk.rs +++ b/crates/onnx-ir/src/node/topk.rs @@ -1,6 +1,6 @@ use crate::ir::{ArgType, ElementType, Node, TensorType}; -/// Update output rank for TopK (same as input rank). +/// Update output rank for `TopK` (same as input rank). pub fn top_k_update_output(node: &mut Node) { log::debug!("TopK rank inference for node {}", node.name); @@ -28,7 +28,7 @@ pub fn top_k_update_output(node: &mut Node) { ); } -/// Configuration for the TopK operation. +/// Configuration for the `TopK` operation. #[derive(Debug, Clone, PartialEq)] pub struct TopKConfig { /// The axis along which to perform the top-k selection. @@ -38,13 +38,15 @@ pub struct TopKConfig { } impl TopKConfig { - /// Creates a new TopKConfig. + /// Creates a new `TopKConfig`. + #[must_use] pub fn new(axis: usize, k: usize) -> Self { Self { axis, k } } } -/// Creates a TopKConfig from the node attributes and inputs. +/// Creates a `TopKConfig` from the node attributes and inputs. +#[must_use] pub fn top_k_config(node: &Node) -> TopKConfig { // Extract the shape of the input data tensor let data_tensor = match node.inputs.first().unwrap().clone().ty { @@ -81,13 +83,13 @@ pub fn top_k_config(node: &Node) -> TopKConfig { if largest.clone().into_i64() != 1 { unimplemented!("TopK: only largest elements is supported") } - }; + } if let Some(sorted) = node.attrs.get("sorted") { if sorted.clone().into_i64() != 1 { unimplemented!("TopK: only sorted elements is supported") } - }; + } TopKConfig::new(axis as usize, k as usize) } diff --git a/crates/onnx-ir/src/node/transpose.rs b/crates/onnx-ir/src/node/transpose.rs index fbff00c2c9..f062e1846d 100644 --- a/crates/onnx-ir/src/node/transpose.rs +++ b/crates/onnx-ir/src/node/transpose.rs @@ -1,12 +1,12 @@ use crate::ir::{ArgType, Node}; +#[must_use] pub fn transpose_config(curr: &Node) -> Vec { - if curr.inputs.len() != 1 { - panic!( - "Transpose: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } + assert!( + (curr.inputs.len() == 1), + "Transpose: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); // Extract the shape of the input tensor let tensor = match curr.inputs.first().unwrap().clone().ty { diff --git a/crates/onnx-ir/src/node/trilu.rs b/crates/onnx-ir/src/node/trilu.rs index 41a022f777..4f5dbaf513 100644 --- a/crates/onnx-ir/src/node/trilu.rs +++ b/crates/onnx-ir/src/node/trilu.rs @@ -10,19 +10,21 @@ pub struct TriluConfig { } impl TriluConfig { - /// Creates a TriluConfig from the node attributes and inputs. + /// Creates a `TriluConfig` from the node attributes and inputs. + #[must_use] pub fn new(upper: bool, diagonal: i64) -> Self { Self { upper, diagonal } } } -/// Creates a TriluConfig from the node attributes and inputs. +/// Creates a `TriluConfig` from the node attributes and inputs. +#[must_use] pub fn trilu_config(node: &Node) -> TriluConfig { let mut upper = true; let mut diagonal = 0; - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { if key.as_str() == "upper" { - upper = value.clone().into_i64() != 0 + upper = value.clone().into_i64() != 0; } } // The second input of the Trilu node is the diagonal value, coming from a constant node diff --git a/crates/onnx-ir/src/node/unsqueeze.rs b/crates/onnx-ir/src/node/unsqueeze.rs index 0d18d843b9..2e0173c366 100644 --- a/crates/onnx-ir/src/node/unsqueeze.rs +++ b/crates/onnx-ir/src/node/unsqueeze.rs @@ -72,13 +72,14 @@ pub enum UnsqueezeConfig { Runtime(Argument), } -/// Creates UnsqueezeAxes configuration from the node attributes. +/// Creates `UnsqueezeAxes` configuration from the node attributes. /// /// Note: This function should only execute if the second input is a constant. /// If it wasn't and the output shape was known, unsqueeze has been remapped to reshape. +#[must_use] pub fn unsqueeze_config(node: &Node) -> UnsqueezeConfig { // Check if axes attribute exists - for (key, value) in node.attrs.iter() { + for (key, value) in &node.attrs { if key.as_str() == "axes" { return UnsqueezeConfig::Static(value.clone().into_i64s()); } diff --git a/crates/onnx-ir/src/proto_conversion.rs b/crates/onnx-ir/src/proto_conversion.rs index 31cafe7fb7..17b7cce4a0 100644 --- a/crates/onnx-ir/src/proto_conversion.rs +++ b/crates/onnx-ir/src/proto_conversion.rs @@ -20,7 +20,7 @@ pub enum ParseError { VariantNotFound, } -/// Convert a vector of AttributeProto to a HashMap of AttributeValue +/// Convert a vector of `AttributeProto` to a `HashMap` of `AttributeValue` impl TryFrom for TensorData { type Error = ParseError; fn try_from(tensor: TensorProto) -> Result { @@ -28,10 +28,10 @@ impl TryFrom for TensorData { DataType::FLOAT => ( ElementType::Float32, // Convert the raw data to a vector of floats - if !tensor.raw_data.is_empty() { - Data::Float32s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { + if tensor.raw_data.is_empty() { Data::Float32s(tensor.float_data) + } else { + Data::Float32s(cast_slice(&tensor.raw_data[..]).to_vec()) }, ), DataType::INT16 => { @@ -41,28 +41,28 @@ impl TryFrom for TensorData { DataType::INT32 => ( ElementType::Int32, // Convert the raw data to a vector of ints - if !tensor.raw_data.is_empty() { - Data::Int32s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { + if tensor.raw_data.is_empty() { Data::Int32s(tensor.int32_data) + } else { + Data::Int32s(cast_slice(&tensor.raw_data[..]).to_vec()) }, ), DataType::INT64 => ( ElementType::Int64, // Convert the raw data to a vector of ints - if !tensor.raw_data.is_empty() { - Data::Int64s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { + if tensor.raw_data.is_empty() { Data::Int64s(tensor.int64_data) + } else { + Data::Int64s(cast_slice(&tensor.raw_data[..]).to_vec()) }, ), DataType::DOUBLE => ( ElementType::Float64, // Convert the raw data to a vector of floats - if !tensor.raw_data.is_empty() { - Data::Float64s(cast_slice(&tensor.raw_data[..]).to_vec()) - } else { + if tensor.raw_data.is_empty() { Data::Float64s(tensor.double_data) + } else { + Data::Float64s(cast_slice(&tensor.raw_data[..]).to_vec()) }, ), DataType::BOOL => (ElementType::Bool, { @@ -76,7 +76,7 @@ impl TryFrom for TensorData { }; let shape = convert_shape(tensor.dims); - Ok(TensorData { shape, data }) + Ok(TensorData { data, shape }) } } @@ -103,7 +103,7 @@ fn convert_vec_tensor_proto(tensors: Vec) -> Result Ok(result) } -/// Convert a vector of AttributeProto to a HashMap of AttributeValue +/// Convert a vector of `AttributeProto` to a `HashMap` of `AttributeValue` impl TryFrom for AttributeValue { type Error = ParseError; @@ -136,7 +136,7 @@ impl TryFrom for AttributeValue { } } -/// Convert a vector of AttributeProto to a HashMap of AttributeValue +/// Convert a vector of `AttributeProto` to a `HashMap` of `AttributeValue` pub fn convert_vec_attrs_proto(attrs: Vec) -> Attributes { let mut result = Attributes::new(); for attr in attrs { @@ -190,9 +190,10 @@ impl TryFrom for Argument { let name = value.name.clone(); let proto_type = value.type_.unwrap(); - if !proto_type.has_tensor_type() { - panic!("Unsupported argument type {:?}", proto_type); - } + assert!( + proto_type.has_tensor_type(), + "Unsupported argument type {proto_type:?}" + ); let tensor_proto = proto_type.tensor_type(); diff --git a/crates/onnx-ir/src/util.rs b/crates/onnx-ir/src/util.rs index fe63f651a1..5355449697 100644 --- a/crates/onnx-ir/src/util.rs +++ b/crates/onnx-ir/src/util.rs @@ -2,13 +2,13 @@ use crate::ir::{ArgType, Node, TensorType}; use crate::protos::OperatorSetIdProto; +#[must_use] pub fn shape_config(curr: &Node) -> (usize, usize) { - if curr.inputs.len() != 1 { - panic!( - "Shape: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } + assert!( + (curr.inputs.len() == 1), + "Shape: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); // Extract the shape of the input tensor let tensor = match curr.inputs.first().unwrap().clone().ty { @@ -21,7 +21,7 @@ pub fn shape_config(curr: &Node) -> (usize, usize) { let mut end_dim: i64 = tensor.rank as i64; // Extract the attributes - for (key, value) in curr.attrs.iter() { + for (key, value) in &curr.attrs { match key.as_str() { "start" => start_dim = value.clone().into_i64(), "end" => end_dim = value.clone().into_i64(), @@ -56,9 +56,10 @@ pub fn shape_config(curr: &Node) -> (usize, usize) { /// * If the domain is not the empty ONNX domain pub fn check_opset_version(opset: &OperatorSetIdProto, min_version: i64) -> bool { // For now, only empty domain (standard ONNX operators) is supported - if !opset.domain.is_empty() { - panic!("Only the standard ONNX domain is supported"); - } + assert!( + opset.domain.is_empty(), + "Only the standard ONNX domain is supported" + ); // Return true if the opset version is greater than or equal to min_version opset.version >= min_version @@ -174,7 +175,7 @@ mod tests { for (i, rank) in input_ranks.iter().enumerate() { inputs.push(Argument { - name: format!("input_{}", i), + name: format!("input_{i}"), ty: ArgType::Tensor(TensorType { elem_type: ElementType::Float32, rank: *rank, @@ -198,7 +199,7 @@ mod tests { Node { node_type: op_type.clone(), - name: format!("test_{:?}", op_type).to_lowercase(), + name: format!("test_{op_type:?}").to_lowercase(), inputs, outputs, attrs: HashMap::new(), diff --git a/examples/custom-csv-dataset/Cargo.toml b/examples/custom-csv-dataset/Cargo.toml index 0cd9ea07e9..61368592db 100644 --- a/examples/custom-csv-dataset/Cargo.toml +++ b/examples/custom-csv-dataset/Cargo.toml @@ -7,6 +7,9 @@ description = "Example implementation for loading a custom CSV dataset from disk publish = false version.workspace = true +[lints] +workspace = true + [features] default = ["burn/dataset"] diff --git a/examples/custom-csv-dataset/examples/custom-csv-dataset.rs b/examples/custom-csv-dataset/examples/custom-csv-dataset.rs index 0fb464c704..5ff62879b7 100644 --- a/examples/custom-csv-dataset/examples/custom-csv-dataset.rs +++ b/examples/custom-csv-dataset/examples/custom-csv-dataset.rs @@ -8,8 +8,8 @@ fn main() { // Display first and last elements let item = dataset.get(0).unwrap(); - println!("First item:\n{:?}", item); + println!("First item:\n{item:?}"); let item = dataset.get(441).unwrap(); - println!("Last item:\n{:?}", item); + println!("Last item:\n{item:?}"); } diff --git a/examples/custom-csv-dataset/src/dataset.rs b/examples/custom-csv-dataset/src/dataset.rs index 26552eb6cf..8462a152d1 100644 --- a/examples/custom-csv-dataset/src/dataset.rs +++ b/examples/custom-csv-dataset/src/dataset.rs @@ -91,10 +91,10 @@ impl DiabetesDataset { let file_name = example_dir.join("diabetes.csv"); if file_name.exists() { - println!("File already downloaded at {:?}", file_name); + println!("File already downloaded at {file_name:?}"); } else { // Get file from web - println!("Downloading file to {:?}", file_name); + println!("Downloading file to {file_name:?}"); let url = "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt"; let mut response = reqwest::blocking::get(url).unwrap(); @@ -103,7 +103,7 @@ impl DiabetesDataset { // Copy the downloaded contents copy(&mut response, &mut file).unwrap(); - }; + } file_name } diff --git a/examples/custom-cubecl-kernel/Cargo.toml b/examples/custom-cubecl-kernel/Cargo.toml index bff47f62ca..9a40dd9096 100644 --- a/examples/custom-cubecl-kernel/Cargo.toml +++ b/examples/custom-cubecl-kernel/Cargo.toml @@ -6,6 +6,9 @@ name = "custom-cubecl-kernel" publish = false version.workspace = true +[lints] +workspace = true + [dependencies] burn = { path = "../../crates/burn", default-features = false, features = [ "autodiff", diff --git a/examples/custom-cubecl-kernel/src/lib.rs b/examples/custom-cubecl-kernel/src/lib.rs index d699d53cfe..0c245a516a 100644 --- a/examples/custom-cubecl-kernel/src/lib.rs +++ b/examples/custom-cubecl-kernel/src/lib.rs @@ -13,7 +13,7 @@ pub trait Backend: burn::tensor::backend::Backend { ) -> FloatTensor; } -/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait. +/// We create our own `AutodiffBackend` trait that extends the Burn autodiff backend trait. pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {} /// We define our custom implementation using the added function on our custom backend. diff --git a/examples/custom-image-dataset/Cargo.toml b/examples/custom-image-dataset/Cargo.toml index f14825d9f7..1115ce44ce 100644 --- a/examples/custom-image-dataset/Cargo.toml +++ b/examples/custom-image-dataset/Cargo.toml @@ -7,6 +7,9 @@ description = "Example implementation for loading a custom image dataset from di publish = false version.workspace = true +[lints] +workspace = true + [features] default = ["burn/default"] tch-gpu = ["burn/tch"] diff --git a/examples/custom-renderer/Cargo.toml b/examples/custom-renderer/Cargo.toml index aaa5a7597b..d7aa4e1ca5 100644 --- a/examples/custom-renderer/Cargo.toml +++ b/examples/custom-renderer/Cargo.toml @@ -7,6 +7,9 @@ description = "Example of how to render training progress outside of the tui" publish = false version.workspace = true +[lints] +workspace = true + [dependencies] burn = {path = "../../crates/burn", features=["autodiff", "wgpu", "train", "dataset", "vision"], default-features=false} guide = {path = "../guide", default-features=false} diff --git a/examples/custom-training-loop/Cargo.toml b/examples/custom-training-loop/Cargo.toml index 6e1fca1e92..8552388a3a 100644 --- a/examples/custom-training-loop/Cargo.toml +++ b/examples/custom-training-loop/Cargo.toml @@ -6,6 +6,9 @@ name = "custom-training-loop" publish = false version.workspace = true +[lints] +workspace = true + [dependencies] burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]} guide = {path = "../guide"} diff --git a/examples/custom-training-loop/src/lib.rs b/examples/custom-training-loop/src/lib.rs index 810ad968a7..6eca05bac9 100644 --- a/examples/custom-training-loop/src/lib.rs +++ b/examples/custom-training-loop/src/lib.rs @@ -58,7 +58,7 @@ pub fn run(device: B::Device) { .build(MnistDataset::test()); // Iterate over our training and validation loop for X epochs. - for epoch in 1..config.num_epochs + 1 { + for epoch in 1..=config.num_epochs { // Implement our training loop. for (iteration, batch) in dataloader_train.iter().enumerate() { let output = model.forward(batch.images); diff --git a/examples/custom-wgpu-kernel/Cargo.toml b/examples/custom-wgpu-kernel/Cargo.toml index 67e3f438c1..ffc219db0c 100644 --- a/examples/custom-wgpu-kernel/Cargo.toml +++ b/examples/custom-wgpu-kernel/Cargo.toml @@ -6,6 +6,9 @@ name = "custom-wgpu-kernel" publish = false version.workspace = true +[lints] +workspace = true + [dependencies] burn = { path = "../../crates/burn", default-features = false, features = [ "autodiff", diff --git a/examples/custom-wgpu-kernel/src/lib.rs b/examples/custom-wgpu-kernel/src/lib.rs index 0c69ee237f..d453e47bf1 100644 --- a/examples/custom-wgpu-kernel/src/lib.rs +++ b/examples/custom-wgpu-kernel/src/lib.rs @@ -12,7 +12,7 @@ pub trait Backend: burn::tensor::backend::Backend { ) -> FloatTensor; } -/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait. +/// We create our own `AutodiffBackend` trait that extends the Burn autodiff backend trait. pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {} /// We define our custom implementation using the added function on our custom backend. diff --git a/examples/guide/Cargo.toml b/examples/guide/Cargo.toml index aea61f5e25..835177191f 100644 --- a/examples/guide/Cargo.toml +++ b/examples/guide/Cargo.toml @@ -6,6 +6,9 @@ name = "guide" publish = false version.workspace = true +[lints] +workspace = true + [features] default = ["burn/default"] diff --git a/examples/guide/src/bin/print.rs b/examples/guide/src/bin/print.rs index 6f3b710c25..2bf200a858 100644 --- a/examples/guide/src/bin/print.rs +++ b/examples/guide/src/bin/print.rs @@ -7,5 +7,5 @@ fn main() { let device = Default::default(); let model = ModelConfig::new(10, 512).init::(&device); - println!("{}", model); + println!("{model}"); } diff --git a/examples/guide/src/data.rs b/examples/guide/src/data.rs index f8cf8b8840..7ba4849169 100644 --- a/examples/guide/src/data.rs +++ b/examples/guide/src/data.rs @@ -28,7 +28,7 @@ impl Batcher> for MnistBatcher { let targets = items .iter() .map(|item| { - Tensor::::from_data([(item.label as i64).elem::()], device) + Tensor::::from_data([i64::from(item.label).elem::()], device) }) .collect(); diff --git a/examples/guide/src/inference.rs b/examples/guide/src/inference.rs index 8acf23a6fc..835847039d 100644 --- a/examples/guide/src/inference.rs +++ b/examples/guide/src/inference.rs @@ -20,5 +20,5 @@ pub fn infer(artifact_dir: &str, device: B::Device, item: MnistItem) let output = model.forward(batch.images); let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); - println!("Predicted {} Expected {}", predicted, label); + println!("Predicted {predicted} Expected {label}"); } diff --git a/examples/guide/src/model.rs b/examples/guide/src/model.rs index 91e646e52d..4e2285d162 100644 --- a/examples/guide/src/model.rs +++ b/examples/guide/src/model.rs @@ -43,8 +43,8 @@ impl ModelConfig { impl Model { /// # Shapes - /// - Images [batch_size, height, width] - /// - Output [batch_size, class_prob] + /// - Images [`batch_size`, height, width] + /// - Output [`batch_size`, `class_prob`] pub fn forward(&self, images: Tensor) -> Tensor { let [batch_size, height, width] = images.dims(); diff --git a/examples/image-classification-web/Cargo.toml b/examples/image-classification-web/Cargo.toml index 385f49a0a2..7dabab0d7d 100644 --- a/examples/image-classification-web/Cargo.toml +++ b/examples/image-classification-web/Cargo.toml @@ -6,6 +6,9 @@ name = "image-classification-web" publish = false version.workspace = true +[lints] +workspace = true + [lib] crate-type = ["cdylib"] diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 254655ab4e..5843b6dc2c 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -34,10 +34,10 @@ pub enum ModelType { /// The model is loaded to the Candle backend WithCandleBackend(Model>), - /// The model is loaded to the NdArray backend + /// The model is loaded to the `NdArray` backend WithNdArrayBackend(Model>), - /// The model is loaded to the WebGpu backend + /// The model is loaded to the `WebGpu` backend WithWgpuBackend(Model>), } @@ -56,6 +56,7 @@ pub struct ImageClassifier { impl ImageClassifier { /// Constructor called by JavaScripts with the new keyword. #[wasm_bindgen(constructor)] + #[must_use] pub fn new() -> Self { log::info!("Initializing the image classifier"); let device = Default::default(); @@ -78,7 +79,7 @@ impl ImageClassifier { let duration = start.elapsed(); - log::debug!("Inference is completed in {:?}", duration); + log::debug!("Inference is completed in {duration:?}"); top_5_classes(result) } @@ -90,7 +91,7 @@ impl ImageClassifier { let device = Default::default(); self.model = ModelType::WithCandleBackend(Model::new(&device)); let duration = start.elapsed(); - log::debug!("Model is loaded to the Candle backend in {:?}", duration); + log::debug!("Model is loaded to the Candle backend in {duration:?}"); Ok(()) } @@ -101,7 +102,7 @@ impl ImageClassifier { let device = Default::default(); self.model = ModelType::WithNdArrayBackend(Model::new(&device)); let duration = start.elapsed(); - log::debug!("Model is loaded to the NdArray backend in {:?}", duration); + log::debug!("Model is loaded to the NdArray backend in {duration:?}"); Ok(()) } @@ -113,13 +114,13 @@ impl ImageClassifier { init_setup_async::(&device, Default::default()).await; self.model = ModelType::WithWgpuBackend(Model::new(&device)); let duration = start.elapsed(); - log::debug!("Model is loaded to the Wgpu backend in {:?}", duration); + log::debug!("Model is loaded to the Wgpu backend in {duration:?}"); log::debug!("Warming up the model"); let start = Instant::now(); let _ = self.inference(&[0.0; HEIGHT * WIDTH * CHANNELS]).await; let duration = start.elapsed(); - log::debug!("Warming up is completed in {:?}", duration); + log::debug!("Warming up is completed in {duration:?}"); Ok(()) } } @@ -172,7 +173,7 @@ pub struct InferenceResult { label: String, } -/// Returns the top 5 classes and convert them into a JsValue +/// Returns the top 5 classes and convert them into a `JsValue` fn top_5_classes(probabilities: Vec) -> Result { // Convert the probabilities into a vector of (index, probability) let mut probabilities: Vec<_> = probabilities.iter().enumerate().collect(); diff --git a/examples/import-model-weights/Cargo.toml b/examples/import-model-weights/Cargo.toml index d68ab69ef6..3d2cbb9c10 100644 --- a/examples/import-model-weights/Cargo.toml +++ b/examples/import-model-weights/Cargo.toml @@ -6,6 +6,9 @@ name = "import-model-weights" publish = false version = "0.18.0" +[lints] +workspace = true + [dependencies] burn = { path = "../../crates/burn", features = [ diff --git a/examples/import-model-weights/src/bin/convert.rs b/examples/import-model-weights/src/bin/convert.rs index db52b829e3..34a0601641 100644 --- a/examples/import-model-weights/src/bin/convert.rs +++ b/examples/import-model-weights/src/bin/convert.rs @@ -38,34 +38,26 @@ pub fn main() { // Load the model record based on the specified format let model_record: ModelRecord = match weight_format { "pytorch" => { - println!("Loading PyTorch weights from '{}'...", PYTORCH_WEIGHTS_PATH); + println!("Loading PyTorch weights from '{PYTORCH_WEIGHTS_PATH}'..."); PyTorchFileRecorder::::default() .load(PYTORCH_WEIGHTS_PATH.into(), &device) .unwrap_or_else(|_| { - panic!( - "Failed to load PyTorch model weights from '{}'", - PYTORCH_WEIGHTS_PATH - ) + panic!("Failed to load PyTorch model weights from '{PYTORCH_WEIGHTS_PATH}'") }) } "safetensors" => { - println!( - "Loading Safetensors weights from '{}'...", - SAFETENSORS_WEIGHTS_PATH - ); + println!("Loading Safetensors weights from '{SAFETENSORS_WEIGHTS_PATH}'..."); SafetensorsFileRecorder::::default() .load(SAFETENSORS_WEIGHTS_PATH.into(), &device) .unwrap_or_else(|_| { panic!( - "Failed to load Safetensors model weights from '{}'", - SAFETENSORS_WEIGHTS_PATH + "Failed to load Safetensors model weights from '{SAFETENSORS_WEIGHTS_PATH}'" ) }) } _ => { eprintln!( - "Error: Unsupported weight format '{}'. Please use 'pytorch' or 'safetensors'.", - weight_format + "Error: Unsupported weight format '{weight_format}'. Please use 'pytorch' or 'safetensors'." ); process::exit(1); } diff --git a/examples/import-model-weights/src/bin/pytorch.rs b/examples/import-model-weights/src/bin/pytorch.rs index 2d94f80898..07b98a7942 100644 --- a/examples/import-model-weights/src/bin/pytorch.rs +++ b/examples/import-model-weights/src/bin/pytorch.rs @@ -10,7 +10,7 @@ type B = NdArray; const WEIGHTS_FILE: &str = "weights/mnist.pt"; pub fn main() { - println!("Loading PyTorch model weights from file: {}", WEIGHTS_FILE); + println!("Loading PyTorch model weights from file: {WEIGHTS_FILE}"); // Load PyTorch weights into a model record. let record: ModelRecord = PyTorchFileRecorder::::default() diff --git a/examples/import-model-weights/src/bin/safetensors.rs b/examples/import-model-weights/src/bin/safetensors.rs index 2370803976..5b55f756a8 100644 --- a/examples/import-model-weights/src/bin/safetensors.rs +++ b/examples/import-model-weights/src/bin/safetensors.rs @@ -10,10 +10,7 @@ type B = NdArray; const WEIGHTS_FILE: &str = "weights/mnist.safetensors"; pub fn main() { - println!( - "Loading Safetensors model weights from file: {}", - WEIGHTS_FILE - ); + println!("Loading Safetensors model weights from file: {WEIGHTS_FILE}"); // Load Safetensors weights exported from PyTorch into a model record. let record: ModelRecord = SafetensorsFileRecorder::::default() .load(WEIGHTS_FILE.into(), &Default::default()) diff --git a/examples/import-model-weights/src/inference.rs b/examples/import-model-weights/src/inference.rs index 3cb4ef3751..cd6a4d0623 100644 --- a/examples/import-model-weights/src/inference.rs +++ b/examples/import-model-weights/src/inference.rs @@ -13,7 +13,7 @@ pub fn infer(record: ModelRecord) { // Get image index argument (first) from command line let image_index = if let Some(image_index) = args().nth(1) { - println!("Image index: {}", image_index); + println!("Image index: {image_index}"); image_index .parse::() .expect("Failed to parse image index") @@ -51,7 +51,7 @@ pub fn infer(record: ModelRecord) { assert!(arg_max == item.label); println!("Success!"); - println!("Predicted: {}", arg_max); + println!("Predicted: {arg_max}"); println!("Actual: {}", item.label); println!("See the image online, click the link below:"); println!("https://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row={image_index}"); diff --git a/examples/mnist-inference-web/Cargo.toml b/examples/mnist-inference-web/Cargo.toml index b7aec27a15..4d9091873a 100644 --- a/examples/mnist-inference-web/Cargo.toml +++ b/examples/mnist-inference-web/Cargo.toml @@ -6,6 +6,9 @@ name = "mnist-inference-web" publish = false version.workspace = true +[lints] +workspace = true + [lib] crate-type = ["cdylib"] diff --git a/examples/mnist-inference-web/src/web.rs b/examples/mnist-inference-web/src/web.rs index bd2623ddff..eb32974fa9 100644 --- a/examples/mnist-inference-web/src/web.rs +++ b/examples/mnist-inference-web/src/web.rs @@ -25,8 +25,9 @@ pub struct Mnist { #[cfg_attr(target_family = "wasm", wasm_bindgen)] impl Mnist { - /// Constructor called by JavaScripts with the new keyword. + /// Constructor called by `JavaScripts` with the new keyword. #[cfg_attr(target_family = "wasm", wasm_bindgen(constructor))] + #[must_use] pub fn new() -> Self { console_error_panic_hook::set_once(); Self { model: None } diff --git a/examples/mnist/Cargo.toml b/examples/mnist/Cargo.toml index 0c13a4dcf5..007e9a6d1c 100644 --- a/examples/mnist/Cargo.toml +++ b/examples/mnist/Cargo.toml @@ -6,6 +6,9 @@ name = "mnist" publish = false version.workspace = true +[lints] +workspace = true + [features] default = ["burn/dataset", "burn/vision"] ndarray = ["burn/ndarray"] diff --git a/examples/mnist/src/data.rs b/examples/mnist/src/data.rs index 8b6d81bc26..5ae1c260b7 100644 --- a/examples/mnist/src/data.rs +++ b/examples/mnist/src/data.rs @@ -29,7 +29,7 @@ impl Batcher> for MnistBatcher { .iter() .map(|item| { Tensor::::from_data( - TensorData::from([(item.label as i64).elem::()]), + TensorData::from([i64::from(item.label).elem::()]), device, ) }) diff --git a/examples/modern-lstm/Cargo.toml b/examples/modern-lstm/Cargo.toml index 4454236c82..1650b03b03 100644 --- a/examples/modern-lstm/Cargo.toml +++ b/examples/modern-lstm/Cargo.toml @@ -3,6 +3,9 @@ name = "modern-lstm" version = "0.2.0" edition.workspace = true +[lints] +workspace = true + [features] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] diff --git a/examples/modern-lstm/src/dataset.rs b/examples/modern-lstm/src/dataset.rs index 190c5390a0..6f146aff32 100644 --- a/examples/modern-lstm/src/dataset.rs +++ b/examples/modern-lstm/src/dataset.rs @@ -23,6 +23,7 @@ pub struct SequenceDatasetItem { } impl SequenceDatasetItem { + #[must_use] pub fn new(seq_length: usize, noise_level: f32) -> Self { // Start with two random numbers between 0 and 1 let mut seq = vec![rand::rng().random(), rand::rng().random()]; @@ -50,6 +51,7 @@ pub struct SequenceDataset { } impl SequenceDataset { + #[must_use] pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self { let mut items = vec![]; for _i in 0..num_sequences { @@ -84,7 +86,7 @@ impl Batcher> for SequenceB fn batch(&self, items: Vec, device: &B::Device) -> SequenceBatch { let mut sequences: Vec> = Vec::new(); - for item in items.iter() { + for item in &items { let seq_tensor = Tensor::::from_floats(item.sequence.as_slice(), device); // Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations sequences.push(seq_tensor.unsqueeze_dims(&[-1])); diff --git a/examples/modern-lstm/src/model.rs b/examples/modern-lstm/src/model.rs index 5ac62c9b14..6cf31f988c 100644 --- a/examples/modern-lstm/src/model.rs +++ b/examples/modern-lstm/src/model.rs @@ -9,18 +9,18 @@ use burn::{ /// LSTM Cell implementation with layer normalization. /// /// Mathematical formulation of LSTM: -/// f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # Forget gate -/// i_t = σ(W_i · [h_{t-1}, x_t] + b_i] # Input gate -/// g_t = tanh(W_g · [h_{t-1}, x_t] + b_g] # Candidate cell state -/// o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # Output gate +/// `f_t` = `σ(W_f` · [h_{t-1}, `x_t`] + `b_f`) # Forget gate +/// `i_t` = `σ(W_i` · [h_{t-1}, `x_t`] + `b_i`] # Input gate +/// `g_t` = `tanh(W_g` · [h_{t-1}, `x_t`] + `b_g`] # Candidate cell state +/// `o_t` = `σ(W_o` · [h_{t-1}, `x_t`] + `b_o`) # Output gate /// -/// c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t # New cell state -/// h_t = o_t ⊙ tanh(c_t) # New hidden state +/// `c_t` = `f_t` ⊙ c_{t-1} + `i_t` ⊙ `g_t` # New cell state +/// `h_t` = `o_t` ⊙ `tanh(c_t)` # New hidden state /// /// where: /// - σ is the sigmoid function /// - ⊙ is the element-wise multiplication -/// - [h_{t-1}, x_t] represents concatenation +/// - [h_{t-1}, `x_t`] represents concatenation #[derive(Module, Debug)] pub struct LstmCell { @@ -95,10 +95,10 @@ impl LstmCellConfig { impl LstmCell { /// Forward pass of LSTM cell. /// Args: - /// x: Input tensor of shape (batch_size, input_size) - /// state: Tuple of (h_{t-1}, c_{t-1}) each of shape (batch_size, hidden_size) + /// x: Input tensor of shape (`batch_size`, `input_size`) + /// state: Tuple of (h_{t-1}, c_{t-1}) each of shape (`batch_size`, `hidden_size`) /// Returns: - /// Tuple of (h_t, c_t) representing new hidden and cell states + /// Tuple of (`h_t`, `c_t`) representing new hidden and cell states pub fn forward(&self, x: Tensor, state: LstmState) -> LstmState { let (h_prev, c_prev) = (state.hidden, state.cell); @@ -200,12 +200,12 @@ impl StackedLstm { /// Process input sequence through stacked LSTM layers. /// /// Args: - /// x: Input tensor of shape (batch_size, seq_length, input_size) + /// x: Input tensor of shape (`batch_size`, `seq_length`, `input_size`) /// states: Optional initial states for each layer /// /// Returns: - /// Tuple of (output, states) where output has shape (batch_size, seq_length, hidden_size) - /// and states is a vector of length num_layers, both cell and hidden state in each element have shape (batch_size, hidden_size) + /// Tuple of (output, states) where output has shape (`batch_size`, `seq_length`, `hidden_size`) + /// and states is a vector of length `num_layers`, both cell and hidden state in each element have shape (`batch_size`, `hidden_size`) pub fn forward( &self, x: Tensor, @@ -217,7 +217,7 @@ impl StackedLstm { let mut states = match states { None => { let mut temp: Vec> = vec![]; - for layer in self.layers.iter() { + for layer in &self.layers { temp.push(layer.init_state(batch_size, &device)); } temp @@ -227,7 +227,7 @@ impl StackedLstm { let mut layer_outputs = vec![]; for t in 0..seq_length { - let mut input_t = x.clone().slice(s![.., t..t + 1, ..]).squeeze::<2>(1); + let mut input_t = x.clone().slice(s![.., t..=t, ..]).squeeze::<2>(1); for (i, lstm_cell) in self.layers.iter().enumerate() { let mut state: LstmState = LstmState::new(states[i].cell.clone(), states[i].hidden.clone()); @@ -324,11 +324,11 @@ impl LstmNetwork { /// 4. Apply final linear transformation /// /// Args: - /// x: Input tensor of shape (batch_size, seq_length, input_size) + /// x: Input tensor of shape (`batch_size`, `seq_length`, `input_size`) /// states: Optional initial states /// /// Returns: - /// Output tensor of shape (batch_size, output_size) + /// Output tensor of shape (`batch_size`, `output_size`) pub fn forward(&self, x: Tensor, states: Option>>) -> Tensor { let seq_length = x.dims()[1]; // Forward direction diff --git a/examples/modern-lstm/src/training.rs b/examples/modern-lstm/src/training.rs index b066e8b145..ab3cc0133f 100644 --- a/examples/modern-lstm/src/training.rs +++ b/examples/modern-lstm/src/training.rs @@ -73,7 +73,7 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, dev println!("Starting training..."); // Iterate over our training for X epochs - for epoch in 1..config.num_epochs + 1 { + for epoch in 1..=config.num_epochs { // Initialize the training and validation metrics at the start of each epoch let mut train_losses = vec![]; let mut train_loss = 0.0; diff --git a/examples/named-tensor/Cargo.toml b/examples/named-tensor/Cargo.toml index 628df32c6d..51c0686b41 100644 --- a/examples/named-tensor/Cargo.toml +++ b/examples/named-tensor/Cargo.toml @@ -6,6 +6,9 @@ name = "named-tensor" publish = false version.workspace = true +[lints] +workspace = true + [dependencies] burn = {path = "../../crates/burn", features = ["experimental-named-tensor", "ndarray"]} diff --git a/examples/onnx-inference/Cargo.toml b/examples/onnx-inference/Cargo.toml index 253fe8c35d..2a2d97b78e 100644 --- a/examples/onnx-inference/Cargo.toml +++ b/examples/onnx-inference/Cargo.toml @@ -6,6 +6,9 @@ name = "onnx-inference" publish = false version.workspace = true +[lints] +workspace = true + [features] default = ["embedded-model"] diff --git a/examples/onnx-inference/src/bin/mnist_inference.rs b/examples/onnx-inference/src/bin/mnist_inference.rs index e5492e0249..689d7aad61 100644 --- a/examples/onnx-inference/src/bin/mnist_inference.rs +++ b/examples/onnx-inference/src/bin/mnist_inference.rs @@ -14,7 +14,7 @@ fn main() { // Get image index argument (first) from command line let image_index = if let Some(image_index) = args().nth(1) { - println!("Image index: {}", image_index); + println!("Image index: {image_index}"); image_index .parse::() .expect("Failed to parse image index") @@ -55,7 +55,7 @@ fn main() { assert!(arg_max == item.label); println!("Success!"); - println!("Predicted: {}", arg_max); + println!("Predicted: {arg_max}"); println!("Actual: {}", item.label); println!("See the image online, click the link below:"); println!("https://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row={image_index}"); diff --git a/examples/raspberry-pi-pico/Cargo.toml b/examples/raspberry-pi-pico/Cargo.toml index d3db347567..5119ebbfd5 100644 --- a/examples/raspberry-pi-pico/Cargo.toml +++ b/examples/raspberry-pi-pico/Cargo.toml @@ -5,6 +5,9 @@ name = "raspberry-pi-pico" license = "MIT OR Apache-2.0" version = "0.1.0" +[lints] +workspace = true + [dependencies] embassy-embedded-hal = { version = "0.3.0", features = ["defmt"] } embassy-executor = { version = "0.7.0", features = ["arch-cortex-m", "executor-thread", "executor-interrupt", "defmt"] } diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index c91b91b722..b95f8e2bc6 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -6,6 +6,9 @@ name = "server" publish = false version.workspace = true +[lints] +workspace = true + [features] default = ["webgpu"] cuda = ["burn/cuda", "cubecl/compilation-cache"] diff --git a/examples/simple-regression/Cargo.toml b/examples/simple-regression/Cargo.toml index 108d9b567e..ffa37a9221 100644 --- a/examples/simple-regression/Cargo.toml +++ b/examples/simple-regression/Cargo.toml @@ -6,6 +6,9 @@ name = "simple-regression" publish = false version.workspace = true +[lints] +workspace = true + [features] default = ["burn/dataset", "burn/sqlite-bundled"] ndarray = ["burn/ndarray"] diff --git a/examples/simple-regression/examples/regression.rs b/examples/simple-regression/examples/regression.rs index cac92231a2..7d4329aab3 100644 --- a/examples/simple-regression/examples/regression.rs +++ b/examples/simple-regression/examples/regression.rs @@ -65,7 +65,7 @@ mod remote { /// Train a regression model and predict results on a number of samples. pub fn run(device: B::Device) { training::run::>(ARTIFACT_DIR, device.clone()); - inference::infer::(ARTIFACT_DIR, device) + inference::infer::(ARTIFACT_DIR, device); } fn main() { diff --git a/examples/simple-regression/src/dataset.rs b/examples/simple-regression/src/dataset.rs index 3af3289885..8da3af01c8 100644 --- a/examples/simple-regression/src/dataset.rs +++ b/examples/simple-regression/src/dataset.rs @@ -68,18 +68,22 @@ impl Dataset for HousingDataset { } impl HousingDataset { + #[must_use] pub fn train() -> Self { Self::new("train") } + #[must_use] pub fn validation() -> Self { Self::new("validation") } + #[must_use] pub fn test() -> Self { Self::new("test") } + #[must_use] pub fn new(split: &str) -> Self { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("gvlassis/california_housing") @@ -142,7 +146,7 @@ impl Batcher> for HousingBat fn batch(&self, items: Vec, device: &B::Device) -> HousingBatch { let mut inputs: Vec> = Vec::new(); - for item in items.iter() { + for item in &items { let input_tensor = Tensor::::from_floats( [ item.median_income, diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 8cc83f3f07..00fcd513b4 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -6,6 +6,9 @@ name = "text-classification" publish = false version.workspace = true +[lints] +workspace = true + [features] default = [] f16 = [] diff --git a/examples/text-classification/src/data/batcher.rs b/examples/text-classification/src/data/batcher.rs index 170e096626..18622484a4 100644 --- a/examples/text-classification/src/data/batcher.rs +++ b/examples/text-classification/src/data/batcher.rs @@ -36,7 +36,7 @@ pub struct TextClassificationInferenceBatch { pub mask_pad: Tensor, // Padding mask for the tokenized text } -/// Implement Batcher trait for TextClassificationBatcher struct for training +/// Implement Batcher trait for `TextClassificationBatcher` struct for training impl Batcher> for TextClassificationBatcher { @@ -75,7 +75,7 @@ impl Batcher Batcher> for TextClassificationBatcher { diff --git a/examples/text-classification/src/data/dataset.rs b/examples/text-classification/src/data/dataset.rs index ba3e2a29dc..d8af0ebf6c 100644 --- a/examples/text-classification/src/data/dataset.rs +++ b/examples/text-classification/src/data/dataset.rs @@ -50,16 +50,19 @@ impl Dataset for AgNewsDataset { // Implement methods for constructing the AG News dataset impl AgNewsDataset { /// Returns the training portion of the dataset + #[must_use] pub fn train() -> Self { Self::new("train") } /// Returns the testing portion of the dataset + #[must_use] pub fn test() -> Self { Self::new("test") } /// Constructs the dataset from a split (either "train" or "test") + #[must_use] pub fn new(split: &str) -> Self { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("ag_news") .dataset(split) @@ -68,7 +71,7 @@ impl AgNewsDataset { } } -/// Implements the TextClassificationDataset trait for the AG News dataset +/// Implements the `TextClassificationDataset` trait for the AG News dataset impl TextClassificationDataset for AgNewsDataset { /// Returns the number of unique classes in the dataset fn num_classes() -> usize { @@ -88,7 +91,7 @@ impl TextClassificationDataset for AgNewsDataset { } } -/// Struct for items in the DbPedia dataset +/// Struct for items in the `DbPedia` dataset #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct DbPediaItem { pub title: String, // The title of the item @@ -96,12 +99,12 @@ pub struct DbPediaItem { pub label: usize, // The label of the item (classification category) } -/// Struct for the DbPedia dataset +/// Struct for the `DbPedia` dataset pub struct DbPediaDataset { dataset: SqliteDataset, // Underlying SQLite dataset } -/// Implements the Dataset trait for the DbPedia dataset +/// Implements the Dataset trait for the `DbPedia` dataset impl Dataset for DbPediaDataset { /// Returns a specific item from the dataset fn get(&self, index: usize) -> Option { @@ -119,19 +122,22 @@ impl Dataset for DbPediaDataset { } } -/// Implement methods for constructing the DbPedia dataset +/// Implement methods for constructing the `DbPedia` dataset impl DbPediaDataset { /// Returns the training portion of the dataset + #[must_use] pub fn train() -> Self { Self::new("train") } /// Returns the testing portion of the dataset + #[must_use] pub fn test() -> Self { Self::new("test") } /// Constructs the dataset from a split (either "train" or "test") + #[must_use] pub fn new(split: &str) -> Self { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") .dataset(split) @@ -140,7 +146,7 @@ impl DbPediaDataset { } } -/// Implement the TextClassificationDataset trait for the DbPedia dataset +/// Implement the `TextClassificationDataset` trait for the `DbPedia` dataset impl TextClassificationDataset for DbPediaDataset { /// Returns the number of unique classes in the dataset fn num_classes() -> usize { diff --git a/examples/text-classification/src/inference.rs b/examples/text-classification/src/inference.rs index b760ff2357..231f24a1ee 100644 --- a/examples/text-classification/src/inference.rs +++ b/examples/text-classification/src/inference.rs @@ -63,7 +63,7 @@ pub fn infer( // Print out predictions for each sample for (i, text) in samples.into_iter().enumerate() { #[allow(clippy::single_range_in_vec_init)] - let prediction = predictions.clone().slice([i..i + 1]); // Get prediction for current sample + let prediction = predictions.clone().slice([i..=i]); // Get prediction for current sample let logits = prediction.to_data(); // Convert prediction tensor to data let class_index = prediction.argmax(1).squeeze::<1>(1).into_scalar(); // Get class index with the highest value let class = D::class_name(class_index.elem::() as usize); // Get class name diff --git a/examples/text-generation/Cargo.toml b/examples/text-generation/Cargo.toml index 3f7f2d37fa..534c00faf9 100644 --- a/examples/text-generation/Cargo.toml +++ b/examples/text-generation/Cargo.toml @@ -6,6 +6,9 @@ name = "text-generation" publish = false version.workspace = true +[lints] +workspace = true + [features] default = ["burn/dataset", "burn/sqlite-bundled"] f16 = [] diff --git a/examples/text-generation/src/data/dataset.rs b/examples/text-generation/src/data/dataset.rs index 56abdfdaf8..88a6cf8cfc 100644 --- a/examples/text-generation/src/data/dataset.rs +++ b/examples/text-generation/src/data/dataset.rs @@ -27,13 +27,16 @@ impl Dataset for DbPediaDataset { } impl DbPediaDataset { + #[must_use] pub fn train() -> Self { Self::new("train") } + #[must_use] pub fn test() -> Self { Self::new("test") } + #[must_use] pub fn new(split: &str) -> Self { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") .dataset(split) diff --git a/examples/text-generation/src/data/tokenizer.rs b/examples/text-generation/src/data/tokenizer.rs index 53b294bc3f..1940ed4fb4 100644 --- a/examples/text-generation/src/data/tokenizer.rs +++ b/examples/text-generation/src/data/tokenizer.rs @@ -36,9 +36,10 @@ impl Default for Gpt2Tokenizer { impl Tokenizer for Gpt2Tokenizer { fn encode(&self, value: &str, special_tokens: bool) -> Vec { - let text = match special_tokens { - true => "[START]".to_owned() + value + "[END]", - false => value.to_string(), + let text = if special_tokens { + "[START]".to_owned() + value + "[END]" + } else { + value.to_string() }; let tokens = self.tokenizer.encode(text, true).unwrap(); tokens.get_ids().iter().map(|t| *t as usize).collect() diff --git a/examples/wgan/Cargo.toml b/examples/wgan/Cargo.toml index 5039baa352..7cdf88c056 100644 --- a/examples/wgan/Cargo.toml +++ b/examples/wgan/Cargo.toml @@ -3,6 +3,9 @@ name = "wgan" version = "0.2.0" edition.workspace = true +[lints] +workspace = true + [features] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] diff --git a/examples/wgan/src/dataset.rs b/examples/wgan/src/dataset.rs index 6cca5980d3..fd404e464d 100644 --- a/examples/wgan/src/dataset.rs +++ b/examples/wgan/src/dataset.rs @@ -27,7 +27,7 @@ impl Batcher> for MnistBatcher { .iter() .map(|item| { Tensor::::from_data( - TensorData::from([(item.label as i64).elem::()]), + TensorData::from([i64::from(item.label).elem::()]), device, ) }) diff --git a/examples/wgan/src/training.rs b/examples/wgan/src/training.rs index 9fa61835c6..b0d24a7148 100644 --- a/examples/wgan/src/training.rs +++ b/examples/wgan/src/training.rs @@ -192,7 +192,7 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, dev // Add 0.5/255.0 to the images, refer to pytorch save_image source let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); // Save images in artifact directory - let path = format!("{artifact_dir}/image-{}.png", epoch); + let path = format!("{artifact_dir}/image-{epoch}.png"); save_image::(fake_images, 5, path).unwrap(); } } diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index f185c295eb..17dbc1f85a 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -6,6 +6,9 @@ license = "MIT OR Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lints] +workspace = true + [dependencies] log = { workspace = true } strum = { workspace = true } diff --git a/xtask/src/commands/build.rs b/xtask/src/commands/build.rs index d661ce0faa..73d386f969 100644 --- a/xtask/src/commands/build.rs +++ b/xtask/src/commands/build.rs @@ -66,7 +66,7 @@ pub(crate) fn handle_command( ]); if std::env::var("DISABLE_WGPU").is_ok() { args.exclude.extend(vec!["burn-wgpu".to_string()]); - }; + } } // Build workspace base_commands::build::handle_command(args.try_into().unwrap())?; diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index 419ac1b5c6..bbc3a54725 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -58,7 +58,7 @@ pub(crate) fn handle_command( // "burn-router" uses "burn-wgpu" for the tests. "burn-router".to_string(), ]); - }; + } } CiTestType::GcpCudaRunner => { args.target = Target::AllPackages;