lustrec / src / causality.ml @ a2d97a3e
History  View  Annotate  Download (18 KB)
1 
(********************************************************************) 

2 
(* *) 
3 
(* The LustreC compiler toolset / The LustreC Development Team *) 
4 
(* Copyright 2012   ONERA  CNRS  INPT  LIFL *) 
5 
(* *) 
6 
(* LustreC is free software, distributed WITHOUT ANY WARRANTY *) 
7 
(* under the terms of the GNU Lesser General Public License *) 
8 
(* version 2.1. *) 
9 
(* *) 
10 
(* This file was originally from the Prelude compiler *) 
11 
(* *) 
12 
(********************************************************************) 
13  
14  
15 
(** Simple modular syntactic causality analysis. Can reject correct 
16 
programs, especially if the program is not flattened first. *) 
17 
open Utils 
18 
open LustreSpec 
19 
open Corelang 
20 
open Graph 
21 
open Format 
22  
23 
exception Cycle of ident list 
24  
25 
module IdentDepGraph = Graph.Imperative.Digraph.ConcreteBidirectional (IdentModule) 
26  
27 
(* Dependency of mem variables on mem variables is cut off 
28 
by duplication of some mem vars into local node vars. 
29 
Thus, cylic dependency errors may only arise between nomem vars. 
30 
nonmem variables are: 
31 
 inputs: not needed for causality/scheduling, needed only for detecting useless vars 
32 
 read mems (fake vars): same remark as above. 
33 
 outputs: decoupled from mems, if necessary 
34 
 locals 
35 
 instance vars (fake vars): simplify causality analysis 
36  
37 
global constants are not part of the dependency graph. 
38  
39 
no_mem' = combinational(no_mem, mem); 
40 
=> (mem > no_mem' > no_mem) 
41  
42 
mem' = pre(no_mem, mem); 
43 
=> (mem' > no_mem), (mem > mem') 
44  
45 
Global roadmap: 
46 
 compute two dep graphs g (nonmem/nonmem&mem) and g' (mem/mem) 
47 
 check cycles in g (a cycle means a dependency error) 
48 
 break cycles in g' (it's legal !): 
49 
 check cycles in g' 
50 
 if any, introduce aux var to break cycle, then start afresh 
51 
 insert g' into g 
52 
 return g 
53 
*) 
54  
55 
(* Tests whether [v] is a root of graph [g], i.e. a source *) 
56 
let is_graph_root v g = 
57 
IdentDepGraph.in_degree g v = 0 
58  
59 
(* Computes the set of graph roots, i.e. the sources of acyclic graph [g] *) 
60 
let graph_roots g = 
61 
IdentDepGraph.fold_vertex 
62 
(fun v roots > if is_graph_root v g then v::roots else roots) 
63 
g [] 
64  
65 
let add_edges src tgt g = 
66 
(*List.iter (fun s > List.iter (fun t > Format.eprintf "add %s > %s@." s t) tgt) src;*) 
67 
List.iter 
68 
(fun s > 
69 
List.iter 
70 
(fun t > IdentDepGraph.add_edge g s t) 
71 
tgt) 
72 
src; 
73 
g 
74  
75 
let add_vertices vtc g = 
76 
(*List.iter (fun t > Format.eprintf "add %s@." t) vtc;*) 
77 
List.iter (fun v > IdentDepGraph.add_vertex g v) vtc; 
78 
g 
79  
80 
let new_graph () = 
81 
IdentDepGraph.create () 
82  
83 
module ExprDep = struct 
84  
85 
let instance_var_cpt = ref 0 
86  
87 
(* read vars represent input/mem readonly vars, 
88 
they are not part of the program/schedule, 
89 
as they are not assigned, 
90 
but used to compute useless inputs/mems. 
91 
a mem read var represents a mem at the beginning of a cycle *) 
92 
let mk_read_var id = 
93 
sprintf "#%s" id 
94  
95 
(* instance vars represent node instance calls, 
96 
they are not part of the program/schedule, 
97 
but used to simplify causality analysis 
98 
*) 
99 
let mk_instance_var id = 
100 
incr instance_var_cpt; sprintf "!%s_%d" id !instance_var_cpt 
101  
102 
let is_read_var v = v.[0] = '#' 
103  
104 
let is_instance_var v = v.[0] = '!' 
105  
106 
let is_ghost_var v = is_instance_var v  is_read_var v 
107  
108 
let undo_read_var id = 
109 
assert (is_read_var id); 
110 
String.sub id 1 (String.length id  1) 
111  
112 
let eq_memory_variables mems eq = 
113 
let rec match_mem lhs rhs mems = 
114 
match rhs.expr_desc with 
115 
 Expr_fby _ 
