Skip to content

Commit

Permalink
Add dichotomic symbolic clz
Browse files Browse the repository at this point in the history
This is supposedly SMT friendly, and keeps the
bit-hacks for the concrete case.
  • Loading branch information
krtab committed Feb 26, 2024
1 parent d17beb3 commit c4040c1
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/int32.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

include Stdlib.Int32

let clz n = Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_leading_zeros n)
let clz =
Some
(fun n -> Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_leading_zeros n))

let ctz n = Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_trailing_zeros n)

Expand Down
2 changes: 1 addition & 1 deletion src/int32.mli
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ val unsigned_to_int : t -> int option

(** unary operators *)

val clz : t -> t
val clz : (t -> t) option

val ctz : t -> t

Expand Down
4 changes: 3 additions & 1 deletion src/int64.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

include Stdlib.Int64

let clz n = Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_leading_zeros n)
let clz =
Some
(fun n -> Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_leading_zeros n))

let ctz n = Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_trailing_zeros n)

Expand Down
2 changes: 1 addition & 1 deletion src/int64.mli
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ val extend_s : int -> t -> t

val abs : t -> t

val clz : t -> t
val clz : (t -> t) option

val ctz : t -> t

Expand Down
53 changes: 48 additions & 5 deletions src/interpret.ml
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,61 @@ module Make (P : Interpret_intf.P) :

let consti i = const_i32 (Int32.of_int i)

let clz_impl_32 n =
let rec aux (lb : int) ub =
if ub = lb + 1 then return (const_i32 (Int32.of_int (32 - ub)))
else begin
let mid = (lb + ub) / 2 in
let two_pow_mid = Int32.shl 1l (Int32.of_int mid) in
let> cond = I32.(lt_u n (const_i32 two_pow_mid)) in
if cond then aux lb mid else aux mid ub
end
in
let> cond = I32.(eqz n) in
if cond then return @@ const 32l else aux 0 32

let clz_impl_64 n =
let rec aux (lb : int) ub =
if ub = lb + 1 then return (const_i64 (Int64.of_int (64 - ub)))
else begin
let mid = (lb + ub) / 2 in
let two_pow_mid = Int64.shl 1L (Int64.of_int mid) in
let> cond = I64.(lt_u n (const_i64 two_pow_mid)) in
if cond then aux lb mid else aux mid ub
end
in
let> cond = I64.(eqz n) in
if cond then return @@ const_i64 64L else aux 0 64

let with_choosing_default_impl f ch_f =
match f with
| Some f -> fun n -> Choice.return (f n)
| None -> fun n -> ch_f n

let exec_iunop stack nn op =
match nn with
| S32 ->
let n, stack = Stack.pop_i32 stack in
let res =
let+ res =
let open I32 in
match op with Clz -> clz n | Ctz -> ctz n | Popcnt -> popcnt n
match op with
| Clz ->
let clz = with_choosing_default_impl clz clz_impl_32 in
clz n
| Ctz -> Choice.return @@ ctz n
| Popcnt -> Choice.return @@ popcnt n
in
Stack.push_i32 stack res
| S64 ->
let n, stack = Stack.pop_i64 stack in
let res =
let+ res =
let open I64 in
match op with Clz -> clz n | Ctz -> ctz n | Popcnt -> popcnt n
match op with
| Clz ->
let clz = with_choosing_default_impl clz clz_impl_64 in
clz n
| Ctz -> Choice.return @@ ctz n
| Popcnt -> Choice.return @@ popcnt n
in
Stack.push_i64 stack res

