-
Notifications
You must be signed in to change notification settings - Fork 599
Migrate to new cubecl multi tensor handle changes #3136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing the current limitations with the quantization kernels!
Just to give you a bit more (historical) context (might help fill in the blanks lol):
The current implementation dates back to a long time ago, it hasn't changed much since the first draft that came around the initial release of CubeCL. It was also written specifically with wgpu in mind which only supported f32 and u32 (hence the packing / unpacking) because the CUDA backend was barely a thing around this time lol.
I was just discussing this recently with Max and the plan was to add int8 support. This is exactly in line with your changes, and the additions in cubecl will make things a lot easier. Thank you 🙏
(from the linked cubecl PR)
Adds multi-tensor allocations so quantization params can be allocated along with strided tensors. The multiple tensors are allocated into the same buffer, but with separate strides.
This is awesome! It makes a lot more sense. The initial works actually dealt with multiple handles on the burn-ir side, but it was messy. And there was no easy way to do this at the time for cubecl. Tensor operations and transformations (e.g., layout changes) will be easier to handle with such a separation, especially as we add more quantization levels (e.g., per-block).
TL;DR: thanks for contribution(s) 😄 the changes LGTM
let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); | ||
|
||
match scheme { | ||
QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will change with #3042 as the quantization scheme will be more like a config struct.
But no worries, we'll make sure that this PR is merged first so you can avoid the possible conflict hell 😅
/edit: sorry, I guess I was wrong. I'll fix the conflicts now that it was merged on cubecl!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw I merged conflicts earlier to bring this up to date since we introduced some conflicts on main 🙂
On CUDA matmul and quantization work, but dequantization not yet (hasn't been modified to match the changes, so that's expected - we also need to keep the quantized output scale around).
On wgpu there seem to be alignment issues, e.g.
thread 'tests::cube::kernel::quantization::tests::should_quantize_dequantize_symmetric_multiple' panicked at /Users/admin/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/wgpu-25.0.0/src/backend/wgpu_core.rs:1223:26:
wgpu error: Validation Error
Caused by:
In Device::create_bind_group
Buffer offset 520 does not respect device's requested `min_storage_buffer_offset_alignment` limit 256
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah it sems like the align for copy is incorrect for offsets, we need to align to the min storage buffer offset instead. I'll have to fix that in cubecl.
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (55.50%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #3136 +/- ##
==========================================
- Coverage 82.20% 82.10% -0.10%
==========================================
Files 962 963 +1
Lines 122541 122833 +292
==========================================
+ Hits 100732 100854 +122
- Misses 21809 21979 +170 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This PR has been marked as stale because it has not been updated for over a month |
Pull Request Template
Checklist
cargo run-checks
command has been executed.Related Issues/PRs
Update PR for tracel-ai/cubecl#661
Changes
Updates to new matmul signature, fixes some bugs in quantize kernel
Testing
Tests pass