1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
(** Copyright 2025, Tenyaeva Ekaterina *)

(** SPDX-License-Identifier: LGPL-3.0-or-later *)

open Ast
open Monad

type eval_error =
  | TypeError
  | DivisionByZero
  | MatchFailure
  | NoVariable of Ast.ident
  | OutOfSteps

let pp_eval_error ppf : eval_error -> _ = function
  | TypeError -> Format.fprintf ppf "Type error"
  | DivisionByZero -> Format.fprintf ppf "Division by zero"
  | MatchFailure -> Format.fprintf ppf "Matching failure"
  | NoVariable id -> Format.fprintf ppf "Undefined variable '%s'" id
  | OutOfSteps -> Format.fprintf ppf "Out of steps"
;;

type value =
  | ValInt of int
  | ValBool of bool
  | ValUnit
  | ValFun of Ast.rec_flag * Ast.pattern * Ast.expression * environment
  | ValFunction of Ast.case list * environment
  | ValOption of value option
  | ValBuiltin of Ast.ident

and environment = (Ast.ident, value, Base.String.comparator_witness) Base.Map.t

let rec pp_value ppf =
  let open Stdlib.Format in
  function
  | ValInt int -> fprintf ppf "%i" int
  | ValBool bool -> fprintf ppf "%b" bool
  | ValUnit -> fprintf ppf "()"
  | ValOption value ->
    (match value with
     | Some value -> fprintf ppf "Some %a" pp_value value
     | None -> fprintf ppf "None")
  | ValFun _ -> fprintf ppf "<fun>"
  | ValFunction _ -> fprintf ppf "<function>"
  | ValBuiltin _ -> fprintf ppf "<builtin>"
;;

module StepCounter = struct
  include StateR (struct
      type state = int
      type error = eval_error
    end)

  let tick =
    let* st = get in
    if st <= 0 then fail OutOfSteps else put (st - 1) >>| fun _ -> ()
  ;;
end

module Env = struct
  open Base
  open StepCounter

  let extend env key value = Map.update env key ~f:(fun _ -> value)

  let compose env1 env2 =
    Map.fold env2 ~f:(fun ~key ~data env_acc -> extend env_acc key data) ~init:env1
  ;;

  let find_exn env key =
    match Map.find env key with
    | Some value -> return value
    | None -> fail (NoVariable key)
  ;;

  let find_exn1 env key =
    let val' = Map.find_exn env key in
    val'
  ;;

  let empty = Map.empty (module String)
  let env_with_print_funs = extend empty "print_int" (ValBuiltin "print_int")
end

