Skip to content

Commit

Permalink
Generalize over field_var type
Browse files Browse the repository at this point in the history
  • Loading branch information
mrmr1993 committed Jan 7, 2025
1 parent b9f1a2c commit a608ed9
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 63 deletions.
53 changes: 25 additions & 28 deletions snarky_integer/integer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module Interval = struct

let iter t ~f = match t with Constant x -> f x | Less_than x -> f x

let check (type f) ~m:((module M) : f m) t =
let check (type f v) ~m:((module M) : (f, v) m) t =
iter t ~f:(fun x -> assert (B.(x < M.Field.size))) ;
t

Expand Down Expand Up @@ -113,18 +113,15 @@ module Interval = struct
end

(* TODO: Use <= instead of < for the upper bound *)
type 'f t =
{ value : 'f Cvar.t
; interval : Interval.t
; mutable bits : 'f Cvar.t Boolean.t list option
}
type ('f, 'v) t =
{ value : 'v; interval : Interval.t; mutable bits : 'v Boolean.t list option }

let create ~value ~upper_bound =
{ value; interval = Less_than upper_bound; bits = None }

let to_field t = t.value

let constant (type f) ?length ~m:((module M) as m : f m) x =
let constant (type f v) ?length ~m:((module M) as m : (f, v) m) x =
let open M in
assert (B.( < ) x Field.size) ;
let upper_bound = B.(one + x) in
Expand All @@ -145,7 +142,7 @@ let constant (type f) ?length ~m:((module M) as m : f m) x =
constant Boolean.typ B.(shift_right x i land one = one) ) )
}

let shift_left (type f) ~m:((module M) as m : f m) t k =
let shift_left (type f v) ~m:((module M) as m : (f, v) m) t k =
let open M in
let two_to_k = B.(one lsl k) in
{ value = Field.(constant (bigint_to_field ~m two_to_k) * t.value)
Expand All @@ -155,7 +152,7 @@ let shift_left (type f) ~m:((module M) as m : f m) t k =
List.init k ~f:(fun _ -> Boolean.false_) @ bs )
}

let of_bits (type f) ~m:((module M) : f m) bs =
let of_bits (type f v) ~m:((module M) : (f, v) m) bs =
let bs = Bitstring.Lsb_first.to_list bs in
{ value = M.Field.project bs
; interval = Less_than B.(one lsl List.length bs)
Expand All @@ -167,7 +164,7 @@ let of_bits (type f) ~m:((module M) : f m) bs =
a = q * b + r
r < b
*)
let div_mod (type f) ~m:((module M) as m : f m) a b =
let div_mod (type f v) ~m:((module M) as m : (f, v) m) a b =
let open M in
(* Guess (q, r) *)
let q, r =
Expand Down Expand Up @@ -198,7 +195,7 @@ let div_mod (type f) ~m:((module M) as m : f m) a b =
}
, { value = r; interval = b.interval; bits = Some r_bits } )

let subtract_unpacking (type f) ~m:((module M) : f m) a b =
let subtract_unpacking (type f v) ~m:((module M) : (f, v) m) a b =
M.with_label "Integer.subtract_unpacking" (fun () ->
assert (Interval.gte a.interval b.interval) ;
let value = M.Field.(sub a.value b.value) in
Expand All @@ -207,15 +204,15 @@ let subtract_unpacking (type f) ~m:((module M) : f m) a b =
let bits = M.Field.unpack value ~length in
{ value; interval = a.interval; bits = Some bits } )

let add (type f) ~m:((module M) as m : f m) a b =
let add (type f v) ~m:((module M) as m : (f, v) m) a b =
let interval = Interval.(add ~m a.interval b.interval) in
{ value = M.Field.(a.value + b.value); interval; bits = None }

let mul (type f) ~m:((module M) as m : f m) a b =
let mul (type f v) ~m:((module M) as m : (f, v) m) a b =
let interval = Interval.(mul ~m a.interval b.interval) in
{ value = M.Field.(a.value * b.value); interval; bits = None }

