Search code examples
performancehaskellparallel-processingk-means

Why I get no gains during parallelizing K-means?


I am learning parallel programming in Haskell using Simon Marlow's book. I am on Chapter 3. I don't understand why launching this code on 6 core machine with -N6 flag changes nothing about performance. I got similar problem and then just realized that adding more data to process makes difference between parallel and single-thread execution more significant. But in that case even adding more data gives no time of execution difference. In my exercise I tried to find centroids of world cities. To give it more data I created some 'fake' cities.

Here is first five lines of the cities file:

worldcities.csv

"city","city_ascii","lat","lng","country","iso2","iso3","admin_name","capital","population","id"
"Tokyo","Tokyo","35.6839","139.7744","Japan","JP","JPN","Tōkyō","primary","39105000","1392685764"
"Jakarta","Jakarta","-6.2146","106.8451","Indonesia","ID","IDN","Jakarta","primary","35362000","1360771077"
"Delhi","Delhi","28.6667","77.2167","India","IN","IND","Delhi","admin","31870000","1356872604"
"Manila","Manila","14.6000","120.9833","Philippines","PH","PHL","Manila","primary","23971000","1608618140"

Here is a code I used for adding more 'cities':

main :: IO ()
main = do
    [c] <- getArgs
    cities <- getCities c
    let target = "wc_extended.csv"
    BS.writeFile target (encode [("city" :: BS.ByteString, "lat" :: BS.ByteString, "lng" :: BS.ByteString)])
    forM_ cities $ \(City name (Point lat lng)) -> BS.appendFile target $ encode [
            (name,lat,lng)
        ,   ("A" <> name,lat-10,lng+10)
        ,   ("B" <> name,lat-20,lng+20)
        ,   ("C" <> name,lat-30,lng+30)
        ,   ("D" <> name,lat-40,lng+40)
        ,   ("E" <> name,lat-50,lng+50)
        ]

And here are the source files of my K-means clustering program.

Types.hs

module Types where

import Data.ByteString (ByteString)
import System.Random
import System.Random.Stateful

data Point = Point {
        lat :: !Double
    ,   lng :: !Double
} deriving (Eq,Show)

instance Uniform Point where
    uniformM g = do
                    lat <- uniformRM (-180, 180) g
                    lng <- uniformRM (-180, 180) g
                    return $ Point lat lng

instance Semigroup Point where
    (Point lat lng) <> (Point lat' lng') = Point (lat + lat') (lng + lng')

instance Monoid Point where
    mempty = Point 0 0

sqDistance :: Point -> Point -> Double
sqDistance (Point lat lng) (Point lat' lng') = (lat-lat')^2 + (lng-lng')^2 

data City = City {
        name :: ByteString
    ,   location :: Point
} deriving Show

data Cluster = Cluster {
        cId :: Int
    ,   center :: Point
} deriving (Eq, Show)

data PointSum = PointSum !Int !Point

instance Semigroup PointSum where
    (PointSum c p) <> (PointSum c' p') = PointSum (c+c') (p <> p')

instance Monoid PointSum where
    mempty = PointSum 0 mempty

addToPointSum :: Point -> PointSum -> PointSum
addToPointSum point' (PointSum count point) = PointSum (count+1) $ point <> point'

pointSumToCluster :: Int -> PointSum -> Cluster
pointSumToCluster i (PointSum count (Point lat lng)) = Cluster {
      cId = i
    , center = Point (lat / fromIntegral count) (lng / fromIntegral count)
}

CitiesLoader.hs

