Project

General

Profile

Statistics
| Branch: | Tag: | Revision:

lustrec / src / tools / seal / seal_extract.ml @ d75eb6f1

History | View | Annotate | Download (40.1 KB)

1
open Lustre_types
2
open Utils
3
open Seal_utils			
4
open Zustre_data (* Access to Z3 context *)
5
   
6

    
7
(* Switched system extraction: expression are memoized *)
8
(*let expr_mem = Hashtbl.create 13*)
9
             
10
type element = IsInit | Expr of expr
11
                              
12
type guard = (element * bool) list
13
type guarded_expr = guard * element
14
type mdef_t = guarded_expr list
15

    
16
let pp_elem fmt e =
17
  match e with
18
  | IsInit -> Format.fprintf fmt "init"
19
  | Expr e -> Format.fprintf fmt "%a" Printers.pp_expr e 
20

    
21
let pp_guard_list fmt gl =
22
  (fprintf_list ~sep:"; "
23
     (fun fmt (e,b) -> if b then pp_elem fmt e else Format.fprintf fmt "not(%a)" pp_elem e)) fmt gl
24
  
25
let pp_guard_expr fmt (gl,e) =
26
  Format.fprintf fmt "%a -> %a"
27
    pp_guard_list  gl
28
    pp_elem e
29

    
30
let pp_mdefs fmt gel = fprintf_list ~sep:"@ " pp_guard_expr fmt gel
31

    
32
  
33
let add_init defs vid =
34
  Hashtbl.add defs vid [[], IsInit]
35

    
36
let deelem e =  match e with
37
    Expr e -> e
38
  | IsInit -> assert false (* Wasn't expecting isinit here: we are building values! *)
39

    
40
let is_eq_elem elem elem' =
41
  match elem, elem' with
42
  | IsInit, IsInit -> true
43
  | Expr e, Expr e' -> e = e' (*
44
     Corelang.is_eq_expr e e' *)
45
  | _ -> false
46

    
47
let select_elem elem (gelem, _) =
48
  is_eq_elem elem gelem
49

    
50

    
51
(**************************************************************)
52
(* Convert from Lustre expressions to Z3 expressions and back *)
53
(* All (free) variables have to be declared in the Z3 context *)
54
(**************************************************************)
55

    
56
let is_init_name = "__is_init"
57

    
58
let const_defs = Hashtbl.create 13
59
let is_const id = Hashtbl.mem const_defs id
60
let get_const id = Hashtbl.find const_defs id
61
                 
62
(* expressions are only basic constructs here, no more ite, tuples,
63
   arrows, fby, ... *)
64

    
65
(* Set of hash to support memoization *)
66
let expr_hash: (expr * Utils.tag) list ref = ref []
67
let ze_hash: (Z3.Expr.expr, Utils.tag) Hashtbl.t = Hashtbl.create 13
68
let e_hash: (Utils.tag, Z3.Expr.expr) Hashtbl.t = Hashtbl.create 13
69
let pp_hash pp_key pp_v fmt h = Hashtbl.iter (fun key v -> Format.fprintf fmt "%a -> %a@ " pp_key key pp_v v) h
70
let pp_e_map fmt = List.iter (fun (e,t) -> Format.fprintf fmt "%i -> %a@ " t Printers.pp_expr e) !expr_hash
71
let pp_ze_hash fmt = pp_hash
72
                      (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))
73
                      Format.pp_print_int
74
                      fmt
75
                      ze_hash
76
let pp_e_hash fmt = pp_hash
77
                      Format.pp_print_int
78
                      (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))
79
                      fmt
80
                      e_hash                              
81
let mem_expr e =
82
  (* Format.eprintf "Searching for %a in map: @[<v 0>%t@]"
83
   *   Printers.pp_expr e
84
   *   pp_e_map; *)