116 
 Expr_pre _ > List.fold_right ISet.add lhs mems 
117 
 Expr_tuple tl > 
118 
let lhs' = (transpose_list [lhs]) in 
119 
List.fold_right2 match_mem lhs' tl mems 
120 
 _ > mems in 
121 
match_mem eq.eq_lhs eq.eq_rhs mems 
122  
123 
let node_memory_variables nd = 
124 
List.fold_left eq_memory_variables ISet.empty nd.node_eqs 
125  
126 
let node_input_variables nd = 
127 
List.fold_left (fun inputs v > ISet.add v.var_id inputs) ISet.empty nd.node_inputs 
128  
129 
let node_local_variables nd = 
130 
List.fold_left (fun locals v > ISet.add v.var_id locals) ISet.empty nd.node_locals 
131  
132 
let node_output_variables nd = 
133 
List.fold_left (fun outputs v > ISet.add v.var_id outputs) ISet.empty nd.node_outputs 
134  
135 
let node_auxiliary_variables nd = 
136 
ISet.diff (node_local_variables nd) (node_memory_variables nd) 
137  
138 
let node_variables nd = 
139 
let inputs = node_input_variables nd in 
140 
let inoutputs = List.fold_left (fun inoutputs v > ISet.add v.var_id inoutputs) inputs nd.node_outputs in 
141 
List.fold_left (fun vars v > ISet.add v.var_id vars) inoutputs nd.node_locals 
142  
143 
(* computes the equivalence relation relating variables 
144 
in the same equation lhs, under the form of a table 
145 
of class representatives *) 
146 
let node_eq_equiv nd = 
147 
let eq_equiv = Hashtbl.create 23 in 
148 
List.iter (fun eq > 
149 
let first = List.hd eq.eq_lhs in 
150 
List.iter (fun v > Hashtbl.add eq_equiv v first) eq.eq_lhs 
151 
) 
152 
nd.node_eqs; 
153 
eq_equiv 
154  
155 
(* Create a tuple of right dimension, according to [expr] type, *) 
156 
(* filled with variable [v] *) 
157 
let adjust_tuple v expr = 
158 
match expr.expr_type.Types.tdesc with 
159 
 Types.Ttuple tl > duplicate v (List.length tl) 
160 
 _ > [v] 
