Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract over the specific Constraint.t type used by the backing constraint system #859

Merged
merged 10 commits into from
Jan 7, 2025
22 changes: 6 additions & 16 deletions src/base/backend_extended.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,15 @@ module type S = sig
module Constraint : sig
type t = (Cvar.t, Field.t) Constraint.t [@@deriving sexp]

type 'k with_constraint_args = ?label:string -> 'k
val boolean : Cvar.t -> t
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I thought we were using it everywhere in Pickles.


val boolean : (Cvar.t -> t) with_constraint_args
val equal : Cvar.t -> Cvar.t -> t

val equal : (Cvar.t -> Cvar.t -> t) with_constraint_args
val r1cs : Cvar.t -> Cvar.t -> Cvar.t -> t

val r1cs : (Cvar.t -> Cvar.t -> Cvar.t -> t) with_constraint_args
val square : Cvar.t -> Cvar.t -> t

val square : (Cvar.t -> Cvar.t -> t) with_constraint_args

val annotation : t -> string

val eval :
(Cvar.t, Field.t) Constraint.basic_with_annotation
-> (Cvar.t -> Field.t)
-> bool
val eval : (Cvar.t, Field.t) Constraint.t -> (Cvar.t -> Field.t) -> bool
end

module Run_state : Run_state_intf.S
Expand Down Expand Up @@ -208,16 +201,13 @@ module Make (Backend : Backend_intf.S) :
end

module Constraint = struct
open Constraint
include Constraint.T

type 'k with_constraint_args = ?label:string -> 'k

type t = (Cvar.t, Field.t) Constraint.t [@@deriving sexp]

let m = (module Field : Snarky_intf.Field.S with type t = Field.t)

let eval { basic; _ } get_value = Constraint.Basic.eval m get_value basic
let eval basic get_value = Constraint.Basic.eval m get_value basic
end

module R1CS_constraint_system = R1CS_constraint_system
Expand Down
15 changes: 7 additions & 8 deletions src/base/checked.ml
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,22 @@ end)
in
handle t (fun request -> (Option.value_exn !handler) request)

let assert_ ?label c = add_constraint (Constraint.override_label c label)
let assert_ c = add_constraint c

let assert_r1cs ?label a b c = assert_ (Constraint.r1cs ?label a b c)
let assert_r1cs a b c = assert_ (Constraint.r1cs a b c)

let assert_square ?label a c = assert_ (Constraint.square ?label a c)
let assert_square a c = assert_ (Constraint.square a c)

let assert_all ?label cs =
let assert_all cs =
List.fold_right cs ~init:(return ()) ~f:(fun c (acc : _ t) ->
bind acc ~f:(fun () ->
add_constraint (Constraint.override_label c label) ) )
bind acc ~f:(fun () -> add_constraint c) )

let assert_equal ?label x y =
let assert_equal x y =
match (x, y) with
| Cvar.Constant x, Cvar.Constant y ->
if Field.equal x y then return ()
else
failwithf !"assert_equal: %{sexp: Field.t} != %{sexp: Field.t}" x y ()
| _ ->
assert_ (Constraint.equal ?label x y)
assert_ (Constraint.equal x y)
end
15 changes: 5 additions & 10 deletions src/base/checked_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,15 @@ module type S = sig

