diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs index e110c155da089..86fdfae1ffbf5 100644 --- a/compiler/rustc_const_eval/src/interpret/cast.rs +++ b/compiler/rustc_const_eval/src/interpret/cast.rs @@ -430,10 +430,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { }; let erased_trait_ref = ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref); - assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env( - erased_trait_ref, - self.tcx.instantiate_bound_regions_with_erased(b) - ))); + assert_eq!( + data_b.principal().map(|b| { + self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b) + }), + Some(erased_trait_ref), + ); } else { // In this case codegen would keep using the old vtable. We don't want to do // that as it has the wrong trait. The reason codegen can do this is that diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs index 95a72d3cbc1d7..242cf6484dd9f 100644 --- a/compiler/rustc_const_eval/src/interpret/eval_context.rs +++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs @@ -4,9 +4,6 @@ use either::{Left, Right}; use rustc_abi::{Align, HasDataLayout, Size, TargetDataLayout}; use rustc_errors::DiagCtxtHandle; use rustc_hir::def_id::DefId; -use rustc_infer::infer::TyCtxtInferExt; -use rustc_infer::infer::at::ToTrace; -use rustc_infer::traits::ObligationCause; use rustc_middle::mir::interpret::{ErrorHandled, InvalidMetaKind, ReportedErrorInfo}; use rustc_middle::query::TyCtxtAt; use rustc_middle::ty::layout::{ @@ -17,8 +14,7 @@ use rustc_middle::{mir, span_bug}; use rustc_session::Limit; use rustc_span::Span; use rustc_target::callconv::FnAbi; -use rustc_trait_selection::traits::ObligationCtxt; -use tracing::{debug, instrument, trace}; +use tracing::{debug, trace}; use super::{ Frame, FrameInfo, GlobalId, InterpErrorInfo, InterpErrorKind, InterpResult, MPlaceTy, Machine, @@ -323,40 +319,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { } } - /// Check if the two things are equal in the current param_env, using an infcx to get proper - /// equality checks. - #[instrument(level = "trace", skip(self), ret)] - pub(super) fn eq_in_param_env(&self, a: T, b: T) -> bool - where - T: PartialEq + TypeFoldable> + ToTrace<'tcx>, - { - // Fast path: compare directly. - if a == b { - return true; - } - // Slow path: spin up an inference context to check if these traits are sufficiently equal. - let (infcx, param_env) = self.tcx.infer_ctxt().build_with_typing_env(self.typing_env); - let ocx = ObligationCtxt::new(&infcx); - let cause = ObligationCause::dummy_with_span(self.cur_span()); - // equate the two trait refs after normalization - let a = ocx.normalize(&cause, param_env, a); - let b = ocx.normalize(&cause, param_env, b); - - if let Err(terr) = ocx.eq(&cause, param_env, a, b) { - trace!(?terr); - return false; - } - - let errors = ocx.select_all_or_error(); - if !errors.is_empty() { - trace!(?errors); - return false; - } - - // All good. - true - } - /// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a /// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic, /// and is primarily intended for the panic machinery. diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs index 4cfaacebfcd0e..a5029eea5a796 100644 --- a/compiler/rustc_const_eval/src/interpret/traits.rs +++ b/compiler/rustc_const_eval/src/interpret/traits.rs @@ -86,21 +86,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type }); } + // This checks whether there is a subtyping relation between the predicates in either direction. + // For example: + // - casting between `dyn for<'a> Trait` and `dyn Trait` is OK + // - casting between `dyn Trait fn(&'a u8)>` and either of the above is UB for (a_pred, b_pred) in std::iter::zip(sorted_vtable, sorted_expected) { - let is_eq = match (a_pred.skip_binder(), b_pred.skip_binder()) { - ( - ty::ExistentialPredicate::Trait(a_data), - ty::ExistentialPredicate::Trait(b_data), - ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)), + let a_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, a_pred); + let b_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b_pred); - ( - ty::ExistentialPredicate::Projection(a_data), - ty::ExistentialPredicate::Projection(b_data), - ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)), - - _ => false, - }; - if !is_eq { + if a_pred != b_pred { throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type }); } } diff --git a/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.rs b/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.rs new file mode 100644 index 0000000000000..7de4aef422a08 --- /dev/null +++ b/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.rs @@ -0,0 +1,30 @@ +// Test that transmuting from `&dyn Trait` to `&dyn Trait fn(&'a ())>` is UB. +// +// The vtable of `() as Trait` and `() as Trait fn(&'a ())>` can have +// different entries and, because in the former the entry for `foo` is vacant, this test will +// segfault at runtime. + +trait Trait { + fn foo(&self) + where + U: HigherRanked, + { + } +} +impl Trait for T {} + +trait HigherRanked {} +impl HigherRanked for for<'a> fn(&'a ()) {} + +// 2nd candidate is required so that selecting `(): Trait` will +// evaluate the candidates and fail the leak check instead of returning the +// only applicable candidate. +trait Unsatisfied {} +impl HigherRanked for T {} + +fn main() { + let x: &dyn Trait = &(); + let y: &dyn Trait fn(&'a ())> = unsafe { std::mem::transmute(x) }; + //~^ ERROR: wrong trait in wide pointer vtable + y.foo(); +} diff --git a/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.stderr b/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.stderr new file mode 100644 index 0000000000000..92c7c54c1dcc4 --- /dev/null +++ b/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.stderr @@ -0,0 +1,15 @@ +error: Undefined Behavior: constructing invalid value: wrong trait in wide pointer vtable: expected `Trait fn(&'a ())>`, but encountered `Trait` + --> tests/fail/validity/dyn-trait-leak-check.rs:LL:CC + | +LL | let y: &dyn Trait fn(&'a ())> = unsafe { std::mem::transmute(x) }; + | ^^^^^^^^^^^^^^^^^^^^^^ constructing invalid value: wrong trait in wide pointer vtable: expected `Trait fn(&'a ())>`, but encountered `Trait` + | + = help: this indicates a bug in the program: it performed an invalid operation, and caused Undefined Behavior + = help: see https://doc.rust-lang.org/nightly/reference/behavior-considered-undefined.html for further information + = note: BACKTRACE: + = note: inside `main` at tests/fail/validity/dyn-trait-leak-check.rs:LL:CC + +note: some details are omitted, run with `MIRIFLAGS=-Zmiri-backtrace=full` for a verbose backtrace + +error: aborting due to 1 previous error + diff --git a/src/tools/miri/tests/pass/dyn-upcast.rs b/src/tools/miri/tests/pass/dyn-upcast.rs index 6d5e228e8ecc1..de1467299dced 100644 --- a/src/tools/miri/tests/pass/dyn-upcast.rs +++ b/src/tools/miri/tests/pass/dyn-upcast.rs @@ -12,6 +12,7 @@ fn main() { drop_principal(); modulo_binder(); modulo_assoc(); + bidirectional_subtyping(); } fn vtable_nop_cast() { @@ -532,3 +533,31 @@ fn modulo_assoc() { (&() as &dyn Trait as &dyn Middle<()>).say_hello(&0); } + +fn bidirectional_subtyping() { + // Test that transmuting between subtypes of dyn traits is fine, even in the + // "wrong direction", i.e. going from a lower-ranked to a higher-ranked dyn trait. + // Note that compared to the `dyn-trait-leak-check` test, the `for` is on the *outside* here! + + trait Trait {} + impl Trait for T {} + + struct Wrapper(T); + + let x: &dyn Trait = &(); + let _y: &dyn for<'a> Trait = unsafe { std::mem::transmute(x) }; + + let x: &dyn for<'a> Trait = &(); + let _y: &dyn Trait = unsafe { std::mem::transmute(x) }; + + let x: &dyn Trait> = &(); + let _y: &dyn for<'a> Trait> = unsafe { std::mem::transmute(x) }; + + let x: &dyn for<'a> Trait> = &(); + let _y: &dyn Trait> = unsafe { std::mem::transmute(x) }; + + // This lowers to a ptr-to-ptr cast (which behaves like a transmute) + // and not an unsizing coercion: + let x: *const dyn for<'a> Trait<&'a ()> = &(); + let _y: *const Wrapper> = x as _; +}