As noted in the comments, your type won't quite cut it. A reasonable implementation of a set as a tree might have type:
data Set a = Leaf | Node a (Set a) (Set a) deriving (Show)
^- note this extra `a`
where each internal Node x l r
has all elements in l
less than x
and all elements in r
greater than x
.
You can partition such a Set
recursively as follows:
partition :: (a -> Bool) -> Set a -> (Set a, Set a)
The Leaf
case is easy, obviously:
partition _ Leaf = (Leaf, Leaf)
Here's how we do the Node
case. For the sub-case where the predicate holds for the value x
in the node, note that we want:
partition f (Node x l r) | f x = (Node x l1 r1, ...)
where l1
and r1
are the subsets of elements in l
and r
that satisfy the predicate, which we can get by recursively partitioning l
and r
.
where (l1, l2) = partition f l
(r1, r2) = partition f r
The Set
invariant will be preserved here, because all elements in l
, including those in the subset l1
are less than x
; for the same reason, all elements in r1
are greater than x
. The only missing piece is that we somehow need to combine l2
and r2
to form the second part of the tuple:
partition f (Node x l r) | f x = (Node x l1 r1, combine l2 r2)
Since combine
is a function that takes two trees with all elements in the first tree less than all elements in the second tree, the following recursive function will do:
combine Leaf r' = r'
combine (Node x l r) r' = Node x l (combine r r')
The case where the predicate does not hold for x
is handled similarly, giving the full definition:
data Set a = Leaf | Node a (Set a) (Set a)
partition :: (a -> Bool) -> Set a -> (Set a, Set a)
partition _ Leaf = (Leaf, Leaf)
partition f (Node x l r)
| f x = (Node x l1 r1, combine l2 r2)
| otherwise = (combine l1 r1, Node x l2 r2)
where (l1, l2) = partition f l
(r1, r2) = partition f r
combine Leaf r' = r'
combine (Node x l r) r' = Node x l (combine r r')
Here's the complete code plus a QuickCheck test that that partition function works as expected:
import Test.QuickCheck
import qualified Data.List (nub, partition, sort)
import Data.List (nub, sort)
data Set a = Leaf | Node a (Set a) (Set a) deriving (Show)
partition :: (a -> Bool) -> Set a -> (Set a, Set a)
partition _ Leaf = (Leaf, Leaf)
partition f (Node x l r)
| f x = (Node x l1 r1, combine l2 r2)
| otherwise = (combine l1 r1, Node x l2 r2)
where (l1, l2) = partition f l
(r1, r2) = partition f r
combine Leaf r' = r'
combine (Node x l r) r' = Node x l (combine r r')
insert :: (Ord a) => a -> Set a -> Set a
insert x Leaf = Node x Leaf Leaf
insert x (Node y l r) = case compare x y of
LT -> Node y (insert x l) r
GT -> Node y l (insert x r)
EQ -> Node y l r
fromList :: (Ord a) => [a] -> Set a
fromList = foldr insert Leaf
toList :: (Ord a) => Set a -> [a]
toList Leaf = []
toList (Node x l r) = toList l ++ x : toList r
prop_partition :: [Int] -> Bool
prop_partition lst =
let (l, r) = Main.partition even (fromList lst') in (toList l, toList r)
== Data.List.partition even (sort $ lst')
where lst' = nub lst
main = quickCheck (withMaxSuccess 10000 prop_partition)