I am trying to use a dfold defined here
:: KnownNat k
=> Proxy (p :: TyFun Nat * -> *)
-> (forall l. SNat l -> a -> (p @@ l) -> p @@ (l + 1))
-> (p @@ 0)
-> Vec k a
-> p @@ k
Basically it is a fold that allows you to return a new type after each cycle.
I am trying to generalize the bitonicSort defined in this project: https://github.com/adamwalker/clash-utils/blob/master/src/Clash/Sort.hs
I two functions that are important for the types that the dfold with generate:
:: forall n a. (KnownNat n, Ord a)
=> (Vec n a -> Vec n a) -- ^ The recursive step
-> (Vec (2 * n) a -> Vec (2 * n) a) -- ^ Merge step
-> Vec (2 * n) a -- ^ Input vector
-> Vec (2 * n) a -- ^ Output vector
:: forall n a. (Ord a , KnownNat n)
=> (Vec n a -> Vec n a) -- ^ The recursive step
-> Vec (2 * n) a -- ^ Input vector
-> Vec (2 * n) a -- ^ Output vector
The example used in the project mentioned above is:
:: forall a. (Ord a)
=> Vec 16 a -- ^ Input vector
-> Vec 16 a -- ^ Sorted output vector
bitonicSorterExample = sort16
sort16 = bitonicSort sort8 merge16
merge16 = bitonicMerge merge8
sort8 = bitonicSort sort4 merge8
merge8 = bitonicMerge merge4
sort4 = bitonicSort sort2 merge4
merge4 = bitonicMerge merge2
sort2 = bitonicSort id merge2
merge2 = bitonicMerge id
I went ahead and made a more general version.
genBitonic :: (Ord a, KnownNat n) =>
(Vec n a -> Vec n a, Vec (2 * n) a -> Vec (2 * n) a)
-> (Vec (2 * n) a -> Vec (2 * n) a, Vec (4 * n) a -> Vec (4 * n) a)
genBitonic (bSort,bMerge) = (bitonicSort bSort bMerge, bitonicMerge bMerge)
bitonicBase :: Ord a => (Vec 1 a -> Vec 1 a, Vec 2 a -> Vec 2 a)
bitonicBase = (id, bitonicMerge id)
With this version I can quick make new Bitonic Sorts like so:
bSort16 :: Ord a => Vec 16 a -> Vec 16 a
bSort16 = fst $ genBitonic $ genBitonic $ genBitonic $ genBitonic bitonicBase
bSort8 :: Ord a => Vec 8 a -> Vec 8 a
bSort8 = fst $ genBitonic $ genBitonic $ genBitonic bitonicBase
bSort4 :: Ord a => Vec 4 a -> Vec 4 a
bSort4 = fst $ genBitonic $ genBitonic bitonicBase
bSort2 :: Ord a => Vec 2 a -> Vec 2 a
bSort2 = fst $ genBitonic bitonicBase
Each Sort with work with a vector of the specified size.
testVec16 :: Num a => Vec 16 a
testVec16 = 9 :> 2 :> 8 :> 6 :> 3 :> 7 :> 0 :> 1 :> 4 :> 5 :> 2 :> 8 :> 6 :> 3 :> 7 :> 0 :> Nil
testVec8 :: Num a => Vec 8 a
testVec8 = 9 :> 2 :> 8 :> 6 :> 3 :> 7 :> 0 :> 1 :> Nil
testVec4 :: Num a => Vec 4 a
testVec4 = 9 :> 2 :> 8 :> 6 :> Nil
testVec2 :: Num a => Vec 2 a
testVec2 = 2 :> 9 :> Nil
Quick notes:
I am trying to the apply "genBitonic" to "bitonicBase" t times.
I am using CLaSH to synthesis this to VHDL, so I cannot use recursion to apply t times
We will always be sorting a vec size 2^t in to a vec of the same size
"Vec n a" is a vector of size n and type a
I would like to make a function that generates the function for a given Vec. I believe using dfold or dtfold, is the correct path here.
I wanted to do the fold with something like the function genBitonic
Then use fst
to get the function I need for sorting.
I had two possible designs:
One: Fold using composition to get a Function that that takes a base.
bSort8 :: Ord a => Vec 8 a -> Vec 8 a
bSort8 = fst $ genBitonic.genBitonic.genBitonic $ bitonicBase
Before the base was replied it would have resulted in something like
**If composition was performed three times**
foo3 ::
(Ord a, KnownNat n) =>
(Vec n a -> Vec n a, Vec (2 * n) a -> Vec (2 * n) a)
-> (Vec (2 * (2 * (2 * n))) a -> Vec (2 * (2 * (2 * n))) a,
Vec (4 * (2 * (2 * n))) a -> Vec (4 * (2 * (2 * n))) a)
Second idea was to use the bitonicBase as the value b to start accumulation on. This would have resulted directly in the form I need it in before I apply fst
is just meant to be the value building up inside of the dfold
In the dfold example they fold using a :>
which is just the vector form of the list operator :
>>> :t (:>)
(:>) :: a -> Vec n a -> Vec (n + 1) a
What I want to do is take a tuple of two functions like:
genBitonic :: (Ord a, KnownNat n) =>
(Vec n a -> Vec n a, Vec (2 * n) a -> Vec (2 * n) a)
-> (Vec (2 * n) a -> Vec (2 * n) a, Vec (4 * n) a -> Vec (4 * n) a)
And compose them.
So genBitonic . genBitonic
would have type:
(Vec n a -> Vec n a, Vec (2 * n) a -> Vec (2 * n) a)
-> (Vec (2 * (2 * n)) a -> Vec (2 * (2 * n)) a, Vec (4 * (2 * n)) a -> Vec (4 * (2 * n)) a)
So then the base function would be what solidifies the types. e.g.
bitonicBase :: Ord a => (Vec 1 a -> Vec 1 a, Vec 2 a -> Vec 2 a)
bitonicBase = (id, bitonicMerge id)
bSort4 :: Ord a => Vec 4 a -> Vec 4 a
bSort4 = fst $ genBitonic $ genBitonic bitonicBase
I am using the dfold to build the function for Vectors of length n that is equivalent to doing the recursion on a vector of length n.
I tried:
I tried to follow the example listed under dfold
data SplitHalf (a :: *) (f :: TyFun Nat *) :: *
type instance Apply (SplitHalf a) l = (Vec (2^l) a -> Vec (2^l) a, Vec (2 ^ (l + 1)) a -> Vec (2 ^ (l + 1)) a)
generateBitonicSortN2 :: forall k a . (Ord a, KnownNat k) => SNat k -> Vec (2^k) a -> Vec (2^k) a
generateBitonicSortN2 k = fst $ dfold (Proxy :: Proxy (SplitHalf a)) vecAcum base vecMath
vecMath = operationList k
vecAcum :: (KnownNat l, KnownNat gl, Ord a) => SNat l
-> (SNat gl -> SplitHalf a @@ gl -> SplitHalf a @@ (gl+1))
-> SplitHalf a @@ l
-> SplitHalf a @@ (l+1)
vecAcum l0 f acc = undefined -- (f l0) acc
base :: (Ord a) => SplitHalf a @@ 0
base = (id,id)
general :: (KnownNat l, Ord a)
=> SNat l
-> SplitHalf a @@ l
-> SplitHalf a @@ (l+1)
general _ (x,y) = (bitonicSort x y, bitonicMerge y )
operationList :: (KnownNat k, KnownNat l, Ord a)
=> SNat k
-> Vec k
(SNat l
-> SplitHalf a @@ l
-> SplitHalf a @@ (l+1))
operationList k0 = replicate k0 general
I am using the extensions the dfold source code uses
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE Trustworthy #-}
Error Messages:
Sort.hs:182:71: error:
* Could not deduce (KnownNat l) arising from a use of `vecAcum'
from the context: (Ord a, KnownNat k)
bound by the type signature for:
generateBitonicSortN2 :: (Ord a, KnownNat k) =>
SNat k -> Vec (2 ^ k) a -> Vec (2 ^ k) a
at Sort.hs:181:1-98
Possible fix:
add (KnownNat l) to the context of
a type expected by the context:
SNat l
-> (SNat l0
-> (Vec (2 ^ l0) a -> Vec (2 ^ l0) a,
Vec (2 ^ (l0 + 1)) a -> Vec (2 ^ (l0 + 1)) a)
-> (Vec (2 ^ (l0 + 1)) a -> Vec (2 ^ (l0 + 1)) a,
Vec (2 ^ ((l0 + 1) + 1)) a -> Vec (2 ^ ((l0 + 1) + 1)) a))
-> SplitHalf a @@ l
-> SplitHalf a @@ (l + 1)
* In the second argument of `dfold', namely `vecAcum'
In the second argument of `($)', namely
`dfold (Proxy :: Proxy (SplitHalf a)) vecAcum base vecMath'
In the expression:
fst $ dfold (Proxy :: Proxy (SplitHalf a)) vecAcum base vecMath
Sort.hs:182:84: error:
* Could not deduce (KnownNat l0) arising from a use of `vecMath'
from the context: (Ord a, KnownNat k)
bound by the type signature for:
generateBitonicSortN2 :: (Ord a, KnownNat k) =>
SNat k -> Vec (2 ^ k) a -> Vec (2 ^ k) a
at Sort.hs:181:1-98
The type variable `l0' is ambiguous
* In the fourth argument of `dfold', namely `vecMath'
In the second argument of `($)', namely
`dfold (Proxy :: Proxy (SplitHalf a)) vecAcum base vecMath'
In the expression:
fst $ dfold (Proxy :: Proxy (SplitHalf a)) vecAcum base vecMath
Failed, modules loaded: none.
** EDIT ** Added much more detail.
Your base
case was wrong; it should be
base :: (Ord a) => SplitHalf a @@ 0
base = (id, bitonicMerge id)
Putting it all together, here's a fully working version, tested on GHC 8.0.2 (but it should work all the same on a GHC 8.0.2-based CLaSH, modulo the Prelude
import stuff). It turns out the operationList
thing is not used except for its spine, so we can use a Vec n ()
{-# LANGUAGE DataKinds, GADTs, KindSignatures #-}
{-# LANGUAGE Rank2Types, ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies, TypeOperators, UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns -fno-warn-redundant-constraints #-}
{-# LANGUAGE NoImplicitPrelude #-}
import Prelude (Integer, (+), Num, ($), undefined, id, fst, Int, otherwise)
import CLaSH.Sized.Vector
import CLaSH.Promoted.Nat
import Data.Singletons
import GHC.TypeLits
import Data.Ord
type ExpVec k a = Vec (2 ^ k) a
data SplitHalf (a :: *) (f :: TyFun Nat *) :: *
type instance Apply (SplitHalf a) k = (ExpVec k a -> ExpVec k a, ExpVec (k + 1) a -> ExpVec (k + 1) a)
generateBitonicSortN2 :: forall k a . (Ord a, KnownNat k) => SNat k -> ExpVec k a -> ExpVec k a
generateBitonicSortN2 k = fst $ dfold (Proxy :: Proxy (SplitHalf a)) step base (replicate k ())
step :: SNat l -> () -> SplitHalf a @@ l -> SplitHalf a @@ (l+1)
step SNat _ (sort, merge) = (bitonicSort sort merge, bitonicMerge merge)
base = (id, bitonicMerge id)
This works as expected, e.g.:
*Main> generateBitonicSortN2 (snatProxy Proxy) testVec2
*Main> generateBitonicSortN2 (snatProxy Proxy) testVec4
*Main> generateBitonicSortN2 (snatProxy Proxy) testVec8
*Main> generateBitonicSortN2 (snatProxy Proxy) testVec16