Skip to content

Commit f6e458b

Browse files
committed
Remove the unnecessary bitcasts in check_call
Add support fot `bf16` linking
1 parent a56c9dc commit f6e458b

File tree

4 files changed

+77
-46
lines changed

4 files changed

+77
-46
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use smallvec::SmallVec;
2020

2121
use crate::attributes::{self, llfn_attrs_from_instance};
2222
use crate::builder::Builder;
23-
use crate::context::CodegenCx;
23+
use crate::context::{CodegenCx, GenericCx, SCx};
2424
use crate::llvm::{self, Attribute, AttributePlace};
2525
use crate::type_::Type;
2626
use crate::type_of::LayoutLlvmExt;
@@ -362,15 +362,15 @@ fn match_intrinsic_signature<'ll>(
362362
);
363363
}
364364

365-
if !equate_ty(cx, rust_return_ty, llvm_return_ty) {
365+
if !cx.equate_ty(rust_return_ty, llvm_return_ty) {
366366
error!(
367367
"Intrinsic signature mismatch: could not match `{rust_return_ty:?}` (found) with {llvm_return_ty:?} (expected) as return type for `{fn_name}`"
368368
);
369369
}
370370
for (idx, (&rust_argument_ty, llvm_argument_ty)) in
371371
iter::zip(rust_argument_tys, llvm_argument_tys).enumerate()
372372
{
373-
if !equate_ty(cx, rust_argument_ty, llvm_argument_ty) {
373+
if !cx.equate_ty(rust_argument_ty, llvm_argument_ty) {
374374
error!(
375375
"Intrinsic signature mismatch: could not match `{rust_return_ty:?}` (found) with {llvm_return_ty:?} (expected) as argument {idx} for `{fn_name}`"
376376
);
@@ -380,28 +380,53 @@ fn match_intrinsic_signature<'ll>(
380380
fn_ty
381381
}
382382

383-
fn equate_ty<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool {
384-
if rust_ty == llvm_ty {
385-
return true;
386-
}
387-
if cx.type_kind(llvm_ty) == TypeKind::X86_AMX && cx.type_kind(rust_ty) == TypeKind::Vector {
388-
let element_count = cx.vector_length(rust_ty);
389-
let element_ty = cx.element_type(rust_ty);
390-
391-
let element_size_bits = match cx.type_kind(element_ty) {
392-
TypeKind::Half => 16,
393-
TypeKind::Float => 32,
394-
TypeKind::Double => 64,
395-
TypeKind::FP128 => 128,
396-
TypeKind::Integer => cx.int_width(element_ty),
397-
TypeKind::Pointer => cx.int_width(cx.isize_ty),
398-
_ => bug!("Vector element type `{element_ty:?}` not one of integer, float or pointer"),
399-
};
400-
let vector_size_bits = element_size_bits * element_count as u64;
383+
impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
384+
pub(crate) fn equate_ty(&self, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool {
385+
if rust_ty == llvm_ty {
386+
return true;
387+
}
388+
389+
match self.type_kind(llvm_ty) {
390+
TypeKind::X86_AMX if self.type_kind(rust_ty) == TypeKind::Vector => {
391+
let element_count = self.vector_length(rust_ty);
392+
let element_ty = self.element_type(rust_ty);
393+
394+
let element_size_bits = match self.type_kind(element_ty) {
395+
TypeKind::Half => 16,
396+
TypeKind::Float => 32,
397+
TypeKind::Double => 64,
398+
TypeKind::FP128 => 128,
399+
TypeKind::Integer => self.int_width(element_ty),
400+
TypeKind::Pointer => self.int_width(self.isize_ty()),
401+
_ => bug!(
402+
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
403+
),
404+
};
405+
let vector_size_bits = element_size_bits * element_count as u64;
406+
407+
vector_size_bits == 8192
408+
}
409+
TypeKind::BFloat => rust_ty == self.type_i16(),
410+
TypeKind::Vector if self.type_kind(rust_ty) == TypeKind::Vector => {
411+
let llvm_element_count = self.vector_length(llvm_ty);
412+
let rust_element_count = self.vector_length(rust_ty);
401413

402-
return vector_size_bits == 8192;
414+
if llvm_element_count != rust_element_count {
415+
return false;
416+
}
417+
418+
let llvm_element_ty = self.element_type(llvm_ty);
419+
let rust_element_ty = self.element_type(rust_ty);
420+
421+
if llvm_element_ty == self.type_bf16() {
422+
rust_element_ty == self.type_i16()
423+
} else {
424+
false
425+
}
426+
}
427+
_ => false,
428+
}
403429
}
404-
return false;
405430
}
406431

407432
impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,17 +1582,18 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
15821582
let actual_ty = self.cx.val_ty(actual_val);
15831583

15841584
if expected_ty != actual_ty {
1585-
warn!(
1585+
assert!(
1586+
self.cx.equate_ty(actual_ty, expected_ty),
15861587
"type mismatch in function call of {llfn:?}. \
1587-
Expected {expected_ty:?} for param {i}, got {actual_ty:?}; injecting bitcast",
1588+
Can't match {expected_ty:?} (expected) as parameter {i} with {actual_ty:?} (found)",
1589+
);
1590+
1591+
assert!(
1592+
has_fnabi,
1593+
"Should inject auto-bitcasts in function call of {llfn:?}, but not able to get Rust signature"
15881594
);
15891595

15901596
casted_args.to_mut()[i] = if self.cx.type_kind(expected_ty) == TypeKind::X86_AMX {
1591-
// we can't do `cast_return` in without `FnAbi`
1592-
assert!(
1593-
has_fnabi,
1594-
"Found `x86amx` for parameter {i} in function call of {llfn:?}, but not able to get Rust return type"
1595-
);
15961597
self.cast_vector_to_tile(actual_val)
15971598
} else {
15981599
self.bitcast(actual_val, expected_ty)
@@ -1609,11 +1610,20 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
16091610

16101611
pub(crate) fn simple_call(
16111612
&mut self,
1613+
name: &[u8],
16121614
fn_ty: &'ll Type,
1613-
llfn: &'ll Value,
16141615
args: &[&'ll Value],
16151616
) -> &'ll Value {
1616-
let args = self.cast_arguments("simple call", fn_ty, llfn, args, false);
1617+
let llfn = unsafe {
1618+
llvm::LLVMRustGetOrInsertFunction(
1619+
self.cx.llmod(),
1620+
name.as_c_char_ptr(),
1621+
name.len(),
1622+
fn_ty,
1623+
)
1624+
};
1625+
1626+
let args = self.cast_arguments("simple_call", fn_ty, llfn, args, false);
16171627

16181628
unsafe {
16191629
llvm::LLVMBuildCall2(
@@ -1646,15 +1656,8 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
16461656
assert!(type_params.is_empty());
16471657
base_name.as_bytes()
16481658
};
1649-
let llfn = unsafe {
1650-
llvm::LLVMRustGetOrInsertFunction(
1651-
self.cx.llmod(),
1652-
full_name.as_c_char_ptr(),
1653-
full_name.len(),
1654-
fn_ty,
1655-
)
1656-
};
1657-
self.simple_call(fn_ty, llfn, args)
1659+
1660+
self.simple_call(full_name, fn_ty, args)
16581661
}
16591662
}
16601663

@@ -1707,11 +1710,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17071710
let fn_ty =
17081711
self.type_func(&[self.type_ptr(), self.type_ptr(), self.type_isize()], llreturn_ty);
17091712

1710-
let llfn = self
1711-
.get_function("memcmp")
1712-
.unwrap_or_else(|| self.declare_cfn("memcmp", llvm::UnnamedAddr::No, fn_ty));
1713-
1714-
self.simple_call(fn_ty, llfn, &[ptr1, ptr2, num])
1713+
self.simple_call(b"memcmp", fn_ty, &[ptr1, ptr2, num])
17151714
}
17161715

17171716
fn cast_return(

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,9 @@ unsafe extern "C" {
10501050
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
10511051
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
10521052

1053+
// Operations on non-IEEE real types
1054+
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
1055+
10531056
// Operations on function types
10541057
pub(crate) fn LLVMFunctionType<'a>(
10551058
ReturnType: &'a Type,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
147147
)
148148
}
149149
}
150+
151+
pub(crate) fn type_bf16(&self) -> &'ll Type {
152+
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
153+
}
150154
}
151155

152156
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {

0 commit comments

Comments
 (0)