Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract over the specific Constraint.t type used by the backing constraint system #859

Merged
merged 10 commits into from
Jan 7, 2025
48 changes: 19 additions & 29 deletions src/base/backend_extended.ml
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,31 @@ module type S = sig
val to_constant : t -> Field.t option
end

module R1CS_constraint_system : Constraint_system.S with module Field := Field

module Constraint : sig
type t = (Cvar.t, Field.t) Constraint.t [@@deriving sexp]

type 'k with_constraint_args = ?label:string -> 'k
type t [@@deriving sexp]

val boolean : (Cvar.t -> t) with_constraint_args
val boolean : Cvar.t -> t
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I thought we were using it everywhere in Pickles.


val equal : (Cvar.t -> Cvar.t -> t) with_constraint_args
val equal : Cvar.t -> Cvar.t -> t

val r1cs : (Cvar.t -> Cvar.t -> Cvar.t -> t) with_constraint_args
val r1cs : Cvar.t -> Cvar.t -> Cvar.t -> t

val square : (Cvar.t -> Cvar.t -> t) with_constraint_args
val square : Cvar.t -> Cvar.t -> t

val annotation : t -> string
val eval : t -> (Cvar.t -> Field.t) -> bool

val eval :
(Cvar.t, Field.t) Constraint.basic_with_annotation
-> (Cvar.t -> Field.t)
-> bool
val log_constraint : t -> (Cvar.t -> Field.t) -> string
end

module Run_state : Run_state_intf.S
module R1CS_constraint_system :
Constraint_system.S
with module Field := Field
with type constraint_ = Constraint.t

module Run_state :
Run_state_intf.S
with type field := Field.t
and type constraint_ := Constraint.t
end

module Make (Backend : Backend_intf.S) :
Expand All @@ -91,7 +92,8 @@ module Make (Backend : Backend_intf.S) :
and type Field.Vector.t = Backend.Field.Vector.t
and type Bigint.t = Backend.Bigint.t
and type R1CS_constraint_system.t = Backend.R1CS_constraint_system.t
and type 'field Run_state.t = 'field Backend.Run_state.t = struct
and type Run_state.t = Backend.Run_state.t
and type Constraint.t = Backend.Constraint.t = struct
open Backend

module Bigint = struct
Expand Down Expand Up @@ -207,19 +209,7 @@ module Make (Backend : Backend_intf.S) :
None
end

module Constraint = struct
open Constraint
include Constraint.T

type 'k with_constraint_args = ?label:string -> 'k

type t = (Cvar.t, Field.t) Constraint.t [@@deriving sexp]

let m = (module Field : Snarky_intf.Field.S with type t = Field.t)

let eval { basic; _ } get_value = Constraint.Basic.eval m get_value basic
end

module Constraint = Constraint
module R1CS_constraint_system = R1CS_constraint_system
module Run_state = Run_state
end
26 changes: 24 additions & 2 deletions src/base/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,29 @@ module type S = sig

val field_size : Bigint.t

module R1CS_constraint_system : Constraint_system.S with module Field := Field
module Constraint : sig
type t [@@deriving sexp]

module Run_state : Run_state_intf.S
val boolean : Field.t Cvar.t -> t

val equal : Field.t Cvar.t -> Field.t Cvar.t -> t

val r1cs : Field.t Cvar.t -> Field.t Cvar.t -> Field.t Cvar.t -> t

val square : Field.t Cvar.t -> Field.t Cvar.t -> t

val eval : t -> (Field.t Cvar.t -> Field.t) -> bool

val log_constraint : t -> (Field.t Cvar.t -> Field.t) -> string
end

module R1CS_constraint_system :
Constraint_system.S
with module Field := Field
with type constraint_ = Constraint.t

module Run_state :
Run_state_intf.S
with type field := Field.t
and type constraint_ := Constraint.t
end
47 changes: 25 additions & 22 deletions src/base/checked.ml
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
open Core_kernel

module Make (Field : sig
type t [@@deriving sexp]

val equal : t -> t -> bool
end)
(Types : Types.Types)
(Basic : Checked_intf.Basic with type field = Field.t with module Types := Types)
(As_prover : As_prover_intf.Basic
with type field := Basic.field
with module Types := Types) :
module Make
(Backend : Backend_extended.S)
(Types : Types.Types)
(Basic : Checked_intf.Basic
with type field = Backend.Field.t
and type constraint_ = Backend.Constraint.t
with module Types := Types)
(As_prover : As_prover_intf.Basic
with type field := Basic.field
with module Types := Types) :
Checked_intf.S
with module Types := Types
with type field = Field.t
and type run_state = Basic.run_state = struct
with type field = Backend.Field.t
and type run_state = Basic.run_state
and type constraint_ = Basic.constraint_ = struct
include Basic