{-# LANGUAGE OverloadedStrings #-}
module CitiesLoader where

import Data.Attoparsec.ByteString
import Data.Csv
import Data.Vector (Vector)
import qualified Data.Vector as V 
import qualified Data.ByteString as BS
import qualified Data.ByteString.UTF8 as UTF8
import Data.Csv.Parser (csvWithHeader)
import Data.HashMap.Strict ( (!) )

import Types

getCSV :: FilePath -> IO (Vector NamedRecord)
getCSV path = do
    raw <- BS.readFile path
    case parseOnly (csvWithHeader defaultDecodeOptions) raw of
        Left error -> do
            putStrLn $ "Error during parsing: " <> error <> ", returned empty result"
            return mempty
        Right (_, values) -> return values

extractCities :: Vector NamedRecord -> Vector City
extractCities = fmap f
                where f vmap = City (vmap ! "city") $ Point ((read . UTF8.toString) $ vmap ! "lat") ((read . UTF8.toString) $ vmap ! "lng")

getCities :: FilePath -> IO (Vector City)
getCities = (fmap . fmap) extractCities getCSV

Clustering.hs

module Clustering where

import Types
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as M
import Data.Function (on)
import Data.List (minimumBy)
import Control.Monad.Trans.Except
import Control.Parallel.Strategies

assign :: Int -> [Cluster] -> Vector City -> Vector PointSum
assign n clusters points = V.create $ do
    vec <- M.replicate n mempty
    let addpoint (City _ p) = M.modify vec (addToPointSum p) (cId $ nearest p)
    V.mapM_ addpoint points
    return vec
    where nearest p = fst $ minimumBy (compare `on` snd) [(c, sqDistance p (center c)) | c <- clusters]

makeNewClusters :: Vector PointSum -> [Cluster]
makeNewClusters vec = [pointSumToCluster i ps | (i,ps@(PointSum count _)) <- zip [0..] (V.toList vec), count > 0]

step :: Int -> Vector City -> [Cluster] -> [Cluster]
step n cities clusters = makeNewClusters $ assign n clusters cities

kmeansSeq :: Int -> Vector City -> [Cluster] -> Except String [Cluster]
kmeansSeq limit cities clusters = loop 0 clusters
                            where loop n c | n > limit = throwE "reached loop limit"
                                  loop n c = let c' = step nClusters cities c
                                                     in if c' == c
                                                            then return c'
                                                            else loop (n+1) c'
                                  nClusters = length clusters

split :: Int -> Vector a -> [Vector a]
split numChunks xs = chunk (V.length xs `quot` numChunks) xs

chunk :: Int -> Vector a -> [Vector a]
chunk n xs | V.null xs = []
chunk n xs = as : chunk n bs
    where (as, bs) = V.splitAt n xs

combine :: Vector PointSum -> Vector PointSum -> Vector PointSum
combine = V.zipWith (<>)

parStepsStrat :: Int -> [Vector City] -> [Cluster] -> [Cluster]
parStepsStrat n pointss clusters = makeNewClusters $ foldr1 combine (map (assign n clusters) pointss `using` parList rseq)

kMeansStrat :: Int -> Int -> Vector City -> [Cluster] -> Except String [Cluster]
kMeansStrat limit numChunks points clusters = loop 0 clusters
                                        where loop n clusters | n > limit = throwE "reached loop limit"
                                              loop n clusters = let c' = parStepsStrat nClusters chunks clusters
                                                                         in if c' == clusters
                                                                            then return c'
                                                                            else loop (n+1) c'
                                              chunks = split numChunks points
                                              nClusters = length clusters

Main.hs

{-# LANGUAGE OverloadedStrings#-}

module Main where
import System.Environment (getArgs)
import CitiesLoader (getCities)
import System.Random
import System.Random.Stateful (newIOGenM, uniformListM)
import Types
import Clustering
import Control.Monad.Trans.Except (runExcept)
import Data.Vector(forM_)
import Data.Csv(encode)
import qualified Data.ByteString.Lazy as BS

main :: IO ()
main = do
    [c] <- getArgs
    cities <- getCities c
    print (length cities)
    let seed = mkStdGen $ length cities
    g <- newIOGenM seed
    centroids <- uniformListM 1000 g
    let clusters = zipWith Cluster [0..] centroids
    case runExcept (kMeansStrat 10000 6 cities clusters) of
        Left err -> putStrLn err
        Right c -> print c

When I run this on file containing 257431 records I got following times of execution:

cabal exec kcities -- wc_extended.csv +RTS -N1 -s -l

Total time 266.631s (266.351s elapsed)

and threadscope profile

N1 threadscope

cabal exec kcities -- wc_extended.csv +RTS -N6 -s -l

Total time 1737.342s (340.016s elapsed) (execution time even increased)

and threadscope profile

enter image description here


Solution

  • Instead of parList rseq try using parList rdeepseq. For this of course you will need an NFData instance for PointSum, but I am sure you can figure out how it can be done.

    Problem is that you are parallelizing evaluation to WHNF only, which means none of the work is done in parallel.

    Feel free to leave a comment if it is not enough information, and I can expend this as a more detailed answer