lustrec / src / inliner.ml @ 53206908
History | View | Annotate | Download (16.1 KB)
1 |
(********************************************************************) |
---|---|
2 |
(* *) |
3 |
(* The LustreC compiler toolset / The LustreC Development Team *) |
4 |
(* Copyright 2012 - -- ONERA - CNRS - INPT *) |
5 |
(* *) |
6 |
(* LustreC is free software, distributed WITHOUT ANY WARRANTY *) |
7 |
(* under the terms of the GNU Lesser General Public License *) |
8 |
(* version 2.1. *) |
9 |
(* *) |
10 |
(********************************************************************) |
11 |
|
12 |
open LustreSpec |
13 |
open Corelang |
14 |
open Utils |
15 |
|
16 |
(* Local annotations are declared with the following key /inlining/: true *) |
17 |
let keyword = ["inlining"] |
18 |
|
19 |
let is_inline_expr expr = |
20 |
match expr.expr_annot with |
21 |
| Some ann -> |
22 |
List.exists (fun (key, value) -> key = keyword) ann.annots |
23 |
| None -> false |
24 |
|
25 |
let check_node_name id = (fun t -> |
26 |
match t.top_decl_desc with |
27 |
| Node nd -> nd.node_id = id |
28 |
| _ -> false) |
29 |
|
30 |
let is_node_var node v = |
31 |
try |
32 |
ignore (Corelang.get_node_var v node); true |
33 |
with Not_found -> false |
34 |
|
35 |
let rename_expr rename expr = expr_replace_var rename expr |
36 |
|
37 |
let rename_eq rename eq = |
38 |
{ eq with |
39 |
eq_lhs = List.map rename eq.eq_lhs; |
40 |
eq_rhs = rename_expr rename eq.eq_rhs |
41 |
} |
42 |
(* |
43 |
let get_static_inputs input_arg_list = |
44 |
List.fold_right (fun (vdecl, arg) res -> |
45 |
if vdecl.var_dec_const |
46 |
then (vdecl.var_id, Corelang.dimension_of_expr arg) :: res |
47 |
else res) |
48 |
input_arg_list [] |
49 |
|
50 |
let get_carrier_inputs input_arg_list = |
51 |
List.fold_right (fun (vdecl, arg) res -> |
52 |
if Corelang.is_clock_dec_type vdecl.var_dec_type.ty_dec_desc |
53 |
then (vdecl.var_id, ident_of_expr arg) :: res |
54 |
else res) |
55 |
input_arg_list [] |
56 |
*) |
57 |
(* |
58 |
expr, locals', eqs = inline_call id args' reset locals node nodes |
59 |
|
60 |
We select the called node equations and variables. |
61 |
renamed_inputs = args |
62 |
renamed_eqs |
63 |
|
64 |
the resulting expression is tuple_of_renamed_outputs |
65 |
|
66 |
TODO: convert the specification/annotation/assert and inject them |
67 |
DONE: annotations |
68 |
TODO: deal with reset |
69 |
*) |
70 |
let inline_call node orig_expr args reset locals caller = |
71 |
let loc = orig_expr.expr_loc in |
72 |
let uid = orig_expr.expr_tag in |
73 |
let rename v = |
74 |
if v = tag_true || v = tag_false || not (is_node_var node v) then v else |
75 |
Corelang.mk_new_node_name caller (Format.sprintf "%s_%i_%s" node.node_id uid v) |
76 |
in |
77 |
let eqs' = List.map (rename_eq rename) (get_node_eqs node) in |
78 |
let input_arg_list = List.combine node.node_inputs (Corelang.expr_list_of_expr args) in |
79 |
let static_inputs, dynamic_inputs = List.partition (fun (vdecl, arg) -> vdecl.var_dec_const) input_arg_list in |
80 |
let static_inputs = List.map (fun (vdecl, arg) -> vdecl, Corelang.dimension_of_expr arg) static_inputs in |
81 |
let carrier_inputs, other_inputs = List.partition (fun (vdecl, arg) -> Corelang.is_clock_dec_type vdecl.var_dec_type.ty_dec_desc) dynamic_inputs in |
82 |
let carrier_inputs = List.map (fun (vdecl, arg) -> vdecl, Corelang.ident_of_expr arg) carrier_inputs in |
83 |
let rename_static v = |
84 |
try |
85 |
snd (List.find (fun (vdecl, _) -> v = vdecl.var_id) static_inputs) |
86 |
with Not_found -> Dimension.mkdim_ident loc v in |
87 |
let rename_carrier v = |
88 |
try |
89 |
snd (List.find (fun (vdecl, _) -> v = vdecl.var_id) carrier_inputs) |
90 |
with Not_found -> v in |
91 |
let rename_var v = |
92 |
let vdecl = |
93 |
Corelang.mkvar_decl v.var_loc |
94 |
(rename v.var_id, |
95 |
{ v.var_dec_type with ty_dec_desc = Corelang.rename_static rename_static v.var_dec_type.ty_dec_desc }, |
96 |
{ v.var_dec_clock with ck_dec_desc = Corelang.rename_carrier rename_carrier v.var_dec_clock.ck_dec_desc }, |
97 |
v.var_dec_const, |
98 |
Utils.option_map (rename_expr rename) v.var_dec_value) in |
99 |
begin |
100 |
(* |
101 |
(try |
102 |
Format.eprintf "Inliner.inline_call unify %a %a@." Types.print_ty vdecl.var_type Dimension.pp_dimension (List.assoc v.var_id static_inputs); |
103 |
Typing.unify vdecl.var_type (Type_predef.type_static (List.assoc v.var_id static_inputs) (Types.new_var ())) |
104 |
with Not_found -> ()); |
105 |
(try |
106 |
Clock_calculus.unify vdecl.var_clock (Clock_predef.ck_carrier (List.assoc v.var_id carrier_inputs) (Clocks.new_var true)) |
107 |
with Not_found -> ()); |
108 |
(*Format.eprintf "Inliner.inline_call res=%a@." Printers.pp_var vdecl;*) |
109 |
*) |
110 |
vdecl |
111 |
end |
112 |
(*Format.eprintf "Inliner.rename_var %a@." Printers.pp_var v;*) |
113 |
in |
114 |
let inputs' = List.map (fun (vdecl, _) -> rename_var vdecl) dynamic_inputs in |
115 |
let outputs' = List.map rename_var node.node_outputs in |
116 |
let locals' = |
117 |
(List.map (fun (vdecl, arg) -> let vdecl' = rename_var vdecl in { vdecl' with var_dec_value = Some (Corelang.expr_of_dimension arg) }) static_inputs) |
118 |
@ (List.map rename_var node.node_locals) |
119 |
in |
120 |
(* checking we are at the appropriate (early) step: node_checks and |
121 |
node_gencalls should be empty (not yet assigned) *) |
122 |
assert (node.node_checks = []); |
123 |
assert (node.node_gencalls = []); |
124 |
|
125 |
(* Bug included: todo deal with reset *) |
126 |
assert (reset = None); |
127 |
|
128 |
let assign_inputs = mkeq loc (List.map (fun v -> v.var_id) inputs', expr_of_expr_list args.expr_loc (List.map snd dynamic_inputs)) in |
129 |
let expr = expr_of_expr_list loc (List.map expr_of_vdecl outputs') |
130 |
in |
131 |
let asserts' = (* We rename variables in assert expressions *) |
132 |
List.map |
133 |
(fun a -> |
134 |
{a with assert_expr = |
135 |
let expr = a.assert_expr in |
136 |
rename_expr rename expr |
137 |
}) |
138 |
node.node_asserts |
139 |
in |
140 |
let annots' = |
141 |
Plugins.inline_annots rename node.node_annot |
142 |
in |
143 |
expr, |
144 |
inputs'@outputs'@locals'@locals, |
145 |
assign_inputs::eqs', |
146 |
asserts', |
147 |
annots' |
148 |
|
149 |
|
150 |
|
151 |
let inline_table = Hashtbl.create 23 |
152 |
|
153 |
(* |
154 |
new_expr, new_locals, new_eqs = inline_expr expr locals node nodes |
155 |
|
156 |
Each occurence of a node in nodes in the expr should be replaced by fresh |
157 |
variables and the code of called node instance added to new_eqs |
158 |
|
159 |
*) |
160 |
let rec inline_expr ?(selection_on_annotation=false) expr locals node nodes = |
161 |
let inline_expr = inline_expr ~selection_on_annotation:selection_on_annotation in |
162 |
let inline_node = inline_node ~selection_on_annotation:selection_on_annotation in |
163 |
let inline_tuple el = |
164 |
List.fold_right (fun e (el_tail, locals, eqs, asserts, annots) -> |
165 |
let e', locals', eqs', asserts', annots' = inline_expr e locals node nodes in |
166 |
e'::el_tail, locals', eqs'@eqs, asserts@asserts', annots@annots' |
167 |
) el ([], locals, [], [], []) |
168 |
in |
169 |
let inline_pair e1 e2 = |
170 |
let el', l', eqs', asserts', annots' = inline_tuple [e1;e2] in |
171 |
match el' with |
172 |
| [e1'; e2'] -> e1', e2', l', eqs', asserts', annots' |
173 |
| _ -> assert false |
174 |
in |
175 |
let inline_triple e1 e2 e3 = |
176 |
let el', l', eqs', asserts', annots' = inline_tuple [e1;e2;e3] in |
177 |
match el' with |
178 |
| [e1'; e2'; e3'] -> e1', e2', e3', l', eqs', asserts', annots' |
179 |
| _ -> assert false |
180 |
in |
181 |
|
182 |
match expr.expr_desc with |
183 |
| Expr_appl (id, args, reset) -> |
184 |
let args', locals', eqs', asserts', annots' = inline_expr args locals node nodes in |
185 |
if List.exists (check_node_name id) nodes && (* the current node call is provided |
186 |
as arguments nodes *) |
187 |
(not selection_on_annotation || is_inline_expr expr) (* and if selection on annotation is activated, |
188 |
it is explicitely inlined here *) |
189 |
then |
190 |
(* The node should be inlined *) |
191 |
(* let _ = Format.eprintf "Inlining call to %s@." id in *) |
192 |
let called = try List.find (check_node_name id) nodes |
193 |
with Not_found -> (assert false) in |
194 |
let called = node_of_top called in |
195 |
let called' = inline_node called nodes in |
196 |
let expr, locals', eqs'', asserts'', annots'' = |
197 |
inline_call called' expr args' reset locals' node in |
198 |
expr, locals', eqs'@eqs'', asserts'@asserts'', annots'@annots'' |
199 |
else |
200 |
(* let _ = Format.eprintf "Not inlining call to %s@." id in *) |
201 |
{ expr with expr_desc = Expr_appl(id, args', reset)}, |
202 |
locals', |
203 |
eqs', |
204 |
asserts', |
205 |
annots' |
206 |
|
207 |
(* For other cases, we just keep the structure, but convert sub-expressions *) |
208 |
| Expr_const _ |
209 |
| Expr_ident _ -> expr, locals, [], [], [] |
210 |
| Expr_tuple el -> |
211 |
let el', l', eqs', asserts', annots' = inline_tuple el in |
212 |
{ expr with expr_desc = Expr_tuple el' }, l', eqs', asserts', annots' |
213 |
| Expr_ite (g, t, e) -> |
214 |
let g', t', e', l', eqs', asserts', annots' = inline_triple g t e in |
215 |
{ expr with expr_desc = Expr_ite (g', t', e') }, l', eqs', asserts', annots' |
216 |
| Expr_arrow (e1, e2) -> |
217 |
let e1', e2', l', eqs', asserts', annots' = inline_pair e1 e2 in |
218 |
{ expr with expr_desc = Expr_arrow (e1', e2') } , l', eqs', asserts', annots' |
219 |
| Expr_fby (e1, e2) -> |
220 |
let e1', e2', l', eqs', asserts', annots' = inline_pair e1 e2 in |
221 |
{ expr with expr_desc = Expr_fby (e1', e2') }, l', eqs', asserts', annots' |
222 |
| Expr_array el -> |
223 |
let el', l', eqs', asserts', annots' = inline_tuple el in |
224 |
{ expr with expr_desc = Expr_array el' }, l', eqs', asserts', annots' |
225 |
| Expr_access (e, dim) -> |
226 |
let e', l', eqs', asserts', annots' = inline_expr e locals node nodes in |
227 |
{ expr with expr_desc = Expr_access (e', dim) }, l', eqs', asserts', annots' |
228 |
| Expr_power (e, dim) -> |
229 |
let e', l', eqs', asserts', annots' = inline_expr e locals node nodes in |
230 |
{ expr with expr_desc = Expr_power (e', dim) }, l', eqs', asserts', annots' |
231 |
| Expr_pre e -> |
232 |
let e', l', eqs', asserts', annots' = inline_expr e locals node nodes in |
233 |
{ expr with expr_desc = Expr_pre e' }, l', eqs', asserts', annots' |
234 |
| Expr_when (e, id, label) -> |
235 |
let e', l', eqs', asserts', annots' = inline_expr e locals node nodes in |
236 |
{ expr with expr_desc = Expr_when (e', id, label) }, l', eqs', asserts', annots' |
237 |
| Expr_merge (id, branches) -> |
238 |
let el, l', eqs', asserts', annots' = inline_tuple (List.map snd branches) in |
239 |
let branches' = List.map2 (fun (label, _) v -> label, v) branches el in |
240 |
{ expr with expr_desc = Expr_merge (id, branches') }, l', eqs', asserts', annots' |
241 |
|
242 |
and inline_node ?(selection_on_annotation=false) node nodes = |
243 |
try copy_node (Hashtbl.find inline_table node.node_id) |
244 |
with Not_found -> |
245 |
let inline_expr = inline_expr ~selection_on_annotation:selection_on_annotation in |
246 |
let new_locals, eqs, asserts, annots = |
247 |
List.fold_left (fun (locals, eqs, asserts, annots) eq -> |
248 |
let eq_rhs', locals', new_eqs', asserts', annots' = |
249 |
inline_expr eq.eq_rhs locals node nodes |
250 |
in |
251 |
locals', { eq with eq_rhs = eq_rhs' }::new_eqs'@eqs, asserts'@asserts, annots'@annots |
252 |
) (node.node_locals, [], node.node_asserts, node.node_annot) (get_node_eqs node) |
253 |
in |
254 |
let inlined = |
255 |
{ node with |
256 |
node_locals = new_locals; |
257 |
node_stmts = List.map (fun eq -> Eq eq) eqs; |
258 |
node_asserts = asserts; |
259 |
node_annot = annots; |
260 |
} |
261 |
in |
262 |
begin |
263 |
(*Format.eprintf "inline node:<< %a@.>>@." Printers.pp_node inlined;*) |
264 |
Hashtbl.add inline_table node.node_id inlined; |
265 |
inlined |
266 |
end |
267 |
|
268 |
let inline_all_calls node nodes = |
269 |
let nd = match node.top_decl_desc with Node nd -> nd | _ -> assert false in |
270 |
{ node with top_decl_desc = Node (inline_node nd nodes) } |
271 |
|
272 |
|
273 |
|
274 |
|
275 |
|
276 |
let witness filename main_name orig inlined type_env clock_env = |
277 |
let loc = Location.dummy_loc in |
278 |
let rename_local_node nodes prefix id = |
279 |
if List.exists (check_node_name id) nodes then |
280 |
prefix ^ id |
281 |
else |
282 |
id |
283 |
in |
284 |
let main_orig_node = match (List.find (check_node_name main_name) orig).top_decl_desc with |
285 |
Node nd -> nd | _ -> assert false in |
286 |
|
287 |
let orig_rename = rename_local_node orig "orig_" in |
288 |
let inlined_rename = rename_local_node inlined "inlined_" in |
289 |
let identity = (fun x -> x) in |
290 |
let is_node top = match top.top_decl_desc with Node _ -> true | _ -> false in |
291 |
let orig = rename_prog orig_rename identity identity orig in |
292 |
let inlined = rename_prog inlined_rename identity identity inlined in |
293 |
let nodes_origs, others = List.partition is_node orig in |
294 |
let nodes_inlined, _ = List.partition is_node inlined in |
295 |
|
296 |
(* One ok_i boolean variable per output var *) |
297 |
let nb_outputs = List.length main_orig_node.node_outputs in |
298 |
let ok_ident = "OK" in |
299 |
let ok_i = List.map (fun id -> |
300 |
mkvar_decl |
301 |
loc |
302 |
(Format.sprintf "%s_%i" ok_ident id, |
303 |
{ty_dec_desc=Tydec_bool; ty_dec_loc=loc}, |
304 |
{ck_dec_desc=Ckdec_any; ck_dec_loc=loc}, |
305 |
false, |
306 |
None) |
307 |
) (Utils.enumerate nb_outputs) |
308 |
in |
309 |
|
310 |
(* OK = ok_1 and ok_2 and ... ok_n-1 *) |
311 |
let ok_output = mkvar_decl |
312 |
loc |
313 |
(ok_ident, |
314 |
{ty_dec_desc=Tydec_bool; ty_dec_loc=loc}, |
315 |
{ck_dec_desc=Ckdec_any; ck_dec_loc=loc}, |
316 |
false, |
317 |
None) |
318 |
in |
319 |
let main_ok_expr = |
320 |
let mkv x = mkexpr loc (Expr_ident x) in |
321 |
match ok_i with |
322 |
| [] -> assert false |
323 |
| [x] -> mkv x.var_id |
324 |
| hd::tl -> |
325 |
List.fold_left (fun accu elem -> |
326 |
mkpredef_call loc "&&" [mkv elem.var_id; accu] |
327 |
) (mkv hd.var_id) tl |
328 |
in |
329 |
|
330 |
(* Building main node *) |
331 |
|
332 |
let ok_i_eq = |
333 |
{ eq_loc = loc; |
334 |
eq_lhs = List.map (fun v -> v.var_id) ok_i; |
335 |
eq_rhs = |
336 |
let inputs = expr_of_expr_list loc (List.map (fun v -> mkexpr loc (Expr_ident v.var_id)) main_orig_node.node_inputs) in |
337 |
let call_orig = |
338 |
mkexpr loc (Expr_appl ("orig_" ^ main_name, inputs, None)) in |
339 |
let call_inlined = |
340 |
mkexpr loc (Expr_appl ("inlined_" ^ main_name, inputs, None)) in |
341 |
let args = mkexpr loc (Expr_tuple [call_orig; call_inlined]) in |
342 |
mkexpr loc (Expr_appl ("=", args, None)) |
343 |
} in |
344 |
let ok_eq = |
345 |
{ eq_loc = loc; |
346 |
eq_lhs = [ok_ident]; |
347 |
eq_rhs = main_ok_expr; |
348 |
} in |
349 |
let main_node = { |
350 |
node_id = "check"; |
351 |
node_type = Types.new_var (); |
352 |
node_clock = Clocks.new_var true; |
353 |
node_inputs = main_orig_node.node_inputs; |
354 |
node_outputs = [ok_output]; |
355 |
node_locals = ok_i; |
356 |
node_gencalls = []; |
357 |
node_checks = []; |
358 |
node_asserts = []; |
359 |
node_stmts = [Eq ok_i_eq; Eq ok_eq]; |
360 |
node_dec_stateless = false; |
361 |
node_stateless = None; |
362 |
node_spec = Some |
363 |
{requires = []; |
364 |
ensures = [mkeexpr loc (mkexpr loc (Expr_ident ok_ident))]; |
365 |
behaviors = []; |
366 |
spec_loc = loc |
367 |
}; |
368 |
node_annot = []; |
369 |
} |
370 |
in |
371 |
let main = [{ top_decl_desc = Node main_node; top_decl_loc = loc; top_decl_owner = filename; top_decl_itf = false }] in |
372 |
let new_prog = others@nodes_origs@nodes_inlined@main in |
373 |
(* |
374 |
let _ = Typing.type_prog type_env new_prog in |
375 |
let _ = Clock_calculus.clock_prog clock_env new_prog in |
376 |
*) |
377 |
|
378 |
let witness_file = (Options.get_witness_dir filename) ^ "/" ^ "inliner_witness.lus" in |
379 |
let witness_out = open_out witness_file in |
380 |
let witness_fmt = Format.formatter_of_out_channel witness_out in |
381 |
begin |
382 |
List.iter (fun vdecl -> Typing.try_unify Type_predef.type_bool vdecl.var_type vdecl.var_loc) (ok_output::ok_i); |
383 |
Format.fprintf witness_fmt |
384 |
"(* Generated lustre file to check validity of inlining process *)@."; |
385 |
Printers.pp_prog witness_fmt new_prog; |
386 |
Format.fprintf witness_fmt "@."; |
387 |
() |
388 |
end (* xx *) |
389 |
|
390 |
let global_inline basename prog type_env clock_env = |
391 |
(* We select the main node desc *) |
392 |
let main_node, other_nodes, other_tops = |
393 |
List.fold_right |
394 |
(fun top (main_opt, nodes, others) -> |
395 |
match top.top_decl_desc with |
396 |
| Node nd when nd.node_id = !Options.main_node -> |
397 |
Some top, nodes, others |
398 |
| Node _ -> main_opt, top::nodes, others |
399 |
| _ -> main_opt, nodes, top::others) |
400 |
prog (None, [], []) |
401 |
in |
402 |
(* Recursively each call of a node in the top node is replaced *) |
403 |
let main_node = Utils.desome main_node in |
404 |
let main_node' = inline_all_calls main_node other_nodes in |
405 |
let res = main_node'::other_tops in |
406 |
if !Options.witnesses then ( |
407 |
witness |
408 |
basename |
409 |
(match main_node.top_decl_desc with Node nd -> nd.node_id | _ -> assert false) |
410 |
prog res type_env clock_env |
411 |
); |
412 |
res |
413 |
|
414 |
let local_inline basename prog type_env clock_env = |
415 |
let local_anns = Annotations.get_expr_annotations keyword in |
416 |
if local_anns != [] then ( |
417 |
let nodes_with_anns = List.fold_left (fun accu (k, _) -> ISet.add k accu) ISet.empty local_anns in |
418 |
ISet.iter (fun node_id -> Format.eprintf "Node %s has local expression annotations@." node_id) nodes_with_anns; |
419 |
List.fold_right (fun top accu -> |
420 |
( match top.top_decl_desc with |
421 |
| Node nd when ISet.mem nd.node_id nodes_with_anns -> |
422 |
{ top with top_decl_desc = Node (inline_node ~selection_on_annotation:true nd prog) } |
423 |
| _ -> top |
424 |
)::accu) prog [] |
425 |
|
426 |
) |
427 |
else |
428 |
prog |
429 |
|
430 |
|
431 |
(* Local Variables: *) |
432 |
(* compile-command:"make -C .." *) |
433 |
(* End: *) |