Skip to content

Commit

Permalink
Changes per review
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Sep 8, 2023
1 parent 26a20f8 commit ca286a2
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 140 deletions.
9 changes: 5 additions & 4 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ let truncate_dist ud_dists (id : Ast.identifier)
, Some y ) } in
let funapp meta kind name args =
Expr.{Fixed.pattern= FunApp (trans_fn_kind kind name, args); meta} in
let ensure_type tp lb : Expr.Typed.t =
let maybe_promote_to_real tp lb : Expr.Typed.t =
match (tp, Expr.Typed.type_of lb) with
| UnsizedType.UInt, _ -> lb
| _, UInt ->
Expand All @@ -153,7 +153,7 @@ let truncate_dist ud_dists (id : Ast.identifier)
let inclusive_bound tp (lb : Expr.Typed.t) =
if UnsizedType.is_int_type tp then
Expr.Helpers.binop lb Minus Expr.Helpers.one
else ensure_type tp lb in
else maybe_promote_to_real tp lb in
let size_adjust e =
if
(not (UnsizedType.is_container ast_obs.Ast.emeta.type_))
Expand Down Expand Up @@ -185,13 +185,14 @@ let truncate_dist ud_dists (id : Ast.identifier)
(targetme ub.meta.loc
(size_adjust
(funapp ub.meta fk fn
(ensure_type tp ub :: trans_exprs ast_args) ) ) ) ]
(maybe_promote_to_real tp ub :: trans_exprs ast_args) ) ) )
]
| TruncateBetween (lb, ub) ->
let fk, fn, tp = find_function_info cdf_suffices in
let lb, ub = (trans_expr lb, trans_expr ub) in
let expr args =
funapp ub.meta (Ast.StanLib FnPlain) "log_diff_exp"
[ funapp ub.meta fk fn (ensure_type tp ub :: args)
[ funapp ub.meta fk fn (maybe_promote_to_real tp ub :: args)
; funapp ub.meta fk fn (inclusive_bound tp lb :: args) ] in
let statement =
match
Expand Down
32 changes: 15 additions & 17 deletions src/stan_math_backend/Lower_functions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,17 @@ let extra_suffix_args fdsuffix =
| FnRng -> (["base_rng__"], ["RNG"])
| FnLpdf _ | FnPlain -> ([], [])

let signature_comment Program.{fdrt; fdname; fdargs; _} =
GlobalComment
Fmt.(
str "@[<1>@[<v>%a@]@ %s(@[<hov>%a@])@]" UnsizedType.pp_returntype fdrt
fdname
(list ~sep:comma (box UnsizedType.pp_fun_arg))
(List.map ~f:(fun (ad, _id, ty) -> (ad, ty)) fdargs))

let lower_fun_def (functors : Lower_expr.variadic list)
Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} :
defn * fun_defn * struct_defn list =
let comment =
GlobalComment
Fmt.(
str "%a %s(@[%a@])" UnsizedType.pp_returntype fdrt fdname
(list ~sep:comma UnsizedType.pp_fun_arg)
(List.map ~f:(fun (ad, _id, ty) -> (ad, ty)) fdargs)) in
fun_defn * struct_defn list =
let extra_arg_names, extra_template_names = extra_suffix_args fdsuffix in
let template_parameter_and_arg_names is_possibly_eigen_expr variadic_fun_type
=
Expand Down Expand Up @@ -329,7 +331,7 @@ let lower_fun_def (functors : Lower_expr.variadic list)
~args:cpp_args ~cv_qualifiers:[Const] ~body:defn_body () in
make_struct_defn ~param:struct_template ~name:functor_name
~body:[FunDef functor_decl] () in
(comment, fd, functors |> List.map ~f:register_functor)
(fd, functors |> List.map ~f:register_functor)

