-
-
Notifications
You must be signed in to change notification settings - Fork 50
Expand file tree
/
Copy pathProgram.ml
More file actions
158 lines (132 loc) · 5.32 KB
/
Program.ml
File metadata and controls
158 lines (132 loc) · 5.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
(** Defines the core of the MIR *)
open Core
type fun_arg_decl = (UnsizedType.autodifftype * string * UnsizedType.t) list
[@@deriving sexp, hash, map]
type 'a fun_def =
{ fdrt: UnsizedType.returntype
; fdname: string
; fdsuffix: unit Fun_kind.suffix
; fdargs: (UnsizedType.autodifftype * string * UnsizedType.t) list
; fdbody: 'a option
(* If fdbody is None, this is an external function declaration (forward
decls are removed during AST lowering) *)
; fdloc: (Location_span.t[@sexp.opaque] [@compare.ignore]) }
[@@deriving compare, hash, sexp, map, fold]
type io_block = Parameters | TransformedParameters | GeneratedQuantities
[@@deriving sexp, hash]
type 'e outvar =
{ out_unconstrained_st: 'e SizedType.t
; out_constrained_st: 'e SizedType.t
; out_block: io_block
; out_trans: 'e Transformation.t }
[@@deriving sexp, map, hash, fold]
type ('a, 'b, 'm) t =
{ functions_block: 'b fun_def list
; input_vars: (string * 'm * 'a SizedType.t) list
; prepare_data: 'b list (* data & transformed data decls and statements *)
; log_prob: 'b list (* assumes data & params are in scope and ready *)
; reverse_mode_log_prob: 'b list
(* assumes data & params ready & in scope. A copy of log_prob but with
optimizations which are specific to reverse-mode autodiff. This is
used in the C++ backend, but can be ignored if not needed. It is
initialized to [[]] in [Ast_to_Mir], set it equal to log_prob before
calling into the optimization suite if desired. *)
; generate_quantities: 'b list
(* assumes data & params ready & in scope *)
(* NOTE: the following two items are really backend-specific,
and are set to [] by Ast_to_mir before being populated in
Stan_math_backend.Transform_Mir.
It would be nice to abstract this out somehow
*)
; transform_inits: 'b list
; unconstrain_array: 'b list
; output_vars: (string * 'm * 'a outvar) list
; prog_name: string
; prog_path: string }
[@@deriving sexp, map, fold]
(* -- Pretty printers -- *)
let pp_fun_arg_decl ppf (autodifftype, name, unsizedtype) =
Fmt.pf ppf "%a%a %s" UnsizedType.pp_autodifftype autodifftype UnsizedType.pp
unsizedtype name
let pp_fun_def pp_s ppf {fdrt; fdname; fdargs; fdbody; _} =
match fdbody with
| None ->
Fmt.pf ppf "@[<v2>extern %a %s%a;@]" UnsizedType.pp_returntype fdrt fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs
| Some s ->
Fmt.pf ppf "@[<v2>%a %s%a {@ %a@]@ }" UnsizedType.pp_returntype fdrt
fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_s s
let pp_io_block ppf = function
| Parameters -> Fmt.string ppf "parameters"
| TransformedParameters -> Fmt.string ppf "transformed_parameters"
| GeneratedQuantities -> Fmt.string ppf "generated_quantities"
let pp_block label pp_elem ppf = function
| [] -> ()
| elems ->
Fmt.pf ppf "@[<v2>%s {@ %a@]@ }@\n" label
Fmt.(list ~sep:cut pp_elem)
elems
let pp_functions_block pp_s ppf functions_block =
pp_block "functions" pp_s ppf functions_block
let pp_prepare_data pp_s ppf prepare_data =
pp_block "prepare_data" pp_s ppf prepare_data
let pp_log_prob pp_s ppf log_prob = pp_block "log_prob" pp_s ppf log_prob
let pp_reverse_mode_log_prob pp_s ppf log_prob =
pp_block "rev_log_prob" pp_s ppf log_prob
let pp_generate_quantities pp_s ppf generate_quantities =
pp_block "generate_quantities" pp_s ppf generate_quantities
let pp_transform_inits pp_s ppf transform_inits =
pp_block "transform_inits" pp_s ppf transform_inits
let pp_output_var pp_e ppf
(name, _, {out_unconstrained_st; out_constrained_st; out_block; _}) =
Fmt.pf ppf "@[<hov 2>%a %a %s;@ //%a@]" pp_io_block out_block
(SizedType.pp pp_e) out_constrained_st name (SizedType.pp pp_e)
out_unconstrained_st
let pp_input_var pp_e ppf (name, _, sized_ty) =
Fmt.pf ppf "@[<h>%a %s;@]" (SizedType.pp pp_e) sized_ty name
let pp_input_vars pp_e ppf input_vars =
pp_block "input_vars" (pp_input_var pp_e) ppf input_vars
let pp_output_vars pp_e ppf output_vars =
pp_block "output_vars" (pp_output_var pp_e) ppf output_vars
let pp pp_e pp_s ppf
{ functions_block
; input_vars
; prepare_data
; log_prob
; reverse_mode_log_prob
; generate_quantities
; transform_inits
; output_vars
; _ } =
Format.open_vbox 0;
pp_functions_block (pp_fun_def pp_s) ppf functions_block;
Fmt.cut ppf ();
pp_input_vars pp_e ppf input_vars;
Fmt.cut ppf ();
pp_prepare_data pp_s ppf prepare_data;
Fmt.cut ppf ();
pp_log_prob pp_s ppf log_prob;
Fmt.cut ppf ();
pp_reverse_mode_log_prob pp_s ppf reverse_mode_log_prob;
Fmt.cut ppf ();
pp_generate_quantities pp_s ppf generate_quantities;
Fmt.cut ppf ();
pp_transform_inits pp_s ppf transform_inits;
Fmt.cut ppf ();
pp_output_vars pp_e ppf output_vars;
Format.close_box ()
(** Programs with typed expressions and locations *)
module Typed = struct
type nonrec t = (Expr.Typed.t, Stmt.Located.t, Location_span.t) t
let pp ppf x = pp Expr.Typed.pp Stmt.Located.pp ppf x
let sexp_of_t : t -> Sexp.t =
sexp_of_t Expr.Typed.sexp_of_t Stmt.Located.sexp_of_t
Sexplib.Conv.sexp_of_opaque
end
module Numbered = struct
type nonrec t = (Expr.Typed.t, Stmt.Numbered.t, int) t
let pp ppf x = pp Expr.Typed.pp Stmt.Numbered.pp ppf x
end