diff --git a/src/lib/frontend/transformations/TupleElimination.ml b/src/lib/frontend/transformations/TupleElimination.ml index 0dc8fdf3..2cc81ad6 100644 --- a/src/lib/frontend/transformations/TupleElimination.ml +++ b/src/lib/frontend/transformations/TupleElimination.ml @@ -296,11 +296,21 @@ let replacer = (match IdMap.find_opt id !env with | None -> EVar cid | Some lst -> - ETuple + ETuple (* if we are eliminating tuples, why do we return a tuple? *) (List.map (fun (id, ty) -> { (var_sp (Id id) Span.default) with ety = Some ty }) lst)) + + (* Given a (possibly nested) tuple-type expression, turn it into + a single unnested tuple expression with all elements in order *) + method flatten env exp = + match (Option.get exp.ety).raw_ty with + | TTuple _ -> + let es = self#visit_exp env exp |> extract_etuple in + List.map (self#flatten env) es |> List.concat + | _ -> [self#visit_exp env exp] + method! visit_EOp env op args = match op, args with @@ -308,6 +318,48 @@ let replacer = let es = self#visit_exp env e |> extract_etuple in let e = List.nth es idx |> self#visit_exp env in e.e + (* tuples may appear in equality or inequality operations, + and must be eliminated. *) + | Eq, [x; y] -> ( + let x = self#visit_exp env x in + let y = self#visit_exp env y in + match x.ety with + | Some({raw_ty = TTuple _}) -> + (* unpack tuple *) + let xs = self#flatten env x in + let ys = self#flatten env y in + (* fold into an "and" expression *) + let acc_exp acc x' y' = + let eq' = op_sp Eq [x'; y'] (Span.default) in + match acc with + | None -> Some(eq') + | Some(acc) -> + Some(op_sp And [acc; eq'] (Span.default)) + in + let exp_opt = (List.fold_left2 acc_exp None xs ys) in + (Option.default (value_to_exp (vbool true)) exp_opt).e + | _ -> super#visit_EOp env op args + ) + | Neq, [x; y] -> ( + let x = self#visit_exp env x in + let y = self#visit_exp env y in + match x.ety with + | Some({raw_ty = TTuple _}) -> + (* unpack tuple *) + let xs = self#flatten env x in + let ys = self#flatten env y in + (* fold into an "or" expression *) + let acc_exp acc x' y' = + let eq' = op_sp Neq [x'; y'] (Span.default) in + match acc with + | None -> Some(eq') + | Some(acc) -> + Some(op_sp Or [acc; eq'] (Span.default)) + in + let exp_opt = (List.fold_left2 acc_exp None xs ys) in + (Option.default (value_to_exp (vbool true)) exp_opt).e + | _ -> super#visit_EOp env op args + ) | _ -> super#visit_EOp env op args (** Now we have to visit all expressions which might contain tuples, @@ -328,14 +380,6 @@ let replacer = sequence_statements new_defs | _ -> super#visit_SUnit env exp - (* Given a (possibly nested) tuple-type expression, turn it into - a single unnested tuple expression with all elements in order *) - method flatten env exp = - match (Option.get exp.ety).raw_ty with - | TTuple _ -> - let es = self#visit_exp env exp |> extract_etuple in - List.map (self#flatten env) es |> List.concat - | _ -> [self#visit_exp env exp] (* We replace tuple-type parameters to functions and events with one parameter for each tuple entry. So we need to adjust the @@ -606,15 +650,15 @@ and replace_decls env ds = (* Sanity checker to make sure no tuples remain in the program *) let checker = object - inherit [_] s_iter - + inherit [_] s_iter as super method! visit_exp _ e = match e.e with + | ECall _ -> () | ETuple _ -> Console.error_position e.espan @@ "Internal error: failed to eliminate tuple " ^ Printing.exp_to_string e - | _ -> () + | _ -> super#visit_exp () e end ;; @@ -625,6 +669,7 @@ let print_prog ds = let eliminate_prog ds = let _, ds = replace_decls IdMap.empty ds in + checker#visit_decls () ds; ds ;;