let request_witness (typ : ('var, 'value) Types.Typ.t)
Expand Down Expand Up @@ -69,23 +70,25 @@ end)
in
handle t (fun request -> (Option.value_exn !handler) request)

let assert_ ?label c = add_constraint (Constraint.override_label c label)
let assert_ c = add_constraint c

let assert_r1cs ?label a b c = assert_ (Constraint.r1cs ?label a b c)
let assert_r1cs a b c = assert_ (Backend.Constraint.r1cs a b c)

let assert_square ?label a c = assert_ (Constraint.square ?label a c)
let assert_square a c = assert_ (Backend.Constraint.square a c)

let assert_all ?label cs =
let assert_all cs =
List.fold_right cs ~init:(return ()) ~f:(fun c (acc : _ t) ->
bind acc ~f:(fun () ->
add_constraint (Constraint.override_label c label) ) )
bind acc ~f:(fun () -> add_constraint c) )

let assert_equal ?label x y =
let assert_equal x y =
match (x, y) with
| Cvar.Constant x, Cvar.Constant y ->
if Field.equal x y then return ()
if Backend.Field.equal x y then return ()
else
failwithf !"assert_equal: %{sexp: Field.t} != %{sexp: Field.t}" x y ()
failwithf
!"assert_equal: %{sexp: Backend.Field.t} != %{sexp: \
Backend.Field.t}"
x y ()
| _ ->
assert_ (Constraint.equal ?label x y)
assert_ (Backend.Constraint.equal x y)
end
25 changes: 12 additions & 13 deletions src/base/checked_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ module type Basic = sig

type field

type constraint_

type 'a t = 'a Types.Checked.t

type run_state

include Monad_let.S with type 'a t := 'a t

val add_constraint : (field Cvar.t, field) Constraint.t -> unit t
val add_constraint : constraint_ -> unit t

val as_prover : unit Types.As_prover.t -> unit t