let get_functor_requirements (p : Program.Numbered.t) =
let open Expr.Fixed in
Expand Down Expand Up @@ -365,21 +367,21 @@ let collect_functors_functions (p : Program.Numbered.t) : defn list =
|> List.stable_dedup
|> List.filter_map ~f:(fun (hof, types) ->
if matching_argtypes d types then Some hof else None ) in
let comment, fn, st = lower_fun_def functors d in
let fn, st = lower_fun_def functors d in
List.iter st ~f:(fun s ->
(* Side effecting, collates functor structs *)
Hashtbl.update structs s.struct_name ~f:(function
| Some x -> {x with body= x.body @ s.body}
| None -> s ) ) ;
(comment, fn) in
fn in
let fun_decls, fun_defns =
p.functions_block
|> List.filter_map ~f:(fun d ->
let comment, fn = register_functors d in
let fn = register_functors d in
if Option.is_none d.fdbody then None
else
let decl, defn = Cpp.split_fun_decl_defn fn in
Some (FunDef decl, [comment; FunDef defn]) )
Some (FunDef decl, [signature_comment d; FunDef defn]) )
|> List.unzip in
let structs = Hashtbl.data structs |> List.map ~f:(fun s -> Struct s) in
fun_decls @ structs @ List.concat fun_defns
Expand Down Expand Up @@ -426,9 +428,7 @@ module Testing = struct
open Fmt