161  
162  
163 
(* Add dependencies from lhs to rhs in [g, g'], *) 
164 
(* nomem/nomem and mem/nomem in g *) 
165 
(* mem/mem in g' *) 
166 
(* match (lhs_is_mem, ISet.mem x mems) with 
167 
 (false, true ) > (add_edges [x] lhs g, 
168 
g') 
169 
 (false, false) > (add_edges lhs [x] g, 
170 
g') 
171 
 (true , false) > (add_edges lhs [x] g, 
172 
g') 
173 
 (true , true ) > (g, 
174 
add_edges [x] lhs g') 
175 
*) 
176 
let add_eq_dependencies mems inputs node_vars eq (g, g') = 
177 
let add_var lhs_is_mem lhs x (g, g') = 
178 
if is_instance_var x  ISet.mem x node_vars then 
179 
if ISet.mem x mems 
180 
then 
181 
let g = add_edges lhs [mk_read_var x] g in 
182 
if lhs_is_mem 
183 
then 
184 
(g, add_edges [x] lhs g') 
185 
else 
186 
(add_edges [x] lhs g, g') 
187 
else 
188 
let x = if ISet.mem x inputs then mk_read_var x else x in 
189 
(add_edges lhs [x] g, g') 
190 
else (g, g') in 
191 
(* Add dependencies from [lhs] to rhs clock [ck]. *) 
192 
let rec add_clock lhs_is_mem lhs ck g = 
193 
(*Format.eprintf "add_clock %a@." Clocks.print_ck ck;*) 
194 
match (Clocks.repr ck).Clocks.cdesc with 
195 
 Clocks.Con (ck', cr, _) > add_var lhs_is_mem lhs (Clocks.const_of_carrier cr) (add_clock lhs_is_mem lhs ck' g) 
196 
 Clocks.Ccarrying (_, ck') > add_clock lhs_is_mem lhs ck' g 
197 
 _ > g 
198 
in 
199 
let rec add_dep lhs_is_mem lhs rhs g = 
200 
(* Add mashup dependencies for a userdefined node instance [lhs] = [f]([e]) *) 
201 
(* i.e every input is connected to every output, through a ghost var *) 
202 
let mashup_appl_dependencies f e g = 
203 
let f_var = mk_instance_var (sprintf "%s_%d" f eq.eq_loc.Location.loc_start.Lexing.pos_lnum) in 
204 
List.fold_right (fun rhs > add_dep lhs_is_mem (adjust_tuple f_var rhs) rhs) 
205 
(expr_list_of_expr e) (add_var lhs_is_mem lhs f_var g) 
206 
in 
207 
match rhs.expr_desc with 
208 
 Expr_const _ > g 
209 
 Expr_fby (e1, e2) > add_dep true lhs e2 (add_dep false lhs e1 g) 
210 
 Expr_pre e > add_dep true lhs e g 
211 
 Expr_ident x > add_var lhs_is_mem lhs x (add_clock lhs_is_mem lhs rhs.expr_clock g) 
212 
 Expr_access (e1, _) 
213 
 Expr_power (e1, _) > add_dep lhs_is_mem lhs e1 g 
214 
 Expr_array a > List.fold_right (add_dep lhs_is_mem lhs) a g 
215 
 Expr_tuple t > 
216 
(* 
217 
if List.length t <> List.length lhs then ( 
218 
match lhs with 
219 
 [l] > List.fold_right (fun r > add_dep lhs_is_mem [l] r) t g 
220 
 _ > 
221 
Format.eprintf "Incompatible tuple assign: %a (%i) vs %a (%i)@.@?" 
222 
(Utils.fprintf_list ~sep:"," (Format.pp_print_string)) lhs 
223 
(List.length lhs) 
224 
Printers.pp_expr rhs 
225 
(List.length t) 
226 
; 
227 
assert false 
228 
) 
229 
else 
230 
*) 
231 
List.fold_right2 (fun l r > add_dep lhs_is_mem [l] r) lhs t g 
232 
 Expr_merge (c, hl) > add_var lhs_is_mem lhs c (List.fold_right (fun (_, h) > add_dep lhs_is_mem lhs h) hl g) 
233 
 Expr_ite (c, t, e) > add_dep lhs_is_mem lhs c (add_dep lhs_is_mem lhs t (add_dep lhs_is_mem lhs e g)) 
234 
 Expr_arrow (e1, e2) > add_dep lhs_is_mem lhs e2 (add_dep lhs_is_mem lhs e1 g) 
235 
 Expr_when (e, c, _) > add_dep lhs_is_mem lhs e (add_var lhs_is_mem lhs c g) 
236 
 Expr_appl (f, e, None) > 
237 
if Basic_library.is_internal_fun f 
238 
(* tuple componentwise dependency for internal operators *) 
239 
then 
240 
List.fold_right (add_dep lhs_is_mem lhs) (expr_list_of_expr e) g 
241 
(* mashed up dependency for userdefined operators *) 
242 
else 
243 
mashup_appl_dependencies f e g 
244 
 Expr_appl (f, e, Some (r, _)) > 
245 
mashup_appl_dependencies f e (add_var lhs_is_mem lhs r g) 
246 
in 
247 
let g = 
248 
List.fold_left 
249 
(fun g lhs > if ISet.mem lhs mems then add_vertices [lhs; mk_read_var lhs] g else add_vertices [lhs] g) g eq.eq_lhs in 
250 
add_dep false eq.eq_lhs eq.eq_rhs (g, g') 
251 

252  
253 
(* Returns the dependence graph for node [n] *) 
254 
let dependence_graph mems inputs node_vars n = 
255 
instance_var_cpt := 0; 
256 
let g = new_graph (), new_graph () in 
257 
(* Basic dependencies *) 
258 
let g = List.fold_right (add_eq_dependencies mems inputs node_vars) n.node_eqs g in 
259 
g 
260  
261 
end 
262  
263 
module NodeDep = struct 
264  
265 
module ExprModule = 
266 
struct 
267 
type t = expr 
268 
let compare = compare 
269 
let hash n = Hashtbl.hash n 
270 
let equal n1 n2 = n1 = n2 
271 
end 
272  
273 
module ESet = Set.Make(ExprModule) 
274  
275 
let rec get_expr_calls prednode expr = 
276 
match expr.expr_desc with 
277 
 Expr_const _ 
278 
 Expr_ident _ > ESet.empty 
279 
 Expr_access (e, _) 
280 
 Expr_power (e, _) > get_expr_calls prednode e 
281 
 Expr_array t 
282 
 Expr_tuple t > List.fold_right (fun x set > ESet.union (get_expr_calls prednode x) set) t ESet.empty 
283 
 Expr_merge (_,hl) > List.fold_right (fun (_,h) set > ESet.union (get_expr_calls prednode h) set) hl ESet.empty 
284 
 Expr_fby (e1,e2) 
285 
 Expr_arrow (e1,e2) > ESet.union (get_expr_calls prednode e1) (get_expr_calls prednode e2) 
286 
 Expr_ite (c, t, e) > ESet.union (get_expr_calls prednode c) (ESet.union (get_expr_calls prednode t) (get_expr_calls prednode e)) 
287 
 Expr_pre e 
288 
 Expr_when (e,_,_) > get_expr_calls prednode e 
289 
 Expr_appl (id,e, _) > 
290 
if not (Basic_library.is_internal_fun id) && prednode id 
291 
then ESet.add expr (get_expr_calls prednode e) 
292 
else (get_expr_calls prednode e) 
293  
294 
let get_callee expr = 
295 
match expr.expr_desc with 
296 
 Expr_appl (id, args, _) > Some (id, expr_list_of_expr args) 
297 
 _ > None 
298  
299 
let get_calls prednode eqs = 
300 
let deps = 
301 
List.fold_left 
302 
(fun accu eq > ESet.union accu (get_expr_calls prednode eq.eq_rhs)) 
303 
ESet.empty 
304 
eqs in 
305 
ESet.elements deps 
306  
307 
let dependence_graph prog = 
308 
let g = new_graph () in 
309 
let g = List.fold_right 
310 
(fun td accu > (* for each node we add its dependencies *) 
311 
match td.top_decl_desc with 
312 
 Node nd > 
313 
(*Format.eprintf "Computing deps of node %s@.@?" nd.node_id; *) 
314 
let accu = add_vertices [nd.node_id] accu in 
315 
let deps = List.map (fun e > fst (desome (get_callee e))) (get_calls (fun _ > true) nd.node_eqs) in 
316 
(*Format.eprintf "%a@.@?" (Utils.fprintf_list ~sep:"@." Format.pp_print_string) deps; *) 
317 
add_edges [nd.node_id] deps accu 
318 
 _ > assert false (* should not happen *) 
319 

320 
) prog g in 
321 
g 
322  
323 
let rec filter_static_inputs inputs args = 
324 
match inputs, args with 
325 
 [] , [] > [] 
326 
 v::vq, a::aq > if v.var_dec_const then (dimension_of_expr a) :: filter_static_inputs vq aq else filter_static_inputs vq aq 
327 
 _ > assert false 
328  
329 
let compute_generic_calls prog = 
330 
List.iter 
331 
(fun td > 
332 
match td.top_decl_desc with 
333 
 Node nd > 
334 
let prednode n = is_generic_node (Hashtbl.find node_table n) in 
335 
nd.node_gencalls < get_calls prednode nd.node_eqs 
336 
 _ > () 
337 

338 
) prog 
339  
340 
end 
341  
342 
module CycleDetection = struct 
343  
344 
(*  Look for cycles in a dependency graph *) 
345 
module Cycles = Graph.Components.Make (IdentDepGraph) 
346  
347 
let mk_copy_var n id = 
348 
mk_new_name (get_node_vars n) id 
349  
350 
let mk_copy_eq n var = 
351 
let var_decl = get_node_var var n in 
352 
let cp_var = mk_copy_var n var in 
353 
let expr = 
354 
{ expr_tag = Utils.new_tag (); 
355 
expr_desc = Expr_ident var; 
356 
expr_type = var_decl.var_type; 
357 
expr_clock = var_decl.var_clock; 
358 
expr_delay = Delay.new_var (); 
359 
expr_annot = None; 
360 
expr_loc = var_decl.var_loc } in 
361 
{ var_decl with var_id = cp_var }, 
362 
mkeq var_decl.var_loc ([cp_var], expr) 
363  
364 
let wrong_partition g partition = 
365 
match partition with 
366 
 [id] > IdentDepGraph.mem_edge g id id 
367 
 _::_::_ > true 
368 
 [] > assert false 
369  
370 
(* Checks that the dependency graph [g] does not contain a cycle. Raises 
371 
[Cycle partition] if the succession of dependencies [partition] forms a cycle *) 
372 
let check_cycles g = 
373 
let scc_l = Cycles.scc_list g in 
374 
List.iter (fun partition > 
375 
if wrong_partition g partition then 
376 
raise (Cycle partition) 
377 
else () 
378 
) scc_l 
379  
380 
(* Creates the subgraph of [g] restricted to vertices and edges in partition *) 
381 
let copy_partition g partition = 
382 
let copy_g = IdentDepGraph.create () in 
383 
IdentDepGraph.iter_edges 
384 
(fun src tgt > 
385 
if List.mem src partition && List.mem tgt partition 
386 
then IdentDepGraph.add_edge copy_g src tgt) 
387 
g 
388  
389 

390 
(* Breaks dependency cycles in a graph [g] by inserting aux variables. 
391 
[head] is a head of a nontrivial scc of [g]. 
392 
In Lustre, this is legal only for mem/mem cycles *) 
393 
let break_cycle head cp_head g = 
394 
let succs = IdentDepGraph.succ g head in 
395 
IdentDepGraph.add_edge g head cp_head; 
396 
IdentDepGraph.add_edge g cp_head (ExprDep.mk_read_var head); 
397 
List.iter 
398 
(fun s > 
399 
IdentDepGraph.remove_edge g head s; 
400 
IdentDepGraph.add_edge g s cp_head) 
401 
succs 
402  
403 
(* Breaks cycles of the dependency graph [g] of memory variables [mems] 
404 
belonging in node [node]. Returns: 
405 
 a list of new auxiliary variable declarations 
406 
 a list of new equations 
407 
 a modified acyclic version of [g] 
408 
*) 
409 
let break_cycles node mems g = 
410 
let (mem_eqs, non_mem_eqs) = List.partition (fun eq > List.exists (fun v > ISet.mem v mems) eq.eq_lhs) node.node_eqs in 
411 
let rec break vdecls mem_eqs g = 
412 
let scc_l = Cycles.scc_list g in 
413 
let wrong = List.filter (wrong_partition g) scc_l in 
414 
match wrong with 
415 
 [] > (vdecls, non_mem_eqs@mem_eqs, g) 
416 
 [head]::_ > 
417 
begin 
418 
IdentDepGraph.remove_edge g head head; 
419 
break vdecls mem_eqs g 
420 
end 
421 
 (head::part)::_ > 
422 
begin 
423 
let vdecl_cp_head, cp_eq = mk_copy_eq node head in 
424 
let pvar v = List.mem v part in 
425 
let fvar v = if v = head then vdecl_cp_head.var_id else v in 
426 
let mem_eqs' = List.map (eq_replace_rhs_var pvar fvar) mem_eqs in 
427 
break_cycle head vdecl_cp_head.var_id g; 
428 
break (vdecl_cp_head::vdecls) (cp_eq::mem_eqs') g 
429 
end 
430 
 _ > assert false 
431 
in break [] mem_eqs g 
432  
433 
end 
434  
435 
(* Module used to compute static disjunction of variables based upon their clocks. *) 
436 
module Disjunction = 
437 
struct 
438 
module ClockedIdentModule = 
439 
struct 
440 
type t = var_decl 
441 
let root_branch vdecl = Clocks.root vdecl.var_clock, Clocks.branch vdecl.var_clock 
442 
let compare v1 v2 = compare (root_branch v2) (root_branch v1) 
443 
end 
444  
445 
module CISet = Set.Make(ClockedIdentModule) 
446  
447 
(* map: var > list of disjoint vars, sorted in increasing branch length order, 
448 
maybe removing shorter branches *) 
449 
type disjoint_map = (ident, CISet.t) Hashtbl.t 
450  
451 
let pp_ciset fmt t = 
452 
begin 
453 
Format.fprintf fmt "{@ "; 
454 
CISet.iter (fun s > Format.fprintf fmt "%a@ " Printers.pp_var_name s) t; 
455 
Format.fprintf fmt "}@." 
456 
end 
457  
458 
let clock_disjoint_map vdecls = 
459 
let map = Hashtbl.create 23 in 
460 
begin 
461 
List.iter 
462 
(fun v1 > let disj_v1 = 
463 
List.fold_left 
464 
(fun res v2 > if Clocks.disjoint v1.var_clock v2.var_clock then CISet.add v2 res else res) 
465 
CISet.empty 
466 
vdecls in 
467 
(* disjoint vdecls are stored in increasing branch length order *) 
468 
Hashtbl.add map v1.var_id disj_v1) 
469 
vdecls; 
470 
(map : disjoint_map) 
471 
end 
472  
473 
(* replace variable [v] by [v'] in disjunction [map]. Then: 
474 
 the mapping v' becomes v' > (map v) inter (map v') 
475 
 the mapping v > ... then disappears 
476 
 other mappings become x > (map x) \ (if v in x then v else v') 
477 
*) 
478 
let replace_in_disjoint_map map v v' = 
479 
begin 
480 
Hashtbl.replace map v'.var_id (CISet.inter (Hashtbl.find map v.var_id) (Hashtbl.find map v'.var_id)); 
481 
Hashtbl.remove map v.var_id; 
482 
Hashtbl.iter (fun x map_x > Hashtbl.replace map x (CISet.remove (if CISet.mem v map_x then v else v') map_x)) map; 
483 
end 
484  
485 
let pp_disjoint_map fmt map = 
486 
begin 
487 
Format.fprintf fmt "{ /* disjoint map */@."; 
488 
Hashtbl.iter (fun k v > Format.fprintf fmt "%s # { %a }@." k (Utils.fprintf_list ~sep:", " Printers.pp_var_name) (CISet.elements v)) map; 
489 
Format.fprintf fmt "}@." 
490 
end 
491 
end 
492  
493 
let pp_dep_graph fmt g = 
494 
begin 
495 
Format.fprintf fmt "{ /* graph */@."; 
496 
IdentDepGraph.iter_edges (fun s t > Format.fprintf fmt "%s > %s@." s t) g; 
497 
Format.fprintf fmt "}@." 
498 
end 
499  
500 
let pp_error fmt trace = 
501 
fprintf fmt "@.Causality error, cyclic data dependencies: %a@." 
502 
(fprintf_list ~sep:">" pp_print_string) trace 
503  
504 
(* Merges elements of graph [g2] into graph [g1] *) 
505 
let merge_with g1 g2 = 
506 
IdentDepGraph.iter_vertex (fun v > IdentDepGraph.add_vertex g1 v) g2; 
507 
IdentDepGraph.iter_edges (fun s t > IdentDepGraph.add_edge g1 s t) g2 
508  
509 
let global_dependency node = 
510 
let mems = ExprDep.node_memory_variables node in 
511 
let inputs = ExprDep.node_input_variables node in 
512 
let node_vars = ExprDep.node_variables node in 
513 
let (g_non_mems, g_mems) = ExprDep.dependence_graph mems inputs node_vars node in 
514 
(*Format.eprintf "g_non_mems: %a" pp_dep_graph g_non_mems; 
515 
Format.eprintf "g_mems: %a" pp_dep_graph g_mems;*) 
516 
CycleDetection.check_cycles g_non_mems; 
517 
let (vdecls', eqs', g_mems') = CycleDetection.break_cycles node mems g_mems in 
518 
(*Format.eprintf "g_mems': %a" pp_dep_graph g_mems';*) 
519 
merge_with g_non_mems g_mems'; 
520 
{ node with node_eqs = eqs'; node_locals = vdecls'@node.node_locals }, 
521 
g_non_mems 
522  
523  
524 
(* Local Variables: *) 
525 
(* compilecommand:"make C .." *) 
526 
(* End: *) 