1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
(** Copyright 2024-2025, Danil Usoltsev *)
(** SPDX-License-Identifier: LGPL-3.0-or-later *)
(* Template: https://gitlab.com/Kakadu/fp2020course-materials/-/tree/master/code/miniml?ref_type=heads*)
open Base
open Frontend.Ast
open Stdlib.Format
type error =
| OccursCheck of string * ty
| NoVariable of string
| UnificationFailed of ty * ty
| SeveralBounds of string
| LHS of string
| RHS of string
| UnexpectedFunction of ty
let pp_error fmt = function
| OccursCheck (id, ty) ->
fprintf fmt "Occurs check failed. Type variable '%s' occurs inside %a." id pp_ty ty
| NoVariable name -> fprintf fmt "Unbound variable '%s'." name
| UnificationFailed (ty1, ty2) ->
fprintf fmt "Failed to unify types: %a and %a." pp_ty ty1 pp_ty ty2
| SeveralBounds name -> fprintf fmt "Multiple bounds for variable '%s'." name
| LHS msg -> fprintf fmt "Left-hand side error: %s." msg
| RHS msg -> fprintf fmt "Right-hand side error: %s." msg
| UnexpectedFunction ty1 -> fprintf fmt "UnexpectedFunction error: %a" pp_ty ty1
;;
module VarSet = struct
include Stdlib.Set.Make (String)
end
module ResultMonad : sig
type 'a t
val return : 'a -> 'a t
val fail : error -> 'a t
include Monad.Infix with type 'a t := 'a t
module Syntax : sig
val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t
end
val fresh : int t
val current_level : int t
val enter_level : unit t
val leave_level : unit t
val set_var_level : string -> int -> unit t
val get_var_level : string -> int option t
val run : 'a t -> ('a, error) Result.t
module RMap : sig
val fold : ('a, 'b, 'c) Map.t -> init:'d t -> f:('a -> 'b -> 'd -> 'd t) -> 'd t
end
end = struct
type state =
{ counter : int
; current_level : int
; var_levels : (string, int, String.comparator_witness) Map.t
}
type 'a t = state -> state * ('a, error) Result.t
let ( >>= ) m f state =
let last, r = m state in
match r with
| Result.Error x -> last, Result.fail x
| Result.Ok a -> f a last
;;
let return x last = last, Result.return x
let fail e st = st, Result.fail e
let ( >>| ) m f st =
match m st with
| st, Ok x -> st, Result.return (f x)
| st, Result.Error e -> st, Result.fail e
;;
module Syntax = struct
let ( let* ) = ( >>= )
end
module RMap = struct
let fold map ~init ~f =
Map.fold map ~init ~f:(fun ~key ~data acc ->
let open Syntax in
let* acc = acc in
f key data acc)
;;
end
let fresh : int t =
fun st -> { st with counter = st.counter + 1 }, Result.return st.counter
;;
let current_level : int t = fun st -> st, Result.return st.current_level
let enter_level : unit t =
fun st -> { st with current_level = st.current_level + 1 }, Result.return ()
;;
let leave_level : unit t =
fun st -> { st with current_level = max 0 (st.current_level - 1) }, Result.return ()
;;
let set_var_level var lvl : unit t =
fun st ->
{ st with var_levels = Map.set st.var_levels ~key:var ~data:lvl }, Result.return ()
;;
let get_var_level var : int option t =
fun st -> st, Result.return (Map.find st.var_levels var)
;;
let run monad =
snd (monad { counter = 0; current_level = 0; var_levels = Map.empty (module String) })
;;
end
module Type = struct
let rec occurs_in var = function
| TyVar b -> String.equal b var
| TyArrow (left, right) -> occurs_in var left || occurs_in var right
| TyTuple types -> List.exists types ~f:(occurs_in var)
| TyList ty -> occurs_in var ty
| TyOption ty -> occurs_in var ty
| TyPrim _ -> false
;;
let free_vars =
let rec helper acc = function
| TyVar b -> VarSet.add b acc
| TyArrow (left, right) -> helper (helper acc left) right
| TyTuple types -> List.fold_left types ~init:acc ~f:helper
| TyList ty -> helper acc ty
| TyOption ty -> helper acc ty
| TyPrim _ -> acc
in
helper VarSet.empty
;;
end
module Substitution : sig
type t
val empty : t
val singleton : string -> ty -> t ResultMonad.t
val remove : t -> string -> t
val apply : t -> ty -> ty
val unify : ty -> ty -> t ResultMonad.t
val compose : t -> t -> t ResultMonad.t
val compose_all : t list -> t ResultMonad.t
end = struct
open ResultMonad
open ResultMonad.Syntax
type t = (string, ty, String.comparator_witness) Map.t
let empty = Map.empty (module String)
let mapping key value =
if Type.occurs_in key value
then fail (OccursCheck (key, value))
else
let* key_lvl = get_var_level key in
let vars = Type.free_vars value |> VarSet.elements in
let* () =
match key_lvl with
| None -> return ()
| Some key_lvl ->
List.fold_left vars ~init:(return ()) ~f:(fun acc v ->
let* () = acc in
let* v_lvl = get_var_level v in
match v_lvl with
| Some v_lvl when v_lvl > key_lvl -> set_var_level v key_lvl
| _ -> return ())
in
return (key, value)
;;
let singleton key value =
match value with
| TyVar v when String.equal v key -> return empty
| _ ->
let* key, value = mapping key value in
return (Map.singleton (module String) key value)
;;
let find = Map.find
let remove = Map.remove
let apply subst =
let rec helper = function
| TyPrim x -> TyPrim x
| TyVar b as ty ->
(match find subst b with
| None -> ty
| Some x -> x)
| TyArrow (left, right) -> TyArrow (helper left, helper right)
| TyList ty -> TyList (helper ty)
| TyOption ty -> TyOption (helper ty)
| TyTuple types -> TyTuple (List.map ~f:helper types)
in
helper
;;
let rec unify left right =
match left, right with
| TyPrim l, TyPrim r when String.equal l r -> return empty
| TyPrim _, TyPrim _ -> fail (UnificationFailed (left, right))
| TyVar l, TyVar r when String.equal l r -> return empty
| TyVar b, ty | ty, TyVar b -> singleton b ty
| TyArrow (left1, right1), TyArrow (left2, right2) ->
let* subst1 = unify left1 left2 in
let* subst2 = unify (apply subst1 right1) (apply subst1 right2) in
compose subst1 subst2
| TyTuple types1, TyTuple types2 ->
if List.length types1 <> List.length types2
then fail (UnificationFailed (left, right))
else (
let rec unify_tuples subst types1 types2 =
match types1, types2 with
| [], [] -> return subst
| t1 :: rest1, t2 :: rest2 ->
let* subst' = unify (apply subst t1) (apply subst t2) in
let* composed_subst = compose subst subst' in
unify_tuples composed_subst rest1 rest2
| _, _ -> fail (UnificationFailed (left, right))
in
unify_tuples empty types1 types2)
| TyList ty1, TyList ty2 -> unify ty1 ty2
| TyOption ty1, TyOption ty2 -> unify ty1 ty2
| _ -> fail (UnificationFailed (left, right))
and extend key value subst =
match find subst key with
| None ->
let value = apply subst value in
let* subst2 = singleton key value in
RMap.fold subst ~init:(return subst2) ~f:(fun key value acc ->
let value = apply subst2 value in
let* key, value = mapping key value in
return (Map.update acc key ~f:(fun _ -> value)))
| Some value2 ->
let* subst2 = unify value value2 in
compose subst subst2
and compose subst1 subst2 = RMap.fold subst2 ~init:(return subst1) ~f:extend
let compose_all =
List.fold_left ~init:(return empty) ~f:(fun acc subst ->
let* acc = acc in
compose acc subst)
;;
end
module Scheme = struct
type t = Scheme of VarSet.t * ty
let free_vars (Scheme (vars, ty)) = VarSet.diff (Type.free_vars ty) vars
let apply subst (Scheme (vars, ty)) =
let subst2 =
VarSet.fold (fun key subst -> Substitution.remove subst key) vars subst
in
Scheme (vars, Substitution.apply subst2 ty)
;;
end
module TypeEnv = struct
type t = (ident, Scheme.t, String.comparator_witness) Map.t
let extend env key value = Map.update env key ~f:(fun _ -> value)
let free_vars : t -> VarSet.t =
Map.fold ~init:VarSet.empty ~f:(fun ~key:_ ~data:scheme acc ->
VarSet.union acc (Scheme.free_vars scheme))
;;
let apply subst env = Map.map env ~f:(Scheme.apply subst)
let find = Map.find
let keys = Map.keys
let initial_env =
let open Base.Map in
empty (module String)
|> set
~key:"print_int"
~data:(Scheme.Scheme (VarSet.empty, TyArrow (TyPrim "int", TyPrim "unit")))
|> set
~key:"print_endline"
~data:(Scheme.Scheme (VarSet.empty, TyArrow (TyPrim "string", TyPrim "unit")))
|> set
~key:"print_bool"
~data:(Scheme.Scheme (VarSet.empty, TyArrow (TyPrim "bool", TyPrim "unit")))
;;
let env_with_gc =
let open Base.Map in
initial_env
|> set
~key:"collect"
~data:(Scheme.Scheme (VarSet.empty, TyArrow (TyPrim "unit", TyPrim "unit")))
|> set
~key:"print_gc_status"
~data:(Scheme.Scheme (VarSet.empty, TyArrow (TyPrim "unit", TyPrim "unit")))
|> set
~key:"get_heap_start"
~data:(Scheme.Scheme (VarSet.empty, TyArrow (TyPrim "unit", TyPrim "int")))
|> set
~key:"get_heap_final"
~data:(Scheme.Scheme (VarSet.empty, TyArrow (TyPrim "unit", TyPrim "int")))
;;
end
open ResultMonad
open ResultMonad.Syntax
let fresh_var =
let* n = fresh in
let* lvl = current_level in
let name = "t" ^ Int.to_string n in
let* () = set_var_level name lvl in
return (TyVar name)
;;
let instantiate : Scheme.t -> ty ResultMonad.t =
fun (Scheme (vars, ty)) ->
VarSet.fold
(fun var typ ->
let* typ = typ in
let* fresh_ty = fresh_var in
let* subst = Substitution.singleton var fresh_ty in
return (Substitution.apply subst typ))
vars
(return ty)
;;
let generalize _env ty =
let* lvl = current_level in
let vars = Type.free_vars ty |> VarSet.elements in
let* generic =
List.fold_left vars ~init:(return VarSet.empty) ~f:(fun acc v ->
let* acc = acc in
let* v_lvl = get_var_level v in
match v_lvl with
| Some v_lvl when v_lvl > lvl -> return (VarSet.add v acc)
| _ -> return acc)
in
return (Scheme.Scheme (generic, ty))
;;
let infer_const = function
| ConstInt _ -> TyPrim "int"
| ConstBool _ -> TyPrim "bool"
| ConstString _ -> TyPrim "string"
| ConstChar _ -> TyPrim "char"
;;
let rec infer_pattern env = function
| PatAny ->
let* fresh = fresh_var in
return (Substitution.empty, fresh, env)
| PatConst const -> return (Substitution.empty, infer_const const, env)
| PatVariable var ->
let* fresh = fresh_var in
let env = TypeEnv.extend env var (Scheme.Scheme (VarSet.empty, fresh)) in
return (Substitution.empty, fresh, env)
| PatTuple (first_pat, second_pat, rest_pats) ->
let* sub_first, type_first, env_first = infer_pattern env first_pat in
let updated_env_second = TypeEnv.apply sub_first env_first in
let* sub_second, type_second, env_second =
infer_pattern updated_env_second second_pat
in
let process_remaining_patterns acc pat =
let open ResultMonad.Syntax in
let* current_sub, types, current_env = acc in
let* sub_new, type_new, env_new = infer_pattern current_env pat in
let* combined_sub = Substitution.compose current_sub sub_new in
return (combined_sub, type_new :: types, env_new)
in
let initial_state = return (sub_second, [ type_second; type_first ], env_second) in
let* final_sub, collected_types, final_env =
List.fold_left rest_pats ~init:initial_state ~f:process_remaining_patterns
in
let tuple_type = TyTuple (List.rev collected_types) in
return (final_sub, tuple_type, final_env)
| PatList pats ->
let* fresh_el_type = fresh_var in
let* final_sub, final_env =
List.fold_left
pats
~init:(return (Substitution.empty, env))
~f:(fun acc pat ->
let open ResultMonad.Syntax in
let* sub_acc, env_acc = acc in
let* sub_cur, el_type, env_cur = infer_pattern env_acc pat in
let* unified_sub = Substitution.compose sub_acc sub_cur in
let* final_sub =
Substitution.unify (Substitution.apply sub_cur fresh_el_type) el_type
in
let* combined_sub = Substitution.compose unified_sub final_sub in
return (combined_sub, TypeEnv.apply final_sub env_cur))
in
return (final_sub, TyList (Substitution.apply final_sub fresh_el_type), final_env)
| PatOption opt ->
let* sub, typ, env =
match opt with
| None ->
let* fresh = fresh_var in
return (Substitution.empty, fresh, env)
| Some p -> infer_pattern env p
in
return (sub, TyOption typ, env)
| PatType (pat, annotated_ty) ->
let* subst, inferred_ty, env = infer_pattern env pat in
let* unified_subst = Substitution.unify inferred_ty annotated_ty in
let* total_subst = Substitution.compose subst unified_subst in
return
( total_subst
, Substitution.apply total_subst annotated_ty
, TypeEnv.apply total_subst env )
| PatUnit -> return (Substitution.empty, TyPrim "unit", env)
| PatConstruct (name, opt) ->
(match name, opt with
| "()", None -> return (Substitution.empty, TyPrim "unit", env)
| "None", None ->
let* fresh = fresh_var in
return (Substitution.empty, TyOption fresh, env)
| "Some", Some p ->
let* sub, typ, env' = infer_pattern env p in
return (sub, TyOption typ, env')
| "[]", None ->
let* fresh = fresh_var in
return (Substitution.empty, TyList fresh, env)
| "::", Some (PatTuple (_, _, []) as pair_pat) ->
let* sub_pair, ty_pair, env' = infer_pattern env pair_pat in
let* fresh_hd = fresh_var in
let* fresh_tl = fresh_var in
let* sub_cons = Substitution.unify ty_pair (TyTuple [ fresh_hd; fresh_tl ]) in
let* sub_total = Substitution.compose sub_cons sub_pair in
return
( sub_total
, Substitution.apply sub_total (TyList fresh_hd)
, TypeEnv.apply sub_total env' )
| "::", _ -> fail (RHS "Constructor (::) expects a pair pattern")
| _ -> fail (RHS ("Unknown constructor: " ^ name)))
;;
let infer_binop_type = function
| Equal | NotEqual | GreaterThan | GreatestEqual | LowerThan | LowestEqual ->
fresh_var >>| fun fresh_ty -> fresh_ty, fresh_ty, TyPrim "bool"
| Plus | Minus | Multiply | Division -> return (TyPrim "int", TyPrim "int", TyPrim "int")
| And | Or -> return (TyPrim "bool", TyPrim "bool", TyPrim "bool")
| Custom _ -> fail (NoVariable "infer_binop_type: Custom handled in infer_expr")
;;
(* Returns (arg_ty1, arg_ty2, res_ty, subst_op). For Custom the caller must ensure op_name is in env. *)
let get_binop_arg_res env op =
match op with
| Custom op_name ->
let* op_scheme =
match TypeEnv.find env op_name with
| Some s -> return s
| None -> fail (NoVariable op_name)
in
let* op_ty = instantiate op_scheme in
let* arg_ty1 = fresh_var in
let* arg_ty2 = fresh_var in
let* res_ty = fresh_var in
let* subst =
Substitution.unify op_ty (TyArrow (arg_ty1, TyArrow (arg_ty2, res_ty)))
in
return
( Substitution.apply subst arg_ty1
, Substitution.apply subst arg_ty2
, Substitution.apply subst res_ty
, subst )
| _ ->
let* ty1, ty2, ty_res = infer_binop_type op in
return (ty1, ty2, ty_res, Substitution.empty)
;;
let rec infer_expr env = function
| ExpConst const -> return (Substitution.empty, infer_const const)
| ExpIdent var ->
(match TypeEnv.find env var with
| Some scheme ->
let* ty = instantiate scheme in
return (Substitution.empty, ty)
| None -> fail (NoVariable var))
| ExpUnarOper (operation, expr) ->
let* subst, ty = infer_expr env expr in
let* operation_type =
match operation with
| Negative -> return (TyArrow (TyPrim "int", TyPrim "int"))
| Not -> return (TyArrow (TyPrim "bool", TyPrim "bool"))
in
let* subst2 =
match operation_type with
| TyArrow (arg, _) -> Substitution.unify ty arg
| ty -> fail (UnexpectedFunction ty)
in
let* subst2 = Substitution.compose_all [ subst2; subst ] in
(match operation_type with
| TyArrow (_, x) -> return (subst2, Substitution.apply subst2 x)
| ty -> fail (UnexpectedFunction ty))
| ExpBinOper (op, expr1, expr2) ->
let* subst1, ty = infer_expr env expr1 in
let* subst2, ty' = infer_expr (TypeEnv.apply subst1 env) expr2 in
let* arg_ty1, arg_ty2, res_ty, subst_op =
match op with
| Custom op_name when Option.is_none (TypeEnv.find env op_name) ->
(match builtin_op_of_string op_name with
| Some builtin_op -> get_binop_arg_res env builtin_op
| None -> fail (NoVariable op_name))
| _ -> get_binop_arg_res env op
in
let* subst3 =
Substitution.unify
(Substitution.apply subst2 ty)
(Substitution.apply subst_op arg_ty1)
in
let* subst4 =
Substitution.unify
(Substitution.apply subst3 ty')
(Substitution.apply subst3 arg_ty2)
in
let* subst = Substitution.compose_all [ subst1; subst2; subst_op; subst3; subst4 ] in
return (subst, Substitution.apply subst res_ty)
| ExpBranch (cond, then_expr, else_expr) ->
let* subst1, ty1 = infer_expr env cond in
let* subst2, ty2 = infer_expr (TypeEnv.apply subst1 env) then_expr in
let* ty3 =
match else_expr with
| Some el ->
let* _, ty3 = infer_expr (TypeEnv.apply subst2 env) el in
return ty3
| None -> return (TyPrim "unit")
in
let* subst4 = Substitution.unify ty1 (TyPrim "bool") in
let* subst5 = Substitution.unify ty2 ty3 in
let* total_subst =
match else_expr with
| Some el ->
let* subst3, _ = infer_expr (TypeEnv.apply subst2 env) el in
Substitution.compose_all [ subst5; subst4; subst3; subst2; subst1 ]
| None -> Substitution.compose_all [ subst5; subst4; subst2; subst1 ]
in
return (total_subst, Substitution.apply total_subst ty2)
| ExpTuple (expr1, expr2, exprs) ->
let* subst1, ty1 = infer_expr env expr1 in
let* subst2, ty2 = infer_expr (TypeEnv.apply subst1 env) expr2 in
let infer_tuple_elements env es =
let rec aux env = function
| [] -> return ([], [])
| e :: es' ->
let* s, t = infer_expr env e in
let* s', ts = aux (TypeEnv.apply s env) es' in
return (s' @ [ s ], t :: ts)
in
aux env es
in
let* subst3, tys = infer_tuple_elements (TypeEnv.apply subst2 env) exprs in
let* subst = Substitution.compose_all (subst3 @ [ subst2; subst1 ]) in
return (subst, TyTuple (ty1 :: ty2 :: tys))
| ExpList exprs ->
(match exprs with
| [] ->
let* fresh = fresh_var in
return (Substitution.empty, TyList fresh)
| _ :: _ ->
let infer_list_elements env es =
let rec aux env = function
| [] -> return ([], [])
| e :: es' ->
let* s, t = infer_expr env e in
let* s', ts = aux (TypeEnv.apply s env) es' in
return (s' @ [ s ], t :: ts)
in
aux env es
in
let* subst, tys = infer_list_elements env exprs in
let* total_subst = Substitution.compose_all subst in
(match tys with
| [] -> fail (SeveralBounds "inferred empty list type")
| ty :: _ -> return (total_subst, TyList ty)))
| ExpLet (NonRec, (PatVariable x, expr1), [], expr2) ->
let* () = enter_level in
let* subst1, ty1 = infer_expr env expr1 in
let* () = leave_level in
let env2 = TypeEnv.apply subst1 env in
let* ty_gen = generalize env2 ty1 in
let env3 = TypeEnv.extend env2 x ty_gen in
let* subst2, ty2 = infer_expr env3 expr2 in
let* total_subst = Substitution.compose subst1 subst2 in
return (total_subst, ty2)
| ExpLet (NonRec, (pattern, expr1), bindings, expr2) ->
let* () = enter_level in
let* subst1, ty1 = infer_expr env expr1 in
let* () = leave_level in
let* subst2, ty_pat, env1 = infer_pattern env pattern in
let* subst = Substitution.compose subst1 subst2 in
let* unified_subst = Substitution.unify (Substitution.apply subst ty_pat) ty1 in
let initial_env = TypeEnv.apply unified_subst env1 in
let* extended_env =
List.fold_left
~f:(fun acc_env (pattern, expr) ->
let* acc_env = acc_env in
let* subst_bind, ty_bind = infer_expr acc_env expr in
let* subst_pattern, _, env_pattern = infer_pattern acc_env pattern in
let* combined_subst = Substitution.compose subst_bind subst_pattern in
let* final_subst =
Substitution.unify (Substitution.apply combined_subst ty_pat) ty_bind
in
let updated_env =
Map.fold
~init:(TypeEnv.apply final_subst acc_env)
~f:(fun ~key ~data acc_env -> TypeEnv.extend acc_env key data)
(TypeEnv.apply final_subst env_pattern)
in
return updated_env)
~init:(return initial_env)
bindings
in
let* subst3, ty2 = infer_expr extended_env expr2 in
let* total_subst = Substitution.compose_all [ subst3; unified_subst; subst ] in
return (total_subst, ty2)
| ExpLet (Rec, (PatVariable x, expr1), [], expr2) ->
let* expr1 =
match expr1 with
| ExpLambda _ -> return expr1
| _ -> fail (RHS "Right-hand side of let rec must be a lambda expression")
in
let* tv = fresh_var in
let env2 = TypeEnv.extend env x (Scheme.Scheme (VarSet.empty, tv)) in
let* () = enter_level in
let* subst1, ty1 = infer_expr env2 expr1 in
let* () = leave_level in
let* subst2 = Substitution.unify (Substitution.apply subst1 tv) ty1 in
let* subst_total = Substitution.compose subst1 subst2 in
let env3 = TypeEnv.apply subst_total env in
let env4 = TypeEnv.apply subst1 env3 in
let* ty_gen = generalize env4 (Substitution.apply subst_total tv) in
let* subst3, ty2 = infer_expr (TypeEnv.extend env4 x ty_gen) expr2 in
let* subst_total = Substitution.compose subst_total subst3 in
return (subst_total, ty2)
| ExpLet (Rec, value_binding, value_bindings, expr2) ->
let* env_ext, subst_acc =
List.fold_left
~f:(fun acc_env (pat, expr) ->
let* expr =
match expr with
| ExpLambda _ -> return expr
| _ -> fail (RHS "Right-hand side of let rec must be a lambda expression")
in
let* pat =
match pat with
| PatVariable _ -> return pat
| _ ->
fail (LHS "Only variables are allowed on the left-hand side of let rec")
in
let* env_acc, _ = acc_env in
let* () = enter_level in
let* subst_expr, ty_expr = infer_expr env_acc expr in
let* () = leave_level in
let* subst_pattern, ty_pat, env_pat = infer_pattern env_acc pat in
let* subst = Substitution.compose subst_expr subst_pattern in
let* unified_subst = Substitution.unify ty_expr ty_pat in
let* combined_subst = Substitution.compose subst unified_subst in
let extended_env = TypeEnv.apply combined_subst env_pat in
return (extended_env, combined_subst))
~init:(return (env, Substitution.empty))
(value_binding :: value_bindings)
in
let* subst2, ty2 = infer_expr env_ext expr2 in
let* total_subst = Substitution.compose subst_acc subst2 in
return (total_subst, ty2)
| ExpLambda (pat, pats, body) ->
let patterns = pat :: pats in
let* env, pat_types =
List.fold_left
patterns
~init:(return (env, []))
~f:(fun acc pat ->
let* env, pat_types = acc in
let* _, typ, env = infer_pattern env pat in
return (env, typ :: pat_types))
in
let* subst_body, ty_body = infer_expr env body in
let arrow_type =
List.fold_right
~f:(fun pat_type acc -> TyArrow (Substitution.apply subst_body pat_type, acc))
~init:ty_body
(List.rev pat_types)
in
return (subst_body, arrow_type)
| ExpApply (func, arg) ->
let* subst1, ty_func = infer_expr env func in
let* subst2, ty_arg = infer_expr (TypeEnv.apply subst1 env) arg in
let* tv = fresh_var in
let* subst3 =
Substitution.unify (Substitution.apply subst2 ty_func) (TyArrow (ty_arg, tv))
in
let* total_subst = Substitution.compose_all [ subst3; subst2; subst1 ] in
return (total_subst, Substitution.apply total_subst tv)
| ExpFunction ((pat, body), rest_cases) ->
(match rest_cases with
| [] ->
let patterns = [ pat ] in
let* env', pat_types =
List.fold_left
patterns
~init:(return (env, []))
~f:(fun acc p ->
let* env_acc, pat_types = acc in
let* _, typ, env_new = infer_pattern env_acc p in
return (env_new, typ :: pat_types))
in
let* subst_body, ty_body = infer_expr env' body in
let arrow_type =
List.fold_right
~f:(fun pt acc -> TyArrow (Substitution.apply subst_body pt, acc))
~init:ty_body
(List.rev pat_types)
in
return (subst_body, arrow_type)
| _ -> fail (RHS "Multiple function cases not yet supported"))
| ExpOption opt_expr ->
(match opt_expr with
| Some expr ->
let* subst, ty = infer_expr env expr in
return (subst, TyOption ty)
| None ->
let* tv = fresh_var in
return (Substitution.empty, TyOption tv))
| ExpTypeAnnotation (expr, t) ->
let* subst1, ty1 = infer_expr env expr in
let* subst2 = Substitution.unify ty1 (Substitution.apply subst1 t) in
let* total_subst = Substitution.compose subst1 subst2 in
return (total_subst, Substitution.apply subst2 ty1)
| ExpMatch (scrut, (pat, expr), bind_list) ->
let* subst_scrut, ty_scrut = infer_expr env scrut in
let all_cases = (pat, expr) :: bind_list in
let* final_subst, ty_res =
List.fold_left
all_cases
~init:(return (subst_scrut, None))
~f:(fun acc (pat', expr') ->
let* sub_acc, ty_res_opt = acc in
let env' = TypeEnv.apply sub_acc env in
let* sub_pat, ty_pat, env_pat = infer_pattern env' pat' in
let* sub_u =
Substitution.unify
(Substitution.apply sub_pat (Substitution.apply sub_acc ty_scrut))
ty_pat
in
let* sub_comp = Substitution.compose sub_u sub_pat in
let* sub_expr, ty_branch = infer_expr (TypeEnv.apply sub_comp env_pat) expr' in
let* sub_total = Substitution.compose_all [ sub_expr; sub_comp; sub_acc ] in
let ty_branch' = Substitution.apply sub_total ty_branch in
match ty_res_opt with
| None -> return (sub_total, Some ty_branch')
| Some ty_prev ->
let* sub_merge = Substitution.unify ty_prev ty_branch' in
let* sub_final = Substitution.compose sub_total sub_merge in
return (sub_final, Some (Substitution.apply sub_merge ty_prev)))
in
(match ty_res with
| Some t -> return (final_subst, t)
| None -> fail (RHS "Empty match"))
| ExpConstruct (name, opt_expr) ->
(match name, opt_expr with
| "()", None -> return (Substitution.empty, TyPrim "unit")
| "None", None ->
let* tv = fresh_var in
return (Substitution.empty, TyOption tv)
| "Some", Some e ->
let* subst, ty = infer_expr env e in
return (subst, TyOption ty)
| "[]", None ->
let* tv = fresh_var in
return (Substitution.empty, TyList tv)
| "::", Some (ExpTuple (head_e, tail_e, [])) ->
let* subst_h, ty_h = infer_expr env head_e in
let* subst_t, _ty_t = infer_expr (TypeEnv.apply subst_h env) tail_e in
let ty_h = Substitution.apply subst_t ty_h in
let* subst_total = Substitution.compose_all [ subst_t; subst_h ] in
return (subst_total, Substitution.apply subst_total (TyList ty_h))
| "::", _ -> fail (RHS "Constructor (::) expects a pair argument")
| _ -> fail (RHS ("Unknown constructor: " ^ name)))
;;
let infer_structure_item env = function
| SEval expr ->
let* subst, _ = infer_expr env expr in
let updated_env = TypeEnv.apply subst env in
return (subst, updated_env)
| SValue (Rec, (PatVariable x, expr), []) ->
let* expr =
match expr with
| ExpLambda _ -> return expr
| _ -> fail (RHS "Right-hand side of let rec must be a lambda expression")
in
let* tv = fresh_var in
let env = TypeEnv.extend env x (Scheme.Scheme (VarSet.empty, tv)) in
let* () = enter_level in
let* subst, ty = infer_expr env expr in
let* () = leave_level in
let* subst2 = Substitution.unify (Substitution.apply subst tv) ty in
let* composed_subst = Substitution.compose subst subst2 in
let env2 = TypeEnv.apply composed_subst env in
let* generalized_ty = generalize env2 (Substitution.apply composed_subst ty) in
let env = TypeEnv.extend env2 x generalized_ty in
return (composed_subst, env)
| SValue (Rec, value_binding, value_bindings) ->
let all_bindings = value_binding :: value_bindings in
let* env_with_placeholders =
List.fold_left
~f:(fun acc_env (pat, _) ->
let* pat =
match pat with
| PatVariable _ -> return pat
| _ ->
fail (LHS "Only variables are allowed on the left-hand side of let rec")
in
let* env_acc = acc_env in
let* subst_pat, _, env_pat = infer_pattern env_acc pat in
let extended_env = TypeEnv.apply subst_pat env_pat in
return extended_env)
~init:(return env)
all_bindings
in
let* env_ext, subst_acc =
List.fold_left
~f:(fun acc_env (ty_pattern, expr) ->
let* expr =
match expr with
| ExpLambda _ -> return expr
| _ -> fail (RHS "Right-hand side of let rec must be a lambda expression")
in
let* env_acc, _ = acc_env in
let* subst_expr, ty_expr = infer_expr env_acc expr in
let* subst_pat, ty_pat, env_pat = infer_pattern env_acc ty_pattern in
let* subst = Substitution.compose subst_expr subst_pat in
let* unified_subst = Substitution.unify ty_expr ty_pat in
let* combined_subst = Substitution.compose subst unified_subst in
let extended_env = TypeEnv.apply combined_subst env_pat in
return (extended_env, combined_subst))
~init:(return (env_with_placeholders, Substitution.empty))
all_bindings
in
return (subst_acc, env_ext)
| SValue (NonRec, (PatVariable x, expr), _) ->
let* () = enter_level in
let* subst, ty = infer_expr env expr in
let* () = leave_level in
let env2 = TypeEnv.apply subst env in
let* generalized_ty = generalize env2 ty in
let env = TypeEnv.extend (TypeEnv.apply subst env) x generalized_ty in
return (subst, env)
| SValue (NonRec, (pattern, expr), _) ->
let* subst_expr, ty = infer_expr env expr in
let* subst_pat, ty_pat, env_pat = infer_pattern env pattern in
let* combined_subst = Substitution.compose subst_expr subst_pat in
let* unified_subst =
Substitution.unify (Substitution.apply combined_subst ty_pat) ty
in
let updated_env = TypeEnv.apply unified_subst env_pat in
let* final_subst = Substitution.compose unified_subst combined_subst in
return (final_subst, updated_env)
;;
let infer_structure env structure =
let rec process_structure env subst = function
| [] -> return (subst, env)
| item :: rest ->
let* subst1, env1 = infer_structure_item env item in
let* composed_subst = Substitution.compose subst subst1 in
process_structure env1 composed_subst rest
in
process_structure env Substitution.empty structure
;;
let infer_simple_expression expr =
Result.map ~f:snd (run (infer_expr TypeEnv.initial_env expr))
;;
let run_infer str = Result.map ~f:snd (run (infer_structure TypeEnv.initial_env str))