1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
(** Copyright 2024, Mikhail Gavrilenko, Danila Rudnev-Stepanyan, Daniel Vlasenko*)
(** SPDX-License-Identifier: LGPL-3.0-or-later *)
open Common.Ast
open Common.Ast.Expression
open Common.Ast.Pattern
open Common.Ast.Structure
module SSet = Set.Make (String)
module SMap = Map.Make (String)
(* ---------- error monad ---------- *)
type cc_error = Empty_toplevel_let
let string_of_cc_error = function
| Empty_toplevel_let -> "Cannot have empty let-binding at top level"
;;
let ( let* ) = Result.bind
let std_lib_names =
[ "print_int"
; "alloc_block"
; "alloc_closure"
; "apply1"
; "+"
; "-"
; "*"
; "/"
; "="
; "<>"
; "<"
; ">"
; "<="
; ">="
; "&&"
; "||"
]
;;
let rec pattern_vars_list acc = function
| Pat_var v -> v :: acc
| Pat_construct (_, Some p) -> pattern_vars_list acc p
| Pat_tuple (p1, p2, ps) -> List.fold_left pattern_vars_list acc (p1 :: p2 :: ps)
| Pat_constraint (p, _) -> pattern_vars_list acc p
| Pat_any | Pat_constant _ | Pat_construct (_, None) -> acc
;;
let pattern_vars p = pattern_vars_list [] p
let construct_fun patterns body =
match patterns with
| [] -> body
| hd :: tl -> Exp_fun ((hd, tl), body)
;;
(* Recursively computes the set of free variables in expression. *)
let rec free_vars_in bound_vars = function
| Exp_ident id when SSet.mem id bound_vars -> SSet.empty
| Exp_ident id -> SSet.singleton id
| Exp_constant _ | Exp_construct (_, None) -> SSet.empty
| Exp_tuple (e1, e2, es) ->
List.fold_left
(fun acc e -> SSet.union acc (free_vars_in bound_vars e))
SSet.empty
(e1 :: e2 :: es)
| Exp_apply (e1, e2) ->
SSet.union (free_vars_in bound_vars e1) (free_vars_in bound_vars e2)
| Exp_construct (_, Some e) -> free_vars_in bound_vars e
| Exp_constraint (e, _) -> free_vars_in bound_vars e
| Exp_fun ((p, ps), body) ->
let fun_bound_vars =
List.fold_left
(fun acc p -> SSet.union acc (SSet.of_list (pattern_vars p)))
SSet.empty
(p :: ps)
in
free_vars_in (SSet.union bound_vars fun_bound_vars) body
| Exp_if (e1, e2, e3_opt) ->
let fv1 = free_vars_in bound_vars e1 in
let fv2 = free_vars_in bound_vars e2 in
let fv3 =
match e3_opt with
| Some e3 -> free_vars_in bound_vars e3
| None -> SSet.empty
in
SSet.union fv1 (SSet.union fv2 fv3)
| Exp_match (e, (case, cases)) ->
let fv_e = free_vars_in bound_vars e in
let all_cases = case :: cases in
let fv_cases =
List.fold_left
(fun acc { first; second } ->
let case_bound_vars = SSet.of_list (pattern_vars first) in
let fv_second = free_vars_in (SSet.union bound_vars case_bound_vars) second in
SSet.union acc fv_second)
SSet.empty
all_cases
in
SSet.union fv_e fv_cases
| Exp_let (rec_flag, (vb, vbs), body) ->
let bindings = vb :: vbs in
let bound_in_let =
List.fold_left
(fun acc b -> SSet.union acc (SSet.of_list (pattern_vars b.pat)))
SSet.empty
bindings
in
let bound_for_rhss =
if rec_flag = Recursive then SSet.union bound_vars bound_in_let else bound_vars
in
let fv_rhss =
List.fold_left
(fun acc b -> SSet.union acc (free_vars_in bound_for_rhss b.expr))
SSet.empty
bindings
in
let fv_body = free_vars_in (SSet.union bound_vars bound_in_let) body in
SSet.union fv_rhss fv_body
| Exp_function (case, cases) ->
let all_cases = case :: cases in
List.fold_left
(fun acc { first; second } ->
let case_bound_vars = SSet.of_list (pattern_vars first) in
let fv_second = free_vars_in (SSet.union bound_vars case_bound_vars) second in
SSet.union acc fv_second)
SSet.empty
all_cases
;;
(* main function *)
let rec closure_expr toplvl_set env expr =
match expr with
(* if id is already converted functions - apply it to free vars *)
| Exp_ident id ->
(match SMap.find_opt id env with
| Some free_vars when not (SSet.is_empty free_vars) ->
SSet.fold (fun fv acc -> Exp_apply (acc, Exp_ident fv)) free_vars (Exp_ident id)
| _ -> expr)
(* conversion `fun p1 ... -> body` *)
| Exp_fun ((p, ps), body) ->
let patterns = p :: ps in
let fun_bound_vars =
List.fold_left
(fun acc p -> SSet.union acc (SSet.of_list (pattern_vars p)))
SSet.empty
patterns
in
let free_vars = free_vars_in fun_bound_vars body in
let captured_vars = SSet.diff free_vars toplvl_set in
let captured_vars_list = SSet.elements captured_vars in
let new_pats_for_capture = List.map (fun v -> Pat_var v) captured_vars_list in
let saturated_patterns = new_pats_for_capture @ patterns in
let new_body = closure_expr toplvl_set env body in
let new_fun = construct_fun saturated_patterns new_body in
List.fold_left
(fun acc_expr fv -> Exp_apply (acc_expr, Exp_ident fv))
new_fun
captured_vars_list
| Exp_let (rec_flag, (vb, vbs), body) ->
let bindings = vb :: vbs in
let new_bindings, new_env = transform_bindings toplvl_set env rec_flag bindings in
let new_body = closure_expr toplvl_set new_env body in
(match new_bindings with
| [] -> new_body
| hd :: tl -> Exp_let (rec_flag, (hd, tl), new_body))
| Exp_function (case, cases) ->
let fresh_arg_name = "__fun_arg" in
let desugared_expr =
Exp_fun
((Pat_var fresh_arg_name, []), Exp_match (Exp_ident fresh_arg_name, (case, cases)))
in
closure_expr toplvl_set env desugared_expr
| Exp_apply (e1, e2) ->
Exp_apply (closure_expr toplvl_set env e1, closure_expr toplvl_set env e2)
| Exp_tuple (e1, e2, es) ->
let f = closure_expr toplvl_set env in
Exp_tuple (f e1, f e2, List.map f es)
| Exp_if (e1, e2, e3_opt) ->
let f = closure_expr toplvl_set env in
Exp_if (f e1, f e2, Option.map f e3_opt)
| Exp_match (e, (case, cases)) ->
let e' = closure_expr toplvl_set env e in
let transform_case { first; second } =
{ first; second = closure_expr toplvl_set env second }
in
let case' = transform_case case in
let cases' = List.map transform_case cases in
Exp_match (e', (case', cases'))
| Exp_constant _ -> expr
| Exp_construct (id, Some e) -> Exp_construct (id, Some (closure_expr toplvl_set env e))
| Exp_construct (_, None) -> expr
| Exp_constraint (e, t) -> Exp_constraint (closure_expr toplvl_set env e, t)
and transform_bindings toplvl_set env rec_flag bindings =
let transform_one binding env =
let { pat; expr } = binding in
match pat, expr with
| Pat_var v, Exp_fun ((p, ps), body) ->
let patterns = p :: ps in
let bound_in_fun =
List.fold_left
(fun acc p -> SSet.union acc (SSet.of_list (pattern_vars p)))
SSet.empty
patterns
in
let bound_for_body =
if rec_flag = Recursive then SSet.add v bound_in_fun else bound_in_fun
in
let free_vars = free_vars_in bound_for_body body in
let captured_vars = SSet.diff free_vars toplvl_set in
let captured_vars_list = SSet.elements captured_vars in
let new_pats_for_capture = List.map (fun v -> Pat_var v) captured_vars_list in
let saturated_patterns = new_pats_for_capture @ patterns in
let env_for_body = SMap.add v captured_vars env in
let new_body = closure_expr toplvl_set env_for_body body in
let new_fun = construct_fun saturated_patterns new_body in
let final_expr =
List.fold_left
(fun acc fv -> Exp_apply (acc, Exp_ident fv))
new_fun
captured_vars_list
in
let new_binding = { pat; expr = final_expr } in
let final_env = SMap.add v captured_vars env in
new_binding, final_env
| _ ->
let new_expr = closure_expr toplvl_set env expr in
let new_binding = { pat; expr = new_expr } in
let bound_vars = pattern_vars pat in
let final_env =
List.fold_left (fun acc v -> SMap.add v SSet.empty acc) env bound_vars
in
new_binding, final_env
in
if rec_flag = Nonrecursive
then (
let transformed, final_env =
List.fold_left
(fun (bindings_acc, current_env) b ->
let new_b, next_env = transform_one b current_env in
new_b :: bindings_acc, next_env)
([], env)
bindings
in
List.rev transformed, final_env)
else (
let p_vars = List.concat_map (fun b -> pattern_vars b.pat) bindings in
let env_rec = List.fold_left (fun e v -> SMap.add v SSet.empty e) env p_vars in
let env_with_fvs, _ =
List.fold_left
(fun (env_acc, _) b ->
let _, next_env = transform_one b env_acc in
next_env, [])
(env_rec, [])
bindings
in
let transformed, final_env =
List.fold_left
(fun (bindings_acc, _) b ->
let new_b, _ = transform_one b env_with_fvs in
new_b :: bindings_acc, env_with_fvs)
([], env_with_fvs)
bindings
in
List.rev transformed, final_env)
;;
let closure_structure_item_result toplvl_set = function
| Str_eval e ->
let e' = closure_expr toplvl_set SMap.empty e in
Ok (Str_eval e', toplvl_set)
| Str_value (rec_flag, (vb, vbs)) ->
let bindings = vb :: vbs in
let new_bindings, _ = transform_bindings toplvl_set SMap.empty rec_flag bindings in
let new_bound_names = List.concat_map (fun b -> pattern_vars b.pat) new_bindings in
let new_toplvl_set = SSet.union toplvl_set (SSet.of_list new_bound_names) in
(match new_bindings with
| [] -> Error Empty_toplevel_let
| hd :: tl -> Ok (Str_value (rec_flag, (hd, tl)), new_toplvl_set))
| Str_adt _ as item -> Ok (item, toplvl_set)
;;
let closure_structure_item toplvl_set item =
match closure_structure_item_result toplvl_set item with
| Ok x -> x
| Error _e ->
(match item with
| Str_value (rec_flag, (vb, vbs)) ->
let bindings = vb :: vbs in
let bound_names = List.concat_map (fun b -> pattern_vars b.pat) bindings in
let new_toplvl_set = SSet.union toplvl_set (SSet.of_list bound_names) in
Str_value (rec_flag, (vb, vbs)), new_toplvl_set
| _ -> item, toplvl_set)
;;
let cc_program_result (ast : program) : (program, cc_error) result =
let toplvl0 = SSet.of_list std_lib_names in
let rec go acc current_top = function
| [] -> Ok (List.rev acc)
| item :: rest ->
let* new_item, next_top = closure_structure_item_result current_top item in
go (new_item :: acc) next_top rest
in
go [] toplvl0 ast
;;
let cc_program (ast : program) : program =
match cc_program_result ast with
| Ok p -> p
| Error _e -> ast
;;