1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
(** Copyright 2026, Dmitrii Kuznetsov *)

(** SPDX-License-Identifier: LGPL-3.0-or-later *)

open Ast
open Monads.TYPECHECK
open Common

let value_to_type = function
  | ValInt _ -> TypeBase TypeInt
  | ValChar _ -> TypeBase TypeChar
  | ValBool _ -> TypeBase TypeBool
  | ValString _ -> TypeBase TypeString
  | ValNull -> TypeBase TypeInt
;;

let string_of_ident (Id s) = s

let vartype_to_type = function
  | TypeVar t -> t
;;

let name_to_obj_ctx = read_local_el
let eq f e1 e2 = if f e1 e2 then return e1 else fail (TCError TypeMismatch)
let eq_type t1 t2 = eq equal__type t1 t2

let field_of_ast = function
  | VarField (mods, typ, id, init) ->
    let is_static =
      List.exists
        (function
          | MStatic -> true
          | _ -> false)
        mods
    in
    Ok
      { field_modifiers = mods
      ; field_type = typ
      ; field_name = id
      ; field_init = init
      ; is_static
      }
  | Method _ -> Error (TCError TypeMismatch)
;;

(* Expected field, got method *)

let method_of_ast = function
  | Ast.Method (mods, ret_type, id, pms, body) ->
    let is_static =
      List.exists
        (function
          | MStatic -> true
          | _ -> false)
        mods
    in
    let is_main = equal_ident id (Id "Main") in
    Ok
      { method_modifiers = mods
      ; method_return = ret_type
      ; method_name = id
      ; method_params = pms
      ; method_body = body
      ; is_static
      ; is_main
      }
  | Ast.VarField _ -> Error (TCError TypeMismatch)
;;

(* Expected method, got field *)

let get_class_memb id memb =
  match memb with
  | VarField (_, _, f_id, _) when equal_ident f_id id ->
    (match field_of_ast memb with
     | Ok f_info -> Some (TCField f_info)
     | Error _ -> None)
  | Method (_, _, m_id, _, _) when equal_ident m_id id ->
    (match method_of_ast memb with
     | Ok m_info -> Some (TCMethod m_info)
     | Error _ -> None)
  | _ -> None
;;

let builtin_methods =
  [ ( Id "System.Console.WriteLine"
    , { method_modifiers = [ MStatic ]
      ; method_return = TypeVoid
      ; method_name = Id "System.Console.WriteLine"
      ; method_params = Params [ Var (TypeVar (TypeBase TypeInt), Id "value") ]
      ; method_body = SBlock []
      ; is_static = true
      ; is_main = false
      } )
  ]
;;

let find_memb_from_obj obj_id id =
  let find_memb b id f = List.find_map (f id) b in
  let find_class_memb b id = find_memb b id get_class_memb in
  read_global_el obj_id
  >>= function
  | TCClass (Class (_, _, b)) ->
    (match find_class_memb b id with
     | Some memb -> return (Some memb)
     | None ->
       read_global_el obj_id
       >>= (function
        | TCClass (Class (_, _, fields)) ->
          let static_fields =
            List.filter_map
              (function
                | VarField (mods, typ, fid, init)
                  when List.exists
                         (function
                           | MStatic -> true
                           | _ -> false)
                         mods
                       && equal_ident fid id ->
                  Some (field_of_ast (VarField (mods, typ, fid, init)))
                | _ -> None)
              fields
          in
          (match static_fields with
           | [ Ok field_info ] -> return (Some (TCField field_info))
           | _ ->
             List.find_opt
               (fun (builtin_id, _) -> equal_ident builtin_id id)
               builtin_methods
             |> Option.map (fun (_, info) -> TCMethod info)
             |> return)))
;;

let find_memb_type = function
  | TCLocalVar v -> return (vartype_to_type v.var_type)
  | TCField f -> return (vartype_to_type f.field_type)
  | TCMethod m -> return m.method_return
;;

let find_expr_type e expr_tc = expr_tc e >>= find_memb_type

let tc_bin_op b e1 e2 expr_tc =
  let compare_two_expr_type e1 e2 =
    find_expr_type e1 expr_tc
    >>= fun e1 -> find_expr_type e2 expr_tc >>= fun e2 -> eq_type e1 e2
  in
  let compare_three_expr_type e1 e2 t =
    compare_two_expr_type e1 e2 >>= fun e -> eq_type e t
  in
  let return_rez rez =
    let var_info = { var_type = TypeVar rez; initialized = true } in
    return (TCLocalVar var_info)
  in
  match b with
  | OpAdd | OpMul | OpSub | OpDiv | OpMod ->
    compare_three_expr_type e1 e2 (TypeBase TypeInt) *> return_rez (TypeBase TypeInt)
  | OpLess | OpLessEqual | OpMore | OpMoreEqual ->
    compare_three_expr_type e1 e2 (TypeBase TypeInt) *> return_rez (TypeBase TypeBool)
  | OpEqual | OpNonEqual -> compare_two_expr_type e1 e2 *> return_rez (TypeBase TypeBool)
  | OpAnd | OpOr ->
    compare_three_expr_type e1 e2 (TypeBase TypeBool) *> return_rez (TypeBase TypeBool)
  | OpAssign ->
    find_expr_type e1 expr_tc >>= fun e -> compare_two_expr_type e1 e2 *> return_rez e
