lustrec / src / plugins / salsa / salsaDatatypes.ml @ 53206908
History | View | Annotate | Download (11.4 KB)
1 |
module LT = LustreSpec |
---|---|
2 |
module MC = Machine_code |
3 |
module ST = Salsa.SalsaTypes |
4 |
module Float = Salsa.Float |
5 |
|
6 |
let debug = true |
7 |
|
8 |
let pp_hash ~sep f fmt r = |
9 |
Format.fprintf fmt "[@[<v>"; |
10 |
Hashtbl.iter (fun k v -> Format.fprintf fmt "%t%s@ " (f k v) sep) r; |
11 |
Format.fprintf fmt "]@]"; |
12 |
|
13 |
module FormalEnv = |
14 |
struct |
15 |
type fe_t = (LT.ident, (int * LT.value_t)) Hashtbl.t |
16 |
let cpt = ref 0 |
17 |
|
18 |
exception NoDefinition of LT.var_decl |
19 |
(* Returns the expression associated to v in env *) |
20 |
let get_def (env: fe_t) v = |
21 |
try |
22 |
snd (Hashtbl.find env v.LT.var_id) |
23 |
with Not_found -> raise (NoDefinition v) |
24 |
|
25 |
let def (env: fe_t) d expr = |
26 |
incr cpt; |
27 |
let fresh = Hashtbl.copy env in |
28 |
Hashtbl.add fresh d.LT.var_id (!cpt, expr); fresh |
29 |
|
30 |
let empty (): fe_t = Hashtbl.create 13 |
31 |
|
32 |
let pp fmt env = pp_hash ~sep:";" (fun k (_,v) fmt -> Format.fprintf fmt "%s -> %a" k MC.pp_val v) fmt env |
33 |
|
34 |
let fold f = Hashtbl.fold (fun k (_,v) accu -> f k v accu) |
35 |
|
36 |
let get_sort_fun env = |
37 |
let order = Hashtbl.fold (fun k (cpt, _) accu -> (k,cpt)::accu) env [] in |
38 |
fun v1 v2 -> |
39 |
if List.mem_assoc v1.LT.var_id order && List.mem_assoc v2.LT.var_id order then |
40 |
if (List.assoc v1.LT.var_id order) <= (List.assoc v2.LT.var_id order) then |
41 |
-1 |
42 |
else |
43 |
1 |
44 |
else |
45 |
assert false |
46 |
|
47 |
end |
48 |
|
49 |
module Ranges = |
50 |
functor (Value: sig type t val union: t -> t -> t val pp: Format.formatter -> t -> unit end) -> |
51 |
struct |
52 |
type t = Value.t |
53 |
type r_t = (LT.ident, Value.t) Hashtbl.t |
54 |
|
55 |
let empty: r_t = Hashtbl.create 13 |
56 |
|
57 |
(* Look for def of node i with inputs living in vtl_ranges, reinforce ranges |
58 |
to bound vdl: each output of node i *) |
59 |
let add_call ranges vdl id vtl_ranges = ranges (* TODO assert false. On est |
60 |
pas obligé de faire |
61 |
qqchose. On peut supposer |
62 |
que les ranges sont donnés |
63 |
pour chaque noeud *) |
64 |
|
65 |
|
66 |
let pp = pp_hash ~sep:";" (fun k v fmt -> Format.fprintf fmt "%s -> %a" k Value.pp v) |
67 |
let pp_val = Value.pp |
68 |
|
69 |
let add_def ranges name r = |
70 |
(* Format.eprintf "%s: declare %a@." *) |
71 |
(* x.LT.var_id *) |
72 |
(* Value.pp r ; *) |
73 |
|
74 |
let fresh = Hashtbl.copy ranges in |
75 |
Hashtbl.add fresh name r; fresh |
76 |
|
77 |
let enlarge ranges name r = |
78 |
let fresh = Hashtbl.copy ranges in |
79 |
if Hashtbl.mem fresh name then |
80 |
Hashtbl.replace fresh name (Value.union r (Hashtbl.find fresh name)) |
81 |
else |
82 |
Hashtbl.add fresh name r; |
83 |
fresh |
84 |
|
85 |
|
86 |
(* Compute a join per variable *) |
87 |
let merge ranges1 ranges2 = |
88 |
Format.eprintf "Mergeing rangesint %a with %a@." pp ranges1 pp ranges2; |
89 |
let ranges = Hashtbl.copy ranges1 in |
90 |
Hashtbl.iter (fun k v -> |
91 |
if Hashtbl.mem ranges k then ( |
92 |
(* Format.eprintf "%s: %a union %a = %a@." *) |
93 |
(* k *) |
94 |
(* Value.pp v *) |
95 |
(* Value.pp (Hashtbl.find ranges k) *) |
96 |
(* Value.pp (Value.union v (Hashtbl.find ranges k)); *) |
97 |
Hashtbl.replace ranges k (Value.union v (Hashtbl.find ranges k)) |
98 |
) |
99 |
else |
100 |
Hashtbl.add ranges k v |
101 |
) ranges2; |
102 |
Format.eprintf "Merge result %a@." pp ranges; |
103 |
ranges |
104 |
|
105 |
end |
106 |
|
107 |
module FloatIntSalsa = |
108 |
struct |
109 |
type t = ST.abstractValue |
110 |
|
111 |
let pp fmt (f,r) = |
112 |
match f, r with |
113 |
| ST.I(a,b), ST.J(c,d) -> |
114 |
Format.fprintf fmt "[%f, %f] + [%f, %f]" a b c d |
115 |
| ST.I(a,b), ST.JInfty -> Format.fprintf fmt "[%f, %f] + oo" a b |
116 |
| ST.Empty, _ -> Format.fprintf fmt "???" |
117 |
|
118 |
| _ -> assert false |
119 |
|
120 |
let union v1 v2 = |
121 |
match v1, v2 with |
122 |
|(ST.I(x1, x2), ST.J(y1, y2)), (ST.I(x1', x2'), ST.J(y1', y2')) -> |
123 |
ST.(I(min x1 x1', max x2 x2'), J(min y1 y1', max y2 y2')) |
124 |
| _ -> Format.eprintf "%a cup %a failed@.@?" pp v1 pp v2; assert false |
125 |
|
126 |
let inject cst = match cst with |
127 |
| LT.Const_int(i) -> Salsa.Builder.mk_cst (ST.I(float_of_int i,float_of_int i),ST.J(0.0,0.0)) |
128 |
| LT.Const_real (c,e,s) -> (* TODO: this is incorrect. We should rather |
129 |
compute the error associated to the float *) |
130 |
let r = float_of_string s in |
131 |
if r = 0. then |
132 |
Salsa.Builder.mk_cst (ST.I(-. min_float, min_float),Float.ulp (ST.I(-. min_float, min_float))) |
133 |
else |
134 |
Salsa.Builder.mk_cst (ST.I(r*.(1.-.epsilon_float),r*.(1.+.epsilon_float)),Float.ulp (ST.I(r,r))) |
135 |
| _ -> assert false |
136 |
end |
137 |
|
138 |
module RangesInt = Ranges (FloatIntSalsa) |
139 |
|
140 |
module Vars = |
141 |
struct |
142 |
module VarSet = Set.Make (struct type t = LT.var_decl let compare x y = compare x.LT.var_id y.LT.var_id end) |
143 |
let real_vars vs = VarSet.filter (fun v -> Types.is_real_type v.LT.var_type) vs |
144 |
let of_list = List.fold_left (fun s e -> VarSet.add e s) VarSet.empty |
145 |
|
146 |
include VarSet |
147 |
|
148 |
let remove_list (set:t) (v_list: elt list) : t = List.fold_right VarSet.remove v_list set |
149 |
let pp fmt vs = Utils.fprintf_list ~sep:", " Printers.pp_var fmt (VarSet.elements vs) |
150 |
end |
151 |
|
152 |
|
153 |
|
154 |
|
155 |
|
156 |
|
157 |
|
158 |
|
159 |
|
160 |
|
161 |
(*************************************************************************************) |
162 |
(* Converting values back and forth *) |
163 |
(*************************************************************************************) |
164 |
|
165 |
let rec value_t2salsa_expr constEnv vt = |
166 |
let value_t2salsa_expr = value_t2salsa_expr constEnv in |
167 |
let res = |
168 |
match vt.LT.value_desc with |
169 |
(* | LT.Cst(LT.Const_tag(t) as c) -> *) |
170 |
(* Format.eprintf "v2s: cst tag@."; *) |
171 |
(* if List.mem_assoc t constEnv then ( *) |
172 |
(* Format.eprintf "trouvé la constante %s: %a@ " t Printers.pp_const c; *) |
173 |
(* FloatIntSalsa.inject (List.assoc t constEnv) *) |
174 |
(* ) *) |
175 |
(* else ( *) |
176 |
(* Format.eprintf "Const tag %s unhandled@.@?" t ; *) |
177 |
(* raise (Salsa.Prelude.Error ("Entschuldigung6, constant tag not yet implemented")) *) |
178 |
(* ) *) |
179 |
| LT.Cst(cst) -> Format.eprintf "v2s: cst tag 2: %a@." Printers.pp_const cst; FloatIntSalsa.inject cst |
180 |
| LT.LocalVar(v) |
181 |
| LT.StateVar(v) -> Format.eprintf "v2s: var %s@." v.LT.var_id; |
182 |
let sel_fun = (fun (vname, _) -> v.LT.var_id = vname) in |
183 |
if List.exists sel_fun constEnv then |
184 |
let _, cst = List.find sel_fun constEnv in |
185 |
FloatIntSalsa.inject cst |
186 |
else |
187 |
let id = v.LT.var_id in |
188 |
Salsa.Builder.mk_id id |
189 |
| LT.Fun(binop, [x;y]) -> let salsaX = value_t2salsa_expr x in |
190 |
let salsaY = value_t2salsa_expr y in |
191 |
let op = ( |
192 |
let pred f x y = Salsa.Builder.mk_int_of_bool (f x y) in |
193 |
match binop with |
194 |
| "+" -> Salsa.Builder.mk_plus |
195 |
| "-" -> Salsa.Builder.mk_minus |
196 |
| "*" -> Salsa.Builder.mk_times |
197 |
| "/" -> Salsa.Builder.mk_div |
198 |
| "=" -> pred Salsa.Builder.mk_eq |
199 |
| "<" -> pred Salsa.Builder.mk_lt |
200 |
| ">" -> pred Salsa.Builder.mk_gt |
201 |
| "<=" -> pred Salsa.Builder.mk_lte |
202 |
| ">=" -> pred Salsa.Builder.mk_gte |
203 |
| _ -> assert false |
204 |
) |
205 |
in |
206 |
op salsaX salsaY |
207 |
| LT.Fun(unop, [x]) -> let salsaX = value_t2salsa_expr x in |
208 |
Salsa.Builder.mk_uminus salsaX |
209 |
|
210 |
| LT.Fun(f,_) -> raise (Salsa.Prelude.Error |
211 |
("Unhandled function "^f^" in conversion to salsa expression")) |
212 |
|
213 |
| LT.Array(_) |
214 |
| LT.Access(_) |
215 |
| LT.Power(_) -> raise (Salsa.Prelude.Error ("Unhandled construct in conversion to salsa expression")) |
216 |
in |
217 |
(* if debug then *) |
218 |
(* Format.eprintf "value_t2salsa_expr: %a -> %a@ " *) |
219 |
(* MC.pp_val vt *) |
220 |
(* (fun fmt x -> Format.fprintf fmt "%s" (Salsa.Print.printExpression x)) res; *) |
221 |
res |
222 |
|
223 |
type var_decl = { vdecl: LT.var_decl; is_local: bool } |
224 |
module VarEnv = Map.Make (struct type t = LT.ident let compare = compare end ) |
225 |
|
226 |
(* let is_local_var vars_env v = *) |
227 |
(* try *) |
228 |
(* (VarEnv.find v vars_env).is_local *) |
229 |
(* with Not_found -> Format.eprintf "Impossible to find var %s@.@?" v; assert false *) |
230 |
|
231 |
let get_var vars_env v = |
232 |
try |
233 |
VarEnv.find v vars_env |
234 |
with Not_found -> Format.eprintf "Impossible to find var %s@.@?" v; assert false |
235 |
|
236 |
let compute_vars_env m = |
237 |
let env = VarEnv.empty in |
238 |
let env = |
239 |
List.fold_left |
240 |
(fun accu v -> VarEnv.add v.LT.var_id {vdecl = v; is_local = false; } accu) |
241 |
env |
242 |
m.MC.mmemory |
243 |
in |
244 |
let env = |
245 |
List.fold_left ( |
246 |
fun accu v -> VarEnv.add v.LT.var_id {vdecl = v; is_local = true; } accu |
247 |
) |
248 |
env |
249 |
MC.(m.mstep.step_inputs@m.mstep.step_outputs@m.mstep.step_locals) |
250 |
in |
251 |
env |
252 |
|
253 |
let rec salsa_expr2value_t vars_env cst_env e = |
254 |
let salsa_expr2value_t = salsa_expr2value_t vars_env cst_env in |
255 |
let binop op e1 e2 t = |
256 |
let x = salsa_expr2value_t e1 in |
257 |
let y = salsa_expr2value_t e2 in |
258 |
MC.mk_val (LT.Fun (op, [x;y])) t |
259 |
in |
260 |
match e with |
261 |
ST.Cst((ST.I(f1,f2),_),_) -> (* We project ranges into constants. We |
262 |
forget about errors and provide the |
263 |
mean/middle value of the interval |
264 |
*) |
265 |
let new_float = |
266 |
if f1 = f2 then |
267 |
f1 |
268 |
else |
269 |
(f1 +. f2) /. 2.0 |
270 |
in |
271 |
Format.eprintf "Converting [%.45f, %.45f] in %.45f@." f1 f2 new_float; |
272 |
let cst = |
273 |
let s = |
274 |
if new_float = 0. then "0." else |
275 |
(* We have to convert it into our format: int * int * real *) |
276 |
let _ = Format.flush_str_formatter () in |
277 |
Format.fprintf Format.str_formatter "%.50f" new_float; |
278 |
Format.flush_str_formatter () |
279 |
in |
280 |
Parser_lustre.signed_const Lexer_lustre.token (Lexing.from_string s) |
281 |
in |
282 |
MC.mk_val (LT.Cst(cst)) Type_predef.type_real |
283 |
| ST.Id(id, _) -> |
284 |
Format.eprintf "Looking for id=%s@.@?" id; |
285 |
if List.mem_assoc id cst_env then ( |
286 |
let cst = List.assoc id cst_env in |
287 |
Format.eprintf "Found cst = %a@.@?" Printers.pp_const cst; |
288 |
MC.mk_val (LT.Cst cst) Type_predef.type_real |
289 |
) |
290 |
else |
291 |
(* if is_const salsa_label then *) |
292 |
(* MC.Cst(LT.Const_tag(get_const salsa_label)) *) |
293 |
(* else *) |
294 |
let var_id = try get_var vars_env id with Not_found -> assert false in |
295 |
if var_id.is_local then |
296 |
MC.mk_val (LT.LocalVar(var_id.vdecl)) var_id.vdecl.LT.var_type |
297 |
else |
298 |
MC.mk_val (LT.StateVar(var_id.vdecl)) var_id.vdecl.LT.var_type |
299 |
| ST.Plus(x, y, _) -> binop "+" x y Type_predef.type_real |
300 |
| ST.Minus(x, y, _) -> binop "-" x y Type_predef.type_real |
301 |
| ST.Times(x, y, _) -> binop "*" x y Type_predef.type_real |
302 |
| ST.Div(x, y, _) -> binop "/" x y Type_predef.type_real |
303 |
| ST.Uminus(x,_) -> let x = salsa_expr2value_t x in |
304 |
MC.mk_val (LT.Fun("uminus",[x])) Type_predef.type_real |
305 |
| ST.IntOfBool(ST.Eq(x, y, _),_) -> binop "=" x y Type_predef.type_bool |
306 |
| ST.IntOfBool(ST.Lt(x,y,_),_) -> binop "<" x y Type_predef.type_bool |
307 |
| ST.IntOfBool(ST.Gt(x,y,_),_) -> binop ">" x y Type_predef.type_bool |
308 |
| ST.IntOfBool(ST.Lte(x,y,_),_) -> binop "<=" x y Type_predef.type_bool |
309 |
| ST.IntOfBool(ST.Gte(x,y,_),_) -> binop ">=" x y Type_predef.type_bool |
310 |
| _ -> raise (Salsa.Prelude.Error "Entschuldigung, salsaExpr2value_t case not yet implemented") |
311 |
|
312 |
|
313 |
let rec get_salsa_free_vars vars_env constEnv absenv e = |
314 |
let f = get_salsa_free_vars vars_env constEnv absenv in |
315 |
match e with |
316 |
| ST.Id (id, _) -> |
317 |
if not (List.mem_assoc id absenv) && not (List.mem_assoc id constEnv) then |
318 |
Vars.singleton ((try VarEnv.find id vars_env with Not_found -> assert false).vdecl) |
319 |
else |
320 |
Vars.empty |
321 |
| ST.Plus(x, y, _) |
322 |
| ST.Minus(x, y, _) |
323 |
| ST.Times(x, y, _) |
324 |
| ST.Div(x, y, _) |
325 |
| ST.IntOfBool(ST.Eq(x, y, _),_) |
326 |
| ST.IntOfBool(ST.Lt(x,y,_),_) |
327 |
| ST.IntOfBool(ST.Gt(x,y,_),_) |
328 |
| ST.IntOfBool(ST.Lte(x,y,_),_) |
329 |
| ST.IntOfBool(ST.Gte(x,y,_),_) |
330 |
-> Vars.union (f x) (f y) |
331 |
| ST.Uminus(x,_) -> f x |
332 |
| ST.Cst _ -> Vars.empty |
333 |
| _ -> assert false |
334 |
|
335 |
(* Local Variables: *) |
336 |
(* compile-command:"make -C ../../.." *) |
337 |
(* End: *) |