From 1a28bc452913714a4cb861e940af23b963be0f18 Mon Sep 17 00:00:00 2001 From: sayantn Date: Tue, 14 Apr 2026 01:57:13 +0530 Subject: [PATCH 1/2] Disable ABI checks for the `unadjusted` ABI --- .../rustc_monomorphize/src/mono_checks/abi_check.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs b/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs index 0921e57844b03..6d0d6f565a90b 100644 --- a/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs +++ b/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs @@ -1,6 +1,6 @@ //! This module ensures that if a function's ABI requires a particular target feature, //! that target feature is enabled both on the callee and all callers. -use rustc_abi::{BackendRepr, CanonAbi, RegKind, X86Call}; +use rustc_abi::{BackendRepr, CanonAbi, ExternAbi, RegKind, X86Call}; use rustc_hir::{CRATE_HIR_ID, HirId}; use rustc_middle::mir::{self, Location, traversal}; use rustc_middle::ty::{self, Instance, InstanceKind, Ty, TyCtxt}; @@ -157,6 +157,12 @@ fn do_check_unsized_params<'tcx>( /// - the signature requires target features that are not enabled fn check_instance_abi<'tcx>(tcx: TyCtxt<'tcx>, instance: Instance<'tcx>) { let typing_env = ty::TypingEnv::fully_monomorphized(); + let ty = instance.ty(tcx, typing_env); + if ty.is_fn() && ty.fn_sig(tcx).abi() == ExternAbi::Unadjusted { + // We disable all checks for the unadjusted ABI to allow linking to arbitrary LLVM + // intrinsics + return; + } let Ok(abi) = tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty()))) else { // An error will be reported during codegen if we cannot determine the ABI of this @@ -191,9 +197,12 @@ fn check_call_site_abi<'tcx>( caller: InstanceKind<'tcx>, loc: impl Fn() -> (Span, HirId) + Copy, ) { - if callee.fn_sig(tcx).abi().is_rustic_abi() { + let extern_abi = callee.fn_sig(tcx).abi(); + if extern_abi.is_rustic_abi() || extern_abi == ExternAbi::Unadjusted { // We directly handle the soundness of Rust ABIs -- so let's skip the majority of // call sites to avoid a perf regression. + // We disable all checks for the unadjusted ABI to allow linking to arbitrary LLVM + // intrinsics return; } let typing_env = ty::TypingEnv::fully_monomorphized(); From 7e24cd823dc81a3647131f4255a20f956e6de629 Mon Sep 17 00:00:00 2001 From: sayantn Date: Tue, 14 Apr 2026 02:10:36 +0530 Subject: [PATCH 2/2] Add autocast for `x86_amx` --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 24 ++++++++++++++++ tests/codegen-llvm/inject-autocast.rs | 29 ++++++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 0d3d682ece21f..34b3f0d81a708 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1015,6 +1015,24 @@ fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll } } TypeKind::BFloat => rust_ty == cx.type_i16(), + TypeKind::X86_AMX if cx.type_kind(rust_ty) == TypeKind::Vector => { + let element_ty = cx.element_type(rust_ty); + let element_count = cx.vector_length(rust_ty) as u64; + + let element_size_bits = match cx.type_kind(element_ty) { + TypeKind::Half => 16, + TypeKind::Float => 32, + TypeKind::Double => 64, + TypeKind::FP128 => 128, + TypeKind::Integer => cx.int_width(element_ty), + TypeKind::Pointer => cx.int_width(cx.isize_ty), + _ => bug!( + "Vector element type `{element_ty:?}` not one of integer, float or pointer" + ), + }; + + element_size_bits * element_count == 8192 + } _ => false, } } @@ -1084,6 +1102,12 @@ fn autocast<'ll>( ) } } + (TypeKind::Vector, TypeKind::X86_AMX) => { + bx.call_intrinsic("llvm.x86.cast.vector.to.tile", &[src_ty], &[val]) + } + (TypeKind::X86_AMX, TypeKind::Vector) => { + bx.call_intrinsic("llvm.x86.cast.tile.to.vector", &[dest_ty], &[val]) + } _ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)` } } diff --git a/tests/codegen-llvm/inject-autocast.rs b/tests/codegen-llvm/inject-autocast.rs index fec9d3f0b1955..cc74256ebbe8d 100644 --- a/tests/codegen-llvm/inject-autocast.rs +++ b/tests/codegen-llvm/inject-autocast.rs @@ -1,7 +1,7 @@ -//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avxneconvert +//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avx512dq,+avxneconvert,+amx-int8 //@ only-x86_64 -#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)] +#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd, repr_simd)] #![crate_type = "lib"] use std::simd::{f32x4, i16x8, i64x2}; @@ -10,6 +10,9 @@ use std::simd::{f32x4, i16x8, i64x2}; pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2); // CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }> +#[repr(simd)] +pub struct Tile([i8; 1024]); + // CHECK-LABEL: @struct_autocast #[no_mangle] pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar { @@ -84,6 +87,22 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 { foo(a) } +// CHECK-LABEL: @amx_autocast +#[no_mangle] +pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile { + extern "unadjusted" { + #[link_name = "llvm.x86.tdpbuud.internal"] + fn foo(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile; + } + + // CHECK: [[A:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}}) + // CHECK: [[B:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}}) + // CHECK: [[C:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}}) + // CHECK: [[D:%[0-9]+]] = call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx [[A]], x86_amx [[B]], x86_amx [[C]]) + // CHECK: call <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx [[D]]) + foo(m, n, k, a, b, c) +} + // CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>) // CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>) @@ -91,3 +110,9 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 { // CHECK: declare <8 x i1> @llvm.x86.avx512.kadd.b(<8 x i1>, <8 x i1>) // CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>) + +// CHECK: declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +// CHECK: declare x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8>) + +// CHECK: declare <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx)