diff --git a/crucible-mir/src/Mir/Trans.hs b/crucible-mir/src/Mir/Trans.hs index 14286f3ca..d03cbfdbc 100644 --- a/crucible-mir/src/Mir/Trans.hs +++ b/crucible-mir/src/Mir/Trans.hs @@ -217,29 +217,25 @@ transConstVal ty@(M.TyAdt aname _ _) tpr (ConstStruct fields) = do adt <- findAdt aname col <- use $ cs . collection case findReprTransparentField col adt of - Just idx -> do - ty <- case adt ^? adtvariants . ix 0 . vfields . ix idx . fty of - Just x -> return x - Nothing -> mirFail $ "repr(transparent) field index " ++ show idx ++ - " out of range for " ++ show (pretty ty) - const <- case fields ^? ix idx of - Just x -> return x - Nothing -> mirFail $ "repr(transparent) field index " ++ show idx ++ - " out of range for " ++ show (pretty ty) ++ " initializer" - transConstVal ty tpr const + Just idx -> + transTransparentVal ty tpr adt fields idx Nothing -> do let fieldDefs = adt ^. adtvariants . ix 0 . vfields let fieldTys = map (\f -> f ^. fty) fieldDefs exps <- zipWithM (\val ty -> transConstVal ty (tyToRepr col ty) val) fields fieldTys buildStruct adt exps -transConstVal (M.TyAdt aname _ _) _ (ConstEnum variant fields) = do +transConstVal ty@(M.TyAdt aname _ _) tpr (ConstEnum variant fields) = do adt <- findAdt aname - let fieldDefs = adt ^. adtvariants . ix variant . vfields - let fieldTys = map (\f -> f ^. fty) fieldDefs col <- use $ cs . collection - exps <- zipWithM (\val ty -> transConstVal ty (tyToRepr col ty) val) fields fieldTys - buildEnum adt variant exps + case findReprTransparentField col adt of + Just idx -> + transTransparentVal ty tpr adt fields idx + Nothing -> do + let fieldDefs = adt ^. adtvariants . ix variant . vfields + let fieldTys = map (\f -> f ^. fty) fieldDefs + exps <- zipWithM (\val ty -> transConstVal ty (tyToRepr col ty) val) fields fieldTys + buildEnum adt variant exps transConstVal ty (Some (MirReferenceRepr tpr)) init = do MirExp tpr' val <- transConstVal (M.typeOfProj M.Deref ty) (Some tpr) init Refl <- testEqualityOrFail tpr tpr' $ @@ -281,6 +277,28 @@ transConstTuple tys tprs vals = do tys (toListFC Some tprs) vals return $ buildTupleMaybe col tys $ map Just vals' +-- Translate a struct or enum marked with repr(transparent). +transTransparentVal :: + M.Ty {- The transparent struct or enum type (only used for error messages) -} -> + Some C.TypeRepr {- The Crucible representation of the transparent struct or + enum type. -} -> + Adt {- The transparent struct or enum's Adt description. -} -> + [ConstVal] {- The field values of the transparent struct or enum variant. + Really, it should only be a single field value, but we must + check that this is actually the case. -} -> + Int {- The index of the underlying field in the variant. -} -> + MirGenerator h s ret (MirExp s) +transTransparentVal ty tpr adt fields idx = do + ty <- case adt ^? adtvariants . ix 0 . vfields . ix idx . fty of + Just x -> return x + Nothing -> mirFail $ "repr(transparent) field index " ++ show idx ++ + " out of range for " ++ show (pretty ty) + const <- case fields ^? ix idx of + Just x -> return x + Nothing -> mirFail $ "repr(transparent) field index " ++ show idx ++ + " out of range for " ++ show (pretty ty) ++ " initializer" + transConstVal ty tpr const + -- Taken from GHC's source code, which is BSD-3 licensed. zipWith3M :: Monad m => (a -> b -> c -> m d) -> [a] -> [b] -> [c] -> m [d] {-# INLINE zipWith3M #-} @@ -1029,12 +1047,21 @@ evalRval rv@(M.RAdtAg (M.AdtAg adt agv ops ty)) = do CTyMethodSpecBuilder -> mirFail $ "evalRval: can't construct MethodSpecBuilder without an override" TyAdt _ _ _ -> do + col <- use $ cs . collection es <- mapM evalOperand ops - case adt^.adtkind of - M.Struct -> buildStruct adt es - M.Enum _ -> buildEnum adt (fromInteger agv) es - M.Union -> do - mirFail $ "evalRval: Union types are unsupported, for " ++ show (adt ^. adtname) + case findReprTransparentField col adt of + Just idx -> do + op <- case ops ^? ix idx of + Just op -> pure op + Nothing -> mirFail $ "repr(transparent) field index " ++ show idx ++ + " out of range for " ++ show (adt ^. adtname) + evalOperand op + Nothing -> do + case adt^.adtkind of + M.Struct -> buildStruct adt es + M.Enum _ -> buildEnum adt (fromInteger agv) es + M.Union -> do + mirFail $ "evalRval: Union types are unsupported, for " ++ show (adt ^. adtname) _ -> mirFail $ "evalRval: unsupported type for AdtAg: " ++ show ty evalRval (M.ThreadLocalRef did _) = staticPlace did >>= addrOfPlace diff --git a/crux-mir/test/conc_eval/enum/repr_transparent.rs b/crux-mir/test/conc_eval/enum/repr_transparent.rs new file mode 100644 index 000000000..6129a78a8 --- /dev/null +++ b/crux-mir/test/conc_eval/enum/repr_transparent.rs @@ -0,0 +1,22 @@ +// A regression test for +// https://github.com/GaloisInc/crucible/issues/1140 +#![cfg_attr(not(with_main), no_std)] + +#[repr(transparent)] +pub enum E { + MkE(u32), +} + +pub fn f() -> u32 { + let x = E::MkE(42); + match x { + E::MkE(y) => y, + } +} + +#[cfg(with_main)] +pub fn main() { + println!("{:?}", f()); +} +#[cfg(not(with_main))] #[cfg_attr(crux, crux::test)] +pub fn crux_test() -> u32 { f() }