Search code examples
arrayshaskellrepa

Repa arrays indexed by a bounded data type?


I want to achieve something similar to the bounded arrays in the standard array package but using repa arrays.

What is the nice and clean way to achieve this?

This is what I tried, but there must be a better way than wrapping everything in custom functions that check for bounds:

import  Data.Array.Repa

data C = A | F | L deriving (Eq,Enum,Ord,Bounded,Show)

data Ballot c = Ballot {
    vote::Array U (Z :. Int) Int
    } deriving Show

mkBallot::(Eq c ,Enum c,Ord c, Bounded c, Show c) => c -> Ballot c
mkBallot c = Ballot $ fromListUnboxed (Z :. max) (genSc c)
where
    max = (fromEnum (maxBound `asTypeOf` c)) + 1

genSc::(Eq c,Enum c,Ord c,Bounded c,Show c) => c -> [Int]
genSc c = [ f x | x <- enumFrom (minBound `asTypeOf` c) , let f v = if x == c then 1 else 0]

showScore c b = index (vote b) (Z :. ((fromEnum c)))

Also I have tried to derive a Shape instance for (sh :. C) but to no avail, I can't really get my head around on how to implement some of the interfaces declared in the Shape class for my data type. I am writing the question with the hope that someone else has a way, but if not, I shall try again. Thank you!


Solution

  • You can make a shape instance for a wrapper around your bounded enum. I'm not sure this is the best way, but it sort of does what you want, I think.

    {-# LANGUAGE ScopedTypeVariables  #-}
    
    import Data.Array.Repa
    

    Here we make a shape instance over bounded things. We need an end-of-index for "full" arrays.

    data Idx a = Idx a | EOI
               deriving (Eq, Ord, Show)
    
    fromIdx :: forall a . (Bounded a, Enum a) => Idx a -> Int
    fromIdx EOI = fromEnum (maxBound :: a) - fromEnum (minBound :: a) + 1
    fromIdx (Idx x) = fromEnum x - fromEnum (minBound :: a)
    
    toIdx ::  forall a . (Bounded a, Enum a) => Int -> Idx a
    toIdx i | i < 0 = error "negative index"
    toIdx i = case compare i range of
      LT -> Idx $ toEnum (i + fromEnum (minBound :: a))
      EQ -> EOI
      GT -> error "out of range"
      where
        range = fromEnum (maxBound :: a) - fromEnum (minBound :: a) + 1
    
    instance (Bounded a, Enum a, Ord a) => Shape (Idx a) where
      rank _ = 1
      zeroDim = Idx minBound
      unitDim = Idx $ succ minBound
      intersectDim EOI n = n
      intersectDim n EOI = n
      intersectDim (Idx n1) (Idx n2) = Idx $ min n1 n2
      addDim = error "undefined"
      size = fromIdx
      sizeIsValid _ = True
      toIndex _ n = fromIdx n
      fromIndex _ i = toIdx i
      inShapeRange _ _ EOI = error "bad index"
      inShapeRange n1 n2 n = n >= n1 && n <= n2
      listOfShape n = [fromIdx n]
      shapeOfList [i] = toIdx i
      shapeOfList _ = error "unsupported shape"
      deepSeq (Idx n) x = n `seq` x
      deepSeq _ x = x
    

    With that, the ballot part is easy and clean:

    data C = A | F | L deriving (Eq, Enum, Ord, Bounded, Show)
    
    data Ballot c = Ballot { vote :: Array U (Idx c) Int
                           } deriving Show
    
    mkBallot :: (Eq c, Enum c, Ord c, Bounded c, Show c) => c -> Ballot c
    mkBallot c = Ballot $ fromListUnboxed EOI vec
      where
        vec = map (fromEnum . (== c)) [minBound .. maxBound]