;;

let tc_un_op u e expr_tc =
  let tc_un_op u e =
    find_expr_type e expr_tc
    >>= fun t ->
    match u with
    | OpNot -> eq_type t (TypeBase TypeBool)
    | OpNeg -> eq_type t (TypeBase TypeInt)
  in
  tc_un_op u e
  >>= fun t ->
  let var_info = { var_type = TypeVar t; initialized = true } in
  return (TCLocalVar var_info)
;;

let tc_method_args (Params params) (Args args) expr_tc =
  let params_to_list_of_type p =
    List.map
      (function
        | Var (t, _) -> vartype_to_type t)
      p
  in
  let args_to_list_of_type a = map (fun x -> expr_tc x >>= find_memb_type) a in
  let compare_two_lists l1 l2 eq rez =
    match List.compare_lengths l1 l2 with
    | 0 ->
      if List.equal eq l1 l2
      then return rez
      else fail (TCError (OtherError "Method invocation check error"))
    | _ -> fail (TCError (OtherError "Method invocation check error"))
  in
  args_to_list_of_type args
  >>= fun args ->
  compare_two_lists (params_to_list_of_type params) args equal__type params
;;

let tc_method_invoke e args expr_tc =
  expr_tc e
  >>= function
  | TCMethod m ->
    tc_method_args m.method_params args expr_tc
    >>= fun _ ->
    (match m.method_return with
     | TypeBase t ->
       let var_info = { var_type = TypeVar (TypeBase t); initialized = true } in
       return (TCLocalVar var_info)
     | TypeVoid ->
       fail (TCError (OtherError "Void methods cannot be used in expressions")))
  | TCField _ -> fail (TCError (OtherError "Cannot call a field as a method"))
  | TCLocalVar _ -> fail (TCError (OtherError "Cannot call a variable as a method"))
;;

let check_initialized n =
  read_local_el n
  >>= function
  | TCLocalVar v when v.initialized -> return ()
  | TCLocalVar _ -> fail (TCError (OtherError "Variable may be uninitialized"))
  | TCField _ -> return ()
  | TCMethod _ -> return ()
;;

let tc_expr =
  let rec tc_expr_ = function
    | EId n ->
      name_to_obj_ctx n
      >>= (fun ctx -> check_initialized n *> return ctx)
      <|> (get_curr_class_name
           >>= fun class_name ->
           find_memb_from_obj class_name n
           >>= function
           | Some memb -> return memb
           | None ->
             fail (TCError (OtherError ("Variable not found: " ^ string_of_ident n))))
    | EValue v ->
      let var_info = { var_type = TypeVar (value_to_type v); initialized = true } in
      return (TCLocalVar var_info)
    | EFuncCall (e, args) -> tc_method_invoke e args tc_expr_
    | EBinOp (b, e1, e2) -> tc_bin_op b e1 e2 tc_expr_
    | EUnOp (u, e) -> tc_un_op u e tc_expr_
    | _ -> fail (TCError NotImplemented)
  in
  tc_expr_
;;

let tc_expr_with_type e = tc_expr e >>= find_memb_type
let eq_type_with_expr t e = tc_expr_with_type e >>= fun e_t -> eq_type e_t t

let save_decl n ctx =
  read_local_el_opt n
  >>= function
  | None -> write_local_el n ctx
  | Some _ -> fail (TCError (OtherError "This variable is already declared"))
;;

let apply_local f = read_local >>= fun old_l -> f *> write_local old_l