let to_bits ?length (type f) ~m:((module M) : f m) t =
let to_bits ?length (type f v) ~m:((module M) : (f, v) m) t =
match t.bits with
| Some bs -> (
let bs = Bitstring.Lsb_first.of_list bs in
Expand All @@ -239,7 +236,7 @@ let to_bits_exn t = Bitstring.Lsb_first.of_list (Option.value_exn t.bits)

let to_bits_opt t = Option.map ~f:Bitstring.Lsb_first.of_list t.bits

let min (type f) ~m:((module M) : f m) (a : f t) (b : f t) =
let min (type f v) ~m:((module M) : (f, v) m) (a : (f, v) t) (b : (f, v) t) =
let open M in
let bit_length =
Int.max (Interval.bits_needed a.interval) (Interval.bits_needed b.interval)
Expand All @@ -250,48 +247,48 @@ let min (type f) ~m:((module M) : f m) (a : f t) (b : f t) =
; bits = None
}

let if_ (type f) ~m:((module M) : f m) cond ~then_ ~else_ =
let if_ (type f v) ~m:((module M) : (f, v) m) cond ~then_ ~else_ =
{ value = M.Field.if_ cond ~then_:then_.value ~else_:else_.value
; interval = Interval.lub then_.interval else_.interval
; bits = None
}

let succ_if (type f) ~m:((module M) as m : f m) t (cond : f Cvar.t Boolean.t) =
let succ_if (type f v) ~m:((module M) as m : (f, v) m) t (cond : v Boolean.t) =
let open M in
{ value = Field.(add (cond :> t) t.value)
; interval = Interval.(lub t.interval (succ ~m t.interval))
; bits = None
}

let succ (type f) ~m:((module M) as m : f m) t =
let succ (type f v) ~m:((module M) as m : (f, v) m) t =
let open M in
{ value = Field.(add one t.value)
; interval = Interval.succ ~m t.interval
; bits = None
}

let equal (type f) ~m:((module M) : f m) a b = M.Field.equal a.value b.value
let equal (type f v) ~m:((module M) : (f, v) m) a b =
M.Field.equal a.value b.value

let max_bits a b =
Int.max (Interval.bits_needed a.interval) (Interval.bits_needed b.interval)

let lt (type f) ~m:((module M) : f m) a b =
let lt (type f v) ~m:((module M) : (f, v) m) a b =
(M.Field.compare ~bit_length:(max_bits a b) a.value b.value).less

let lte (type f) ~m:((module M) : f m) a b =
let lte (type f v) ~m:((module M) : (f, v) m) a b =
(M.Field.compare ~bit_length:(max_bits a b) a.value b.value).less_or_equal

let gte (type f) ~m:((module M) as m : f m) a b = M.Boolean.not (lt ~m a b)
let gte (type f v) ~m:((module M) as m : (f, v) m) a b =
M.Boolean.not (lt ~m a b)

let gt (type f) ~m:((module M) as m : f m) a b = M.Boolean.not (lte ~m a b)
let gt (type f v) ~m:((module M) as m : (f, v) m) a b =
M.Boolean.not (lte ~m a b)

