This answer is inspired by DDub's, but I think it's rather simpler, and it should offer slightly better type inference and probably better type errors. Let's first clear our throats:
{-# language FlexibleContexts #-}
{-# language FlexibleInstances #-}
{-# language MultiParamTypeClasses #-}
{-# language DataKinds #-}
{-# language AllowAmbiguousTypes #-}
{-# language UndecidableInstances #-}
{-# language ScopedTypeVariables #-}
module DMap where
import Data.Kind (Type)
import GHC.TypeNats
GHC's built-in Nat
s are pretty awkward to work with, because we can't pattern match on "not 0". So let's make them just part of the interface, and avoid them in the implementation.
-- Real unary naturals
data UNat = Z | S UNat
-- Convert 'Nat' to 'UNat' in the obvious way.
type family ToUnary (n :: Nat) where
ToUnary 0 = 'Z
ToUnary n = 'S (ToUnary (n - 1))
-- This is just a little wrapper function to deal with the
-- 'Nat'-to-'UNat' business.
dmap :: forall n s t a b. DMap (ToUnary n) s t a b
=> (a -> b) -> s -> t
dmap = dmap' @(ToUnary n)
Now that we've gotten the utterly boring part out of the way, the rest turns out to be pretty simple.
-- @n@ indicates how many 'Functor' layers to peel off @s@
-- and @t@ to reach @a@ and @b@, respectively.
class DMap (n :: UNat) s t a b where
dmap' :: (a -> b) -> s -> t
How do we write the instances? Let's start with the obvious way, and then transform it into a way that will give better inference. The obvious way:
instance DMap 'Z a b a b where
dmap' = id
instance (Functor f, DMap n x y a b)
=> DMap ('S n) (f x) (f y) a b where
dmap' = fmap . dmap' @n
The trouble with writing it this way is the usual problem with multi-parameter instance resolution. GHC will only choose the first instance if it sees that the first argument is 'Z
and the second and fourth arguments are the same and the third and fifth arguments are the same. Similarly, it will only choose the second instance if it sees that the first argument is 'S
and the second argument is an application and the third argument is an application and the constructors applied in the second and third arguments are the same.
We want to choose the right instance as soon as we know the first argument. We can do that by simply shifting everything else to the left of the double arrow:
-- This stays the same.
class DMap (n :: UNat) s t a b where
dmap' :: (a -> b) -> s -> t
instance (s ~ a, t ~ b) => DMap 'Z s t a b where
dmap' = id
-- Notice how we're allowed to pull @f@, @x@,
-- and @y@ out of thin air here.
instance (Functor f, fx ~ (f x), fy ~ (f y), DMap n x y a b)
=> DMap ('S n) fx fy a b where
dmap' = fmap . dmap' @ n
Now, I claimed above that this gives better type inference than DDub's, so I'd better back that up. Let me just pull up GHCi
:
*DMap> :t dmap @3
dmap @3
:: (Functor f1, Functor f2, Functor f3) =>
(a -> b) -> f1 (f2 (f3 a)) -> f1 (f2 (f3 b))
That's precisely the type of fmap.fmap.fmap
. Perfect! With DDub's code, I instead get
dmap @3
:: (DMap (FType 3 c), DT (FType 3 c) a ~ c,
FType 3 (DT (FType 3 c) b) ~ FType 3 c) =>
(a -> b) -> c -> DT (FType 3 c) b
which is ... not so clear. As I mentioned in a comment, this could be fixed, but it adds a bit more complexity to code that is already somewhat complicated.
Just for fun, we can pull the same trick with traverse
and foldMap
.
dtraverse :: forall n f s t a b. (DTraverse (ToUnary n) s t a b, Applicative f) => (a -> f b) -> s -> f t
dtraverse = dtraverse' @(ToUnary n)
class DTraverse (n :: UNat) s t a b where
dtraverse' :: Applicative f => (a -> f b) -> s -> f t
instance (s ~ a, t ~ b) => DTraverse 'Z s t a b where
dtraverse' = id
instance (Traversable t, tx ~ (t x), ty ~ (t y), DTraverse n x y a b) => DTraverse ('S n) tx ty a b where
dtraverse' = traverse . dtraverse' @ n
dfoldMap :: forall n m s a. (DFold (ToUnary n) s a, Monoid m) => (a -> m) -> s -> m
dfoldMap = dfoldMap' @(ToUnary n)
class DFold (n :: UNat) s a where
dfoldMap' :: Monoid m => (a -> m) -> s -> m
instance s ~ a => DFold 'Z s a where
dfoldMap' = id
instance (Foldable t, tx ~ (t x), DFold n x a) => DFold ('S n) tx a where
dfoldMap' = foldMap . dfoldMap' @ n