let rec tc_stmt =
  let is_expr_bool e = tc_expr_with_type e >>= fun t -> eq_type t (TypeBase TypeBool) in
  let tc_stmt_expr expr =
    match expr with
    | EFuncCall (e, args) ->
      tc_expr e
      >>= (function
       | TCMethod { method_return = TypeVoid; method_params = pms; _ } ->
         tc_method_args pms args tc_expr *> return ()
       | TCMethod _ -> fail (TCError TypeMismatch)
       | _ -> fail (TCError TypeMismatch))
    | EBinOp (OpAssign, _, _) -> tc_expr expr *> return ()
    | _ -> fail (TCError TypeMismatch)
  in
  let save_decl n t initialized =
    read_local_el_opt n
    >>= function
    | None ->
      let var_info = { var_type = TypeVar t; initialized } in
      write_local_el n (TCLocalVar var_info)
    | Some _ -> fail (TCError (OtherError "This variable is already declared"))
  in
  let tc_decl t n = function
    | Some e -> eq_type_with_expr t e *> save_decl n t true *> return ()
    | None -> save_decl n t false *> return ()
  in
  let tc_return e_opt =
    read_meth_type
    >>= fun m_t ->
    match m_t, e_opt with
    | Some TypeVoid, None -> return ()
    | Some (TypeBase t), Some e ->
      (eq_type_with_expr (TypeBase t) e
       <|> fail (TCError (OtherError "Returned type does not match the function type")))
      *> return ()
    | _ -> fail (TCError TypeMismatch)
  in
  let opt_unpack f = function
    | None -> return ()
    | Some s -> f s *> return ()
  in
  let tc_for_state init cond iter =
    let tc_init = function
      | None -> return ()
      | Some (SDecl (Var (TypeVar t, n), e)) -> tc_decl t n e
      | _ -> fail (TCError TypeMismatch)
    in
    let tc_cond = opt_unpack is_expr_bool cond in
    let tc_iter = opt_unpack tc_stmt_expr iter in
    lift3 (fun _ _ _ -> ()) (tc_init init) tc_cond tc_iter
  in
  let tc_if_state cond b s_opt tc_st =
    let tc_cond = is_expr_bool cond in
    let tc_state = function
      | Some st -> tc_st st
      | None -> return ()
    in
    lift3 (fun _ _ _ -> ()) tc_cond (tc_st b) (tc_state s_opt)
  in
  function
  | SExpr expr -> tc_stmt_expr expr
  | SDecl (Var (TypeVar t, n), e) -> tc_decl t n e
  | SReturn e -> tc_return e
  | SWhile (e, s) -> apply_local (is_expr_bool e *> tc_stmt s)
  | SFor (init, cond, iter, b) -> apply_local (tc_for_state init cond iter *> tc_stmt b)
  | SIf (e, b, s_opt) -> apply_local (tc_if_state e b s_opt tc_stmt)
  | SBlock st_l -> apply_local (iter tc_stmt st_l)
  | SBreak | SContinue -> return () (* Will check execution in interpreter *)
;;

let tc_member mem class_fields =
  let tc_class_field f_type = function
    | Some e -> eq_type_with_expr (vartype_to_type f_type) e *> return ()
    | None -> return ()
  in
  let save_params_to_l (Params params) =
    let f = function
      | Var (t, n) ->
        let var_info = { var_type = t; initialized = true } in
        write_local_el n (TCLocalVar var_info)
    in
    iter f params
  in
  let tc_meth typ params body class_fields =
    apply_local
      (let add_field_to_env = function
         | VarField (mods, field_typ, id, init) ->
           let is_static =
             List.exists
               (function
                 | MStatic -> true
                 | _ -> false)
               mods
           in
           let field_info =
             { field_modifiers = mods
             ; field_type = field_typ
             ; field_name = id
             ; field_init = init
             ; is_static
             }
           in
           write_local_el id (TCField field_info)
         | Method _ -> return ()
       in
       iter add_field_to_env class_fields
       *> write_meth_type typ
       *> save_params_to_l params
       *> tc_stmt body)
  in
  let tc_class_method (mds, tp, id, pms, b) class_fields =
    match method_of_ast (Method (mds, tp, id, pms, b)) with
    | Ok m ->
      if m.is_main
      then (
        let is_valid_signature =
          mds = [ MPublic; MStatic ]
          && pms = Params []
          &&
          match tp with
          | TypeBase TypeInt | TypeVoid -> true
          | _ -> false
        in
        if is_valid_signature
        then
          tc_meth tp (Params []) b class_fields *> read_main_class
          >>= function
          | None -> get_curr_class_name >>= fun n -> write_main_class (Some n)
          | Some _ -> fail (TCError (OtherError "Main method already exists"))
        else
          fail
            (TCError
               (OtherError "Main must be static, non-async, no params, return int/void")))
      else tc_meth tp pms b class_fields
    | Error e -> fail e
  in
  match mem with
  | VarField (_, tp, _, e_opt) -> tc_class_field tp e_opt
  | Method (mds, tp, id, pms, b) -> tc_class_method (mds, tp, id, pms, b) class_fields
;;

let save_global id ctx =
  read_global_el_opt id
  >>= function
  | None -> write_global_el id ctx
  | Some _ -> fail (TCError (OtherError "This variable is already declared"))
;;

let tc_obj cl =
  match cl with
  | Class (_, id, fields) ->
    let write_mems () =
      let f mem =
        match mem with
        | VarField (_, _, id, _) ->
          (match field_of_ast mem with
           | Ok field_info -> save_decl id (TCField field_info)
           | Error e -> fail e)
        | Method (_, _, id, _, _) ->
          (match method_of_ast mem with
           | Ok method_info -> save_decl id (TCMethod method_info)
           | Error e -> fail e)
      in
      iter f fields
    in
    let add_builtins =
      iter (fun (id, method_info) -> save_decl id (TCMethod method_info)) builtin_methods
    in
    let tc_member_with_fields mem = tc_member mem fields in
    let tc_mems = iter tc_member_with_fields fields in
    let save_class = save_global id (TCClass cl) in
    write_curr_class_name id
    *> apply_local (write_mems () *> add_builtins *> save_class *> tc_mems)
    *> return ()
;;

let typecheck prog = run (tc_obj prog) (IdMap.empty, IdMap.empty, None, None, None)
let typecheck_main prog = typecheck prog |> fun ((_, _, _, _, main), res) -> main, res