lustrec / src / tools / seal_slice.ml @ 0d79d0f3
History  View  Annotate  Download (4.3 KB)
1 
open Lustre_types 

2 
open Utils 
3 
open Seal_utils 
4 

5 
(******************************************************************************) 
6 
(* Computing a slice of a node, selecting only some variables, based on *) 
7 
(* their COI (cone of influence) *) 
8 
(******************************************************************************) 
9  
10 
(* Basic functions to search into nodes. Could be moved to corelang eventually *) 
11 
let is_variable nd vid = 
12 
List.exists 
13 
(fun v > v.var_id = vid) 
14 
nd.node_locals 
15 

16 
let find_variable nd vid = 
17 
List.find 
18 
(fun v > v.var_id = vid) 
19 
nd.node_locals 
20  
21 
(* Returns the vars required to compute v. 
22 
Memories are specifically identified. *) 
23 
let coi_var deps nd v = 
24 
let vname = v.var_id in 
25 
let sliced_deps = 
26 
Causality.slice_graph deps vname 
27 
in 
28 
(* Format.eprintf "sliced graph for %a: %a@." 
29 
* Printers.pp_var v 
30 
* Causality.pp_dep_graph sliced_deps; *) 
31 
let vset, memset = 
32 
IdentDepGraph.fold_vertex 
33 
(fun vname (vset,memset) > 
34 
if Causality.ExprDep.is_read_var vname 
35 
then 
36 
let vname' = String.sub vname 1 (1 + String.length vname) in 
37 
if is_variable nd vname' then 
38 
ISet.add vname' vset, 
39 
ISet.add vname' memset 
40 
else 
41 
vset, memset 
42 
else 
43 
ISet.add vname vset, memset 
44 
) 
45 
sliced_deps 
46 
(ISet.singleton vname, ISet.empty) 
47 
in 
48 
report ~level:3 
49 
(fun fmt > Format.fprintf fmt "COI of var %s: (%a // %a)@." 
50 
v.var_id 
51 
(fprintf_list ~sep:"," Format.pp_print_string) (ISet.elements vset) 
52 
(fprintf_list ~sep:"," Format.pp_print_string) (ISet.elements memset) 
53 
) ; 
54 
vset, memset 
55 

56 

57 
(* Computes the variables required to compute vl. Variables /seen/ do not need 
58 
to be computed *) 
59 
let rec coi_vars deps nd vl seen = 
60 
let coi_vars = coi_vars deps nd in 
61 
List.fold_left 
62 
(fun accu v > 
63 
let vset, memset = coi_var deps nd v in 
64 
(* We handle the new mems discovered in the coi *) 
65 
let memset = 
66 
ISet.filter ( 
67 
fun vid > 
68 
not 
69 
(List.exists 
70 
(fun v > v.var_id = vid) 
71 
vl 
72 
) 
73 
) memset 
74 
in 
75 
let memset_vars = 
76 
ISet.fold ( 
77 
fun vid accu > 
78 
(find_variable nd vid)::accu 
79 
) memset [] 
80 
in 
81 
let vset' = 
82 
coi_vars memset_vars (vl@seen) 
83 
in 
84 
ISet.union accu (ISet.union vset vset') 
85 
) 
86 
ISet.empty 
87 
(List.filter 
88 
(fun v > not (List.mem v seen)) 
89 
vl 
90 
) 
91 

92 

93 
(* compute the coi of vars_to_keeps in node nd *) 
94 
let compute_sliced_vars vars_to_keep deps nd = 
95 
ISet.elements (coi_vars deps nd vars_to_keep []) 
96  
97  
98  
99  
100  
101 
(* If existing outputs are included in vars_to_keep, just slice the content. 
102 
Otherwise outputs are replaced by vars_to_keep *) 
103 
let slice_node vars_to_keep msch nd = 
104 
let coi_vars = 
105 
compute_sliced_vars vars_to_keep msch.Scheduling_type.dep_graph nd 
106 
in 
107 
report ~level:3 (fun fmt > Format.fprintf fmt 
108 
"COI Vars: %a@." 
109 
(Utils.fprintf_list ~sep:"," Format.pp_print_string) 
110 
coi_vars); 
111 
let outputs = 
112 
List.filter 
113 
( 
114 
fun v > List.mem v.var_id coi_vars 
115 
) nd.node_outputs 
116 
in 
117 
let outputs = match outputs with 
118 
[] > ( 
119 
report ~level:2 (fun fmt > Format.fprintf fmt "No visible output variable, subtituting with provided vars@ "); 
120 
vars_to_keep 
121 
) 
122 
 l > l 
123 
in 
124 
let locals = 
125 
List.filter (fun v > List.mem v.var_id coi_vars) nd.node_locals 
126 
in 
127  
128 
report ~level:3 (fun fmt > Format.fprintf fmt "Scheduling node@."); 
129 

130 
(* Split tuples while sorting eqs *) 
131 
let sorted_eqs = Scheduling.sort_equations_from_schedule nd msch.Scheduling_type.schedule in 
132  
133 
report ~level:3 (fun fmt > Format.fprintf fmt "Scheduled node@."); 
134  
135 
let stmts = 
136 
List.filter ( 
137 
fun (* stmt > 
138 
* match stmt with 
139 
*  Aut _ > assert false 
140 
*  Eq *) eq > ( 
141 
match eq.eq_lhs with 
142 
[vid] > List.mem vid coi_vars 
143 
 _ > Format.eprintf "Faulty statement: %a@.@?" Printers.pp_node_eq eq; assert false 
144 
(* should not happen after inlining and normalization *) 
145 
) 
146 
) sorted_eqs 
147 
in 
148 
{ nd 
149 
with 
150 
node_outputs = outputs; 
151 
node_locals = locals; 
152 
node_stmts = List.map (fun e > Eq e) stmts 
153 
} 