:- module main.
%=============================================================================%
% Implementation of an extremely simple binary tree data type.
% Has very modest optimizations to avoid unnecessary allocations.
:- interface.
%=============================================================================%
:- use_module io.
%-----------------------------------------------------------------------------%
:- pred main(io.io::di, io.io::uo) is det.
%-----------------------------------------------------------------------------%
:- type xbtree(K, V).
%-----------------------------------------------------------------------------%
:- func init = xbtree(K, V).
%-----------------------------------------------------------------------------%
:- func singleton(K, V) = xbtree(K, V).
%-----------------------------------------------------------------------------%
:- pred insert(K::in, V::in, xbtree(K, V)::in, xbtree(K, V)::out) is semidet.
%-----------------------------------------------------------------------------%
:- pred update(K::in, V::in, xbtree(K, V)::in, xbtree(K, V)::out) is semidet.
%-----------------------------------------------------------------------------%
:- pred det_insert(K::in, V::in, xbtree(K, V)::in, xbtree(K, V)::out) is det.
:- pred det_update(K::in, V::in, xbtree(K, V)::in, xbtree(K, V)::out) is det.
%-----------------------------------------------------------------------------%
:- pred set(K::in, V::in, xbtree(K, V)::in, xbtree(K, V)::out) is det.
%-----------------------------------------------------------------------------%
:- pred search(xbtree(K, V)::in, K::in, V::out) is semidet.
%-----------------------------------------------------------------------------%
:- pred lookup(xbtree(K, V)::in, K::in, V::out) is det.
%=============================================================================%
% Implementation of a red/black tree
:- implementation.
%=============================================================================%
:- use_module maybe.
:- use_module bool.
:- use_module exception.
%-----------------------------------------------------------------------------%
det_insert(K, V, In, Out) :-
( insert(K, V, In, Tree) -> Tree = Out
; exception.throw(exception.software_error("Could not insert!")) ).
%-----------------------------------------------------------------------------%
det_update(K, V, In, Out) :-
( update(K, V, In, Tree) -> Tree = Out
; exception.throw(exception.software_error("Could not update!")) ).
%-----------------------------------------------------------------------------%
lookup(Tree, K, V) :-
( search(Tree, K, X) -> V = X
; exception.throw(exception.software_error("Could not lookup!")) ).
%-----------------------------------------------------------------------------%
:- use_module string.
:- pred verify(xbtree(string, int)::in, string::in, int::in, int::out) is det.
verify(Tree, Key, In, Out) :-
( search(Tree, Key, X) ->
( X = In ->
Out = X,
trace [io(!IO)] (
io.write_string("Key ", !IO),
io.write_string(Key, !IO),
io.write_string(" was ", !IO),
io.write_int(X, !IO),
io.nl(!IO)
)
; exception.throw(exception.software_error(
string.append(string.append(string.append(
"Expected ", Key),
" to equal "),
string.from_int(In)))) )
;
exception.throw(exception.software_error(string.append("Missing key ", Key)))
).
main(!IO) :-
init = Tree0,
det_insert("X", 216, Tree0, Tree1),
det_insert("Y", 97, Tree1, Tree2),
det_insert("Z", 312, Tree2, Tree3),
det_update("Y", 409, Tree3, Tree4),
set("W", 1717, Tree4, Tree5),
set("X", 292, Tree5, Tree6),
verify(Tree6, "X", 292, _),
verify(Tree1, "X", 216, _),
verify(Tree6, "Y", 409, _),
verify(Tree6, "Z", 312, _),
verify(Tree6, "W", 1717, _),
( not search(Tree6, "A", _) -> true
; exception.throw(exception.software_error("Unexpected search success!")) ),
io.write_string("Success!\n", !IO).
%-----------------------------------------------------------------------------%
:- type children(T) --->
none ;
one(T) ;
two(T, T).
:- type node(K, V) --->
node(k::K, v::V, less::maybe.maybe(node(K, V)), greater::maybe.maybe(node(K, V))).
:- type xbtree(K, V) --->
empty ;
root(node(K, V)).
%-----------------------------------------------------------------------------%
:- pred search_node(node(K, V)::in, K::in, V::out) is semidet.
%-----------------------------------------------------------------------------%
:- pred insert_node(K::in, V::in, node(K, V)::in, node(K, V)::out) is semidet.
%-----------------------------------------------------------------------------%
:- pred update_node(K::in, V::in, node(K, V)::in, node(K, V)::out, bool.bool::out) is semidet.
%-----------------------------------------------------------------------------%
:- pred set_node(K::in, V::in, node(K, V)::in, node(K, V)::out, bool.bool::out) is det.
%-----------------------------------------------------------------------------%
:- func node(K, V) = node(K, V).
:- mode node(in, in) = (out) is det.
:- mode node(out, out) = (in) is det.
%-----------------------------------------------------------------------------%
:- pragma promise_pure(node/2).
node(K::in, V::in) = (node(K, V, maybe.no, maybe.no)::out).
node(K::out, V::out) = (node(K, V, _, _)::in).
%-----------------------------------------------------------------------------%
init = empty.
%-----------------------------------------------------------------------------%
singleton(K, V) = root(node(K, V)).
%-----------------------------------------------------------------------------%
search(root(Node), K, V) :- search_node(Node, K, V).
%-----------------------------------------------------------------------------%
insert(K, V, empty, root(node(K, V))).
insert(K, V, root(In), root(Out)) :-
insert_node(K, V, In, Out).
%-----------------------------------------------------------------------------%
update(K, V, root(In), root(Out)) :-
update_node(K, V, In, Out, _).
%-----------------------------------------------------------------------------%
set(K, V, empty, root(node(K, V))).
set(K, V, root(In), root(Out)) :-
set_node(K, V, In, Out, _).
%-----------------------------------------------------------------------------%
search_node(node(NodeK, NodeV, MaybeLess, MaybeGreater), K, V) :-
builtin.compare(Cmp, K, NodeK),
require_complete_switch [Cmp] (
Cmp = (=),
V = NodeV
;
Cmp = (<),
MaybeLess = maybe.yes(Less),
search_node(Less, K, V)
;
Cmp = (>),
MaybeGreater = maybe.yes(Greater),
search_node(Greater, K, V)
).
%-----------------------------------------------------------------------------%
insert_node(K, V, !Node) :-
!.Node = node(NodeK, _NodeV, MaybeLess, MaybeGreater),
builtin.compare(Cmp, K, NodeK),
(
Cmp = (<),
require_complete_switch [MaybeLess] (
MaybeLess = maybe.no,
NewLess = node(K, V)
;
MaybeLess = maybe.yes(Less),
insert_node(K, V, Less, NewLess)
% TODO: NOT TAIL RECURSIVE!
),
!Node ^ less := maybe.yes(NewLess)
;
Cmp = (>),
require_complete_switch [MaybeGreater] (
MaybeGreater = maybe.no,
NewGreater = node(K, V)
;
MaybeGreater = maybe.yes(Greater),
insert_node(K, V, Greater, NewGreater)
% TODO: NOT TAIL RECURSIVE!
),
!Node ^ greater := maybe.yes(NewGreater)
).
%-----------------------------------------------------------------------------%
update_node(K, V, !Node, Changed) :-
!.Node = node(NodeK, NodeV, MaybeLess, MaybeGreater),
builtin.compare(Cmp, K, NodeK),
require_complete_switch [Cmp] (
Cmp = (=),
( if
NodeV = V
then
Changed = bool.no % No update needed
else
Changed = bool.yes,
!Node ^ v := V
)
;
Cmp = (<),
MaybeLess = maybe.yes(Less),
update_node(K, V, Less, NewLess, Changed),
% TODO: NOT TAIL RECURSIVE!
(
Changed = bool.yes,
!Node ^ less := maybe.yes(NewLess)
;
Changed = bool.no % No update needed
)
;
Cmp = (>),
MaybeGreater = maybe.yes(Greater),
update_node(K, V, Greater, NewGreater, Changed),
% TODO: NOT TAIL RECURSIVE!
(
Changed = bool.yes,
!Node ^ greater := maybe.yes(NewGreater)
;
Changed = bool.no % No update needed
)
).
%-----------------------------------------------------------------------------%
set_node(K, V, !Node, Changed) :-
!.Node = node(NodeK, NodeV, MaybeLess, MaybeGreater),
builtin.compare(Cmp, K, NodeK),
(
Cmp = (=),
( if
NodeV = V
then
Changed = bool.no % No update needed
else
Changed = bool.yes,
!Node ^ v := V
)
;
Cmp = (<),
(
MaybeLess = maybe.yes(Less),
set_node(K, V, Less, NewLess, Changed)
% TODO: NOT TAIL RECURSIVE!
;
MaybeLess = maybe.no,
Changed = bool.yes,
NewLess = node(K, V)
),
(
Changed = bool.yes,
!Node ^ less := maybe.yes(NewLess)
;
Changed = bool.no % No update needed
)
;
Cmp = (>),
(
MaybeGreater = maybe.yes(Greater),
set_node(K, V, Greater, NewGreater, Changed)
% TODO: NOT TAIL RECURSIVE!
;
MaybeGreater = maybe.no,
Changed = bool.yes,
NewGreater = node(K, V)
),
(
Changed = bool.yes,
!Node ^ greater := maybe.yes(NewGreater)
;
Changed = bool.no % No update needed
)
).