let pp_fun_def_test ppf a =
let comment, defn, st = lower_fun_def [FixedArgs] a in
Cpp.Printing.pp_defn ppf comment ;
cut ppf () ;
let defn, st = lower_fun_def [FixedArgs] a in
Cpp.Printing.pp_fun_defn ppf defn ;
cut ppf () ;
(list ~sep:cut Cpp.Printing.pp_struct_defn) ppf st
Expand All @@ -455,7 +455,6 @@ module Testing = struct
|> print_endline ;
[%expect
{|
// void sars(data matrix, row_vector)
template <typename T0__, typename T1__,
stan::require_all_t<stan::is_eigen_matrix_dynamic<T0__>,
stan::is_vt_not_complex<T0__>,
Expand Down Expand Up @@ -518,7 +517,6 @@ module Testing = struct
|> print_endline ;
[%expect
{|
// matrix sars(data matrix, row_vector, row_vector, array[] matrix)
template <typename T0__, typename T1__, typename T2__, typename T3__,
stan::require_all_t<stan::is_eigen_matrix_dynamic<T0__>,
stan::is_vt_not_complex<T0__>,
Expand Down
116 changes: 63 additions & 53 deletions test/integration/good/code-gen/cpp.expected
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,8 @@ f(const T0__& x,
std::vector<
std::tuple<std::vector<std::tuple<T1__0__0__, T1__0__1__>>,
T1__1__>>>& x2, std::ostream* pstream__);
/* tuple(int, array[] tuple(int, real)) f(int,
array[,] tuple(array[] tuple(int,
int),
array[,] int))
/* tuple(int, array[] tuple(int, real))
f(int, array[,] tuple(array[] tuple(int, int), array[,] int))
*/
template <typename T0__, typename T1__0__0__, typename T1__0__1__,
typename T1__1__,
Expand Down Expand Up @@ -7156,8 +7154,8 @@ foo(const T0__& n, std::ostream* pstream__) {
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* array[] real sho(real, array[] real, array[] real, data array[] real,
data array[] int)
/* array[] real
sho(real, array[] real, array[] real, data array[] real, data array[] int)
*/
template <typename T0__, typename T1__, typename T2__,
stan::require_all_t<stan::math::disjunction<stan::is_autodiff<T0__>,
Expand Down Expand Up @@ -7921,9 +7919,10 @@ covsqrt2corsqrt(const T0__& mat_arg__, const T1__& invert, std::ostream*
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* void f0(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
/* void
f0(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8006,9 +8005,10 @@ f0(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* int f1(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
/* int
f1(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8088,9 +8088,10 @@ f1(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* array[] int f2(int, array[] int, array[,] int, real, array[] real,
array[,] real, vector, array[] vector, array[,] vector,
matrix, array[] matrix, array[,] matrix)
/* array[] int
f2(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8170,9 +8171,10 @@ f2(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* array[,] int f3(int, array[] int, array[,] int, real, array[] real,
array[,] real, vector, array[] vector, array[,] vector,
matrix, array[] matrix, array[,] matrix)
/* array[,] int
f3(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8252,9 +8254,10 @@ f3(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* real f4(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
/* real
f4(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8337,9 +8340,10 @@ f4(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* array[] real f5(int, array[] int, array[,] int, real, array[] real,
array[,] real, vector, array[] vector, array[,] vector,
matrix, array[] matrix, array[,] matrix)
/* array[] real
f5(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8424,9 +8428,10 @@ f5(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* array[,] real f6(int, array[] int, array[,] int, real, array[] real,
array[,] real, vector, array[] vector, array[,] vector,
matrix, array[] matrix, array[,] matrix)
/* array[,] real
f6(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8512,10 +8517,10 @@ f6(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* vector f7(int, array[] int, array[,] int, real, array[] real, array[,]
real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
/* vector
f7(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8600,9 +8605,10 @@ f7(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* array[] vector f8(int, array[] int, array[,] int, real, array[] real,
array[,] real, vector, array[] vector, array[,] vector,
matrix, array[] matrix, array[,] matrix)
/* array[] vector
f8(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8688,9 +8694,10 @@ f8(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* array[,] vector f9(int, array[] int, array[,] int, real, array[] real,
array[,] real, vector, array[] vector, array[,] vector,
matrix, array[] matrix, array[,] matrix)
/* array[,] vector
f9(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8777,10 +8784,10 @@ f9(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* matrix f10(int, array[] int, array[,] int, real, array[] real, array[,]
real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
/* matrix
f10(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8865,9 +8872,10 @@ f10(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* array[] matrix f11(int, array[] int, array[,] int, real, array[] real,
array[,] real, vector, array[] vector, array[,] vector,
matrix, array[] matrix, array[,] matrix)
/* array[] matrix
f11(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -8953,9 +8961,10 @@ f11(const T0__& a1, const T1__& a2, const T2__& a3, const T3__& a4,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* array[,] matrix f12(int, array[] int, array[,] int, real, array[] real,
array[,] real, vector, array[] vector, array[,] vector,
matrix, array[] matrix, array[,] matrix)
/* array[,] matrix
f12(int, array[] int, array[,] int, real, array[] real, array[,] real,
vector, array[] vector, array[,] vector, matrix, array[] matrix,
array[,] matrix)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down Expand Up @@ -24471,8 +24480,8 @@ struct dz_dt_functor__ {
return dz_dt(t, z, theta, x_r, x_i, pstream__);
}
};
/* array[] real dz_dt(real, array[] real, array[] real, array[] real,
array[] int)
/* array[] real
dz_dt(real, array[] real, array[] real, array[] real, array[] int)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__,
Expand Down Expand Up @@ -36259,10 +36268,11 @@ g12(const T0__& y_slice, const T1__& start, const T2__& end, const T3__& a,
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
/* real s(array[] real, int, int, int, real, vector, row_vector, matrix,
array[] int, array[] real, array[] vector, array[] row_vector,
array[] matrix, array[,] int, array[,] real, array[,] vector,
array[,] row_vector, array[,] matrix, array[,,] int, array[,,] real)
/* real
s(array[] real, int, int, int, real, vector, row_vector, matrix,
array[] int, array[] real, array[] vector, array[] row_vector,
array[] matrix, array[,] int, array[,] real, array[,] vector,
array[,] row_vector, array[,] matrix, array[,,] int, array[,,] real)
*/
template <typename T0__, typename T1__, typename T2__, typename T3__,
typename T4__, typename T5__, typename T6__, typename T7__,
Expand Down
Loading

0 comments on commit ca286a2

Please sign in to comment.