85
  let res = List.exists (fun (e',_) -> Corelang.is_eq_expr e e') !expr_hash in
86
  (* Format.eprintf "found?%b@." res; *)
87
  res
88
  
89
let mem_zexpr ze =
90
  Hashtbl.mem ze_hash ze
91
let get_zexpr e =
92
  let eref, uid = List.find (fun (e',_) -> Corelang.is_eq_expr e e') !expr_hash in
93
  (* Format.eprintf "found expr=%a id=%i@." Printers.pp_expr eref eref.expr_tag; *)
94
  Hashtbl.find e_hash uid
95
let get_expr ze =
96
  let uid = Hashtbl.find ze_hash ze in
97
  let e,_ = List.find (fun (e,t) -> t = uid) !expr_hash in
98
  e
99
  
100
let neg_ze z3e = Z3.Boolean.mk_not !ctx z3e 
101
let is_init_z3e =
102
  Z3.Expr.mk_const_s !ctx is_init_name  Zustre_common.bool_sort 
103

    
104
let get_zid (ze:Z3.Expr.expr) : Utils.tag = 
105
  try
106
    if Z3.Expr.equal ze is_init_z3e then -1 else
107
      if Z3.Expr.equal ze (neg_ze is_init_z3e) then -2 else
108
    Hashtbl.find ze_hash ze
109
  with _ -> (Format.eprintf "Looking for ze %s in Hash %a"
110
               (Z3.Expr.to_string ze)
111
               (fun fmt hash -> Hashtbl.iter (fun ze uid -> Format.fprintf fmt "%s -> %i@ " (Z3.Expr.to_string ze) uid) hash ) ze_hash;
112
             assert false)
113
let add_expr =
114
  let cpt = ref 0 in
115
  fun e ze ->
116
  incr cpt;
117
  let uid = !cpt in
118
  expr_hash := (e, uid)::!expr_hash;
119
  Hashtbl.add e_hash uid ze;
120
  Hashtbl.add ze_hash ze uid
121

    
122
  
123
let expr_to_z3_expr, zexpr_to_expr =
124
  (* List to store converted expression. *)
125
  (* let hash = ref [] in
126
   * let comp_expr e (e', _) = Corelang.is_eq_expr e e' in
127
   * let comp_zexpr ze (_, ze') = Z3.Expr.equal ze ze' in *)
128
  
129
  let rec e2ze e =
130
    if mem_expr e then (
131
      get_zexpr e
132
    )
133
    else (
134
      let res =
135
        match e.expr_desc with
136
        | Expr_const c ->
137
           let z3e = Zustre_common.horn_const_to_expr c in
138
           add_expr e z3e;
139
           z3e
140
        | Expr_ident id -> (
141
          if is_const id then (
142
            let c = get_const id in
143
            let z3e = Zustre_common.horn_const_to_expr c in
144
            add_expr e z3e;
145
            z3e
146
        )
147
        else (
148
          let fdecl_id = Zustre_common.get_fdecl id in
149
          let z3e = Z3.Expr.mk_const_f !ctx fdecl_id in
150
          add_expr e z3e;
151
          z3e
152
          )
153
        )
154
      | Expr_appl (id,args, None) (* no reset *) ->
155
         let z3e = Zustre_common.horn_basic_app id e2ze (Corelang.expr_list_of_expr args) in
156
         add_expr e z3e;
157
         z3e
158
      | Expr_tuple [e] ->
159
         let z3e = e2ze e in
160
         add_expr e z3e;
161
         z3e
162
      | _ -> ( match e.expr_desc with Expr_tuple _ -> Format.eprintf "tuple e2ze(%a)@.@?" Printers.pp_expr e
163
                                    | _ -> Format.eprintf "e2ze(%a)@.@?" Printers.pp_expr e)
164
                 ; assert false
165
      in
166
      res
167
    )
168
  in
169
  let rec ze2e ze =
170
    let ze_name ze =
171
      let fd = Z3.Expr.get_func_decl ze in
172
      Z3.Symbol.to_string (Z3.FuncDecl.get_name fd)
173
    in
174
    if mem_zexpr ze then
175
      None, Some (get_expr ze)
176
    else
177
      let open Corelang in
178
      let fd = Z3.Expr.get_func_decl ze in
179
      let zel = Z3.Expr.get_args ze in
180
      match Z3.Symbol.to_string (Z3.FuncDecl.get_name fd), zel with
181
      (*      | var, [] -> (* should be in env *) get_e *)
182

    
183
      (* Extracting IsInit status *)
184
      | "not", [ze] when ze_name ze = is_init_name ->
185
         Some false, None
186
      | name, [] when name = is_init_name -> Some true, None
187
      (* Other constructs are converted to a lustre expression *)
188
      | op, _ -> (
189
        
190
        
191
        if Z3.Expr.is_numeral ze then
192
          let e =
193
            if Z3.Arithmetic.is_real ze then
194
              let num =  Num.num_of_ratio (Z3.Arithmetic.Real.get_ratio ze) in
195
              let s = Z3.Arithmetic.Real.numeral_to_string ze in
196
              mkexpr
197
                Location.dummy_loc
198
                (Expr_const
199
                   (Const_real
200
                      (num, 0, s)))
201
            else if Z3.Arithmetic.is_int ze then
202
              mkexpr
203
                Location.dummy_loc
204
                (Expr_const
205
                   (Const_int
206
                      (Big_int.int_of_big_int (Z3.Arithmetic.Integer.get_big_int ze))))
207
            else if Z3.Expr.is_const ze then
208
              match Z3.Expr.to_string ze with
209
              | "true" -> mkexpr Location.dummy_loc
210
                            (Expr_const (Const_tag (tag_true)))
211
              | "false" ->
212
                 mkexpr Location.dummy_loc
213
                   (Expr_const (Const_tag (tag_false)))
214
              | _ -> assert false
215
            else
216
              (
217
                Format.eprintf "Const err: %s %b@." (Z3.Expr.to_string ze) (Z3.Expr.is_const ze);
218
                assert false (* a numeral but no int nor real *)
219
              )
220
          in
221
          None, Some e
222
        else
223
          match op with
224
          | "not" | "=" | "-" | "*" | "/"
225
            | ">=" | "<=" | ">" | "<" 
226
            ->
227
             let args = List.map (fun ze -> Utils.desome (snd (ze2e ze))) zel in
228
             None, Some (mkpredef_call Location.dummy_loc op args)
229
          | "+" -> ( (* Special treatment of + for 2+ args *)
230
            let args = List.map (fun ze -> Utils.desome (snd (ze2e ze))) zel in
231
            let e = match args with
232
                [] -> assert false
233
              | [hd] -> hd
234
              | e1::e2::tl ->
235
                 let first_binary_and = mkpredef_call Location.dummy_loc op [e1;e2] in
236
                 if tl = [] then first_binary_and else
237
                   List.fold_left (fun e e_new ->
238
                       mkpredef_call Location.dummy_loc op [e;e_new]
239
                     ) first_binary_and tl
240
                 
241
            in
242
            None, Some e 
243
          )
244
          | "and" | "or" -> (
245
            (* Special case since it can contain is_init pred *)
246
            let args = List.map (fun ze -> ze2e ze) zel in
247
            let op = if op = "and" then "&&" else if op = "or" then "||" else assert false in 
248
            match args with
249
            | [] -> assert false
250
            | [hd] -> hd
251
            | hd::tl ->
252
               List.fold_left
253
                 (fun (is_init_opt1, expr_opt1) (is_init_opt2, expr_opt2) ->
254
                   (match is_init_opt1, is_init_opt2 with
255
                      None, x | x, None -> x
256
                      | Some _, Some _ -> assert false),
257
                   (match expr_opt1, expr_opt2 with
258
                    | None, x | x, None -> x
259
                    | Some e1, Some e2 ->
260
                       Some (mkpredef_call Location.dummy_loc op [e1; e2])
261
                 ))
262
                 hd tl 
263
          )
264
          | op -> 
265
             let args = List.map (fun ze ->  (snd (ze2e ze))) zel in
266
             Format.eprintf "deal with op %s (nb args: %i). Expr is %s@."  op (List.length args) (Z3.Expr.to_string ze); assert false
267
      )
268
  in
269
  (fun e -> e2ze e), (fun ze -> ze2e ze)
270

    
271
               
272
let zexpr_to_guard_list ze =
273
  let init_opt, expr_opt = zexpr_to_expr ze in
274
  (match init_opt with
275
   | None -> []
276
   |Some b -> [IsInit, b]
277
  ) @ (match expr_opt with
278
       | None -> []
279
       | Some e -> [Expr e, true]
280
      )
281
                
282
               
283
let simplify_neg_guard l =
284
  List.map (fun (g,posneg) ->
285
      match g with
286
      | IsInit -> g, posneg
287
      | Expr g ->
288
          if posneg then
289
            Expr (Corelang.push_negations g),
290
            true
291
          else
292
            (* Pushing the negation in the expression *)
293
            Expr(Corelang.push_negations ~neg:true g),
294
            true
295
    ) l 
296

    
297
(* TODO:
298
individuellement demander si g1 => g2. Si c'est le cas, on peut ne garder que g1 dans la liste
299
*)    
300
(*****************************************************************)
301
(* Checking sat(isfiability) of an expression and simplifying it *)
302
(* All (free) variables have to be declared in the Z3 context    *)
303
(*****************************************************************)
304
(*
305
let goal_simplify zl =
306
  let goal = Z3.Goal.mk_goal !ctx false false false in
307
  Z3.Goal.add goal zl;
308
  let goal' = Z3.Goal.simplify goal None in
309
  (* Format.eprintf "Goal before: %s@.Goal after : %s@.Sat? %s@."
310
   *   (Z3.Goal.to_string goal)
311
   *   (Z3.Goal.to_string goal')
312
   *   (Z3.Solver.string_of_status status_res) 
313
   * ; *)
314
  let ze = Z3.Goal.as_expr goal' in
315
  (* Format.eprintf "as an expr: %s@." (Z3.Expr.to_string ze); *)
316
  zexpr_to_guard_list ze
317
  *)
318
  
319
let implies =
320
  let ze_implies_hash : ((Utils.tag * Utils.tag), bool) Hashtbl.t  = Hashtbl.create 13 in
321
  fun ze1 ze2 ->
322
  let ze1_uid = get_zid ze1 in
323
  let ze2_uid = get_zid ze2 in
324
  if Hashtbl.mem ze_implies_hash (ze1_uid, ze2_uid) then
325
    Hashtbl.find ze_implies_hash (ze1_uid, ze2_uid)
326
  else
327
    begin
328
      if !debug then (
329
        Format.eprintf "Checking implication: %s => %s? "
330
          (Z3.Expr.to_string ze1) (Z3.Expr.to_string ze2)
331
      ); 
332
      let solver = Z3.Solver.mk_simple_solver !ctx in
333
      let tgt = Z3.Boolean.mk_not !ctx (Z3.Boolean.mk_implies !ctx ze1 ze2) in
334
      let res =
335
        try
336
          let status_res = Z3.Solver.check solver [tgt] in
337
          match status_res with
338
          | Z3.Solver.UNSATISFIABLE -> if !debug then Format.eprintf "Valid!@."; 
339
             true
340
          | _ -> if !debug then Format.eprintf "not proved valid@."; 
341
             false
342
        with Zustre_common.UnknownFunction(id, msg) -> (
343
          report ~level:1 msg;
344
          false
345
        )
346
      in
347
      Hashtbl.add ze_implies_hash (ze1_uid,ze2_uid) res ;
348
      res
349
    end
350
                                               
351
let rec simplify zl =
352
  match zl with
353
  | [] | [_] -> zl
354
  | hd::tl -> (
355
    (* Forall e in tl, checking whether hd => e or e => hd, to keep hd
356
     in the first case and e in the second one *)
357
    let tl = simplify tl in
358
    let keep_hd, tl =
359
      List.fold_left (fun (keep_hd, accu) e ->
360
          if implies hd e then
361
            true, accu (* throwing away e *)
362
          else if implies e hd then
363
            false, e::accu (* throwing away hd *)
364
          else
365
            keep_hd, e::accu (* keeping both *)
366
        ) (true,[]) tl
367
    in
368
    (* Format.eprintf "keep_hd?%b hd=%s, tl=[%a]@."
369
     *   keep_hd
370
     *   (Z3.Expr.to_string hd)
371
     * (Utils.fprintf_list ~sep:"; " (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))) tl
372
     *   ; *)
373
    if keep_hd then
374
      hd::tl
375
    else
376
      tl
377
  )
378
  
379
let check_sat ?(just_check=false) (l:guard) : bool * guard =
380
  (* Syntactic simplification *)
381
  if false then
382
    Format.eprintf "Before simplify: %a@." pp_guard_list l; 
383
  let l = simplify_neg_guard l in
384
  if false then (
385
    Format.eprintf "After simplify: %a@." pp_guard_list l; 
386
    Format.eprintf "@[<v 2>Z3 check sat: [%a]@ " pp_guard_list l;
387
  );
388
  let solver = Z3.Solver.mk_simple_solver !ctx in
389
  try (
390
    let zl =
391
      List.map (fun (e, posneg) ->
392
          let ze =
393
            match e with
394
            | IsInit -> is_init_z3e
395
            | Expr e -> expr_to_z3_expr e 
396
          in
397
          if posneg then
398
            ze
399
          else
400
            neg_ze ze
401
        ) l
402
    in
403
    if false then Format.eprintf "Z3 exprs1: [%a]@ " (fprintf_list ~sep:",@ " (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))) zl; 
404
    let zl = simplify zl in
405
        if false then Format.eprintf "Z3 exprs2: [%a]@ " (fprintf_list ~sep:",@ " (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))) zl; 
406
    let status_res = Z3.Solver.check solver zl in
407
     if false then Format.eprintf "Z3 status: %s@ @]@. " (Z3.Solver.string_of_status status_res); 
408
    match status_res with
409
    | Z3.Solver.UNSATISFIABLE -> false, []
410
    | _ -> (
411
      if false && just_check then
412
        true, l
413
      else
414
        (* TODO: may be reactivate but it may create new expressions *)
415
        (* let l = goal_simplify zl in *)
416
        let l = List.fold_left
417
                  (fun accu ze -> accu @ (zexpr_to_guard_list ze))
418
                  []
419
                  zl
420
        in
421
  (* Format.eprintf "@.@[<v 2>Check_Sat:@ before: %a@ after:
422
       %a@. Goal precise? %b/%b@]@.@. " * pp_guard_list l
423
       pp_guard_list l' * (Z3.Goal.is_precise goal) *
424
       (Z3.Goal.is_precise goal'); *)
425
  
426

    
427
        true, l
428
        
429
         
430
    )
431
         
432
  )
433
  with Zustre_common.UnknownFunction(id, msg) -> (
434
    report ~level:1 msg;
435
    true, l (* keeping everything. *)
436
  )
437
      
438
  
439

    
440

    
441
(**************************************************************)
442

    
443
  
444
let clean_sys sys =
445
  List.fold_left (fun accu (guards, updates) ->
446
      let sat, guards' =  check_sat (List.map (fun (g, pn) -> Expr g, pn) guards) in
447
      (*Format.eprintf "Guard: %a@.Guard cleaned: %a@.Sat? %b@."
448
        (fprintf_list ~sep:"@ " (pp_guard_expr Printers.pp_expr))  guards
449
        (fprintf_list ~sep:"@ " (pp_guard_expr Printers.pp_expr))  guards'
450
        sat
451

    
452
      ;*)
453
        if sat then
454
        (List.map (fun (e, b) -> (deelem e, b)) guards', updates)::accu
455
      else
456
        accu
457
    )
458
    [] sys
459

    
460
(* Most costly function: has the be efficiently implemented.  All
461
   registered guards are initially produced by the call to
462
   combine_guards. We csan normalize the guards to ease the
463
   comparisons.
464

    
465
   We assume that gl1 and gl2 are already both satisfiable and in a
466
   kind of reduced form. Let lshort and llong be the short and long
467
   list of gl1, gl2. We check whether each element elong of llong is
468
   satisfiable with lshort. If no, stop. If yes, we search to reduce
469
   the list. If elong => eshort_i, we can remove eshort_i from
470
   lshort. we can continue with this lshort shortened, lshort'; it is
471
   not necessary to add yet elong in lshort' since we already know
472
   rthat elong is somehow reduced with respect to other elements of
473
   llong. If eshort_i => elong, then we keep ehosrt_i in lshort and do
474
   not store elong.
475

    
476
   After iterating through llong, we have a shortened version of
477
   lshort + some elements of llong that have to be remembered. We add
478
   them to this new consolidated list. 
479

    
480
 *)
481

    
482
(* combine_guards ~fresh:Some(e,b) gl1 gl2 returns ok, gl with ok=true
483
   when (e=b) ang gl1 and gl2 is satisfiable and gl is a consilidated
484
   version of it.  *)
485
let combine_guards ?(fresh=None) gl1 gl2 =
486
  (* Filtering out trivial cases. More semantics ones would have to be
487
     addressed later *)
488
  let check_sat e = (* temp function before we clean the original one *)
489
    let ok, _ = check_sat e in
490
    ok
491
  in
492
  let implies (e1,pn1) (e2,pn2) =
493
    let e2z e pn =
494
      match e with
495
        | IsInit -> if pn then is_init_z3e else neg_ze is_init_z3e
496
        | Expr e -> expr_to_z3_expr (if pn then e else (Corelang.push_negations ~neg:true e))
497
      in
498
    implies (e2z e1 pn1) (e2z e2 pn2)
499
  in
500
  let lshort, llong =
501
    if List.length gl1 > List.length gl2 then gl2, gl1 else gl1, gl2
502
  in
503
  let merge long short =
504
    let short, long_sel, ok = 
505
    List.fold_left (fun (short,long_sel, ok) long_e ->
506
        if not ok then
507
          [],[], false (* Propagating unsat case *)
508
        else if check_sat (long_e::short) then
509
          let short, keep_long_e =
510
            List.fold_left (fun (accu_short, keep_long_e) eshort_i ->
511
                if not keep_long_e then (* shorten the algo *)
512
                  eshort_i :: accu_short, false
513
                else (* keep_long_e = true in the following *)
514
                  if implies eshort_i long_e then
515
                    (* First case is trying to remove long_e!
516

    
517
                     Since short is already normalized, we can remove
518
                     long_e. If later long_e is stronger than another
519
                     element of short, then necessarily eshort_i =>
520
                     long_e ->
521
                     that_other_element_of_short. Contradiction. *)
522
                    eshort_i::accu_short, false
523
                  else if implies long_e eshort_i then 
524
                    (* removing eshort_i, keeping long_e. *)
525
                    accu_short, true
526
                  else (* Not comparable, keeping both *)
527
                    eshort_i::accu_short, true
528
              )
529
              ([],true) (* Initially we assume that we will keep long_e *)
530
              short
531
          in
532
          if keep_long_e then
533
            short, long_e::long_sel, true
534
          else
535
            short, long_sel, true
536
        else
537
          [],[],false
538
      ) (short, [], true) long
539
    in
540
    ok, long_sel@short
541
  in
542
  let ok, l = match fresh with
543
    | None -> true, []
544
    | Some g -> merge [g] []
545
  in
546
  if not ok then
547
    false, []
548
  else
549
    let ok, lshort = merge lshort l in
550
    if not ok then
551
      false, []
552
    else
553
      merge llong lshort
554
    
555

    
556
(* Encode "If gel1=posneg then gel2":
557
   - Compute the combination of guarded exprs in gel1 and gel2:
558
     - Each guarded_expr in gel1 is transformed as a guard: the
559
       expression is associated to posneg.
560
     - Existing guards in gel2 are concatenated to that list of guards
561
     - We keep expr in the ge of gel2 as the legitimate expression 
562
 *)
563
let concatenate_ge gel1 posneg gel2 =
564
  let l, all_invalid =
565
    List.fold_left (
566
        fun (accu, all_invalid) (g2,e2) ->
567
        List.fold_left (
568
            fun (accu, all_invalid) (g1,e1) ->
569
            (* Format.eprintf "@[<v 2>Combining guards: (%a=%b) AND [%a] AND [%a]@ "
570
             *  pp_elem e1
571
             *  posneg
572
             *  pp_guard_list g1
573
             *  pp_guard_list g2; *)
574

    
575
            let ok, gl = combine_guards ~fresh:(Some(e1,posneg)) g1 g2 in
576
            (* Format.eprintf "@]@ Result is [%a]@ "
577
             *   pp_guard_list gl; *)
578

    
579
            if ok then
580
              (gl, e2)::accu, false
581
            else (
582
              accu, all_invalid
583
            )
584
          ) (accu, all_invalid) gel1
585
      ) ([], true) gel2
586
  in
587
  not all_invalid, l
588
     
589
(* Rewrite the expression expr, replacing any occurence of a variable
590
   by its definition.
591
*)
592
let rec rewrite defs expr : guarded_expr list =
593
  let rewrite = rewrite defs in
594
  let res =
595
    match expr.expr_desc with
596
    | Expr_appl (id, args, None) ->
597
       let args = rewrite args in
598
       List.map (fun (guards, e) ->
599
           guards,
600
           Expr (Corelang.mkexpr expr.expr_loc (Expr_appl(id, deelem e, None)))
601
         ) args 
602
    | Expr_const _  -> [[], Expr expr]
603
    | Expr_ident id ->
604
       if Hashtbl.mem defs id then
605
         Hashtbl.find defs id
606
       else
607
         (* id should be an input *)
608
         [[], Expr expr]
609
    | Expr_ite (g, e1, e2) ->
610
       let g = rewrite g and
611
           e1 = rewrite e1 and
612
           e2 = rewrite e2 in
613
       let ok_then, g_then = concatenate_ge g true e1 in
614
       let ok_else, g_else = concatenate_ge g false e2 in
615
       (if ok_then then g_then else [])@
616
         (if ok_else then g_else else [])
617
    | Expr_merge (g, branches) ->
618
       assert false (* TODO: deal with merges *)
619
      
620
    | Expr_when (e, _, _) -> rewrite e
621
    | Expr_arrow _ -> [[], IsInit] (* At this point the only arrow should be true -> false *)
622
    | Expr_tuple el ->
623
       (* Each expr is associated to its flatten guarded expr list *)
624
       let gell = List.map rewrite el in
625
       (* Computing all combinations: we obtain a list of guarded tuple *)
626
       let rec aux gell : (guard * expr list) list =
627
         match gell with
628
         | [] -> assert false (* Not happening *)
629
         | [gel] -> List.map (fun (g,e) -> g, [deelem e]) gel
630
         | gel::getl ->
631
            let getl = aux getl in
632
            List.fold_left (
633
                fun accu (g,e) ->
634
                List.fold_left (
635
                    fun accu (gl, minituple) ->
636
                    let is_compat, guard_comb = combine_guards g gl in
637
                    if is_compat then
638
                      let new_gt : guard * expr list = (guard_comb, (deelem e)::minituple) in
639
                      new_gt::accu
640
                    else
641
                      accu
642
                    
643
                  ) accu getl
644
              ) [] gel
645
       in
646
       let gtuples = aux gell in
647
       (* Rebuilding the valid type: guarded expr list (with tuple exprs) *)
648
       List.map
649
         (fun (g,tuple) -> g, Expr (Corelang.mkexpr expr.expr_loc (Expr_tuple tuple)))
650
         gtuples
651
    | Expr_fby _
652
      | Expr_appl _
653
                  (* Should be removed by mormalization and inlining *)
654
      -> Format.eprintf "Pb expr: %a@.@?" Printers.pp_expr expr; assert false
655
    | Expr_array _ | Expr_access _ | Expr_power _
656
                                                (* Arrays not handled here yet *)
657
      -> assert false
658
    | Expr_pre _ -> (* Not rewriting mem assign *)
659
       assert false
660
  in
661
  (* Format.eprintf "Rewriting %a as [@[<v 0>%a@]]@ "
662
   *   Printers.pp_expr expr
663
   *   (Utils.fprintf_list ~sep:"@ "
664
   *        pp_guard_expr) res; *)
665
  res
666
and add_def defs vid expr =
667
  let vid_defs = rewrite defs expr in
668
  (* Format.eprintf "Add_def: %s = %a@. -> @[<v 0>%a@]@."
669
   *     vid
670
   *     Printers.pp_expr expr
671
   *     (
672
   *       (Utils.fprintf_list ~sep:"@ "
673
   *          pp_guard_expr)) vid_defs;  *)
674
  Hashtbl.add defs vid vid_defs
675

    
676
(* Takes a list of guarded exprs (ge) and a guard
677
returns the same list of ge splited into the ones where the guard is true and the ones where it is false. In both lists the associated ge do not mention that guard anymore.
678

    
679
When a given ge doesn't mention positively or negatively such guards, it is duplicated in both lists *)
680
let split_mdefs elem (mdefs: guarded_expr list) =
681
  List.fold_left (
682
      fun (selected, left_out)
683
          ((guards, expr) as ge) ->
684
      (* select the element of guards that match the argument elem *)
685
      let sel, others_guards = List.partition (select_elem elem) guards in
686
      match sel with
687
      (* we extract the element from the list and add it to the
688
         appropriate list *)
689
      | [_, sel_status] ->
690
         if sel_status then
691
           (others_guards,expr)::selected, left_out
692
         else selected, (others_guards,expr)::left_out
693
      | [] -> (* no such guard exists, we have to duplicate the
694
              guard_expr in both lists *)
695
         ge::selected, ge::left_out
696
      | _ -> (
697
        Format.eprintf "@.Spliting list on elem %a.@.List:%a@."
698
          pp_elem elem
699
          pp_mdefs mdefs;
700
        assert false (* more then one element selected. Should
701
                          not happen , or trival dead code like if
702
                              x then if not x then dead code *)
703
      )
704
    ) ([],[]) mdefs
705
    
706
let split_mem_defs
707
      (elem: element)
708
      (mem_defs: (ident * guarded_expr list) list)
709
      :
710
      ((ident * mdef_t) list) * ((ident * mdef_t) list)
711
  
712
  =
713
  List.fold_right (fun (m,mdefs)
714
                       (accu_pos, accu_neg) ->
715
      let pos, neg =
716
        split_mdefs elem mdefs 
717
      in
718
      (m, pos)::accu_pos,
719
      (m, neg)::accu_neg
720
    ) mem_defs ([],[])
721

    
722

    
723
(* Split a list of mem_defs into init and step lists of guarded
724
   expressions per memory. *)
725
let split_init mem_defs =
726
  split_mem_defs IsInit mem_defs 
727

    
728
(* Previous version of the function: way too costly 
729
let pick_guard mem_defs : expr option =  
730
  let gel = List.flatten (List.map snd mem_defs) in
731
  let gl = List.flatten (List.map fst gel) in
732
  let all_guards =
733
    List.map (
734
        (* selecting guards and getting rid of boolean *)
735
        fun (e,b) ->
736
        match e with
737
        | Expr e -> e
738
        | _ -> assert false
739
      (* should have been filtered out
740
                                      yet *)
741
      ) gl
742
  in
743
  (* TODO , one could sort by occurence and provided the most common
744
     one *)
745
  try
746
  Some (List.hd all_guards)  
747
  with _ -> None
748
   *)
749

    
750
(* Returning the first non empty guard expression *)
751
let rec pick_guard mem_defs : expr option =
752
  match mem_defs with
753
  | [] -> None
754
  | (_, gel)::tl -> (
755
    let found =
756
      List.fold_left (fun found (g,_) ->
757
          if found = None then
758
            match g with
759
            | [] -> None
760
            | (Expr e, _)::_ -> Some e
761
            | (IsInit, _)::_ -> assert false (* should be removed already *)
762
          else
763
            found
764
        ) None gel
765
    in
766
    if found = None then pick_guard tl else found
767
  )
768
          
769

    
770
(* Transform a list of variable * guarded exprs into a list of guarded pairs (variable, expressions)
771
*)
772
let rec build_switch_sys
773
          (mem_defs : (Utils.ident * guarded_expr list) list )
774
          prefix
775
        :
776
          ((expr * bool) list * (ident * expr) list ) list =
777
  Format.eprintf "Build_switch with %a@."
778
    (Utils.fprintf_list ~sep:",@ "
779
       (fun fmt (id, gel) -> Format.fprintf fmt "%s -> [@[<v 0>%a@ ]@]"
780
                               id
781
                               pp_mdefs gel))
782
    mem_defs;
783
  (* if all mem_defs have empty guards, we are done, return prefix,
784
     mem_defs expr.
785

    
786
     otherwise pick a guard in one of the mem, eg (g, b) then for each
787
     other mem, one need to select the same guard g with the same
788
     status b, *)
789
  if List.for_all (fun (m,mdefs) ->
790
         (* All defs are unguarded *)
791
         match mdefs with
792
         | [[], _] -> true
793
         | _ -> false
794
       ) mem_defs
795
  then
796
    [prefix ,
797
     List.map (fun (m,gel) ->
798
         match gel with
799
         | [_,e] ->
800
            let e =
801
              match e with
802
              | Expr e -> e
803
              | _ -> assert false (* No IsInit expression *)
804
            in
805
            m,e
806
         | _ -> assert false
807
       ) mem_defs]
808
  else
809
    (* Picking a guard *)
810
    let elem_opt : expr option = pick_guard mem_defs in
811
    match elem_opt with
812
      None -> []
813
    | Some elem -> (
814
      let pos, neg =
815
        split_mem_defs
816
          (Expr elem)
817
          mem_defs
818
      in
819
      (* Special cases to avoid useless computations: true, false conditions *)
820
      match elem.expr_desc with
821
      | Expr_ident "true"  ->   build_switch_sys pos prefix
822
      | Expr_const (Const_tag tag) when tag = Corelang.tag_true
823
        ->   build_switch_sys pos prefix
824
      | Expr_ident "false" ->   build_switch_sys neg prefix
825
      | Expr_const (Const_tag tag) when tag = Corelang.tag_false 
826
        ->   build_switch_sys neg prefix
827
      | _ -> (* Regular case *)
828
         (* let _ = (
829
          *     Format.eprintf "Expr is %a@." Printers.pp_expr elem;
830
          *     match elem.expr_desc with
831
          *     | Expr_const _ -> Format.eprintf "a const@."
832
          *                     
833
          *     | Expr_ident _ -> Format.eprintf "an ident@."
834
          *     | _ -> Format.eprintf "something else@."
835
          *   )
836
          * in *)
837
         let clean l =
838
           let l = List.map (fun (e,b) -> (Expr e), b) l in
839
           let ok, l = check_sat l in
840
           let l = List.map (fun (e,b) -> deelem e, b) l in
841
           ok, l
842
         in
843
         let pos_prefix = (elem, true)::prefix in
844
         let neg_prefix = (elem, false)::prefix in
845
         let ok_pos, pos_prefix = clean pos_prefix in         
846
         let ok_neg, neg_prefix = clean neg_prefix in         
847
         (if ok_pos then build_switch_sys pos pos_prefix else [])
848
         @
849
           (if ok_neg then build_switch_sys neg neg_prefix else [])
850
    )
851

    
852

    
853
      
854
(* Take a normalized node and extract a list of switches: (cond,
855
   update) meaning "if cond then update" where update shall define all
856
   node memories. Everything as to be expressed over inputs or memories, intermediate variables are removed through inlining *)
857
let node_as_switched_sys consts (mems:var_decl list) nd =
858
  (* rescheduling node: has been scheduled already, no need to protect
859
     the call to schedule_node *)
860
  let nd_report = Scheduling.schedule_node nd in
861
  let schedule = nd_report.Scheduling_type.schedule in
862
  let eqs, auts = Corelang.get_node_eqs nd in
863
  assert (auts = []); (* Automata should be expanded by now *)
864
  let sorted_eqs = Scheduling.sort_equations_from_schedule eqs schedule in
865
  let defs : (ident,  guarded_expr list) Hashtbl.t = Hashtbl.create 13 in
866
  let add_def = add_def defs in
867

    
868
  let vars = Corelang.get_node_vars nd in
869
  (* Registering all locals variables as Z3 predicates. Will be use to
870
     simplify the expansion *) 
871
  let _ =
872
    List.iter (fun v ->
873
        let fdecl = Z3.FuncDecl.mk_func_decl_s
874
                      !ctx
875
                      v.var_id
876
                      []
877
                      (Zustre_common.type_to_sort v.var_type)
878
        in
879
        ignore (Zustre_common.register_fdecl v.var_id fdecl)
880
      ) vars
881
  in
882
  let _ =
883
    List.iter (fun c -> Hashtbl.add const_defs c.const_id c.const_value) consts
884
  in
885

    
886
  
887
  (* Registering node equations: identifying mem definitions and
888
     storing others in the "defs" hashtbl.
889

    
890
     Each assign is stored in a hash tbl as list of guarded
891
     expressions. The memory definition is also "rewritten" as such a
892
     list of guarded assigns.  *)
893
  let mem_defs, output_defs =
894
    List.fold_left (fun (accu_mems, accu_outputs) eq ->
895
        match eq.eq_lhs with
896
        | [vid] ->
897
           (* Only focus on memory definitions *)
898
           if List.exists (fun v -> v.var_id = vid) mems then 
899
             (
900
               match eq.eq_rhs.expr_desc with
901
               | Expr_pre def_m ->
902
                  report ~level:3 (fun fmt ->
903
                      Format.fprintf fmt "Preparing mem %s@." vid);
904
                  (vid, rewrite defs def_m)::accu_mems, accu_outputs
905
               | _ -> assert false
906
             ) 
907
           else if List.exists (fun v -> v.var_id = vid) nd.node_outputs then (
908
             report ~level:3 (fun fmt ->
909
                 Format.fprintf fmt "Output variable %s@." vid);
910
             add_def vid eq.eq_rhs;
911
             accu_mems, (vid, rewrite defs eq.eq_rhs)::accu_outputs
912
          
913
           )
914
           else
915
             (
916
             report ~level:3 (fun fmt ->
917
                 Format.fprintf fmt "Registering variable %s@." vid);
918
             add_def vid eq.eq_rhs;
919
             accu_mems, accu_outputs
920
           )
921
        | _ -> assert false (* should have been removed by normalization *)
922
      ) ([], []) sorted_eqs
923
  in
924

    
925
  
926
  report ~level:2 (fun fmt -> Format.fprintf fmt "Printing out (guarded) memories definitions (may takes time)@.");
927
  (* Printing memories definitions *)
928
  report ~level:3
929
    (fun fmt ->
930
      Format.fprintf fmt
931
        "@[<v 0>%a@]@ "
932
        (Utils.fprintf_list ~sep:"@ "
933
           (fun fmt (m,mdefs) ->
934
             Format.fprintf fmt
935
               "%s -> [@[<v 0>%a@] ]@ "
936
               m
937
               (Utils.fprintf_list ~sep:"@ "
938
                  pp_guard_expr) mdefs
939
        ))
940
        mem_defs);
941
  (* Format.eprintf "Split init@."; *)
942
  let init_defs, update_defs =
943
    split_init mem_defs 
944
  in
945
  let init_out, update_out =
946
    split_init output_defs
947
  in
948
  report ~level:3
949
    (fun fmt ->
950
      Format.fprintf fmt
951
        "@[<v 0>Init:@ %a@ Step@ %a@]@ "
952
        (Utils.fprintf_list ~sep:"@ "
953
           (fun fmt (m,mdefs) ->
954
             Format.fprintf fmt
955
               "%s -> @[<v 0>[%a@] ]@ "
956
               m
957
               (Utils.fprintf_list ~sep:"@ "
958
                  pp_guard_expr) mdefs
959
        ))
960
        init_defs
961
        (Utils.fprintf_list ~sep:"@ "
962
           (fun fmt (m,mdefs) ->
963
             Format.fprintf fmt
964
               "%s -> @[<v 0>[%a@] ]@ "
965
               m
966
               (Utils.fprintf_list ~sep:"@ "
967
                  pp_guard_expr) mdefs
968
        ))
969
        update_defs);
970
  (* Format.eprintf "Build init@."; *)
971

    
972
  let sw_init= 
973
    build_switch_sys init_defs []
974
  in
975
  (* Format.eprintf "Build step@."; *)
976
  let sw_sys =
977
    build_switch_sys update_defs []
978
  in
979

    
980
  let init_out =
981
    build_switch_sys init_out []
982
  in
983
  let update_out =
984
    build_switch_sys update_out []
985
  in
986

    
987
  let sw_init = clean_sys sw_init in
988
  let sw_sys = clean_sys sw_sys in
989
  let init_out = clean_sys init_out in
990
  let update_out = clean_sys update_out in
991

    
992
  (* Some additional checks *)
993
  let pp_gl pp_expr =
994
    fprintf_list ~sep:", " (fun fmt (e,b) -> Format.fprintf fmt "%s%a" (if b then "" else "NOT ") pp_expr e)
995
  in
996
  let pp_gl_short = pp_gl (fun fmt e -> Format.fprintf fmt "%i" e.Lustre_types.expr_tag) in
997
  let pp_up fmt up = List.iter (fun (id,e) -> Format.fprintf fmt "%s->%i; " id e.expr_tag) up in
998
  
999
  if false then
1000
    begin
1001
      Format.eprintf "@.@.CHECKING!!!!!!!!!!!@.";
1002
      Format.eprintf "Any duplicate expression in guards?@.";
1003
      
1004
      let sw_sys =
1005
        List.map (fun (gl, up) ->
1006
            let gl = List.sort (fun (e,b) (e',b') ->
1007
                         let res = compare e.expr_tag e'.expr_tag in
1008
                         if res = 0 then (Format.eprintf "Same exprs?@.%a@.%a@.@."
1009
                                            Printers.pp_expr e
1010
                                            Printers.pp_expr e'
1011
                                         );
1012
                         res
1013
                       ) gl in
1014
            gl, up
1015
          ) sw_sys 
1016
      in
1017
      Format.eprintf "Another check for duplicates in guard list@.";
1018
      List.iter (fun (gl, _) ->
1019
          let rec aux hd l =
1020
            match l with
1021
              [] -> ()
1022
            | (e,b)::tl -> let others = hd@tl in
1023
                           List.iter (fun (e',_) -> if Corelang.is_eq_expr e e' then
1024
                                                      (Format.eprintf "Same exprs?@.%a@.%a@.@."
1025
                                                         Printers.pp_expr e
1026
                                                         Printers.pp_expr e'
1027
                             )) others;
1028
                           aux ((e,b)::hd) tl
1029
          in
1030
          aux [] gl
1031
        ) sw_sys;
1032
      Format.eprintf "Checking duplicates in updates@.";
1033
      let rec check_dup_up accu l =
1034
        match l with
1035
        | [] -> ()
1036
        | ((gl, up) as hd)::tl ->
1037
           let others = accu@tl in
1038
           List.iter (fun (gl',up') -> if up = up' then
1039
                                         Format.eprintf "Same updates?@.%a@.%a@.%a@.%a@.@."
1040

    
1041
                                           pp_gl_short gl
1042
                                           pp_up up
1043
                                           pp_gl_short gl'
1044
                                           pp_up up'
1045
                                       
1046
             ) others;
1047
           
1048
           
1049

    
1050
           check_dup_up (hd::accu) tl
1051
           
1052
      in
1053
      check_dup_up [] sw_sys;
1054
      let sw_sys =
1055
        List.sort (fun (gl1, _) (gl2, _) ->
1056
            let glid gl = List.map (fun (e,_) -> e.expr_tag) gl in
1057
            
1058
            let res = compare (glid gl1) (glid gl2) in
1059
            if res = 0 then Format.eprintf "Same guards?@.%a@.%a@.@."
1060
                              pp_gl_short gl1 pp_gl_short gl2
1061
            ;
1062
              res
1063

    
1064
          ) sw_sys
1065

    
1066
      in
1067
      ()
1068
    end;
1069
  
1070

    
1071
  (* Iter through the elements and gather them by updates *)
1072
  let module UpMap =
1073
    struct
1074
      include Map.Make (
1075
                  struct
1076
                    type t = (ident * expr) list
1077
                    let compare l1 l2 =
1078
                      let proj l = List.map (fun (s,e) -> s, e.expr_tag) l in
1079
                      compare (proj l1) (proj l2) 
1080
                  end)
1081
      let pp = pp_up 
1082
    end
1083
  in
1084
  let module Guards = struct
1085
      include Set.Make (
1086
                  struct
1087
                    type t = (expr * bool) 
1088
                    let compare l1 l2 =
1089
                      let proj (e,b) = e.expr_tag, b in
1090
                      compare (proj l1) (proj l2) 
1091
                  end)
1092
      let pp_short fmt s = pp_gl_short fmt (elements s)
1093
      let pp_long fmt s = pp_gl Printers.pp_expr fmt (elements s)
1094
    end
1095
  in
1096
  let process sys =
1097
    (* The map will associate to each update up the pair (set, set
1098
       list) where set is the share guards and set list a list of
1099
       disjunctive guards. Each set represents a conjunction of
1100
       expressions. *)
1101
    let map = 
1102
      List.fold_left (fun map (gl,up) ->
1103
          (* creating a new set to describe gl *)
1104
          let new_set =
1105
            List.fold_left
1106
              (fun set g -> Guards.add g set)
1107
              Guards.empty
1108
              gl
1109
          in
1110
          (* updating the map with up -> new_set *)
1111
          if UpMap.mem up map then
1112
            let (shared, disj) = UpMap.find up map in
1113
            let new_shared = Guards.inter shared new_set in
1114
            let remaining_shared = Guards.diff shared new_shared in
1115
            let remaining_new_set = Guards.diff new_set new_shared in
1116
            (* Adding remaining_shared to all elements of disj *)
1117
            let disj' = List.map (fun gs -> Guards.union remaining_shared gs) disj in
1118
            UpMap.add up (new_shared, remaining_new_set::disj') map
1119
          else
1120
            UpMap.add up (new_set, []) map
1121
        ) UpMap.empty sys
1122
    in
1123
     let rec mk_binop op l = match l with
1124
         [] -> assert false
1125
       | [e] -> e
1126
       | hd::tl -> Corelang.mkpredef_call hd.expr_loc op [hd; mk_binop op tl]
1127
    in
1128
    let gl_as_expr gl =
1129
      let gl = Guards.elements gl in
1130
      let export (e,b) = if b then e else Corelang.push_negations ~neg:true e in 
1131
      match gl with
1132
        [] -> []
1133
      | [e] -> [export e]
1134
      | _ ->
1135
         [mk_binop "&&"
1136
            (List.map export gl)]
1137
    in
1138
    let rec clean_disj disj =
1139
      match disj with
1140
      | [] | [_] -> [] 
1141
      | _::_::_ -> (
1142
        (* First basic version: producing a DNF One can later, (1)
1143
           simplify it with z3, or (2) build the compact tree with
1144
           maximum shared subexpression (and simplify it with z3) *)
1145
        let elems = List.fold_left (fun accu gl -> (gl_as_expr gl) @ accu) [] disj in
1146
        let or_expr = mk_binop "||" elems in
1147
        [or_expr]
1148

    
1149

    
1150
         (* TODO disj*)
1151
      (* get the item that occurs in most case *)
1152
      (* List.fold_left (fun accu s ->
1153
       *     List.fold_left (fun accu (e,b) ->
1154
       *         if List.mem_assoc (e.expr_tag, b)
1155
       *       ) accu (Guards.elements s)
1156
       *   ) [] disj *)
1157

    
1158
      )
1159
    in
1160
    Format.eprintf "Map: %i elements@." (UpMap.cardinal map);
1161
    UpMap.fold (fun up (common, disj) accu ->
1162
        Format.eprintf
1163
          "Guards:@.shared: [%a]@.disj: [@[<v 0>%a@ ]@]@.Updates: %a@."
1164
          Guards.pp_short common
1165
          (fprintf_list ~sep:";@ " Guards.pp_long) disj
1166
          UpMap.pp up;
1167
        let disj = clean_disj disj in
1168
        let guard_expr = (gl_as_expr common)@disj in
1169
        
1170
        ((match guard_expr with
1171
         | [] -> None
1172
         | _ -> Some (mk_binop "&&" guard_expr)), up)::accu
1173
      ) map []
1174
    
1175
  in
1176
  process sw_init, process sw_sys, process init_out, process update_out