1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
(** Copyright 2025, Ksenia Kotelnikova <xeniia.ka@gmail.com>, Sofya Kozyreva <k81sofia@gmail.com>, Vyacheslav Kochergin <vyacheslav.kochergin1@gmail.com> *)
(** SPDX-License-Identifier: LGPL-3.0-or-later *)
open Ast
open Format
type immexpr =
| ImmNum of int (* 42 *)
| ImmId of ident (* a *)
| ITuple of immexpr * immexpr * immexpr list
[@@deriving show { with_path = false }]
type cbinop =
| CPlus (* 42 + a *)
| CMinus (* 42 - a *)
| CMul (* 42 * a *)
| CDiv (* 42 / a *)
| CEq (* 42 = a *)
| CNeq (* 42 != a *)
| CLt (* 42 < a *)
| CLte (* 42 <= a *)
| CGt (* 42 > a *)
| CGte (* 42 >= a *)
[@@deriving show { with_path = false }]
type cexpr =
| CBinop of cbinop * immexpr * immexpr
| CIte of cexpr * aexpr * aexpr option (* if (42 > a) then 42 else a *)
| CImmexpr of immexpr
| CLam of ident * aexpr (* fun a -> a + 42 *)
| CApp of immexpr * immexpr list (* func_name arg1 arg2 ... argn *)
| CField of immexpr * int
[@@deriving show { with_path = false }]
and aexpr =
| ALet of ident * cexpr * aexpr
| ACExpr of cexpr
[@@deriving show { with_path = false }]
type aconstruction =
| AExpr of aexpr
| AStatement of is_recursive * (ident * aexpr) list
[@@deriving show { with_path = false }]
type aconstructions = aconstruction list [@@deriving show { with_path = false }]
type anf_error =
| Unreachable
| Not_Yet_Implemented of string
[@@deriving show { with_path = false }]
let pp_anf_error fmt = function
| Unreachable -> fprintf fmt "Panic: reached unreachable state in ANF computation"
| Not_Yet_Implemented str ->
fprintf fmt "ANF for this structure is not yet implemented: %s" str
[@@deriving show { with_path = false }]
;;
open ResultCounter.ResultCounterMonad
open Syntax
let gen_temp base =
let* c = read in
let new_c = c + 1 in
let* () = write new_c in
return (Ident (Stdlib.Format.sprintf "%s_%d" base c))
;;
let binop_map = function
| Binary_add -> return ("res_of_plus", CPlus)
| Binary_subtract -> return ("res_of_minus", CMinus)
| Binary_multiply -> return ("res_of_mul", CMul)
| Binary_divide -> return ("res_of_div", CDiv)
| Binary_equal -> return ("eq", CEq)
| Binary_unequal -> return ("neq", CNeq)
| Binary_less -> return ("lt", CLt)
| Binary_less_or_equal -> return ("lte", CLte)
| Binary_greater -> return ("gt", CGt)
| Binary_greater_or_equal -> return ("gte", CGte)
| _ -> fail (Not_Yet_Implemented "binary operator")
;;
let rec collect_app_args e =
match e with
| Apply (f, a) ->
let fn, args = collect_app_args f in
fn, args @ [ a ]
| _ -> e, []
;;
module FuncSet = Set.Make (struct
type t = ident
let compare = compare
end)
type state =
{ lifted_lams : (ident * cexpr) list
; lifted_letins : (ident * cexpr) list
; functions : FuncSet.t
}
let empty_state = { lifted_lams = []; lifted_letins = []; functions = FuncSet.empty }
let rec anf_field parent_tuple field body state num =
match field with
| PTuple (fst, snd, rest) ->
let rest = fst :: snd :: rest in
let* tuple_fresh = gen_temp "res_of_tuple_FIELD" in
let* body, state = anf_tuple tuple_fresh rest body state 0 in
return (ALet (tuple_fresh, CField (ImmId parent_tuple, num), body), state)
| PVar id -> return (ALet (id, CField (ImmId parent_tuple, num), body), state)
| Wild -> return (body, state)
| _ -> fail (Not_Yet_Implemented "pattern expr")
and anf_tuple parent_tuple tuple_rest body state num =
match tuple_rest with
| [] -> return (body, state)
| pat :: rest ->
let* body, state = anf_field parent_tuple pat body state num in
anf_tuple parent_tuple rest body state (num + 1)
;;
let rec anf (state : state) e expr_with_hole =
let anf_binop opname op left right expr_with_hole =
let* varname = gen_temp opname in
let* left_anf, state1 =
anf state left (fun limm ->
let* right_anf, state2 =
anf state right (fun rimm ->
let* inner, state3 = expr_with_hole (ImmId varname) in
return (ALet (varname, CBinop (op, limm, rimm), inner), state3))
in
return (right_anf, state2))
in
return (left_anf, state1)
in
match e with
| Const (Int_lt n) -> expr_with_hole (ImmNum n)
| Variable id -> expr_with_hole (ImmId id)
| Bin_expr (op, l, r) ->
let* opname, op_name = binop_map op in
anf_binop opname op_name l r expr_with_hole
| LetIn (_, Let_bind (PVar id, [], expr), [], body) ->
let* body_anf, state1 = anf state body expr_with_hole in
anf state1 expr (fun immval -> return (ALet (id, CImmexpr immval, body_anf), state1))
| LetIn (_, Let_bind (PConst Unit_lt, [], expr), [], body) ->
anf state expr (fun _ -> anf state body expr_with_hole)
| LetIn (_, Let_bind (Wild, [], expr), [], body) ->
anf state expr (fun _ -> anf state body expr_with_hole)
| LetIn (_, Let_bind (PTuple (fst, snd, rest), [], expr), [], body) ->
let* tuple_varname = gen_temp "res_of_tuple_OUTER" in
let rest = fst :: snd :: rest in
let* body_anf, state = anf state body expr_with_hole in
let* body, state = anf_tuple tuple_varname rest body_anf state 0 in
anf state expr (fun immval ->
return (ALet (tuple_varname, CImmexpr immval, body), state))
| LetIn (_, Let_bind (PVar id, args, expr), [], body) ->
let* arg_names =
List.fold_right
(fun pat acc ->
let* names = acc in
match pat with
| PVar s -> return (s :: names)
| _ -> fail (Not_Yet_Implemented "complex patterns"))
args
(return [])
in
let* value, state1 =
anf state expr (fun imm -> return (ACExpr (CImmexpr imm), state))
in
let clams =
List.fold_right (fun id body -> ACExpr (CLam (id, body))) arg_names value
in
let* cclams =
match clams with
| ACExpr c -> return c
| _ -> fail Unreachable
in
let* body, state2 = anf state1 body expr_with_hole in
let state3 =
if List.mem (id, cclams) state2.lifted_letins
then state2
else
{ state2 with
lifted_letins = state2.lifted_letins @ [ id, cclams ]
; functions = FuncSet.add id state2.functions
}
in
return (body, state3)
| If_then_else (cond, thn, Some els) ->
let* thn, state1 = anf state thn expr_with_hole in
let* els, state2 = anf state1 els expr_with_hole in
anf state2 cond (fun condimm ->
return (ACExpr (CIte (CImmexpr condimm, thn, Some els)), state2))
| Apply (f, args) ->
let f, arg_exprs = collect_app_args (Apply (f, args)) in
anf state f (fun fimm ->
let rec anf_args acc st = function
| [] ->
let* varname = gen_temp "res_of_app" in
let* e, state1 = expr_with_hole (ImmId varname) in
return (ALet (varname, CApp (fimm, List.rev acc), e), state1)
| [ Const Unit_lt ] ->
let* varname = gen_temp "res_of_app" in
let* e, state1 = expr_with_hole (ImmId varname) in
return (ALet (varname, CApp (fimm, List.rev acc), e), state1)
| expr :: rest -> anf st expr (fun immval -> anf_args (immval :: acc) st rest)
in
anf_args [] state arg_exprs)
| Lambda (first, rest, body) ->
let* arg_names =
List.fold_right
(fun pat acc ->
let* names = acc in
match pat with
| PVar s -> return (s :: names)
| _ -> fail (Not_Yet_Implemented "complex patterns"))
(first :: rest)
(return [])
in
let* varname = gen_temp "lam" in
let* e, state1 = expr_with_hole (ImmId varname) in
let pair_exists lst (id, ce) =
List.exists (fun (id', ce') -> id = id' && ce = ce') lst
in
let merge_unique_pairs lst1 lst2 =
lst2
|> List.fold_left
(fun acc pair -> if pair_exists lst1 pair then acc else pair :: acc)
lst1
in
let lifted_lams = merge_unique_pairs state1.lifted_lams state.lifted_lams in
let state2 = { state1 with lifted_lams } in
let* body, state2 =
anf state2 body (fun imm -> return (ACExpr (CImmexpr imm), state2))
in
let clams =
List.fold_right (fun id body -> ACExpr (CLam (id, body))) arg_names body
in
let* cclams =
match clams with
| ACExpr c -> return c
| _ -> fail Unreachable
in
let* lifted_name = gen_temp "lifted_lam" in
let state3 =
if List.mem (lifted_name, cclams) state2.lifted_lams
then state2
else
{ state2 with
lifted_lams = state2.lifted_lams @ [ lifted_name, cclams ]
; functions = FuncSet.add lifted_name state2.functions
}
in
return (ALet (varname, CImmexpr (ImmId lifted_name), e), state3)
| Tuple (fst, snd, rest) ->
anf state fst (fun fst_imm ->
anf state snd (fun snd_imm ->
let rec anf_list acc st = function
| [] ->
let* varname = gen_temp "res_of_tuple" in
let* e, state1 = expr_with_hole (ImmId varname) in
return
( ALet (varname, CImmexpr (ITuple (fst_imm, snd_imm, List.rev acc)), e)
, state1 )
| expr :: e_rest ->
anf st expr (fun immval -> anf_list (immval :: acc) st e_rest)
in
anf_list [] state rest))
| _ ->
(* Stdlib.Format.printf "%a@." pp_expr e; *)
fail (Not_Yet_Implemented "ANF expr")
;;
let anf_construction (state : state) = function
| Statement (Let (flag, Let_bind (PVar id, [], expr), [])) ->
let* value, state1 =
anf state expr (fun immval -> return (ACExpr (CImmexpr immval), state))
in
return (AStatement (flag, [ id, value ]), state1)
| Statement (Let (flag, Let_bind (PVar name, args, expr), [])) ->
let* arg_names =
List.fold_right
(fun pat acc ->
let* names = acc in
match pat with
| PVar s -> return (s :: names)
| _ -> fail (Not_Yet_Implemented "complex patterns"))
args
(return [])
in
let* value, state1 =
anf state expr (fun imm -> return (ACExpr (CImmexpr imm), state))
in
let clams =
List.fold_right (fun id body -> ACExpr (CLam (id, body))) arg_names value
in
let state2 = { state1 with functions = FuncSet.add name state1.functions } in
return (AStatement (flag, [ name, clams ]), state2)
| Expr e ->
let* inner, state1 =
anf state e (fun immval -> return (ACExpr (CImmexpr immval), state))
in
return (AExpr inner, state1)
| _ -> fail (Not_Yet_Implemented "ANF construction")
;;
(*
let name = func1 func2 args... in ...
should become
let name1 = func2 args in
let name2 = func1 name1 in ...
*)
let rec refine_applications (state : state) (ae : aexpr) =
match ae with
| ALet (id, CApp (f, args), body) ->
let* body', state' = refine_applications state body in
(match args with
| [] -> return (ALet (id, CApp (f, []), body'), state')
| inner_func :: inner_args ->
(match inner_func, inner_args with
| ImmId name, _ ->
(match FuncSet.find_opt name state.functions with
| Some func ->
let* inner = gen_temp "res_of_inner" in
let new_let =
ALet
( inner
, CApp (ImmId func, inner_args)
, ALet (id, CApp (f, [ ImmId inner ]), body') )
in
let* new_new_let, state'' = refine_applications state' new_let in
return (new_new_let, state'')
| None -> return (ALet (id, CApp (f, args), body'), state'))
| _, _ -> return (ALet (id, CApp (f, args), body'), state')))
| ALet (id, ce, body) ->
let* body', state' = refine_applications state body in
return (ALet (id, ce, body'), state')
| _ -> return (ae, state)
;;
let refine_applications_aconstr = function
| AStatement (flag, [ (name, ae) ]), state ->
let* ae', state' = refine_applications state ae in
return (AStatement (flag, [ name, ae' ]), state')
| AExpr ae, state ->
let* ae', state' = refine_applications state ae in
return (AExpr ae', state')
| _ -> fail (Not_Yet_Implemented "ANF construction")
;;
let rec anf_constructions (state : state) = function
| c :: rest ->
let* c_anf, state1 = anf_construction state c in
let* c_anf, state1 = refine_applications_aconstr (c_anf, state1) in
let* rest_anf, state2 = anf_constructions state1 rest in
return (c_anf :: rest_anf, state2)
| [] -> return ([], state)
;;
(* ---------- Closure conversion & Lambda lifting ---------- *)
module IdentSet = Set.Make (struct
type t = ident
let compare = compare
end)
(* let pp_identset fmt s =
Format.fprintf fmt "{";
IdentSet.iter (fun (Ident id) -> Format.fprintf fmt "%s; " id) s;
Format.fprintf fmt "}"
;; *)
let find_lifted (id : ident) (lams : (ident * cexpr) list) : cexpr option =
match List.find_opt (fun (lam_id, _) -> lam_id = id) lams with
| Some (_, expr) -> Some expr
| None -> None
;;
(* ---------- collect free vars: from fun m -> k (m*n) we get {k, n} ---------- *)
let rec free_vars_imm lams = function
| ImmId id ->
(match find_lifted id lams with
| Some cexpr -> free_vars_cexpr lams cexpr
| None -> IdentSet.singleton id)
| ImmNum _ -> IdentSet.empty
| ITuple (fst, snd, rest) ->
List.fold_left
(fun acc imm -> IdentSet.union acc (free_vars_imm lams imm))
IdentSet.empty
(fst :: snd :: rest)
and free_vars_aexpr lams (expr : aexpr) : IdentSet.t =
match expr with
| ALet (id, cexpr, body) ->
let fv_c = free_vars_cexpr lams cexpr in
let fv_b = free_vars_aexpr lams body in
IdentSet.union fv_c (IdentSet.remove id fv_b)
| ACExpr c -> free_vars_cexpr lams c
and free_vars_cexpr lams = function
| CImmexpr imm -> free_vars_imm lams imm
| CBinop (_, l, r) -> IdentSet.union (free_vars_imm lams l) (free_vars_imm lams r)
| CApp (ImmId f, args) ->
let fv_args =
List.fold_left
(fun acc -> function
| ImmId id -> IdentSet.add id acc
| _ -> acc)
IdentSet.empty
args
in
IdentSet.add f fv_args
| CApp (_, args) ->
List.fold_left
(fun acc -> function
| ImmId id -> IdentSet.add id acc
| _ -> acc)
IdentSet.empty
args
| CLam (param, body) ->
let fv_body = free_vars_aexpr lams body in
(* Format.printf
"Free vars in lambda (param=%s): %a@."
(show_ident param)
pp_identset
fv_body; *)
IdentSet.remove param fv_body
| CIte (cond, t, fopt) ->
let fv_cond = free_vars_cexpr lams cond in
let fv_t = free_vars_aexpr lams t in
let fv_f =
match fopt with
| Some e -> free_vars_aexpr lams e
| None -> IdentSet.empty
in
IdentSet.union fv_cond (IdentSet.union fv_t fv_f)
| CField (imm, _) -> free_vars_imm lams imm
;;
module IdentMap = Map.Make (struct
type t = ident
let compare = compare
end)
let collect_freevars_map (lams : (ident * cexpr) list) : IdentSet.t IdentMap.t =
List.fold_left
(fun acc (id, expr) ->
let fv = free_vars_cexpr lams expr in
IdentMap.add id fv acc)
IdentMap.empty
lams
;;
(* ---------- add free vars: from fun m -> k (m*n) to fun k n m -> k (m*n) ---------- *)
let add_free_args_lam (lams : (ident * cexpr) list) : (ident * cexpr) list =
List.map
(fun (id, expr) ->
let fv = free_vars_cexpr lams expr in
let fv_list = IdentSet.elements fv in
let wrapped =
List.fold_right (fun fv_id acc -> CLam (fv_id, ACExpr acc)) fv_list expr
in
id, wrapped)
lams
;;
(* ---------- apply args: from fun k n m -> k (m*n) to (fun k n m -> k (m*n)) k n ---------- *)
let rec apply_lifted_args_aexpr env (ae : aexpr) : aexpr =
match ae with
| ALet (id, ce, body) ->
ALet (id, apply_lifted_args_cexpr env ce, apply_lifted_args_aexpr env body)
| ACExpr c -> ACExpr (apply_lifted_args_cexpr env c)
and apply_lifted_args_cexpr env (ce : cexpr) : cexpr =
match ce with
| CImmexpr (ImmId lam_id) when IdentMap.mem lam_id env ->
let fv = IdentSet.elements (IdentMap.find lam_id env) in
CApp (ImmId lam_id, List.map (fun v -> ImmId v) fv)
| CApp (ImmId f, args) -> CApp (ImmId f, args)
| CLam (id, body) -> CLam (id, apply_lifted_args_aexpr env body)
| CIte (cond, thn, els) ->
CIte
( apply_lifted_args_cexpr env cond
, apply_lifted_args_aexpr env thn
, Option.map (apply_lifted_args_aexpr env) els )
| CBinop (op, l, r) -> CBinop (op, l, r)
| _ -> ce
;;
let apply_lifted_args_aconstruction env = function
| AExpr ae -> AExpr (apply_lifted_args_aexpr env ae)
| AStatement (flag, binds) ->
let binds' = List.map (fun (id, ae) -> id, apply_lifted_args_aexpr env ae) binds in
AStatement (flag, binds')
;;
(* ---------- lift lambdas ---------- *)
let lift_program (state : state) acs =
let lifted_lams = state.lifted_lams @ state.lifted_letins in
let lifted_top_level =
List.map (fun (id, ce) -> AStatement (Nonrec, [ id, ACExpr ce ])) lifted_lams
in
return (lifted_top_level @ acs)
;;
let anf_and_lift_program ast =
let* anf_program, final_state = anf_constructions empty_state ast in
(* Format.printf "final state %a@. \n" pp_state final_state; *)
let fv_map = collect_freevars_map final_state.lifted_lams in
let wrapped = add_free_args_lam final_state.lifted_lams in
let lams =
List.map (fun (id, expr) -> id, apply_lifted_args_cexpr fv_map expr) wrapped
in
let letins =
List.map
(fun (id, expr) -> id, apply_lifted_args_cexpr fv_map expr)
final_state.lifted_letins
in
let anf_program_with_apps =
List.map (apply_lifted_args_aconstruction fv_map) anf_program
in
let final_state' = { final_state with lifted_lams = lams; lifted_letins = letins } in
lift_program final_state' anf_program_with_apps
;;