Expand Down Expand Up @@ -831,7 +872,9 @@ module Make (P : Interpret_intf.P) :
| I64_const n -> st @@ Stack.push_const_i64 stack n
| F32_const f -> st @@ Stack.push_const_f32 stack f
| F64_const f -> st @@ Stack.push_const_f64 stack f
| I_unop (nn, op) -> st @@ exec_iunop stack nn op
| I_unop (nn, op) ->
let* stack = exec_iunop stack nn op in
st stack
| F_unop (nn, op) -> st @@ exec_funop stack nn op
| I_binop (nn, op) ->
let* stack = exec_ibinop stack nn op in
Expand Down
2 changes: 1 addition & 1 deletion src/interpret_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ module type S = sig
-> Func_intf.t
-> value list Result.t choice

val exec_iunop : State.stack -> Types.nn -> Types.iunop -> State.stack
val exec_iunop : State.stack -> Types.nn -> Types.iunop -> State.stack choice

val exec_funop : State.stack -> Types.nn -> Types.funop -> State.stack

Expand Down
4 changes: 2 additions & 2 deletions src/symbolic_value.ml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ module I32 = struct

let zero = const_i32 0l

let clz e = unop ty Clz e
let clz = None

let ctz _ =
(* TODO *)
Expand Down Expand Up @@ -281,7 +281,7 @@ module I64 = struct

let zero = const_i64 0L

let clz e = unop ty Clz e
let clz = None

let ctz _ =
(* TODO *)
Expand Down
2 changes: 1 addition & 1 deletion src/value_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ module type Iop = sig

val zero : num

val clz : num -> num
val clz : (num -> num) option

val ctz : num -> num

Expand Down
51 changes: 51 additions & 0 deletions test/sym/clz_32.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
(module
(import "symbolic" "i32_symbol" (func $i32_symbol (result i32)))
(import "symbolic" "assume" (func $assume (param i32)))
(import "symbolic" "assert" (func $assert (param i32)))

(func $countLeadingZeros (param i32) (result i32)
(local $x i32)
(local $res i32)


;; Initialize local variables
(local.set $res (i32.const 32)) ;; Initialize with the highest possible index of a bit
(local.set $x (local.get 0)) ;; Store the input

;; Loop to find the leading zeros
(block $outter
(loop $loop

;; Check if all bits are shifted out
(if (i32.eqz (local.get $x))
(then (br $outter))
)

;; Shift the input to the right by 1 bit
(local.set $x (i32.shr_u (local.get $x) (i32.const 1)))

;; Decrement the count of zero bits
(local.set $res (i32.sub (local.get $res) (i32.const 1)))

(br $loop)
)
)

;; Return the number of leading zeros
(return (local.get $res))
)

(func $start

(local $n i32)
(local.set $n (call $i32_symbol))

(call $assert (i32.eq
(call $countLeadingZeros (local.get $n))
(i32.clz (local.get $n))
))
)


(start $start)
)
51 changes: 51 additions & 0 deletions test/sym/clz_64.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
(module
(import "symbolic" "i64_symbol" (func $i64_symbol (result i64)))
(import "symbolic" "assume" (func $assume (param i32)))
(import "symbolic" "assert" (func $assert (param i32)))

(func $countLeadingZeros (param i64) (result i64)
(local $x i64)
(local $res i64)


;; Initialize local variables
(local.set $res (i64.const 64)) ;; Initialize with the highest possible index of a bit
(local.set $x (local.get 0)) ;; Store the input

;; Loop to find the leading zeros
(block $outter
(loop $loop

;; Check if all bits are shifted out
(if (i64.eqz (local.get $x))
(then (br $outter))
)

;; Shift the input to the right by 1 bit
(local.set $x (i64.shr_u (local.get $x) (i64.const 1)))

;; Decrement the count of zero bits
(local.set $res (i64.sub (local.get $res) (i64.const 1)))

(br $loop)
)
)

;; Return the number of leading zeros
(return (local.get $res))
)

(func $start

(local $n i64)
(local.set $n (call $i64_symbol))

(call $assert (i64.eq
(call $countLeadingZeros (local.get $n))
(i64.clz (local.get $n))
))
)


(start $start)
)

0 comments on commit c4040c1

Please sign in to comment.