val with_label : string -> (unit -> 'a t) -> 'a t

val assert_ :
?label:Base.string -> (field Cvar.t, field) Constraint.t -> unit t
val assert_ : (field Cvar.t, field) Constraint.t -> unit t

val assert_r1cs :
?label:Base.string -> field Cvar.t -> field Cvar.t -> field Cvar.t -> unit t
val assert_r1cs : field Cvar.t -> field Cvar.t -> field Cvar.t -> unit t

val assert_square :
?label:Base.string -> field Cvar.t -> field Cvar.t -> unit t
val assert_square : field Cvar.t -> field Cvar.t -> unit t

val assert_all :
?label:Base.string -> (field Cvar.t, field) Constraint.t list -> unit t
val assert_all : (field Cvar.t, field) Constraint.t list -> unit t

val assert_equal :
?label:Base.string -> field Cvar.t -> field Cvar.t -> unit t
val assert_equal : field Cvar.t -> field Cvar.t -> unit t

val direct : (run_state -> run_state * 'a) -> 'a t

Expand Down
11 changes: 4 additions & 7 deletions src/base/checked_runner.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ struct
f ~at_label_boundary:(`End, lab) None ) ;
(Run_state.set_stack s' stack, y) )

let log_constraint ({ basic; _ } : Constraint.t) s =
let log_constraint (basic : Constraint.t) s =
let open Constraint0 in
match basic with
| Boolean var ->
Expand All @@ -169,10 +169,9 @@ struct
!"%{sexp:(Field.t, Field.t) Constraint0.basic}"
(Constraint0.Basic.map basic ~f:(get_value s))

let add_constraint ~stack ({ basic; annotation } : Constraint.t)
let add_constraint (basic : Constraint.t)
(Constraint_system.T ((module C), system) : Field.t Constraint_system.t) =
let label = Option.value annotation ~default:"<unknown>" in
C.add_constraint system basic ~label:(stack_to_string (label :: stack))
C.add_constraint system basic

let add_constraint c : _ t =
Function
Expand All @@ -189,19 +188,17 @@ struct
then
failwithf
"Constraint unsatisfied (unreduced):\n\
%s\n\
%s\n\n\
Constraint:\n\
%s\n\
Data:\n\
%s"
(Constraint.annotation c)
(stack_to_string (Run_state.stack s))
(Sexp.to_string (Constraint.sexp_of_t c))
(log_constraint c s) () ;
if not (Run_state.as_prover s) then
Option.iter (Run_state.system s) ~f:(fun system ->
add_constraint ~stack:(Run_state.stack s) c system ) ;
add_constraint c system ) ;
(s, ()) ) )

let with_handler h t : _ t =
Expand Down
24 changes: 5 additions & 19 deletions src/base/constraint.ml
Original file line number Diff line number Diff line change
Expand Up @@ -161,30 +161,16 @@ let () =
end in
Basic.add_case (module M)

type ('v, 'f) basic_with_annotation =
{ basic : ('v, 'f) basic; annotation : string option }
[@@deriving sexp]

type ('v, 'f) t = ('v, 'f) basic_with_annotation [@@deriving sexp]
type ('v, 'f) t = ('v, 'f) basic [@@deriving sexp]

module T = struct
let create_basic ?label basic = { basic; annotation = label }

let override_label { basic; annotation = a } label_opt =
{ basic
; annotation = (match label_opt with Some x -> Some x | None -> a)
}

let equal ?label x y = create_basic ?label (Equal (x, y))

let boolean ?label x = create_basic ?label (Boolean x)
let equal x y = Equal (x, y)

let r1cs ?label a b c = create_basic ?label (R1CS (a, b, c))
let boolean x = Boolean x

let square ?label a c = create_basic ?label (Square (a, c))
let r1cs a b c = R1CS (a, b, c)

let annotation (t : _ t) =
match t.annotation with Some str -> str | None -> ""
let square a c = Square (a, c)
end

include T
3 changes: 1 addition & 2 deletions src/base/constraint_system.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ module type S = sig

val finalize : t -> unit

val add_constraint :
?label:string -> t -> (Field.t Cvar.t, Field.t) Constraint.basic -> unit
val add_constraint : t -> (Field.t Cvar.t, Field.t) Constraint.basic -> unit

val digest : t -> Md5.t

Expand Down
16 changes: 7 additions & 9 deletions src/base/snark0.ml
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ struct
let open Let_syntax in
let%bind bits = choose_preimage_unchecked v ~length in
let lc = packing_sum bits in
let%map () =
assert_r1cs ~label:"Choose_preimage" lc (Cvar.constant Field.one) v
in
let%map () = assert_r1cs lc (Cvar.constant Field.one) v in
bits

let choose_preimage_flagged (v : Cvar.t) ~length =
Expand Down Expand Up @@ -394,7 +392,7 @@ struct
| _ ->
Checked.assert_non_zero v

let equal x y = Checked.assert_equal ~label:"Checked.Assert.equal" x y
let equal x y = Checked.assert_equal x y

let not_equal (x : t) (y : t) =
match (x, y) with
Expand Down Expand Up @@ -568,7 +566,7 @@ struct
let open Checked in
Base.List.map (chunk_for_equality t1 t2) ~f:(fun (x1, x2) ->
Constraint.equal (Cvar1.pack x1) (Cvar1.pack x2) )
|> assert_all ~label:"Bitstring.Assert.equal"
|> assert_all
end
end

Expand Down Expand Up @@ -1186,13 +1184,13 @@ module Run = struct
active_counters := counters ;
raise exn

let assert_ ?label c = run (assert_ ?label c)
let assert_ c = run (assert_ c)

let assert_all ?label c = run (assert_all ?label c)
let assert_all c = run (assert_all c)

let assert_r1cs ?label a b c = run (assert_r1cs ?label a b c)
let assert_r1cs a b c = run (assert_r1cs a b c)

let assert_square ?label x y = run (assert_square ?label x y)
let assert_square x y = run (assert_square x y)

let as_prover p = run (as_prover (As_prover.run_prover p))

Expand Down
41 changes: 16 additions & 25 deletions src/base/snark_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -278,24 +278,22 @@ module type Constraint_intf = sig
*)
type t = (field_var, field) Constraint0.t

type 'k with_constraint_args = ?label:string -> 'k

(** A constraint that asserts that the field variable is a boolean: either
{!val:Field.zero} or {!val:Field.one}.
*)
val boolean : (field_var -> t) with_constraint_args
val boolean : field_var -> t

(** A constraint that asserts that the field variable arguments are equal.
*)
val equal : (field_var -> field_var -> t) with_constraint_args
val equal : field_var -> field_var -> t

(** A bare rank-1 constraint. *)
val r1cs : (field_var -> field_var -> field_var -> t) with_constraint_args
val r1cs : field_var -> field_var -> field_var -> t

(** A constraint that asserts that the first variable squares to the
second, ie. [square x y] => [x*x = y] within the field.
*)
val square : (field_var -> field_var -> t) with_constraint_args
val square : field_var -> field_var -> t
end

module type Field_var_intf = sig
Expand Down Expand Up @@ -802,30 +800,23 @@ let multiply3 (x : Field.Var.t) (y : Field.Var.t) (z : Field.Var.t)
type t = request -> response
end

(** Add a constraint to the constraint system, optionally with the label
given by [label]. *)
val assert_ : ?label:string -> Constraint.t -> unit Checked.t
(** Add a constraint to the constraint system. *)
val assert_ : Constraint.t -> unit Checked.t

(** Add all of the constraints in the list to the constraint system,
optionally with the label given by [label].
*)
val assert_all : ?label:string -> Constraint.t list -> unit Checked.t
(** Add all of the constraints in the list to the constraint system. *)
val assert_all : Constraint.t list -> unit Checked.t

(** Add a rank-1 constraint to the constraint system, optionally with the
label given by [label].
(** Add a rank-1 constraint to the constraint system.

See {!val:Constraint.r1cs} for more information on rank-1 constraints.
*)
val assert_r1cs :
?label:string -> Field.Var.t -> Field.Var.t -> Field.Var.t -> unit Checked.t
val assert_r1cs : Field.Var.t -> Field.Var.t -> Field.Var.t -> unit Checked.t

(** Add a 'square' constraint to the constraint system, optionally with the
label given by [label].
(** Add a 'square' constraint to the constraint system.

See {!val:Constraint.square} for more information.
*)
val assert_square :
?label:string -> Field.Var.t -> Field.Var.t -> unit Checked.t
val assert_square : Field.Var.t -> Field.Var.t -> unit Checked.t

(** Run an {!module:As_prover} block. *)
val as_prover : unit As_prover.t -> unit Checked.t
Expand Down Expand Up @@ -1266,13 +1257,13 @@ module type Run_basic = sig
type t = request -> response
end

val assert_ : ?label:string -> Constraint.t -> unit
val assert_ : Constraint.t -> unit

val assert_all : ?label:string -> Constraint.t list -> unit
val assert_all : Constraint.t list -> unit

val assert_r1cs : ?label:string -> Field.t -> Field.t -> Field.t -> unit
val assert_r1cs : Field.t -> Field.t -> Field.t -> unit

val assert_square : ?label:string -> Field.t -> Field.t -> unit
val assert_square : Field.t -> Field.t -> unit

val as_prover : (unit -> unit) As_prover.t -> unit

Expand Down
14 changes: 4 additions & 10 deletions src/base/utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ struct
*)
let equal_constraints (z : Cvar.t) (z_inv : Cvar.t) (r : Cvar.t) =
Checked.assert_all
[ Constraint.r1cs ~label:"equals_1" z_inv z Cvar.(constant Field.one - r)
; Constraint.r1cs ~label:"equals_2" r z (Cvar.constant Field.zero)
[ Constraint.r1cs z_inv z Cvar.(constant Field.one - r)
; Constraint.r1cs r z (Cvar.constant Field.zero)
]

(* [equal_vars z] computes [(r, z_inv)] that satisfy the constraints in
Expand Down Expand Up @@ -130,10 +130,7 @@ struct
if Field.(equal zero x) then Field.zero
else Backend.Field.inv x ))
in
let%map () =
assert_r1cs ~label:"field_inverse" x x_inv
(Cvar.constant Field.one)
in
let%map () = assert_r1cs x x_inv (Cvar.constant Field.one) in
x_inv )

let div ?(label = "Checked.div") (x : Cvar.t) (y : Cvar.t) =
Expand Down Expand Up @@ -277,10 +274,7 @@ struct
in
Typ
{ typ with
check =
(fun v ->
Checked.assert_
(Constraint.boolean ~label:"boolean-alloc" (v :> Cvar.t)) )
check = (fun v -> Checked.assert_ (Constraint.boolean (v :> Cvar.t)))
}

let typ_unchecked : (var, value) Typ.t =
Expand Down