abstype point (a:t@ype) = ptr
fun {a:t@ype} fresh (a): point a
fun {a:t@ype} find (point a): a
fun {a:t@ype} union (point a, point a): void
fun {a:t@ype} equiv (point a, point a): bool
fun {a:t@ype} redundant (point a): bool
fun {a:t@ype} change (point a, a): void
#define ATS_DYNLOADFLAG 0
staload "unionfind.sats"
#include "share/atspre_staload.hats"
datatype pt (a:t@ype) =
| Link of (ref (pt a))
| Info of (ref int, ref a)
assume point (a:t@ype) = ref (pt a)
implement {a} fresh (desc) =
ref<pt a> (Info (ref<int> 1, ref<a> desc))
local
fun {a:t@ype} addr_of (p: point a): ptr = $UNSAFE.cast{ptr} p (* given point = ref pt, returns address of pt *)
fun {a:t@ype} link (p: point a, l: point a): void =
case+ !p of
| Link cell => !cell := !l
| Info _ => !p := Link l
fun {a:t@ype} repr (p: point a): point a =
case+ !p of
| Info _ => p
| Link pp =>
let
val ppp = repr pp
val _ = if addr_of ppp != addr_of pp then link (p, ppp)
in
ppp
end
in
implement {a} find (p) =
case+ !p of
| Info (_, desc) => !desc
| Link p =>
case+ !p of
| Info (_, desc) => !desc
| Link p => find (repr p)
implement {a} change (p, desc) =
case+ !p of
| Info (_, cell) => !cell := desc
| Link p =>
case+ !p of
| Info (_, cell) => !cell := desc
| Link p => change (repr p, desc)
implement {a} union (p1, p2) = let
val p1 = repr p1
val p2 = repr p2
val _ = assertloc (addr_of p1 != addr_of p2)
in
case- (!p1, !p2) of
| (Info (w1, d1), Info (w2, d2)) =>
if !w1 >= !w2
then (link (p2, p1); !w1 := !w1 + !w2; !d1 := !d2)
else (link (p1, p2); !w2 := !w1 + !w2)
end
implement {a} equiv (p1, p2) =
addr_of (repr p1) = addr_of (repr p2)
implement {a} redundant (p) =
case+ !p of
| Info _ => false
| Link _ => true
end
#include "share/atspre_staload.hats"
staload "unionfind.sats"
staload _ = "unionfind.dats"
implement main0 () = () where {
val p1 = fresh<int> 1
val p2 = fresh<int> 2
val p3 = fresh<int> 3
val _ = println! (equiv (p1, p2))
val _ = union (p1, p2)
val _ = println! (equiv (p1, p2))
val _ = println! (equiv (p1, p3))
val _ = println! (equiv (p2 ,p3))
val _ = union (p2, p3)
val _ = println! (equiv (p1, p2))
val _ = println! (equiv (p1, p3))
val _ = println! (equiv (p2 ,p3))
}