module Eval = struct
  open StepCounter
  open Env

  let eval_un_op = function
    | Negative, ValInt val1 -> return (ValInt (-val1))
    | Positive, ValInt val1 -> return (ValInt val1)
    | Not, ValBool val1 -> return (ValBool (not val1))
    | _ -> fail TypeError
  ;;

  let eval_bin_op = function
    | Mult, ValInt val1, ValInt val2 -> return (ValInt (val1 * val2))
    | Div, ValInt val1, ValInt val2 when val2 <> 0 -> return (ValInt (val1 / val2))
    | Div, _, ValInt 0 -> fail DivisionByZero
    | Add, ValInt val1, ValInt val2 -> return (ValInt (val1 + val2))
    | Sub, ValInt val1, ValInt val2 -> return (ValInt (val1 - val2))
    | Gte, val1, val2 -> return (ValBool (val1 >= val2))
    | Lte, val1, val2 -> return (ValBool (val1 <= val2))
    | Neq, val1, val2 -> return (ValBool (val1 <> val2))
    | Eq, val1, val2 -> return (ValBool (val1 = val2))
    | Gt, val1, val2 -> return (ValBool (val1 > val2))
    | Lt, val1, val2 -> return (ValBool (val1 < val2))
    | _ -> fail TypeError
  ;;

  let rec match_pattern env = function
    | Pat_any, _ -> Some env
    | Pat_var id, value -> Some (extend env id value)
    | Pat_constant (Const_int pat), ValInt value when pat = value -> Some env
    | Pat_constant (Const_bool pat), ValBool value when pat = value -> Some env
    | Pat_constant Const_unit, _ -> Some env
    | Pat_constraint (_, pat), value -> match_pattern env (pat, value)
    | Pat_option None, ValOption None -> Some env
    | Pat_option (Some pat), ValOption (Some value) -> match_pattern env (pat, value)
    | _ -> None
  ;;

  let rec extend_names_from_pat (env : environment) = function
    | Pat_any, _ | Pat_constant Const_unit, ValUnit | Pat_option None, ValOption None ->
      return env
    | Pat_var id, value -> return (extend env id value)
    | Pat_constraint (_, pat), value | Pat_option (Some pat), ValOption (Some value) ->
      extend_names_from_pat env (pat, value)
    | _ -> fail TypeError
  ;;

  let rec eval_expression env ex =
    let* () = tick in
    match ex with
    | Expr_ident id -> find_exn env id
    | Expr_const const ->
      (match const with
       | Const_int int -> return (ValInt int)
       | Const_bool bool -> return (ValBool bool)
       | Const_unit -> return ValUnit)
    | Expr_let (NonRecursive, value_binding, value_binding_list, exp) ->
      let* env = eval_value_binding_list env (value_binding :: value_binding_list) in
      eval_expression env exp
    | Expr_let (Recursive, value_binding, value_binding_list, exp) ->
      let* env = eval_rec_value_binding_list env (value_binding :: value_binding_list) in
      eval_expression env exp
    | Expr_fun (pat, exp) -> return (ValFun (NonRecursive, pat, exp, env))
    | Expr_function (case, case_list) -> return (ValFunction (case :: case_list, env))
    | Expr_match (exp, case, case_list) ->
      let* match_value = eval_expression env exp in
      find_and_eval_case env match_value (case :: case_list)
    | Expr_binop (op, exp1, exp2) ->
      let* value1 = eval_expression env exp1 in
      let* value2 = eval_expression env exp2 in
      eval_bin_op (op, value1, value2)
    | Expr_unop (op, e) ->
      let* v = eval_expression env e in
      eval_un_op (op, v)
    | Expr_apply (exp1, exp2) ->
      let* fun_val = eval_expression env exp1 in
      let* arg_val = eval_expression env exp2 in
      (match fun_val with
       | ValFun (rec_flag, pat, exp, fun_env) ->
         let* new_env =
           match rec_flag, match_pattern fun_env (pat, arg_val) with
           | Recursive, Some extended_env -> return (compose env extended_env)
           | NonRecursive, Some extended_env -> return extended_env
           | _, None -> fail MatchFailure
         in
         eval_expression new_env exp
       | ValFunction (case_list, env) -> find_and_eval_case env arg_val case_list
       | ValBuiltin builtin ->
         (match builtin, arg_val with
          | "print_int", ValInt integer ->
            Format.printf "%d\n" integer;
            return ValUnit
          | _ -> fail TypeError)
       | _ -> fail TypeError)
    | Expr_option None -> return (ValOption None)
    | Expr_option (Some expr) ->
      let* value = eval_expression env expr in
      return (ValOption (Some value))
    | Expr_if (if_exp, then_exp, Some else_exp) ->
      let* value_if_exp = eval_expression env if_exp in
      (match value_if_exp with
       | ValBool true -> eval_expression env then_exp
       | ValBool false -> eval_expression env else_exp
       | _ -> fail TypeError)
    | Expr_if (fst_val, snd_val, None) ->
      let* value_fst_val = eval_expression env fst_val in
      (match value_fst_val with
       | ValBool true ->
         let* value_snd_val = eval_expression env snd_val in
         (match value_snd_val with
          | ValUnit as v -> return v
          | _ -> fail TypeError)
       | ValBool false -> return ValUnit
       | _ -> fail TypeError)
    | Expr_constraint (_, exp) -> eval_expression env exp

  and find_and_eval_case env value = function
    | [] -> fail MatchFailure
    | { case_pat; case_expr } :: tail ->
      let env_temp = match_pattern env (case_pat, value) in
      (match env_temp with
       | Some env -> eval_expression env case_expr
       | None -> find_and_eval_case env value tail)

  and eval_value_binding_list env value_binding_list =
    Base.List.fold_left
      ~f:(fun acc { vb_pat; vb_expr } ->
        let* env = acc in
        let* value = eval_expression env vb_expr in
        let* env = extend_names_from_pat env (vb_pat, value) in
        return env)
      ~init:(return env)
      value_binding_list

  and eval_rec_value_binding_list env value_binding_list =
    Base.List.fold_left
      ~f:(fun acc { vb_pat; vb_expr } ->
        let* env = acc in
        let* value = eval_expression env vb_expr in
        match vb_pat with
        | Pat_var name | Pat_constraint (_, Pat_var name) ->
          (match value with
           | ValFun (_, pat, expr, env) ->
             let value = ValFun (Recursive, pat, expr, env) in
             let env = extend env name value in
             return env
           | _ -> fail TypeError)
        | _ -> fail TypeError)
      ~init:(return env)
      value_binding_list
  ;;

  let eval_structure_item env out_list =
    let rec extract_names_from_pat (env : environment) acc = function
      | Pat_var id -> acc @ [ Some id, find_exn1 env id ]
      | Pat_constraint (_, pat) -> extract_names_from_pat env acc pat
      | Pat_option (Some pat) -> extract_names_from_pat env acc pat
      | _ -> acc
    in
    let get_names_from_let_binds env =
      Base.List.fold_left ~init:[] ~f:(fun acc { vb_pat; _ } ->
        extract_names_from_pat env acc vb_pat)
    in
    function
    | Str_eval exp ->
      let* val' = eval_expression env exp in
      let* () = tick in
      return (env, out_list @ [ None, val' ])
    | Str_value (NonRecursive, value_binding, value_binding_list) ->
      let value_binding_list = value_binding :: value_binding_list in
      let* env = eval_value_binding_list env value_binding_list in
      let eval_list = get_names_from_let_binds env value_binding_list in
      let* () = tick in
      return (env, out_list @ eval_list)
    | Str_value (Recursive, value_binding, value_binding_list) ->
      let value_binding_list = value_binding :: value_binding_list in
      let* env = eval_rec_value_binding_list env value_binding_list in
      let eval_list = get_names_from_let_binds env value_binding_list in
      let* () = tick in
      return (env, out_list @ eval_list)
  ;;

  let eval_structure env ast =
    let* env, out_list =
      Base.List.fold_left
        ~f:(fun acc item ->
          let* env, out_list = acc in
          let* env, out_list = eval_structure_item env out_list item in
          return (env, out_list))
        ~init:(return (env, []))
        ast
    in
    let remove_duplicates =
      let fun_equal el1 el2 =
        match el1, el2 with
        | (Some id1, _), (Some id2, _) -> String.equal id1 id2
        | _ -> false
      in
      function
      | x :: xs when not (Base.List.mem xs x ~equal:fun_equal) -> x :: xs
      | _ :: xs -> xs
      | [] -> []
    in
    return (env, remove_duplicates out_list)
  ;;
end

let run_interpreter structure n =
  match StepCounter.run (Eval.eval_structure Env.env_with_print_funs structure) n with
  | _state, Ok (_env, value) -> Ok value
  | _state, Error err -> Error err
;;