1
|
(***********************************************************************)
|
2
|
(* *)
|
3
|
(* OCaml *)
|
4
|
(* *)
|
5
|
(* Xavier Leroy, projet Cristal, INRIA Rocquencourt *)
|
6
|
(* *)
|
7
|
(* Copyright 1996 Institut National de Recherche en Informatique et *)
|
8
|
(* en Automatique. All rights reserved. This file is distributed *)
|
9
|
(* under the terms of the GNU Library General Public License, with *)
|
10
|
(* the special exception on linking described in file ../LICENSE. *)
|
11
|
(* *)
|
12
|
(***********************************************************************)
|
13
|
|
14
|
module type OrderedType = sig
|
15
|
type t
|
16
|
|
17
|
val compare : t -> t -> int
|
18
|
end
|
19
|
|
20
|
module type S = sig
|
21
|
type key
|
22
|
|
23
|
type +'a t
|
24
|
|
25
|
val empty : 'a t
|
26
|
|
27
|
val is_empty : 'a t -> bool
|
28
|
|
29
|
val mem : key -> 'a t -> bool
|
30
|
|
31
|
val add : key -> 'a -> 'a t -> 'a t
|
32
|
|
33
|
val singleton : key -> 'a -> 'a t
|
34
|
|
35
|
val remove : key -> 'a t -> 'a t
|
36
|
|
37
|
val merge :
|
38
|
(key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t
|
39
|
|
40
|
val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int
|
41
|
|
42
|
val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
|
43
|
|
44
|
val iter : (key -> 'a -> unit) -> 'a t -> unit
|
45
|
|
46
|
val fold : (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
|
47
|
|
48
|
val for_all : (key -> 'a -> bool) -> 'a t -> bool
|
49
|
|
50
|
val exists : (key -> 'a -> bool) -> 'a t -> bool
|
51
|
|
52
|
val filter : (key -> 'a -> bool) -> 'a t -> 'a t
|
53
|
|
54
|
val partition : (key -> 'a -> bool) -> 'a t -> 'a t * 'a t
|
55
|
|
56
|
val cardinal : 'a t -> int
|
57
|
|
58
|
val bindings : 'a t -> (key * 'a) list
|
59
|
|
60
|
val min_binding : 'a t -> key * 'a
|
61
|
|
62
|
val max_binding : 'a t -> key * 'a
|
63
|
|
64
|
val choose : 'a t -> key * 'a
|
65
|
|
66
|
val split : key -> 'a t -> 'a t * 'a option * 'a t
|
67
|
|
68
|
val find : key -> 'a t -> 'a
|
69
|
|
70
|
val map : ('a -> 'b) -> 'a t -> 'b t
|
71
|
|
72
|
val mapi : (key -> 'a -> 'b) -> 'a t -> 'b t
|
73
|
end
|
74
|
|
75
|
module Make (Ord : OrderedType) = struct
|
76
|
type key = Ord.t
|
77
|
|
78
|
type 'a t = Empty | Node of 'a t * key * 'a * 'a t * int
|
79
|
|
80
|
let height = function Empty -> 0 | Node (_, _, _, _, h) -> h
|
81
|
|
82
|
let create l x d r =
|
83
|
let hl = height l and hr = height r in
|
84
|
Node (l, x, d, r, if hl >= hr then hl + 1 else hr + 1)
|
85
|
|
86
|
let singleton x d = Node (Empty, x, d, Empty, 1)
|
87
|
|
88
|
let bal l x d r =
|
89
|
let hl = match l with Empty -> 0 | Node (_, _, _, _, h) -> h in
|
90
|
let hr = match r with Empty -> 0 | Node (_, _, _, _, h) -> h in
|
91
|
if hl > hr + 2 then
|
92
|
match l with
|
93
|
| Empty ->
|
94
|
invalid_arg "Map.bal"
|
95
|
| Node (ll, lv, ld, lr, _) -> (
|
96
|
if height ll >= height lr then create ll lv ld (create lr x d r)
|
97
|
else
|
98
|
match lr with
|
99
|
| Empty ->
|
100
|
invalid_arg "Map.bal"
|
101
|
| Node (lrl, lrv, lrd, lrr, _) ->
|
102
|
create (create ll lv ld lrl) lrv lrd (create lrr x d r))
|
103
|
else if hr > hl + 2 then
|
104
|
match r with
|
105
|
| Empty ->
|
106
|
invalid_arg "Map.bal"
|
107
|
| Node (rl, rv, rd, rr, _) -> (
|
108
|
if height rr >= height rl then create (create l x d rl) rv rd rr
|
109
|
else
|
110
|
match rl with
|
111
|
| Empty ->
|
112
|
invalid_arg "Map.bal"
|
113
|
| Node (rll, rlv, rld, rlr, _) ->
|
114
|
create (create l x d rll) rlv rld (create rlr rv rd rr))
|
115
|
else Node (l, x, d, r, if hl >= hr then hl + 1 else hr + 1)
|
116
|
|
117
|
let empty = Empty
|
118
|
|
119
|
let is_empty = function Empty -> true | _ -> false
|
120
|
|
121
|
let rec add x data = function
|
122
|
| Empty ->
|
123
|
Node (Empty, x, data, Empty, 1)
|
124
|
| Node (l, v, d, r, h) ->
|
125
|
let c = Ord.compare x v in
|
126
|
if c = 0 then Node (l, x, data, r, h)
|
127
|
else if c < 0 then bal (add x data l) v d r
|
128
|
else bal l v d (add x data r)
|
129
|
|
130
|
let rec find x = function
|
131
|
| Empty ->
|
132
|
raise Not_found
|
133
|
| Node (l, v, d, r, _) ->
|
134
|
let c = Ord.compare x v in
|
135
|
if c = 0 then d else find x (if c < 0 then l else r)
|
136
|
|
137
|
let rec mem x = function
|
138
|
| Empty ->
|
139
|
false
|
140
|
| Node (l, v, _, r, _) ->
|
141
|
let c = Ord.compare x v in
|
142
|
c = 0 || mem x (if c < 0 then l else r)
|
143
|
|
144
|
let rec min_binding = function
|
145
|
| Empty ->
|
146
|
raise Not_found
|
147
|
| Node (Empty, x, d, _, _) ->
|
148
|
x, d
|
149
|
| Node (l, _, _, _, _) ->
|
150
|
min_binding l
|
151
|
|
152
|
let rec max_binding = function
|
153
|
| Empty ->
|
154
|
raise Not_found
|
155
|
| Node (_, x, d, Empty, _) ->
|
156
|
x, d
|
157
|
| Node (_, _, _, r, _) ->
|
158
|
max_binding r
|
159
|
|
160
|
let rec remove_min_binding = function
|
161
|
| Empty ->
|
162
|
invalid_arg "Map.remove_min_elt"
|
163
|
| Node (Empty, _, _, r, _) ->
|
164
|
r
|
165
|
| Node (l, x, d, r, _) ->
|
166
|
bal (remove_min_binding l) x d r
|
167
|
|
168
|
let merge t1 t2 =
|
169
|
match t1, t2 with
|
170
|
| Empty, t ->
|
171
|
t
|
172
|
| t, Empty ->
|
173
|
t
|
174
|
| _, _ ->
|
175
|
let x, d = min_binding t2 in
|
176
|
bal t1 x d (remove_min_binding t2)
|
177
|
|
178
|
let rec remove x = function
|
179
|
| Empty ->
|
180
|
Empty
|
181
|
| Node (l, v, d, r, _) ->
|
182
|
let c = Ord.compare x v in
|
183
|
if c = 0 then merge l r
|
184
|
else if c < 0 then bal (remove x l) v d r
|
185
|
else bal l v d (remove x r)
|
186
|
|
187
|
let rec iter f = function
|
188
|
| Empty ->
|
189
|
()
|
190
|
| Node (l, v, d, r, _) ->
|
191
|
iter f l;
|
192
|
f v d;
|
193
|
iter f r
|
194
|
|
195
|
let rec map f = function
|
196
|
| Empty ->
|
197
|
Empty
|
198
|
| Node (l, v, d, r, h) ->
|
199
|
let l' = map f l in
|
200
|
let d' = f d in
|
201
|
let r' = map f r in
|
202
|
Node (l', v, d', r', h)
|
203
|
|
204
|
let rec mapi f = function
|
205
|
| Empty ->
|
206
|
Empty
|
207
|
| Node (l, v, d, r, h) ->
|
208
|
let l' = mapi f l in
|
209
|
let d' = f v d in
|
210
|
let r' = mapi f r in
|
211
|
Node (l', v, d', r', h)
|
212
|
|
213
|
let rec fold f m accu =
|
214
|
match m with
|
215
|
| Empty ->
|
216
|
accu
|
217
|
| Node (l, v, d, r, _) ->
|
218
|
fold f r (f v d (fold f l accu))
|
219
|
|
220
|
let rec for_all p = function
|
221
|
| Empty ->
|
222
|
true
|
223
|
| Node (l, v, d, r, _) ->
|
224
|
p v d && for_all p l && for_all p r
|
225
|
|
226
|
let rec exists p = function
|
227
|
| Empty ->
|
228
|
false
|
229
|
| Node (l, v, d, r, _) ->
|
230
|
p v d || exists p l || exists p r
|
231
|
|
232
|
(* Beware: those two functions assume that the added k is *strictly* smaller
|
233
|
(or bigger) than all the present keys in the tree; it does not test for
|
234
|
equality with the current min (or max) key.
|
235
|
|
236
|
Indeed, they are only used during the "join" operation which respects this
|
237
|
precondition. *)
|
238
|
|
239
|
let rec add_min_binding k v = function
|
240
|
| Empty ->
|
241
|
singleton k v
|
242
|
| Node (l, x, d, r, _) ->
|
243
|
bal (add_min_binding k v l) x d r
|
244
|
|
245
|
let rec add_max_binding k v = function
|
246
|
| Empty ->
|
247
|
singleton k v
|
248
|
| Node (l, x, d, r, _) ->
|
249
|
bal l x d (add_max_binding k v r)
|
250
|
|
251
|
(* Same as create and bal, but no assumptions are made on the relative heights
|
252
|
of l and r. *)
|
253
|
|
254
|
let rec join l v d r =
|
255
|
match l, r with
|
256
|
| Empty, _ ->
|
257
|
add_min_binding v d r
|
258
|
| _, Empty ->
|
259
|
add_max_binding v d l
|
260
|
| Node (ll, lv, ld, lr, lh), Node (rl, rv, rd, rr, rh) ->
|
261
|
if lh > rh + 2 then bal ll lv ld (join lr v d r)
|
262
|
else if rh > lh + 2 then bal (join l v d rl) rv rd rr
|
263
|
else create l v d r
|
264
|
|
265
|
(* Merge two trees l and r into one. All elements of l must precede the
|
266
|
elements of r. No assumption on the heights of l and r. *)
|
267
|
|
268
|
let concat t1 t2 =
|
269
|
match t1, t2 with
|
270
|
| Empty, t ->
|
271
|
t
|
272
|
| t, Empty ->
|
273
|
t
|
274
|
| _, _ ->
|
275
|
let x, d = min_binding t2 in
|
276
|
join t1 x d (remove_min_binding t2)
|
277
|
|
278
|
let concat_or_join t1 v d t2 =
|
279
|
match d with Some d -> join t1 v d t2 | None -> concat t1 t2
|
280
|
|
281
|
let rec split x = function
|
282
|
| Empty ->
|
283
|
Empty, None, Empty
|
284
|
| Node (l, v, d, r, _) ->
|
285
|
let c = Ord.compare x v in
|
286
|
if c = 0 then l, Some d, r
|
287
|
else if c < 0 then
|
288
|
let ll, pres, rl = split x l in
|
289
|
ll, pres, join rl v d r
|
290
|
else
|
291
|
let lr, pres, rr = split x r in
|
292
|
join l v d lr, pres, rr
|
293
|
|
294
|
let rec merge f s1 s2 =
|
295
|
match s1, s2 with
|
296
|
| Empty, Empty ->
|
297
|
Empty
|
298
|
| Node (l1, v1, d1, r1, h1), _ when h1 >= height s2 ->
|
299
|
let l2, d2, r2 = split v1 s2 in
|
300
|
concat_or_join (merge f l1 l2) v1 (f v1 (Some d1) d2) (merge f r1 r2)
|
301
|
| _, Node (l2, v2, d2, r2, _) ->
|
302
|
let l1, d1, r1 = split v2 s1 in
|
303
|
concat_or_join (merge f l1 l2) v2 (f v2 d1 (Some d2)) (merge f r1 r2)
|
304
|
| _ ->
|
305
|
assert false
|
306
|
|
307
|
let rec filter p = function
|
308
|
| Empty ->
|
309
|
Empty
|
310
|
| Node (l, v, d, r, _) ->
|
311
|
(* call [p] in the expected left-to-right order *)
|
312
|
let l' = filter p l in
|
313
|
let pvd = p v d in
|
314
|
let r' = filter p r in
|
315
|
if pvd then join l' v d r' else concat l' r'
|
316
|
|
317
|
let rec partition p = function
|
318
|
| Empty ->
|
319
|
Empty, Empty
|
320
|
| Node (l, v, d, r, _) ->
|
321
|
(* call [p] in the expected left-to-right order *)
|
322
|
let lt, lf = partition p l in
|
323
|
let pvd = p v d in
|
324
|
let rt, rf = partition p r in
|
325
|
if pvd then join lt v d rt, concat lf rf else concat lt rt, join lf v d rf
|
326
|
|
327
|
type 'a enumeration = End | More of key * 'a * 'a t * 'a enumeration
|
328
|
|
329
|
let rec cons_enum m e =
|
330
|
match m with
|
331
|
| Empty ->
|
332
|
e
|
333
|
| Node (l, v, d, r, _) ->
|
334
|
cons_enum l (More (v, d, r, e))
|
335
|
|
336
|
let compare cmp m1 m2 =
|
337
|
let rec compare_aux e1 e2 =
|
338
|
match e1, e2 with
|
339
|
| End, End ->
|
340
|
0
|
341
|
| End, _ ->
|
342
|
-1
|
343
|
| _, End ->
|
344
|
1
|
345
|
| More (v1, d1, r1, e1), More (v2, d2, r2, e2) ->
|
346
|
let c = Ord.compare v1 v2 in
|
347
|
if c <> 0 then c
|
348
|
else
|
349
|
let c = cmp d1 d2 in
|
350
|
if c <> 0 then c else compare_aux (cons_enum r1 e1) (cons_enum r2 e2)
|
351
|
in
|
352
|
compare_aux (cons_enum m1 End) (cons_enum m2 End)
|
353
|
|
354
|
let equal cmp m1 m2 =
|
355
|
let rec equal_aux e1 e2 =
|
356
|
match e1, e2 with
|
357
|
| End, End ->
|
358
|
true
|
359
|
| End, _ ->
|
360
|
false
|
361
|
| _, End ->
|
362
|
false
|
363
|
| More (v1, d1, r1, e1), More (v2, d2, r2, e2) ->
|
364
|
Ord.compare v1 v2 = 0
|
365
|
&& cmp d1 d2
|
366
|
&& equal_aux (cons_enum r1 e1) (cons_enum r2 e2)
|
367
|
in
|
368
|
equal_aux (cons_enum m1 End) (cons_enum m2 End)
|
369
|
|
370
|
let rec cardinal = function
|
371
|
| Empty ->
|
372
|
0
|
373
|
| Node (l, _, _, r, _) ->
|
374
|
cardinal l + 1 + cardinal r
|
375
|
|
376
|
let rec bindings_aux accu = function
|
377
|
| Empty ->
|
378
|
accu
|
379
|
| Node (l, v, d, r, _) ->
|
380
|
bindings_aux ((v, d) :: bindings_aux accu r) l
|
381
|
|
382
|
let bindings s = bindings_aux [] s
|
383
|
|
384
|
let choose = min_binding
|
385
|
end
|