Project

General

Profile

Statistics
| Branch: | Tag: | Revision:

lustrec / src / tools / seal / seal_extract.ml @ 03c767b1

History | View | Annotate | Download (40.8 KB)

1
open Lustre_types
2
open Utils
3
open Seal_utils			
4
open Zustre_data (* Access to Z3 context *)
5
   
6

    
7
(* Switched system extraction: expression are memoized *)
8
(*let expr_mem = Hashtbl.create 13*)
9
    
10
let add_init defs vid =
11
  Hashtbl.add defs vid [[], IsInit]
12

    
13

    
14
(**************************************************************)
15
(* Convert from Lustre expressions to Z3 expressions and back *)
16
(* All (free) variables have to be declared in the Z3 context *)
17
(**************************************************************)
18

    
19
let is_init_name = "__is_init"
20

    
21
let const_defs = Hashtbl.create 13
22
let is_const id = Hashtbl.mem const_defs id
23
let get_const id = Hashtbl.find const_defs id
24
                 
25
(* expressions are only basic constructs here, no more ite, tuples,
26
   arrows, fby, ... *)
27

    
28
(* Set of hash to support memoization *)
29
let expr_hash: (expr * Utils.tag) list ref = ref []
30
let ze_hash: (Z3.Expr.expr, Utils.tag) Hashtbl.t = Hashtbl.create 13
31
let e_hash: (Utils.tag, Z3.Expr.expr) Hashtbl.t = Hashtbl.create 13
32
let pp_hash pp_key pp_v fmt h = Hashtbl.iter (fun key v -> Format.fprintf fmt "%a -> %a@ " pp_key key pp_v v) h
33
let pp_e_map fmt = List.iter (fun (e,t) -> Format.fprintf fmt "%i -> %a@ " t Printers.pp_expr e) !expr_hash
34
let pp_ze_hash fmt = pp_hash
35
                      (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))
36
                      Format.pp_print_int
37
                      fmt
38
                      ze_hash
39
let pp_e_hash fmt = pp_hash
40
                      Format.pp_print_int
41
                      (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))
42
                      fmt
43
                      e_hash                              
44
let mem_expr e =
45
  (* Format.eprintf "Searching for %a in map: @[<v 0>%t@]"
46
   *   Printers.pp_expr e
47
   *   pp_e_map; *)