Expand All @@ -29,7 +31,7 @@ module type Basic = sig
val direct : (run_state -> run_state * 'a) -> 'a t

val constraint_count :
?weight:((field Cvar.t, field) Constraint.t -> int)
?weight:(constraint_ -> int)
-> ?log:(?start:bool -> string -> int -> unit)
-> (unit -> 'a t)
-> int
Expand All @@ -40,6 +42,8 @@ module type S = sig

type field

type constraint_

type run_state

type 'a t = 'a Types.Checked.t
Expand Down Expand Up @@ -89,25 +93,20 @@ module type S = sig

val with_label : string -> (unit -> 'a t) -> 'a t

val assert_ :
?label:Base.string -> (field Cvar.t, field) Constraint.t -> unit t
val assert_ : constraint_ -> unit t

val assert_r1cs :
?label:Base.string -> field Cvar.t -> field Cvar.t -> field Cvar.t -> unit t
val assert_r1cs : field Cvar.t -> field Cvar.t -> field Cvar.t -> unit t

val assert_square :
?label:Base.string -> field Cvar.t -> field Cvar.t -> unit t
val assert_square : field Cvar.t -> field Cvar.t -> unit t

val assert_all :
?label:Base.string -> (field Cvar.t, field) Constraint.t list -> unit t
val assert_all : constraint_ list -> unit t

val assert_equal :
?label:Base.string -> field Cvar.t -> field Cvar.t -> unit t
val assert_equal : field Cvar.t -> field Cvar.t -> unit t

val direct : (run_state -> run_state * 'a) -> 'a t

val constraint_count :
?weight:((field Cvar.t, field) Constraint.t -> int)
?weight:(constraint_ -> int)
-> ?log:(?start:bool -> string -> int -> unit)
-> (unit -> 'a t)
-> int
Expand Down
64 changes: 16 additions & 48 deletions src/base/checked_runner.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
open Core_kernel
module Constraint0 = Constraint

let stack_to_string = String.concat ~sep:"\n"

Expand All @@ -10,9 +9,7 @@ let eval_constraints_ref = eval_constraints
module T (Backend : Backend_extended.S) = struct
type 'a t =
| Pure of 'a
| Function of
( Backend.Field.t Backend.Run_state.t
-> Backend.Field.t Backend.Run_state.t * 'a )
| Function of (Backend.Run_state.t -> Backend.Run_state.t * 'a)
end

module Simple_types (Backend : Backend_extended.S) = Types.Make_types (struct
Expand Down Expand Up @@ -40,15 +37,15 @@ module Make_checked
with type field := Backend.Field.t
with module Types := Types) =
struct
type run_state = Backend.Field.t Backend.Run_state.t
type run_state = Backend.Run_state.t

type constraint_ = Backend.Constraint.t

type field = Backend.Field.t

type 'a t = 'a T(Backend).t =
| Pure of 'a
| Function of
( Backend.Field.t Backend.Run_state.t
-> Backend.Field.t Backend.Run_state.t * 'a )
| Function of (Backend.Run_state.t -> Backend.Run_state.t * 'a)

let eval (t : 'a t) : run_state -> run_state * 'a =
match t with Pure a -> fun s -> (s, a) | Function g -> g
Expand Down Expand Up @@ -83,7 +80,7 @@ struct

open Backend

let get_value (t : Field.t Run_state.t) : Cvar.t -> Field.t =
let get_value (t : Run_state.t) : Cvar.t -> Field.t =
let get_one i = Run_state.get_variable_value t i in
Cvar.eval (`Return_values_will_be_mutated get_one)

Expand Down Expand Up @@ -143,36 +140,10 @@ struct
f ~at_label_boundary:(`End, lab) None ) ;
(Run_state.set_stack s' stack, y) )

let log_constraint ({ basic; _ } : Constraint.t) s =
let open Constraint0 in
match basic with
| Boolean var ->
Format.(asprintf "Boolean %s" (Field.to_string (get_value s var)))
| Equal (var1, var2) ->
Format.(
asprintf "Equal %s %s"
(Field.to_string (get_value s var1))
(Field.to_string (get_value s var2)))
| Square (var1, var2) ->
Format.(
asprintf "Square %s %s"
(Field.to_string (get_value s var1))
(Field.to_string (get_value s var2)))
| R1CS (var1, var2, var3) ->
Format.(
asprintf "R1CS %s %s %s"
(Field.to_string (get_value s var1))
(Field.to_string (get_value s var2))
(Field.to_string (get_value s var3)))
| _ ->
Format.asprintf
!"%{sexp:(Field.t, Field.t) Constraint0.basic}"
(Constraint0.Basic.map basic ~f:(get_value s))

let add_constraint ~stack ({ basic; annotation } : Constraint.t)
(Constraint_system.T ((module C), system) : Field.t Constraint_system.t) =
let label = Option.value annotation ~default:"<unknown>" in
C.add_constraint system basic ~label:(stack_to_string (label :: stack))
let add_constraint (basic : Constraint.t)
(Constraint_system.T ((module C), system) :
(Field.t, Constraint.t) Constraint_system.t ) =
C.add_constraint system basic

let add_constraint c : _ t =
Function
Expand All @@ -189,19 +160,18 @@ struct
then
failwithf
"Constraint unsatisfied (unreduced):\n\
%s\n\
%s\n\n\
Constraint:\n\
%s\n\
Data:\n\
%s"
(Constraint.annotation c)
(stack_to_string (Run_state.stack s))
(Sexp.to_string (Constraint.sexp_of_t c))
(log_constraint c s) () ;
(Backend.Constraint.log_constraint c (get_value s))
() ;
if not (Run_state.as_prover s) then
Option.iter (Run_state.system s) ~f:(fun system ->
add_constraint ~stack:(Run_state.stack s) c system ) ;
add_constraint c system ) ;
(s, ()) ) )

let with_handler h t : _ t =
Expand Down Expand Up @@ -422,17 +392,15 @@ module type S = sig
module State : sig
val make :
num_inputs:int
-> input:field Run_state.Vector.t
-> input:field Run_state_intf.Vector.t
-> next_auxiliary:int ref
-> aux:field Run_state.Vector.t
-> aux:field Run_state_intf.Vector.t
-> ?system:r1cs
-> ?eval_constraints:bool
-> ?handler:Request.Handler.t
-> with_witness:bool
-> ?log_constraint:
( ?at_label_boundary:[ `End | `Start ] * string
-> (field Cvar.t, field) Constraint.t option
-> unit )
(?at_label_boundary:[ `End | `Start ] * string -> constr -> unit)
-> unit
-> run_state
end
Expand Down
Loading
Loading