I have a library that currently demands of users that they provide a helper function with type:
tEnum :: (KnownNat n) => MyType -> Finite n
so that the library implementation can use a very efficient sized vector representation of a function with type:
foo :: MyType -> a
(MyType
is discrete and finite.)
Assuming that deriving a Generic
instance for MyType
is possible, is there a way to generate tEnum
automatically, thus lifting that burden from my library's users?
I would also like to go the other way; that is, automatically derive:
tGen :: (KnownNat n) => Finite n -> MyType
I have something working for at least the tEnum
side of things. Since you did not specify your representation of Finite
I used my own Finite
and Nat
.
I have included a full code snippet with an example at the bottom of the post, but will only discuss the generic programming parts, leaving out the reasonably standard construction of Peano arithmetic and various useful theorems about it.
A typeclass is used to keep track of things that can be converted into/out of these finite enums. The important bit here is the default type signatures and the default definitions: these mean that if someone derives EnumFin
for a class deriving Generic
, they don't have to actually write any code, as these defaults will be used. The defaults use methods from another class, which is implemented for the various kinds of things that GHC.Generics
can produce. Notice that both the normal and the default signatures use (n ~ ...) => ... n
instead of writing the size of the Finite
directly in the type signature; this is because GHC will otherwise detect that the default signatures don't have to match the regular signatures (in the case of a class implementation that defines Size
but not fromFin
or toFin
):
class EnumFin a where
type Size a :: Nat
type Size a = GSize (Rep a)
toFin :: (n ~ Size a) => a -> Finite n
default toFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
=> a -> Finite n
toFin = gToFin . from
fromFin :: (n ~ Size a) => Finite n -> a
default fromFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
=> Finite n -> a
fromFin = to . gFromFin
There are actually also a couple of other utility methods in the class. These are used by the actual generic implementation to get the minimum/maximum Finite n
produced by an implementation (0
and n
) without having to use more typeclasses & propagate KnownNat
-style constraints:
zero :: (n ~ Size a) => Finite n
default zero :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
=> Finite n
zero = gzero @(Rep a)
gt :: (n ~ Size a) => Finite n
default gt :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
=> Finite n
gt = ggt @(Rep a)
The class declaration for the generic class is fairly simple; note however that its parameter is kind * -> *
, not *
:
class GEnumFin f where
type GSize f :: Nat
gToFin :: f a -> Finite (GSize f)
gFromFin :: Finite (GSize f) -> f a
gzero :: Finite (GSize f)
ggt :: Finite (GSize f)
This generics class now must be implemented for each of the relevant generic constructors. For example, U1
is a very simple one, referring to a constructor without fields, which is just encoded as the Finite
number 0
:
instance GEnumFin U1 where
type GSize U1 = 'Z
gToFin U1 = ZF ZS
gFromFin (ZF ZS) = U1
gzero = ZF ZS
ggt = ZF ZS
:*:
is used to combine individual fields, so both parts need to be encoded (it encodes lhs*(m+1)+rhs
where m
is the max value of the rhs):
instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :*: b) where
type GSize (a :*: b) = Plus (Times (GSize a) ('S (GSize b))) (GSize b)
gToFin (a :*: b) = addFin (mulFin (gToFin a) (SF (ggt @b))) (gToFin b)
gFromFin x = (gFromFin a :*: gFromFin b)
where (a, b) = quotRemFin (toSN (ggt @a)) (toSN (ggt @b)) x
gzero = addFin (mulFin (gzero @a) (SF (ggt @b))) (gzero @b)
ggt = addFin (mulFin (ggt @a) (SF (ggt @b))) (ggt @b)
:+:
on the other hand is used when representing sums, and so must be able to encode either of its constituents (it encodes the left hand side as 0..n
and the right as n+1...n+1+m
):
instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :+: b) where
type GSize (a :+: b) = 'S (Plus (GSize a) (GSize b))
gToFin (L1 a) = case proofPlusComm (toSN (gzero @a)) (toSN (gzero @b)) of
Refl -> addFin (injFin (gzero @b)) (gToFin a)
gToFin (R1 b) = addFin (SF (ggt @a)) (gToFin b)
gFromFin x = case proofPlusComm (toSN (ggt @a)) (toSN (ggt @b)) of
Refl -> splitFin (toSN (ggt @b)) (toSN (ggt @a)) x
(R1 . gFromFin @b) (L1 . gFromFin @a)
gzero = addFin (injFin (gzero @a)) (gzero @b)
ggt = addFin (SF (ggt @a)) (ggt @b)
There is also an important instance for a single constructor field, which requires that the contained type also implement EnumFin
:
instance (EnumFin a) => GEnumFin (K1 i a) where
type GSize (K1 i a) = Size a
gToFin (K1 a) = toFin a
gFromFin = K1 . fromFin
gzero = zero @a
ggt = gt @a
Finally, it is necessary to implement the M1
constructor, which is used to attach metadata to the generic tree, and which we don't care about at all here:
instance forall i c a. (GEnumFin a) => GEnumFin (M1 i c a) where
type GSize (M1 i c a) = GSize a
gToFin (M1 a) = gToFin a
gFromFin = M1 . gFromFin
gzero = gzero @a
ggt = ggt @a
For completeness, here is a complete file that defines all of the Nat
/Finite
infrastructure used above and exhibits using the Generic
implementation:
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveGeneric #-}
import GHC.Generics
import Data.Type.Equality
-- Fairly standard Peano naturals & various useful theorems about them:
data Nat = Z | S Nat
data SNat (n :: Nat) where
ZS :: SNat 'Z
SS :: SNat n -> SNat ('S n)
deriving instance Show (SNat n)
type family Plus (n :: Nat) (m :: Nat) where
Plus 'Z m = m
Plus ('S n) m = 'S (Plus n m)
plus :: SNat n -> SNat m -> SNat (Plus n m)
plus ZS m = m
plus (SS n) m = SS (plus n m)
proofPlusNZ :: SNat n -> Plus n 'Z :~: n
proofPlusNZ ZS = Refl
proofPlusNZ (SS n) = case proofPlusNZ n of Refl -> Refl
proofPlusNS :: SNat n -> SNat m -> Plus n ('S m) :~: 'S (Plus n m)
proofPlusNS ZS _ = Refl
proofPlusNS (SS n) m = case proofPlusNS n m of Refl -> Refl
proofPlusAssoc :: SNat n -> SNat m -> SNat o
-> Plus n (Plus m o) :~: Plus (Plus n m) o
proofPlusAssoc ZS _ _ = Refl
proofPlusAssoc (SS n) ZS _ = case proofPlusNZ n of Refl -> Refl
proofPlusAssoc (SS n) (SS m) ZS =
case proofPlusNZ m of
Refl -> case proofPlusNZ (plus n (SS m)) of
Refl -> Refl
proofPlusAssoc (SS n) (SS m) (SS o) =
case proofPlusAssoc n (SS m) (SS o) of Refl -> Refl
proofPlusComm :: SNat n -> SNat m -> Plus n m :~: Plus m n
proofPlusComm ZS ZS = Refl
proofPlusComm ZS (SS m) = case proofPlusNZ m of Refl -> Refl
proofPlusComm (SS n) ZS = case proofPlusNZ n of Refl -> Refl
proofPlusComm (SS n) (SS m) =
case proofPlusComm (SS n) m of
Refl -> case proofPlusComm n (SS m) of
Refl -> case proofPlusComm n m of
Refl -> Refl
type family Times (n :: Nat) (m :: Nat) where
Times 'Z m = 'Z
Times ('S n) m = Plus m (Times n m)
times :: SNat n -> SNat m -> SNat (Times n m)
times ZS _ = ZS
times (SS n) m = plus m (times n m)
proofMultNZ :: SNat n -> Times n 'Z :~: 'Z
proofMultNZ ZS = Refl
proofMultNZ (SS n) = case proofMultNZ n of Refl -> Refl
proofMultNS :: SNat n -> SNat m -> Times n ('S m) :~: Plus n (Times n m)
proofMultNS ZS ZS = Refl
proofMultNS ZS (SS m) =
case proofMultNZ (SS m) of
Refl -> case proofMultNZ m of
Refl -> Refl
proofMultNS (SS n) ZS =
case proofMultNS n ZS of Refl -> Refl
proofMultNS (SS n) (SS m) =
case proofMultNS (SS n) m of
Refl -> case proofMultNS n (SS m) of
Refl -> case proofMultNS n m of
Refl -> case lemma1 n m (times n (SS m)) of
Refl -> Refl
where lemma1 :: SNat n -> SNat m -> SNat o -> Plus n ('S (Plus m o))
:~:
'S (Plus m (Plus n o))
lemma1 n' m' o' =
case proofPlusComm n' (SS (plus m' o')) of
Refl -> case proofPlusComm m' (plus n' o') of
Refl -> case proofPlusAssoc m' o' n' of
Refl -> case proofPlusComm n' o' of
Refl -> Refl
proofMultSN :: SNat n -> SNat m -> Times ('S n) m :~: Plus (Times n m) m
proofMultSN ZS m = case proofPlusNZ m of Refl -> Refl
proofMultSN (SS n) m =
case proofPlusNZ (times n m) of
Refl -> case proofPlusComm m (plus m (plus (times n m) ZS)) of
Refl -> Refl
proofMultComm :: SNat n -> SNat m -> Times n m :~: Times m n
proofMultComm ZS ZS = Refl
proofMultComm ZS (SS m) = case proofMultNZ (SS m) of
Refl -> case proofMultComm ZS m of
Refl -> Refl
proofMultComm (SS n) ZS = case proofMultComm n ZS of Refl -> Refl
proofMultComm (SS n) (SS m) =
case proofMultNS n m of
Refl -> case proofMultNS m n of
Refl -> case proofPlusAssoc m n (times n m) of
Refl -> case proofPlusAssoc n m (times m n) of
Refl -> case proofPlusComm n m of
Refl -> case proofMultComm n m of
Refl -> Refl
-- `Finite n` represents a number in 0..n (inclusive).
--
-- Notice that the "zero" branch includes an `SNat`; this is useful to be
-- able to conveniently write `toSN` below (generally, to be able to
-- reflect the `n` component to the value level) without needing to use a
-- singleton typeclass & pass constraitns around everywhere.
--
-- It should be possible to switch this out for other implementations of
-- `Finite` with different choices, but may require rewriting many of
-- the following functions.
data Finite (n :: Nat) where
ZF :: SNat n -> Finite n
SF :: Finite n -> Finite ('S n)
deriving instance Show (Finite n)
toSN :: Finite n -> SNat n
toSN (ZF sn) = sn
toSN (SF f) = SS (toSN f)
addFin :: forall n m. Finite n -> Finite m -> Finite (Plus n m)
addFin (ZF n) (ZF m) = ZF (plus n m)
addFin (ZF n) (SF b) =
case proofPlusNS n (toSN b) of
Refl -> SF (addFin (ZF n) b)
addFin (SF a) b = SF (addFin a b)
mulFin :: forall n m. Finite n -> Finite m -> Finite (Times n m)
mulFin (ZF n) (ZF m) = ZF (times n m)
mulFin (ZF n) (SF b) = case proofMultNS n (toSN b) of
Refl -> addFin (ZF n) (mulFin (ZF n) b)
mulFin (SF a) b = addFin b (mulFin a b)
quotRemFin :: SNat n -> SNat m -> Finite (Plus (Times n ('S m)) m)
-> (Finite n, Finite m)
quotRemFin nn mm xx = go mm xx nn mm (ZF ZS) (ZF ZS)
where go :: forall n m s p q r.
( Plus q s ~ n, Plus r p ~ m)
=> SNat m
-> Finite (Plus (Times s ('S m)) p)
-> SNat s
-> SNat p
-> Finite q
-> Finite r
-> (Finite n, Finite m)
go _ (ZF _) s p q r = (addFin q (ZF s), addFin r (ZF p))
go m (SF x) s (SS p) q r =
case proofPlusComm (SS p) (times s m) of
Refl -> case proofPlusNS (times s (SS m)) p of
Refl -> case proofPlusNS (toSN r) p of
Refl -> go m x s p q (SF r)
go m (SF x) (SS s) ZS q _ =
case proofPlusNS (toSN q) s of
Refl -> case proofMultSN s (SS m) of
Refl -> case proofPlusNS (times s (SS m)) m of
Refl -> case proofPlusComm (times s (SS m)) (SS m) of
Refl -> case proofPlusNZ (times (SS s) (SS m)) of
Refl -> go m x s m (SF q) (ZF ZS)
splitFin :: forall n m a. SNat n -> SNat m -> Finite ('S (Plus n m))
-> (Finite n -> a) -> (Finite m -> a) -> a
splitFin nn mm xx f g = go nn mm xx mm (ZF ZS)
where go :: forall r s. (Plus r s ~ m)
=> SNat n -> SNat m -> Finite ('S (Plus n s))
-> SNat s -> Finite r -> a
go _ _ (ZF _) s r = g (addFin r (ZF s))
go n m (SF x) (SS s) r =
case proofPlusNS (toSN r) s of
Refl -> case proofPlusNS n s of
Refl -> go n m x s (SF r)
go n _ (SF x) ZS _ = case proofPlusNZ n of Refl -> f x
injFin :: Finite n -> Finite ('S n)
injFin (ZF n) = ZF (SS n)
injFin (SF a) = SF (injFin a)
toNum :: (Num a) => Finite n -> a
toNum (ZF _) = 0
toNum (SF n) = 1 + toNum n
-- The actual classes & Generic stuff:
class EnumFin a where
type Size a :: Nat
type Size a = GSize (Rep a)
toFin :: (n ~ Size a) => a -> Finite n
default toFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
=> a -> Finite n
toFin = gToFin . from
fromFin :: (n ~ Size a) => Finite n -> a
default fromFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
=> Finite n -> a
fromFin = to . gFromFin
zero :: (n ~ Size a) => Finite n
default zero :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
=> Finite n
zero = gzero @(Rep a)
gt :: (n ~ Size a) => Finite n
default gt :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
=> Finite n
gt = ggt @(Rep a)
class GEnumFin f where
type GSize f :: Nat
gToFin :: f a -> Finite (GSize f)
gFromFin :: Finite (GSize f) -> f a
gzero :: Finite (GSize f)
ggt :: Finite (GSize f)
instance GEnumFin U1 where
type GSize U1 = 'Z
gToFin U1 = ZF ZS
gFromFin (ZF ZS) = U1
gzero = ZF ZS
ggt = ZF ZS
instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :*: b) where
type GSize (a :*: b) = Plus (Times (GSize a) ('S (GSize b))) (GSize b)
gToFin (a :*: b) = addFin (mulFin (gToFin a) (SF (ggt @b))) (gToFin b)
gFromFin x = (gFromFin a :*: gFromFin b)
where (a, b) = quotRemFin (toSN (ggt @a)) (toSN (ggt @b)) x
gzero = addFin (mulFin (gzero @a) (SF (ggt @b))) (gzero @b)
ggt = addFin (mulFin (ggt @a) (SF (ggt @b))) (ggt @b)
instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :+: b) where
type GSize (a :+: b) = 'S (Plus (GSize a) (GSize b))
gToFin (L1 a) = case proofPlusComm (toSN (gzero @a)) (toSN (gzero @b)) of
Refl -> addFin (injFin (gzero @b)) (gToFin a)
gToFin (R1 b) = addFin (SF (ggt @a)) (gToFin b)
gFromFin x = case proofPlusComm (toSN (ggt @a)) (toSN (ggt @b)) of
Refl -> splitFin (toSN (ggt @b)) (toSN (ggt @a)) x
(R1 . gFromFin @b) (L1 . gFromFin @a)
gzero = addFin (injFin (gzero @a)) (gzero @b)
ggt = addFin (SF (ggt @a)) (ggt @b)
instance forall i c a. (GEnumFin a) => GEnumFin (M1 i c a) where
type GSize (M1 i c a) = GSize a
gToFin (M1 a) = gToFin a
gFromFin = M1 . gFromFin
gzero = gzero @a
ggt = ggt @a
instance (EnumFin a) => GEnumFin (K1 i a) where
type GSize (K1 i a) = Size a
gToFin (K1 a) = toFin a
gFromFin = K1 . fromFin
gzero = zero @a
ggt = gt @a
-- Demo:
data Foo = A | B deriving (Show, Generic)
data Bar = C | D deriving (Show, Generic)
data Baz = E Foo | F Bar | G Foo Bar deriving (Show, Generic)
instance EnumFin Foo
instance EnumFin Bar
instance EnumFin Baz
main :: IO ()
main = do
putStrLn $ show $ toNum @Integer $ gt @Baz
putStrLn $ show $ toNum @Integer $ toFin $ E A
putStrLn $ show $ toNum @Integer $ toFin $ E B
putStrLn $ show $ toNum @Integer $ toFin $ F C
putStrLn $ show $ toNum @Integer $ toFin $ F D
putStrLn $ show $ toNum @Integer $ toFin $ G A C
putStrLn $ show $ toNum @Integer $ toFin $ G A D
putStrLn $ show $ toNum @Integer $ toFin $ G B C
putStrLn $ show $ toNum @Integer $ toFin $ G B D
putStrLn $ show $ fromFin @Baz $ toFin $ E A
putStrLn $ show $ fromFin @Baz $ toFin $ E B
putStrLn $ show $ fromFin @Baz $ toFin $ F C
putStrLn $ show $ fromFin @Baz $ toFin $ F D
putStrLn $ show $ fromFin @Baz $ toFin $ G A C
putStrLn $ show $ fromFin @Baz $ toFin $ G A D
putStrLn $ show $ fromFin @Baz $ toFin $ G B C
putStrLn $ show $ fromFin @Baz $ toFin $ G B D