48
  let res = List.exists (fun (e',_) -> Corelang.is_eq_expr e e') !expr_hash in
49
  (* Format.eprintf "found?%b@." res; *)
50
  res
51
  
52
let mem_zexpr ze =
53
  Hashtbl.mem ze_hash ze
54
let get_zexpr e =
55
  let eref, uid = List.find (fun (e',_) -> Corelang.is_eq_expr e e') !expr_hash in
56
  (* Format.eprintf "found expr=%a id=%i@." Printers.pp_expr eref eref.expr_tag; *)
57
  Hashtbl.find e_hash uid
58
let get_expr ze =
59
  let uid = Hashtbl.find ze_hash ze in
60
  let e,_ = List.find (fun (e,t) -> t = uid) !expr_hash in
61
  e
62
  
63
let neg_ze z3e = Z3.Boolean.mk_not !ctx z3e 
64
let is_init_z3e =
65
  Z3.Expr.mk_const_s !ctx is_init_name  Zustre_common.bool_sort 
66

    
67
let get_zid (ze:Z3.Expr.expr) : Utils.tag = 
68
  try
69
    if Z3.Expr.equal ze is_init_z3e then -1 else
70
      if Z3.Expr.equal ze (neg_ze is_init_z3e) then -2 else
71
    Hashtbl.find ze_hash ze
72
  with _ -> (Format.eprintf "Looking for ze %s in Hash %a"
73
               (Z3.Expr.to_string ze)
74
               (fun fmt hash -> Hashtbl.iter (fun ze uid -> Format.fprintf fmt "%s -> %i@ " (Z3.Expr.to_string ze) uid) hash ) ze_hash;
75
             assert false)
76
let add_expr =
77
  let cpt = ref 0 in
78
  fun e ze ->
79
  incr cpt;
80
  let uid = !cpt in
81
  expr_hash := (e, uid)::!expr_hash;
82
  Hashtbl.add e_hash uid ze;
83
  Hashtbl.add ze_hash ze uid
84

    
85
  
86
let expr_to_z3_expr, zexpr_to_expr =
87
  (* List to store converted expression. *)
88
  (* let hash = ref [] in
89
   * let comp_expr e (e', _) = Corelang.is_eq_expr e e' in
90
   * let comp_zexpr ze (_, ze') = Z3.Expr.equal ze ze' in *)
91
  
92
  let rec e2ze e =
93
    if mem_expr e then (
94
      get_zexpr e
95
    )
96
    else (
97
      let res =
98
        match e.expr_desc with
99
        | Expr_const c ->
100
           let z3e = Zustre_common.horn_const_to_expr c in
101
           add_expr e z3e;
102
           z3e
103
        | Expr_ident id -> (
104
          if is_const id then (
105
            let c = get_const id in
106
            let z3e = Zustre_common.horn_const_to_expr c in
107
            add_expr e z3e;
108
            z3e
109
        )
110
        else (
111
          let fdecl_id = Zustre_common.get_fdecl id in
112
          let z3e = Z3.Expr.mk_const_f !ctx fdecl_id in
113
          add_expr e z3e;
114
          z3e
115
          )
116
        )
117
      | Expr_appl (id,args, None) (* no reset *) ->
118
         let z3e = Zustre_common.horn_basic_app id e2ze (Corelang.expr_list_of_expr args) in
119
         add_expr e z3e;
120
         z3e
121
      | Expr_tuple [e] ->
122
         let z3e = e2ze e in
123
         add_expr e z3e;
124
         z3e
125
      | _ -> ( match e.expr_desc with Expr_tuple _ -> Format.eprintf "tuple e2ze(%a)@.@?" Printers.pp_expr e
126
                                    | _ -> Format.eprintf "e2ze(%a)@.@?" Printers.pp_expr e)
127
                 ; assert false
128
      in
129
      res
130
    )
131
  in
132
  let rec ze2e ze =
133
    let ze_name ze =
134
      let fd = Z3.Expr.get_func_decl ze in
135
      Z3.Symbol.to_string (Z3.FuncDecl.get_name fd)
136
    in
137
    if mem_zexpr ze then
138
      None, Some (get_expr ze)
139
    else
140
      let open Corelang in
141
      let fd = Z3.Expr.get_func_decl ze in
142
      let zel = Z3.Expr.get_args ze in
143
      match Z3.Symbol.to_string (Z3.FuncDecl.get_name fd), zel with
144
      (*      | var, [] -> (* should be in env *) get_e *)
145

    
146
      (* Extracting IsInit status *)
147
      | "not", [ze] when ze_name ze = is_init_name ->
148
         Some false, None
149
      | name, [] when name = is_init_name -> Some true, None
150
      (* Other constructs are converted to a lustre expression *)
151
      | op, _ -> (
152
        
153
        
154
        if Z3.Expr.is_numeral ze then
155
          let e =
156
            if Z3.Arithmetic.is_real ze then
157
              let num =  Num.num_of_ratio (Z3.Arithmetic.Real.get_ratio ze) in
158
              let s = Z3.Arithmetic.Real.numeral_to_string ze in
159
              mkexpr
160
                Location.dummy_loc
161
                (Expr_const
162
                   (Const_real
163
                      (num, 0, s)))
164
            else if Z3.Arithmetic.is_int ze then
165
              mkexpr
166
                Location.dummy_loc
167
                (Expr_const
168
                   (Const_int
169
                      (Big_int.int_of_big_int (Z3.Arithmetic.Integer.get_big_int ze))))
170
            else if Z3.Expr.is_const ze then
171
              match Z3.Expr.to_string ze with
172
              | "true" -> mkexpr Location.dummy_loc
173
                            (Expr_const (Const_tag (tag_true)))
174
              | "false" ->
175
                 mkexpr Location.dummy_loc
176
                   (Expr_const (Const_tag (tag_false)))
177
              | _ -> assert false
178
            else
179
              (
180
                Format.eprintf "Const err: %s %b@." (Z3.Expr.to_string ze) (Z3.Expr.is_const ze);
181
                assert false (* a numeral but no int nor real *)
182
              )
183
          in
184
          None, Some e
185
        else
186
          match op with
187
          | "not" | "=" | "-" | "*" | "/"
188
            | ">=" | "<=" | ">" | "<" 
189
            ->
190
             let args = List.map (fun ze -> Utils.desome (snd (ze2e ze))) zel in
191
             None, Some (mkpredef_call Location.dummy_loc op args)
192
          | "+" -> ( (* Special treatment of + for 2+ args *)
193
            let args = List.map (fun ze -> Utils.desome (snd (ze2e ze))) zel in
194
            let e = match args with
195
                [] -> assert false
196
              | [hd] -> hd
197
              | e1::e2::tl ->
198
                 let first_binary_and = mkpredef_call Location.dummy_loc op [e1;e2] in
199
                 if tl = [] then first_binary_and else
200
                   List.fold_left (fun e e_new ->
201
                       mkpredef_call Location.dummy_loc op [e;e_new]
202
                     ) first_binary_and tl
203
                 
204
            in
205
            None, Some e 
206
          )
207
          | "and" | "or" -> (
208
            (* Special case since it can contain is_init pred *)
209
            let args = List.map (fun ze -> ze2e ze) zel in
210
            let op = if op = "and" then "&&" else if op = "or" then "||" else assert false in 
211
            match args with
212
            | [] -> assert false
213
            | [hd] -> hd
214
            | hd::tl ->
215
               List.fold_left
216
                 (fun (is_init_opt1, expr_opt1) (is_init_opt2, expr_opt2) ->
217
                   (match is_init_opt1, is_init_opt2 with
218
                      None, x | x, None -> x
219
                      | Some _, Some _ -> assert false),
220
                   (match expr_opt1, expr_opt2 with
221
                    | None, x | x, None -> x
222
                    | Some e1, Some e2 ->
223
                       Some (mkpredef_call Location.dummy_loc op [e1; e2])
224
                 ))
225
                 hd tl 
226
          )
227
          | op -> 
228
             let args = List.map (fun ze ->  (snd (ze2e ze))) zel in
229
             Format.eprintf "deal with op %s (nb args: %i). Expr is %s@."  op (List.length args) (Z3.Expr.to_string ze); assert false
230
      )
231
  in
232
  (fun e -> e2ze e), (fun ze -> ze2e ze)
233

    
234
               
235
let zexpr_to_guard_list ze =
236
  let init_opt, expr_opt = zexpr_to_expr ze in
237
  (match init_opt with
238
   | None -> []
239
   |Some b -> [IsInit, b]
240
  ) @ (match expr_opt with
241
       | None -> []
242
       | Some e -> [Expr e, true]
243
      )
244
                
245
               
246
let simplify_neg_guard l =
247
  List.map (fun (g,posneg) ->
248
      match g with
249
      | IsInit -> g, posneg
250
      | Expr g ->
251
          if posneg then
252
            Expr (Corelang.push_negations g),
253
            true
254
          else
255
            (* Pushing the negation in the expression *)
256
            Expr(Corelang.push_negations ~neg:true g),
257
            true
258
    ) l 
259

    
260
(* TODO:
261
individuellement demander si g1 => g2. Si c'est le cas, on peut ne garder que g1 dans la liste
262
*)    
263
(*****************************************************************)
264
(* Checking sat(isfiability) of an expression and simplifying it *)
265
(* All (free) variables have to be declared in the Z3 context    *)
266
(*****************************************************************)
267
(*
268
let goal_simplify zl =
269
  let goal = Z3.Goal.mk_goal !ctx false false false in
270
  Z3.Goal.add goal zl;
271
  let goal' = Z3.Goal.simplify goal None in
272
  (* Format.eprintf "Goal before: %s@.Goal after : %s@.Sat? %s@."
273
   *   (Z3.Goal.to_string goal)
274
   *   (Z3.Goal.to_string goal')
275
   *   (Z3.Solver.string_of_status status_res) 
276
   * ; *)
277
  let ze = Z3.Goal.as_expr goal' in
278
  (* Format.eprintf "as an expr: %s@." (Z3.Expr.to_string ze); *)
279
  zexpr_to_guard_list ze
280
  *)
281
  
282
let implies =
283
  let ze_implies_hash : ((Utils.tag * Utils.tag), bool) Hashtbl.t  = Hashtbl.create 13 in
284
  fun ze1 ze2 ->
285
  let ze1_uid = get_zid ze1 in
286
  let ze2_uid = get_zid ze2 in
287
  if Hashtbl.mem ze_implies_hash (ze1_uid, ze2_uid) then
288
    Hashtbl.find ze_implies_hash (ze1_uid, ze2_uid)
289
  else
290
    begin
291
      if !seal_debug then (
292
        report ~level:6 (fun fmt -> Format.fprintf fmt "Checking implication: %s => %s?@ "
293
          (Z3.Expr.to_string ze1) (Z3.Expr.to_string ze2)
294
      )); 
295
      let solver = Z3.Solver.mk_simple_solver !ctx in
296
      let tgt = Z3.Boolean.mk_not !ctx (Z3.Boolean.mk_implies !ctx ze1 ze2) in
297
      let res =
298
        try
299
          let status_res = Z3.Solver.check solver [tgt] in
300
          match status_res with
301
          | Z3.Solver.UNSATISFIABLE -> if !seal_debug then
302
                                         report ~level:6 (fun fmt -> Format.fprintf fmt "Valid!@ "); 
303
             true
304
          | _ -> if !seal_debug then report ~level:6 (fun fmt -> Format.fprintf fmt "not proved valid@ "); 
305
             false
306
        with Zustre_common.UnknownFunction(id, msg) -> (
307
          report ~level:1 msg;
308
          false
309
        )
310
      in
311
      Hashtbl.add ze_implies_hash (ze1_uid,ze2_uid) res ;
312
      res
313
    end
314
                                               
315
let rec simplify zl =
316
  match zl with
317
  | [] | [_] -> zl
318
  | hd::tl -> (
319
    (* Forall e in tl, checking whether hd => e or e => hd, to keep hd
320
     in the first case and e in the second one *)
321
    let tl = simplify tl in
322
    let keep_hd, tl =
323
      List.fold_left (fun (keep_hd, accu) e ->
324
          if implies hd e then
325
            true, accu (* throwing away e *)
326
          else if implies e hd then
327
            false, e::accu (* throwing away hd *)
328
          else
329
            keep_hd, e::accu (* keeping both *)
330
        ) (true,[]) tl
331
    in
332
    (* Format.eprintf "keep_hd?%b hd=%s, tl=[%a]@."
333
     *   keep_hd
334
     *   (Z3.Expr.to_string hd)
335
     * (Utils.fprintf_list ~sep:"; " (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))) tl
336
     *   ; *)
337
    if keep_hd then
338
      hd::tl
339
    else
340
      tl
341
  )
342
  
343
let check_sat ?(just_check=false) (l: elem_boolexpr guard) : bool * (elem_boolexpr guard) =
344
  (* Syntactic simplification *)
345
  if false then
346
    Format.eprintf "Before simplify: %a@." (pp_guard_list pp_elem) l; 
347
  let l = simplify_neg_guard l in
348
  if false then (
349
    Format.eprintf "After simplify: %a@." (pp_guard_list pp_elem) l; 
350
    Format.eprintf "@[<v 2>Z3 check sat: [%a]@ " (pp_guard_list pp_elem)l;
351
  );
352
  let solver = Z3.Solver.mk_simple_solver !ctx in
353
  try (
354
    let zl =
355
      List.map (fun (e, posneg) ->
356
          let ze =
357
            match e with
358
            | IsInit -> is_init_z3e
359
            | Expr e -> expr_to_z3_expr e 
360
          in
361
          if posneg then
362
            ze
363
          else
364
            neg_ze ze
365
        ) l
366
    in
367
    if false then Format.eprintf "Z3 exprs1: [%a]@ " (fprintf_list ~sep:",@ " (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))) zl; 
368
    let zl = simplify zl in
369
        if false then Format.eprintf "Z3 exprs2: [%a]@ " (fprintf_list ~sep:",@ " (fun fmt e -> Format.fprintf fmt "%s" (Z3.Expr.to_string e))) zl; 
370
    let status_res = Z3.Solver.check solver zl in
371
     if false then Format.eprintf "Z3 status: %s@ @]@. " (Z3.Solver.string_of_status status_res); 
372
    match status_res with
373
    | Z3.Solver.UNSATISFIABLE -> false, []
374
    | _ -> (
375
      if false && just_check then
376
        true, l
377
      else
378
        (* TODO: may be reactivate but it may create new expressions *)
379
        (* let l = goal_simplify zl in *)
380
        let l = List.fold_left
381
                  (fun accu ze -> accu @ (zexpr_to_guard_list ze))
382
                  []
383
                  zl
384
        in
385
  (* Format.eprintf "@.@[<v 2>Check_Sat:@ before: %a@ after:
386
       %a@. Goal precise? %b/%b@]@.@. " * pp_guard_list l
387
       pp_guard_list l' * (Z3.Goal.is_precise goal) *
388
       (Z3.Goal.is_precise goal'); *)
389
  
390

    
391
        true, l
392
        
393
         
394
    )
395
         
396
  )
397
  with Zustre_common.UnknownFunction(id, msg) -> (
398
    report ~level:1 msg;
399
    true, l (* keeping everything. *)
400
  )
401
      
402
  
403

    
404

    
405
(**************************************************************)
406

    
407
  
408
let clean_sys sys =
409
  List.fold_left (fun accu (guards, updates) ->
410
      let sat, guards' =  check_sat (List.map (fun (g, pn) -> Expr g, pn) guards) in
411
      (*Format.eprintf "Guard: %a@.Guard cleaned: %a@.Sat? %b@."
412
        (fprintf_list ~sep:"@ " (pp_guard_expr Printers.pp_expr))  guards
413
        (fprintf_list ~sep:"@ " (pp_guard_expr Printers.pp_expr))  guards'
414
        sat
415

    
416
      ;*)
417
        if sat then
418
        (List.map (fun (e, b) -> (deelem e, b)) guards', updates)::accu
419
      else
420
        accu
421
    )
422
    [] sys
423

    
424
(* Most costly function: has the be efficiently implemented.  All
425
   registered guards are initially produced by the call to
426
   combine_guards. We csan normalize the guards to ease the
427
   comparisons.
428

    
429
   We assume that gl1 and gl2 are already both satisfiable and in a
430
   kind of reduced form. Let lshort and llong be the short and long
431
   list of gl1, gl2. We check whether each element elong of llong is
432
   satisfiable with lshort. If no, stop. If yes, we search to reduce
433
   the list. If elong => eshort_i, we can remove eshort_i from
434
   lshort. we can continue with this lshort shortened, lshort'; it is
435
   not necessary to add yet elong in lshort' since we already know
436
   rthat elong is somehow reduced with respect to other elements of
437
   llong. If eshort_i => elong, then we keep ehosrt_i in lshort and do
438
   not store elong.
439

    
440
   After iterating through llong, we have a shortened version of
441
   lshort + some elements of llong that have to be remembered. We add
442
   them to this new consolidated list. 
443

    
444
 *)
445

    
446
(* combine_guards ~fresh:Some(e,b) gl1 gl2 returns ok, gl with ok=true
447
   when (e=b) ang gl1 and gl2 is satisfiable and gl is a consilidated
448
   version of it.  *)
449
let combine_guards ?(fresh=None) gl1 gl2 =
450
  (* Filtering out trivial cases. More semantics ones would have to be
451
     addressed later *)
452
  let check_sat e = (* temp function before we clean the original one *)
453
    let ok, _ = check_sat e in
454
    ok
455
  in
456
  let implies (e1,pn1) (e2,pn2) =
457
    let e2z e pn =
458
      match e with
459
        | IsInit -> if pn then is_init_z3e else neg_ze is_init_z3e
460
        | Expr e -> expr_to_z3_expr (if pn then e else (Corelang.push_negations ~neg:true e))
461
      in
462
    implies (e2z e1 pn1) (e2z e2 pn2)
463
  in
464
  let lshort, llong =
465
    if List.length gl1 > List.length gl2 then gl2, gl1 else gl1, gl2
466
  in
467
  let merge long short =
468
    let short, long_sel, ok = 
469
    List.fold_left (fun (short,long_sel, ok) long_e ->
470
        if not ok then
471
          [],[], false (* Propagating unsat case *)
472
        else if check_sat (long_e::short) then
473
          let short, keep_long_e =
474
            List.fold_left (fun (accu_short, keep_long_e) eshort_i ->
475
                if not keep_long_e then (* shorten the algo *)
476
                  eshort_i :: accu_short, false
477
                else (* keep_long_e = true in the following *)
478
                  if implies eshort_i long_e then
479
                    (* First case is trying to remove long_e!
480

    
481
                     Since short is already normalized, we can remove
482
                     long_e. If later long_e is stronger than another
483
                     element of short, then necessarily eshort_i =>
484
                     long_e ->
485
                     that_other_element_of_short. Contradiction. *)
486
                    eshort_i::accu_short, false
487
                  else if implies long_e eshort_i then 
488
                    (* removing eshort_i, keeping long_e. *)
489
                    accu_short, true
490
                  else (* Not comparable, keeping both *)
491
                    eshort_i::accu_short, true
492
              )
493
              ([],true) (* Initially we assume that we will keep long_e *)
494
              short
495
          in
496
          if keep_long_e then
497
            short, long_e::long_sel, true
498
          else
499
            short, long_sel, true
500
        else
501
          [],[],false
502
      ) (short, [], true) long
503
    in
504
    ok, long_sel@short
505
  in
506
  let ok, l = match fresh with
507
    | None -> true, []
508
    | Some g -> merge [g] []
509
  in
510
  if not ok then
511
    false, []
512
  else
513
    let ok, lshort = merge lshort l in
514
    if not ok then
515
      false, []
516
    else
517
      merge llong lshort
518
    
519

    
520
(* Encode "If gel1=posneg then gel2":
521
   - Compute the combination of guarded exprs in gel1 and gel2:
522
     - Each guarded_expr in gel1 is transformed as a guard: the
523
       expression is associated to posneg.
524
     - Existing guards in gel2 are concatenated to that list of guards
525
     - We keep expr in the ge of gel2 as the legitimate expression 
526
 *)
527
let concatenate_ge gel1 posneg gel2 =
528
  let l, all_invalid =
529
    List.fold_left (
530
        fun (accu, all_invalid) (g2,e2) ->
531
        List.fold_left (
532
            fun (accu, all_invalid) (g1,e1) ->
533
            (* Format.eprintf "@[<v 2>Combining guards: (%a=%b) AND [%a] AND [%a]@ "
534
             *  pp_elem e1
535
             *  posneg
536
             *  pp_guard_list g1
537
             *  pp_guard_list g2; *)
538

    
539
            let ok, gl = combine_guards ~fresh:(Some(e1,posneg)) g1 g2 in
540
            (* Format.eprintf "@]@ Result is [%a]@ "
541
             *   pp_guard_list gl; *)
542

    
543
            if ok then
544
              (gl, e2)::accu, false
545
            else (
546
              accu, all_invalid
547
            )
548
          ) (accu, all_invalid) gel1
549
      ) ([], true) gel2
550
  in
551
  not all_invalid, l
552
     
553
(* Rewrite the expression expr, replacing any occurence of a variable
554
   by its definition.
555
*)
556
let rec rewrite defs expr : elem_guarded_expr list =
557
  let rewrite = rewrite defs in
558
  let res =
559
    match expr.expr_desc with
560
    | Expr_appl (id, args, None) ->
561
       let args = rewrite args in
562
       List.map (fun (guards, e) ->
563
           guards,
564
           Expr (Corelang.mkexpr expr.expr_loc (Expr_appl(id, deelem e, None)))
565
         ) args 
566
    | Expr_const _  -> [[], Expr expr]
567
    | Expr_ident id ->
568
       if Hashtbl.mem defs id then
569
         Hashtbl.find defs id
570
       else
571
         (* id should be an input *)
572
         [[], Expr expr]
573
    | Expr_ite (g, e1, e2) ->
574
       let g = rewrite g and
575
           e1 = rewrite e1 and
576
           e2 = rewrite e2 in
577
       let ok_then, g_then = concatenate_ge g true e1 in
578
       let ok_else, g_else = concatenate_ge g false e2 in
579
       (if ok_then then g_then else [])@
580
         (if ok_else then g_else else [])
581
    | Expr_merge (g, branches) ->
582
       assert false (* TODO: deal with merges *)
583
      
584
    | Expr_when (e, _, _) -> rewrite e
585
    | Expr_arrow _ -> [[], IsInit] (* At this point the only arrow should be true -> false *)
586
    | Expr_tuple el ->
587
       (* Each expr is associated to its flatten guarded expr list *)
588
       let gell = List.map rewrite el in
589
       (* Computing all combinations: we obtain a list of guarded tuple *)
590
       let rec aux gell : (elem_boolexpr guard * expr list) list =
591
         match gell with
592
         | [] -> assert false (* Not happening *)
593
         | [gel] -> List.map (fun (g,e) -> g, [deelem e]) gel
594
         | gel::getl ->
595
            let getl = aux getl in
596
            List.fold_left (
597
                fun accu (g,e) ->
598
                List.fold_left (
599
                    fun accu (gl, minituple) ->
600
                    let is_compat, guard_comb = combine_guards g gl in
601
                    if is_compat then
602
                      let new_gt : elem_boolexpr guard * expr list = (guard_comb, (deelem e)::minituple) in
603
                      new_gt::accu
604
                    else
605
                      accu
606
                    
607
                  ) accu getl
608
              ) [] gel
609
       in
610
       let gtuples = aux gell in
611
       (* Rebuilding the valid type: guarded expr list (with tuple exprs) *)
612
       List.map
613
         (fun (g,tuple) -> g, Expr (Corelang.mkexpr expr.expr_loc (Expr_tuple tuple)))
614
         gtuples
615
    | Expr_fby _
616
      | Expr_appl _
617
                  (* Should be removed by mormalization and inlining *)
618
      -> Format.eprintf "Pb expr: %a@.@?" Printers.pp_expr expr; assert false
619
    | Expr_array _ | Expr_access _ | Expr_power _
620
                                                (* Arrays not handled here yet *)
621
      -> assert false
622
    | Expr_pre _ -> (* Not rewriting mem assign *)
623
       assert false
624
  in
625
  (* Format.eprintf "Rewriting %a as [@[<v 0>%a@]]@ "
626
   *   Printers.pp_expr expr
627
   *   (Utils.fprintf_list ~sep:"@ "
628
   *        pp_guard_expr) res; *)
629
  res
630
and add_def defs vid expr =
631
  let vid_defs = rewrite defs expr in
632
  (* Format.eprintf "Add_def: %s = %a@. -> @[<v 0>%a@]@."
633
   *     vid
634
   *     Printers.pp_expr expr
635
   *     (
636
   *       (Utils.fprintf_list ~sep:"@ "
637
   *          pp_guard_expr)) vid_defs;  *)
638
  Hashtbl.add defs vid vid_defs
639

    
640
(* Takes a list of guarded exprs (ge) and a guard
641
returns the same list of ge splited into the ones where the guard is true and the ones where it is false. In both lists the associated ge do not mention that guard anymore.
642

    
643
When a given ge doesn't mention positively or negatively such guards, it is duplicated in both lists *)
644
let split_mdefs elem (mdefs: elem_guarded_expr list) =
645
  List.fold_left (
646
      fun (selected, left_out)
647
          ((guards, expr) as ge) ->
648
      (* select the element of guards that match the argument elem *)
649
      let sel, others_guards = List.partition (select_elem elem) guards in
650
      match sel with
651
      (* we extract the element from the list and add it to the
652
         appropriate list *)
653
      | [_, sel_status] ->
654
         if sel_status then
655
           (others_guards,expr)::selected, left_out
656
         else selected, (others_guards,expr)::left_out
657
      | [] -> (* no such guard exists, we have to duplicate the
658
              guard_expr in both lists *)
659
         ge::selected, ge::left_out
660
      | _ -> (
661
        Format.eprintf "@.Spliting list on elem %a.@.List:%a@."
662
          pp_elem elem
663
          (pp_mdefs pp_elem) mdefs;
664
        assert false (* more then one element selected. Should
665
                          not happen , or trival dead code like if
666
                              x then if not x then dead code *)
667
      )
668
    ) ([],[]) mdefs
669
    
670
let split_mem_defs
671
      (elem: element)
672
      (mem_defs: (ident * elem_guarded_expr list) list)
673
      :
674
      ((ident * elem_guarded_expr mdef_t) list) * ((ident * elem_guarded_expr mdef_t) list)
675
  
676
  =
677
  List.fold_right (fun (m,mdefs)
678
                       (accu_pos, accu_neg) ->
679
      let pos, neg =
680
        split_mdefs elem mdefs 
681
      in
682
      (m, pos)::accu_pos,
683
      (m, neg)::accu_neg
684
    ) mem_defs ([],[])
685

    
686

    
687
(* Split a list of mem_defs into init and step lists of guarded
688
   expressions per memory. *)
689
let split_init mem_defs =
690
  split_mem_defs IsInit mem_defs 
691

    
692
(* Previous version of the function: way too costly 
693
let pick_guard mem_defs : expr option =  
694
  let gel = List.flatten (List.map snd mem_defs) in
695
  let gl = List.flatten (List.map fst gel) in
696
  let all_guards =
697
    List.map (
698
        (* selecting guards and getting rid of boolean *)
699
        fun (e,b) ->
700
        match e with
701
        | Expr e -> e
702
        | _ -> assert false
703
      (* should have been filtered out
704
                                      yet *)
705
      ) gl
706
  in
707
  (* TODO , one could sort by occurence and provided the most common
708
     one *)
709
  try
710
  Some (List.hd all_guards)  
711
  with _ -> None
712
   *)
713

    
714
(* Returning the first non empty guard expression *)
715
let rec pick_guard mem_defs : expr option =
716
  match mem_defs with
717
  | [] -> None
718
  | (_, gel)::tl -> (
719
    let found =
720
      List.fold_left (fun found (g,_) ->
721
          if found = None then
722
            match g with
723
            | [] -> None
724
            | (Expr e, _)::_ -> Some e
725
            | (IsInit, _)::_ -> assert false (* should be removed already *)
726
          else
727
            found
728
        ) None gel
729
    in
730
    if found = None then pick_guard tl else found
731
  )
732
          
733

    
734
(* Transform a list of variable * guarded exprs into a list of guarded pairs (variable, expressions)
735
*)
736
let rec build_switch_sys
737
          (mem_defs : (Utils.ident * elem_guarded_expr list) list )
738
          prefix
739
        :
740
          ((expr * bool) list * (ident * expr) list ) list =
741
  if !seal_debug then
742
    report ~level:4 (fun fmt -> Format.fprintf fmt "@[<v 2>Build_switch with@ %a@]@."
743
      (Utils.fprintf_list ~sep:",@ "
744
         (fun fmt (id, gel) -> Format.fprintf fmt "%s -> [@[<v 0>%a]@]"
745
                                 id
746
                                 (pp_mdefs pp_elem) gel))
747
      mem_defs);
748
  (* if all mem_defs have empty guards, we are done, return prefix,
749
     mem_defs expr.
750

    
751
     otherwise pick a guard in one of the mem, eg (g, b) then for each
752
     other mem, one need to select the same guard g with the same
753
     status b, *)
754
  let res =
755
  if List.for_all (fun (m,mdefs) ->
756
         (* All defs are unguarded *)
757
         match mdefs with
758
         | [[], _] -> true
759
         | _ -> false
760
       ) mem_defs
761
  then
762
    [prefix ,
763
     List.map (fun (m,gel) ->
764
         match gel with
765
         | [_,e] ->
766
            let e =
767
              match e with
768
              | Expr e -> e
769
              | _ -> assert false (* No IsInit expression *)
770
            in
771
            m,e
772
         | _ -> assert false
773
       ) mem_defs]
774
  else
775
    (* Picking a guard *)
776
    let elem_opt : expr option = pick_guard mem_defs in
777
    match elem_opt with
778
      None -> assert false (* Otherwise the first case should have matched *)
779
    | Some elem -> (
780
      report ~level:4 (fun fmt -> Format.fprintf fmt "selecting guard %a@." Printers.pp_expr elem);
781
      let pos, neg =
782
        split_mem_defs
783
          (Expr elem)
784
          mem_defs
785
      in
786
      (* Special cases to avoid useless computations: true, false conditions *)
787
      match elem.expr_desc with
788
      | Expr_ident "true"  ->   build_switch_sys pos prefix
789
      | Expr_const (Const_tag tag) when tag = Corelang.tag_true
790
        ->   build_switch_sys pos prefix
791
      | Expr_ident "false" ->   build_switch_sys neg prefix
792
      | Expr_const (Const_tag tag) when tag = Corelang.tag_false 
793
        ->   build_switch_sys neg prefix
794
      | _ -> (* Regular case *)
795
         (* let _ = (
796
          *     Format.eprintf "Expr is %a@." Printers.pp_expr elem;
797
          *     match elem.expr_desc with
798
          *     | Expr_const _ -> Format.eprintf "a const@."
799
          *                     
800
          *     | Expr_ident _ -> Format.eprintf "an ident@."
801
          *     | _ -> Format.eprintf "something else@."
802
          *   )
803
          * in *)
804
         let clean l =
805
           let l = List.map (fun (e,b) -> (Expr e), b) l in
806
           let ok, l = check_sat l in
807
           let l = List.map (fun (e,b) -> deelem e, b) l in
808
           ok, l
809
         in
810
         let pos_prefix = (elem, true)::prefix in
811
         let neg_prefix = (elem, false)::prefix in
812
         let ok_pos, pos_prefix = clean pos_prefix in         
813
         let ok_neg, neg_prefix = clean neg_prefix in
814
         report ~level:4 (fun fmt -> Format.fprintf fmt "Enforcing %a@." Printers.pp_expr elem);
815
         let ok_l = if ok_pos then build_switch_sys pos pos_prefix else [] in
816
         report ~level:4 (fun fmt -> Format.fprintf fmt "Enforcing not(%a)@." Printers.pp_expr elem);
817
         let nok_l = if ok_neg then build_switch_sys neg neg_prefix else [] in
818
         ok_l @ nok_l
819
    )
820
  in
821
    if !seal_debug then (
822
      report ~level:4 (fun fmt ->
823
          Format.fprintf fmt
824
            "@[<v 2>===> @[%t@ @]@]@ @]@ "
825
            (fun fmt -> List.iter (fun (gl,up) ->
826
                            Format.fprintf fmt "[@[%a@]] -> (%a)@ "
827
                              (pp_guard_list Printers.pp_expr) gl pp_up up) res);
828
          
829
    ));
830
    res
831
  
832

    
833
      
834
(* Take a normalized node and extract a list of switches: (cond,
835
   update) meaning "if cond then update" where update shall define all
836
   node memories. Everything as to be expressed over inputs or memories, intermediate variables are removed through inlining *)
837
let node_as_switched_sys consts (mems:var_decl list) nd =
838
  (* rescheduling node: has been scheduled already, no need to protect
839
     the call to schedule_node *)
840
  let nd_report = Scheduling.schedule_node nd in
841
  let schedule = nd_report.Scheduling_type.schedule in
842
  let eqs, auts = Corelang.get_node_eqs nd in
843
  assert (auts = []); (* Automata should be expanded by now *)
844
  let sorted_eqs = Scheduling.sort_equations_from_schedule eqs schedule in
845
  let defs : (ident,  elem_guarded_expr list) Hashtbl.t = Hashtbl.create 13 in
846
  let add_def = add_def defs in
847

    
848
  let vars = Corelang.get_node_vars nd in
849
  (* Registering all locals variables as Z3 predicates. Will be use to
850
     simplify the expansion *) 
851
  let _ =
852
    List.iter (fun v ->
853
        let fdecl = Z3.FuncDecl.mk_func_decl_s
854
                      !ctx
855
                      v.var_id
856
                      []
857
                      (Zustre_common.type_to_sort v.var_type)
858
        in
859
        ignore (Zustre_common.register_fdecl v.var_id fdecl)
860
      ) vars
861
  in
862
  let _ =
863
    List.iter (fun c -> Hashtbl.add const_defs c.const_id c.const_value) consts
864
  in
865

    
866
  
867
  (* Registering node equations: identifying mem definitions and
868
     storing others in the "defs" hashtbl.
869

    
870
     Each assign is stored in a hash tbl as list of guarded
871
     expressions. The memory definition is also "rewritten" as such a
872
     list of guarded assigns.  *)
873
  let mem_defs, output_defs =
874
    List.fold_left (fun (accu_mems, accu_outputs) eq ->
875
        match eq.eq_lhs with
876
        | [vid] ->
877
           (* Only focus on memory definitions *)
878
           if List.exists (fun v -> v.var_id = vid) mems then 
879
             (
880
               match eq.eq_rhs.expr_desc with
881
               | Expr_pre def_m ->
882
                  report ~level:3 (fun fmt ->
883
                      Format.fprintf fmt "Preparing mem %s@." vid);
884
                  (vid, rewrite defs def_m)::accu_mems, accu_outputs
885
               | _ -> assert false
886
             ) 
887
           else if List.exists (fun v -> v.var_id = vid) nd.node_outputs then (
888
             report ~level:3 (fun fmt ->
889
                 Format.fprintf fmt "Output variable %s@." vid);
890
             add_def vid eq.eq_rhs;
891
             accu_mems, (vid, rewrite defs eq.eq_rhs)::accu_outputs
892
             
893
           )
894
           else
895
             (
896
               report ~level:3 (fun fmt ->
897
                   Format.fprintf fmt "Registering variable %s@." vid);
898
               add_def vid eq.eq_rhs;
899
               accu_mems, accu_outputs
900
             )
901
        | _ -> assert false (* should have been removed by normalization *)
902
      ) ([], []) sorted_eqs
903
  in
904

    
905
  
906
  report ~level:2 (fun fmt -> Format.fprintf fmt "Printing out (guarded) memories definitions (may takes time)@.");
907
  (* Printing memories definitions *)
908
  report ~level:3
909
    (fun fmt ->
910
      Format.fprintf fmt
911
        "@[<v 0>%a@]@."
912
        (Utils.fprintf_list ~sep:"@ "
913
           (fun fmt (m,mdefs) ->
914
             Format.fprintf fmt
915
               "%s -> [@[<v 0>%a@] ]@ "
916
               m
917
               (Utils.fprintf_list ~sep:"@ "
918
                  (pp_guard_expr pp_elem)) mdefs
919
        ))
920
        mem_defs);
921
  (* Format.eprintf "Split init@."; *)
922
  let init_defs, update_defs =
923
    split_init mem_defs 
924
  in
925
  let init_out, update_out =
926
    split_init output_defs
927
  in
928
  report ~level:3
929
    (fun fmt ->
930
      Format.fprintf fmt
931
        "@[<v 0>Init:@ %a@]@."
932
        (Utils.fprintf_list ~sep:"@ "
933
           (fun fmt (m,mdefs) ->
934
             Format.fprintf fmt
935
               "%s -> @[<v 0>[%a@] ]@ "
936
               m
937
               (Utils.fprintf_list ~sep:"@ "
938
                  (pp_guard_expr pp_elem)) mdefs
939
        ))
940
        init_defs);
941
  report ~level:3
942
    (fun fmt ->
943
      Format.fprintf fmt
944
        "@[<v 0>Step:@ %a@]@."
945
        (Utils.fprintf_list ~sep:"@ "
946
           (fun fmt (m,mdefs) ->
947
             Format.fprintf fmt
948
               "%s -> @[<v 0>[%a@] ]@ "
949
               m
950
               (Utils.fprintf_list ~sep:"@ "
951
                  (pp_guard_expr pp_elem)) mdefs
952
        ))
953
        update_defs);
954
  (* Format.eprintf "Build init@."; *)
955
  report ~level:4 (fun fmt -> Format.fprintf fmt "Build init@.");
956
  let sw_init= 
957
    build_switch_sys init_defs []
958
  in
959
  (* Format.eprintf "Build step@."; *)
960
  report ~level:4 (fun fmt -> Format.fprintf fmt "Build step@.");
961
  let sw_sys =
962
    build_switch_sys update_defs []
963
  in
964

    
965
  report ~level:4 (fun fmt -> Format.fprintf fmt "Build init out@.");
966
  let init_out =
967
    build_switch_sys init_out []
968
  in
969
  report ~level:4 (fun fmt -> Format.fprintf fmt "Build step out@.");
970

    
971
  let update_out =
972
    build_switch_sys update_out []
973
  in
974

    
975
  let sw_init = clean_sys sw_init in
976
  let sw_sys = clean_sys sw_sys in
977
  let init_out = clean_sys init_out in
978
  let update_out = clean_sys update_out in
979

    
980
  (* Some additional checks *)
981
  
982
  if false then
983
    begin
984
      Format.eprintf "@.@.CHECKING!!!!!!!!!!!@.";
985
      Format.eprintf "Any duplicate expression in guards?@.";
986
      
987
      let sw_sys =
988
        List.map (fun (gl, up) ->
989
            let gl = List.sort (fun (e,b) (e',b') ->
990
                         let res = compare e.expr_tag e'.expr_tag in
991
                         if res = 0 then (Format.eprintf "Same exprs?@.%a@.%a@.@."
992
                                            Printers.pp_expr e
993
                                            Printers.pp_expr e'
994
                                         );
995
                         res
996
                       ) gl in
997
            gl, up
998
          ) sw_sys 
999
      in
1000
      Format.eprintf "Another check for duplicates in guard list@.";
1001
      List.iter (fun (gl, _) ->
1002
          let rec aux hd l =
1003
            match l with
1004
              [] -> ()
1005
            | (e,b)::tl -> let others = hd@tl in
1006
                           List.iter (fun (e',_) -> if Corelang.is_eq_expr e e' then
1007
                                                      (Format.eprintf "Same exprs?@.%a@.%a@.@."
1008
                                                         Printers.pp_expr e
1009
                                                         Printers.pp_expr e'
1010
                             )) others;
1011
                           aux ((e,b)::hd) tl
1012
          in
1013
          aux [] gl
1014
        ) sw_sys;
1015
      Format.eprintf "Checking duplicates in updates@.";
1016
      let rec check_dup_up accu l =
1017
        match l with
1018
        | [] -> ()
1019
        | ((gl, up) as hd)::tl ->
1020
           let others = accu@tl in
1021
           List.iter (fun (gl',up') -> if up = up' then
1022
                                         Format.eprintf "Same updates?@.%a@.%a@.%a@.%a@.@."
1023

    
1024
                                           pp_gl_short gl
1025
                                           pp_up up
1026
                                           pp_gl_short gl'
1027
                                           pp_up up'
1028
                                       
1029
             ) others;
1030
           
1031
           
1032

    
1033
           check_dup_up (hd::accu) tl
1034
           
1035
      in
1036
      check_dup_up [] sw_sys;
1037
      let sw_sys =
1038
        List.sort (fun (gl1, _) (gl2, _) ->
1039
            let glid gl = List.map (fun (e,_) -> e.expr_tag) gl in
1040
            
1041
            let res = compare (glid gl1) (glid gl2) in
1042
            if res = 0 then Format.eprintf "Same guards?@.%a@.%a@.@."
1043
                              pp_gl_short gl1 pp_gl_short gl2
1044
            ;
1045
              res
1046

    
1047
          ) sw_sys
1048

    
1049
      in
1050
      ()
1051
    end;
1052
  
1053

    
1054
  (* Iter through the elements and gather them by updates *)
1055
  let process sys =
1056
    (* The map will associate to each update up the pair (set, set
1057
       list) where set is the share guards and set list a list of
1058
       disjunctive guards. Each set represents a conjunction of
1059
       expressions. *)
1060
    report ~level:3 (fun fmt -> Format.fprintf fmt "%t@."
1061
                                  (fun fmt -> List.iter (fun (gl,up) ->
1062
                                                  Format.fprintf fmt "[@[%a@]] -> (%a)@ "
1063
                                                    (pp_guard_list Printers.pp_expr) gl pp_up up) sw_init));
1064
    
1065
    (* We perform multiple pass to avoid errors *)
1066
    let map =
1067
      List.fold_left (fun map (gl,up) ->
1068
          (* creating a new set to describe gl *)
1069
          let new_set =
1070
            List.fold_left
1071
              (fun set g -> Guards.add g set)
1072
              Guards.empty
1073
              gl
1074
          in
1075
          (* updating the map with up -> new_set *)
1076
          if UpMap.mem up map then
1077
            let guard_set = UpMap.find up map in
1078
            UpMap.add up (new_set::guard_set) map
1079
          else
1080
            UpMap.add up [new_set] map
1081
        ) UpMap.empty sys
1082
    in
1083
    (* Processing the set of guards leading to the same update: return
1084
       conj, disj with conf is a set of guards, and disj a DNF, ie a
1085
       list of set of guards *)
1086
    let map =
1087
      UpMap.map (
1088
          fun guards ->
1089
          match guards with
1090
          | [] -> Guards.empty, [] (* Nothing *)
1091
          | [s]-> s, [] (* basic case *)
1092
          | hd::tl ->
1093
             let shared = List.fold_left (fun shared s -> Guards.inter shared s) hd tl in
1094
             let remaining = List.map (fun s -> Guards.diff s shared) guards in
1095
             (* If one of them is empty, we can remove the others, otherwise keep them *)
1096
             if List.exists Guards.is_empty remaining then
1097
               shared, []
1098
             else
1099
               shared, remaining
1100
        ) map
1101
    in
1102
  let rec mk_binop op l = match l with
1103
      [] -> assert false
1104
    | [e] -> e
1105
    | hd::tl -> Corelang.mkpredef_call hd.expr_loc op [hd; mk_binop op tl]
1106
  in
1107
  let gl_as_expr gl =
1108
    let gl = Guards.elements gl in
1109
    let export (e,b) = if b then e else Corelang.push_negations ~neg:true e in 
1110
    match gl with
1111
      [] -> []
1112
    | [e] -> [export e]
1113
    | _ ->
1114
       [mk_binop "&&"
1115
          (List.map export gl)]
1116
  in
1117
  let rec clean_disj disj =
1118
    match disj with
1119
    | [] -> []
1120
    | [_] -> assert false (* A disjunction was a single case can be ignored *) 
1121
    | _::_::_ -> (
1122
      (* First basic version: producing a DNF One can later, (1)
1123
           simplify it with z3, or (2) build the compact tree with
1124
           maximum shared subexpression (and simplify it with z3) *)
1125
      let elems = List.fold_left (fun accu gl -> (gl_as_expr gl) @ accu) [] disj in
1126
      let or_expr = mk_binop "||" elems in
1127
      [or_expr]
1128

    
1129

    
1130
    (* TODO disj*)
1131
    (* get the item that occurs in most case *)
1132
    (* List.fold_left (fun accu s ->
1133
     *     List.fold_left (fun accu (e,b) ->
1134
     *         if List.mem_assoc (e.expr_tag, b)
1135
     *       ) accu (Guards.elements s)
1136
     *   ) [] disj *)
1137

    
1138
    )
1139
  in
1140
  if !seal_debug then Format.eprintf "Map: %i elements@ " (UpMap.cardinal map);
1141
  UpMap.fold (fun up (common, disj) accu ->
1142
      if !seal_debug then
1143
        Format.eprintf
1144
          "Guards:@.shared: [%a]@.disj: [@[<v 0>%a@ ]@]@.Updates: %a@."
1145
          Guards.pp_short common
1146
          (fprintf_list ~sep:";@ " Guards.pp_long) disj
1147
          UpMap.pp up;
1148
      let disj = clean_disj disj in
1149
      let guard_expr = (gl_as_expr common)@disj in
1150
      
1151
      ((match guard_expr with
1152
        | [] -> None 
1153
        | _ -> Some (mk_binop "&&" guard_expr)), up)::accu
1154
    ) map []
1155
  
1156
    in
1157
    
1158
    
1159
    
1160
    report ~level:3 (fun fmt -> Format.fprintf fmt "Process sw_init:@.");
1161
    let sw_init = process sw_init in
1162
    report ~level:3 (fun fmt -> Format.fprintf fmt "Process sw_sys:@.");
1163
    let sw_sys = process sw_sys in
1164
    report ~level:3 (fun fmt -> Format.fprintf fmt "Process init_out:@.");
1165
    let init_out = process init_out in
1166
    report ~level:3 (fun fmt -> Format.fprintf fmt "Process update_out:@.");
1167
    let update_out = process update_out in
1168
    sw_init , sw_sys, init_out, update_out