let subtract_unpacking_or_zero (type f) ~m:((module M) as m : f m) a b =
let subtract_unpacking_or_zero (type f v) ~m:((module M) as m : (f, v) m) a b =
let flag = lt ~m a b in
( `Underflow flag
, { value =
M.Field.mul
(M.Field.sub a.value b.value)
(M.Boolean.not flag :> f Cvar.t)
, { value = M.Field.mul (M.Field.sub a.value b.value) (M.Boolean.not flag :> v)
; interval = a.interval
; bits = None
} )
63 changes: 36 additions & 27 deletions snarky_integer/integer.mli
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,28 @@ module Interval : sig
type t = Constant of B.t | Less_than of B.t
end

type 'f t =
{ value : 'f Cvar.t
; interval : Interval.t
; mutable bits : 'f Cvar.t Boolean.t list option
}
type ('f, 'v) t =
{ value : 'v; interval : Interval.t; mutable bits : 'v Boolean.t list option }

(** Create an value representing the given constant value.
The bit representation of the constant is cached, and is padded to [length]
when given.
*)
val constant : ?length:int -> m:'f m -> Bigint.t -> 'f t
val constant : ?length:int -> m:('f, 'v) m -> Bigint.t -> ('f, 'v) t

(** [shift_left ~m x k] is equivalent to multiplying [x] by [2^k].
The result has a cached bit representation whenever the given [x] had a
cached bit representation.
*)
val shift_left : m:'f m -> 'f t -> int -> 'f t
val shift_left : m:('f, 'v) m -> ('f, 'v) t -> int -> ('f, 'v) t

(** Create a value from the given bit string.
The given bit representation is cached.
*)
val of_bits : m:'f m -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t -> 'f t
val of_bits : m:('f, 'v) m -> 'v Boolean.t Bitstring.Lsb_first.t -> ('f, 'v) t

(** Compute the bit representation of the given integer.
Expand All @@ -51,72 +48,81 @@ val of_bits : m:'f m -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t -> 'f t
value is updated to include the cache.
*)
val to_bits :
?length:int -> m:'f m -> 'f t -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t
?length:int
-> m:('f, 'v) m
-> ('f, 'v) t
-> 'v Boolean.t Bitstring.Lsb_first.t

(** Return the cached bit representation, or raise an exception if the bit
representation has not been cached.
*)
val to_bits_exn : 'f t -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t
val to_bits_exn : ('f, 'v) t -> 'v Boolean.t Bitstring.Lsb_first.t

(** Returns [Some bs] for [bs] the cached bit representation, or [None] if the
bit representation has not been cached.
*)
val to_bits_opt : 'f t -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t option
val to_bits_opt : ('f, 'v) t -> 'v Boolean.t Bitstring.Lsb_first.t option

(** [div_mod ~m a b = (q, r)] such that [a = q * b + r] and [r < b].
The bit representations of [q] and [r] are calculated and cached.
NOTE: This uses approximately [log2(a) + 2 * log2(b)] constraints.
*)
val div_mod : m:'f m -> 'f t -> 'f t -> 'f t * 'f t
val div_mod :
m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t * ('f, 'v) t

val to_field : 'f t -> 'f Cvar.t
val to_field : ('f, 'v) t -> 'v

val create : value:'f Cvar.t -> upper_bound:Bigint.t -> 'f t
val create : value:'v -> upper_bound:Bigint.t -> ('f, 'v) t

(** [min ~m x y] returns a value equal the lesser of [x] and [y].
The result does not carry a cached bit representation.
*)
val min : m:'f m -> 'f t -> 'f t -> 'f t
val min : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t

val if_ : m:'f m -> 'f Cvar.t Boolean.t -> then_:'f t -> else_:'f t -> 'f t
val if_ :
m:('f, 'v) m
-> 'v Boolean.t
-> then_:('f, 'v) t
-> else_:('f, 'v) t
-> ('f, 'v) t

(** [succ ~m x] computes the successor [x+1] of [x].
The result does not carry a cached bit representation.
*)
val succ : m:'f m -> 'f t -> 'f t
val succ : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t

(** [succ_if ~m x b] computes the integer [x+1] if [b] is [true], or [x]
otherwise.
The result does not carry a cached bit representation.
*)
val succ_if : m:'f m -> 'f t -> 'f Cvar.t Boolean.t -> 'f t
val succ_if : m:('f, 'v) m -> ('f, 'v) t -> 'v Boolean.t -> ('f, 'v) t

val equal : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val equal : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

val lt : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val lt : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

val lte : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val lte : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

val gt : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val gt : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

val gte : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val gte : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

(** [add ~m x y] computes [x + y].
The result does not carry a cached bit representation.
*)
val add : m:'f m -> 'f t -> 'f t -> 'f t
val add : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t

(** [mul ~m x y] computes [x * y].
The result does not carry a cached bit representation.
*)
val mul : m:'f m -> 'f t -> 'f t -> 'f t
val mul : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t

(** [subtract_unpacking ~m x y] computes [x - y].
Expand All @@ -125,7 +131,7 @@ val mul : m:'f m -> 'f t -> 'f t -> 'f t
NOTE: This uses approximately [log2(x)] constraints.
*)
val subtract_unpacking : m:'f m -> 'f t -> 'f t -> 'f t
val subtract_unpacking : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t

(** [subtract_unpacking_or_zero ~m x y] computes [x - y].
Expand All @@ -139,4 +145,7 @@ val subtract_unpacking : m:'f m -> 'f t -> 'f t -> 'f t
NOTE: This uses approximately [log2(x)] constraints.
*)
val subtract_unpacking_or_zero :
m:'f m -> 'f t -> 'f t -> [ `Underflow of 'f Cvar.t Boolean.t ] * 'f t
m:('f, 'v) m
-> ('f, 'v) t
-> ('f, 'v) t
-> [ `Underflow of 'v Boolean.t ] * ('f, 'v) t
4 changes: 2 additions & 2 deletions snarky_integer/util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ open Snarky_backendless
open Snark
module B = Bigint

let bigint_to_field (type f) ~m:((module M) : f m) =
let bigint_to_field (type f v) ~m:((module M) : (f, v) m) =
let open M in
Fn.compose Bigint.to_field Bigint.of_bignum_bigint

let bigint_of_field (type f) ~m:((module M) : f m) =
let bigint_of_field (type f v) ~m:((module M) : (f, v) m) =
let open M in
Fn.compose Bigint.to_bignum_bigint Bigint.of_field
7 changes: 5 additions & 2 deletions src/base/snark0.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1566,8 +1566,11 @@ module Run = struct
end
end

type 'field m = (module Snark_intf.Run with type field = 'field)
type ('field, 'field_var) m =
(module Snark_intf.Run
with type field = 'field
and type field_var = 'field_var )

let make (type field) (module Backend : Backend_intf.S with type Field.t = field)
: field m =
: (field, field Cvar.t) m =
(module Run.Make (Backend))
11 changes: 9 additions & 2 deletions src/base/snark0.mli
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ exception Runtime_error of string list * exn * string
module Make (Backend : Backend_intf.S) :
Snark_intf.S
with type field = Backend.Field.t
and type field_var = Backend.Cvar.t
and type Bigint.t = Backend.Bigint.t
and type R1CS_constraint_system.t = Backend.R1CS_constraint_system.t
and type Field.Vector.t = Backend.Field.Vector.t
Expand All @@ -31,12 +32,18 @@ module Run : sig
module Make (Backend : Backend_intf.S) :
Snark_intf.Run
with type field = Backend.Field.t
and type field_var = Backend.Cvar.t
and type Bigint.t = Backend.Bigint.t
and type R1CS_constraint_system.t = Backend.R1CS_constraint_system.t
and type Field.Constant.Vector.t = Backend.Field.Vector.t
and type Constraint.t = Backend.Constraint.t
end

type 'field m = (module Snark_intf.Run with type field = 'field)
type ('field, 'field_var) m =
(module Snark_intf.Run
with type field = 'field
and type field_var = 'field_var )

val make : (module Backend_intf.S with type Field.t = 'field) -> 'field m
val make :
(module Backend_intf.S with type Field.t = 'field)
-> ('field, 'field Cvar.t) m
4 changes: 2 additions & 2 deletions src/base/snark_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ module type Basic = sig
type field

(** The variable type over which the R1CS operates. *)
type field_var = field Cvar.t
type field_var

(** The rank-1 constraint system used by this instance. See
{!module:Backend_intf.S.R1CS_constraint_system}. *)
Expand Down Expand Up @@ -1090,7 +1090,7 @@ module type Run_basic = sig
type field

(** The variable type over which the R1CS operates. *)
type field_var = field Cvar.t
type field_var

module Bigint : sig
include Snarky_intf.Bigint_intf.Extended with type field := field
Expand Down

0 comments on commit a608ed9

Please sign in to comment.