Search code examples
clojure

Sort primitive array with custom Comparator on Clojure


I would like to sort a primitive Java Array using a custom Comparator, but I am getting a type error. I think the comparator function is creating a Comparator<java.lang.Object> rather than a Comparator<Long>, but I can't figure out how to get around this.

Here is a minimal example:

x.core=> (def x (double-array [4 3 5 6 7]))
#'x.core/x
x.core=> (java.util.Arrays/sort x (comparator #(> %1 %2)))

ClassCastException [D cannot be cast to [Ljava.lang.Object;  x.core/eval1524 (form-init5588058267991397340.clj:1)

I have tried adding different type hints to the comparator function, but frankly I am relatively new to the language and was basically just throwing darts.

I have deliberately simplified the example above to focus on the key question, which is a type error. In the sections below I try to give some more detail to motivate the question and demonstrate why I am using a custom Comparator.

Motivation

What I am trying to do is duplicate R's order function, which works like this:

> x = c(7, 2, 5, 3, 1, 4, 6)
> order(x)
[1] 5 2 4 6 3 7 1
> x[order(x)]
[1] 1 2 3 4 5 6 7

As you can see it returns the permutation of indices that will sort its input vector.

Here is a working solution in Clojure:

(defn order
  "Permutation of indices sorted by x"
  [x]
  (let [v (vec x)]
    (sort-by #(v %) (range (count v)))))

x.core=> (order [7 2 5 3 1 4 6])
(4 1 3 5 2 6 0)

(Note that R is 1-indexed while Clojure is 0-indexed.) The trick is to sort one vector (namely the indices of x [0, 1, ..., (count x)] by the vector x itself.

R vs. Clojure Performance

Unfortunately, I am bothered by the performance of this solution. The R solution is much faster:

> x = runif(1000000)
> system.time({ y = order(x) })
   user  system elapsed
  0.041   0.004   0.046

Corresponding Clojure code:

x.core=> (def x (repeatedly 1000000 rand))
#'x.core/x
x.core=> (time (def y (order x)))
"Elapsed time: 2857.216452 msecs"
#'x.core/y

Primitive Arrays the Solution?

I discovered that primitive arrays tend to sort in comparable time to R:

> x = runif(1000000)
> system.time({ y = sort(x) })
   user  system elapsed
  0.061   0.005   0.069

vs.

x.core=> (def x (double-array (repeatedly 1000000 rand)))
#'x.core/x
x.core=> (time (java.util.Arrays/sort x))
"Elapsed time: 86.827277 msecs"
nil

This is the motivation for my attempt to use a custom Comparator with the java.util.Arrays class. My hope is that the speed will be comparable to R.

I should add that I can use a custom Comparator with an ArrayList as shown below, but the performance was no better than my starting function:

(defn order2
  [x]
  (let [v (vec x)
        compx (comparator (fn [i j] (< (v i) (v j))))
        ix (java.util.ArrayList. (range (count v)))]
    (java.util.Collections/sort ix compx)
    (vec ix)))

Any help will be appreciated, even if you just want to give some general Clojure advice. I'm still learning the language and having a lot of fun doing it. :-)


Edit

Per Carcigenicate's answer below,

(defn order
  [x]
  (let [ix (int-array (range (count x)))]
    (vec (-> (java.util.Arrays/stream ix)
             (.boxed)
             (.sorted (fn [i j] (< (aget x i) (aget x j))))
             (.mapToInt
               (proxy [java.util.function.ToIntFunction] []
                 (applyAsInt [^long d] d)))
             (.toArray)))))

will work:

x.core=> (def x (double-array [5 3 1 3.14 -10]))
#'x.core/x
x.core=> (order x)
[4 2 1 3 0]
x.core=> (map #(aget x %) (order x))
(-10.0 1.0 3.0 3.14 5.0)

Unfortunately it's super slow. I guess primitives might not be the answer after all.


Solution

  • Here's a Clojure implementation of the order function using quicksort with randomized pivots. It gets reasonably close to R: using your benchmark with a million doubles, I'm getting timings mostly in the 520-530 ms range, while R generally hovers around 500 ms here.

    Update: With a very basic two-threaded version (2x quicksort followed by a merge step that produces the output vector) I'm getting noticeably improved timings – the worst benchmark average was 415 ms, otherwise I tend to get results in the 325-365 ms range. See end of this message for the two-threaded version, or if you prefer either version in gist form, here they are – two-threaded, single-threaded.

    Note that it pours its input into an array of doubles as an intermediate step and ultimately returns a vector of longs. Pouring a million doubles into a vector seems to take just over 30 ms on my box, so you could leave off that step if you're happy with an array result.

    The main complication is the invokePrim – as of Clojure 1.9.0-RC1, a regular function call in that position would result in boxing. Other approaches are possible, but this works and seems straightforward enough.

    See end of this message for some benchmark results.; the lower quantile result from the first run is actually the best reported result

    (defn order2 [xs]
      (let [rnd (java.util.Random.)
            a1 (double-array xs)
            a2 (long-array (alength a1))]
        (dotimes [i (alength a2)]
          (aset a2 i i))
        (letfn [(quicksort [^long l ^long h]
                  (if (< l h)
                    (let [p (.invokePrim ^clojure.lang.IFn$LLL partition l h)]
                      (quicksort l (dec p))
                      (quicksort (inc p) h))))
                (partition ^long [^long l ^long h]
                  (let [pidx (+ l (.nextInt rnd (- h l)))
                        pivot (aget a1 pidx)]
                    (swap1 a1 pidx h)
                    (swap2 a2 pidx h)
                    (loop [i (dec l)
                           j l]
                      (if (< j h)
                        (if (< (aget a1 j) pivot)
                          (let [i (inc i)]
                            (swap1 a1 i j)
                            (swap2 a2 i j)
                            (recur i (inc j)))
                          (recur i (inc j)))
                        (let [i (inc i)]
                          (when (< (aget a1 h) (aget a1 i))
                            (swap1 a1 i h)
                            (swap2 a2 i h))
                          i)))))
                (swap1 [^doubles a ^long i ^long j]
                  (let [tmp (aget a i)]
                    (aset a i (aget a j))
                    (aset a j tmp)))
                (swap2 [^longs a ^long i ^long j]
                  (let [tmp (aget a i)]
                    (aset a i (aget a j))
                    (aset a j tmp)))]
          (quicksort 0 (dec (alength a1)))
          (vec a2))))
    

    Benchmark results (NB. the first run uses x defined as in the question text – (def x (repeatedly 1000000 rand)); it also uses c/bench, whereas the following runs use c/quick-bench):

    user> (c/bench (order2 x))
    Evaluation count : 120 in 60 samples of 2 calls.
                 Execution time mean : 522.485408 ms
        Execution time std-deviation : 33.490530 ms
       Execution time lower quantile : 470.089782 ms ( 2.5%)
       Execution time upper quantile : 575.687990 ms (97.5%)
                       Overhead used : 15.378363 ns
    nil
    user> (let [x (repeatedly 1000000 rand)]
            (c/quick-bench (order2 x)))
    Evaluation count : 6 in 6 samples of 1 calls.
                 Execution time mean : 527.020004 ms
        Execution time std-deviation : 14.846061 ms
       Execution time lower quantile : 507.175127 ms ( 2.5%)
       Execution time upper quantile : 543.675752 ms (97.5%)
                       Overhead used : 15.378363 ns
    nil
    user> (let [x (repeatedly 1000000 rand)]
            (c/quick-bench (order2 x)))
    Evaluation count : 6 in 6 samples of 1 calls.
                 Execution time mean : 513.476501 ms
        Execution time std-deviation : 12.828449 ms
       Execution time lower quantile : 497.164534 ms ( 2.5%)
       Execution time upper quantile : 525.094463 ms (97.5%)
                       Overhead used : 15.378363 ns
    nil
    user> (let [x (repeatedly 1000000 rand)]
            (c/quick-bench (order2 x)))
    Evaluation count : 6 in 6 samples of 1 calls.
                 Execution time mean : 529.826816 ms
        Execution time std-deviation : 21.454522 ms
       Execution time lower quantile : 508.547461 ms ( 2.5%)
       Execution time upper quantile : 552.592925 ms (97.5%)
                       Overhead used : 15.378363 ns
    nil
    

    Some R timings from the same box for comparison:

    > system.time({ y = order(x) })
       user  system elapsed 
      0.512   0.004   0.514 
    > system.time({ y = order(x) })
       user  system elapsed 
      0.496   0.000   0.496 
    > system.time({ y = order(x) })
       user  system elapsed 
      0.508   0.000   0.510 
    > system.time({ y = order(x) })
       user  system elapsed 
      0.508   0.000   0.513 
    > system.time({ y = order(x) })
       user  system elapsed 
      0.496   0.000   0.499 
    > system.time({ y = order(x) })
       user  system elapsed 
      0.500   0.000   0.502 
    

    Update: The two-threaded Clojure version:

    (defn order3 [xs]
      (let [rnd (java.util.Random.)
            a1 (double-array xs)
            a2 (long-array (alength a1))]
        (dotimes [i (alength a2)]
          (aset a2 i i))
        (letfn [(quicksort [^long l ^long h]
                  (if (< l h)
                    (let [p (.invokePrim ^clojure.lang.IFn$LLL partition l h)]
                      (quicksort l (dec p))
                      (quicksort (inc p) h))))
                (partition ^long [^long l ^long h]
                  (let [pidx (+ l (.nextInt rnd (- h l)))
                        pivot (aget a1 pidx)]
                    (swap1 a1 pidx h)
                    (swap2 a2 pidx h)
                    (loop [i (dec l)
                           j l]
                      (if (< j h)
                        (if (< (aget a1 j) pivot)
                          (let [i (inc i)]
                            (swap1 a1 i j)
                            (swap2 a2 i j)
                            (recur i (inc j)))
                          (recur i (inc j)))
                        (let [i (inc i)]
                          (when (< (aget a1 h) (aget a1 i))
                            (swap1 a1 i h)
                            (swap2 a2 i h))
                          i)))))
                (swap1 [^doubles a ^long i ^long j]
                  (let [tmp (aget a i)]
                    (aset a i (aget a j))
                    (aset a j tmp)))
                (swap2 [^longs a ^long i ^long j]
                  (let [tmp (aget a i)]
                    (aset a i (aget a j))
                    (aset a j tmp)))]
          (let [lim (alength a1)
                mid (quot lim 2)
                f1 (future (quicksort 0 (dec mid)))
                f2 (future (quicksort mid (dec lim)))]
            @f1
            @f2
            (loop [out (transient [])
                   i 0
                   j mid]
              (cond
                (== i mid)
                (persistent!
                  (if (== j lim)
                    out
                    (reduce (fn [out j]
                              (conj! out (aget a2 j)))
                      out
                      (range j lim))))
    
                (== j lim)
                (persistent!
                  (reduce (fn [out i]
                            (conj! out (aget a2 i)))
                    out
                    (range i mid)))
    
                :else
                (let [ie (aget a1 i)
                      je (aget a1 j)]
                  (if (< ie je)
                    (recur (conj! out (aget a2 i)) (inc i) j)
                    (recur (conj! out (aget a2 j)) i (inc j))))))))))
    

    Some benchmark results for this one:

    user> (let [x (repeatedly 1000000 rand)]
            (c/quick-bench (order3 x)))
    Evaluation count : 6 in 6 samples of 1 calls.
                 Execution time mean : 325.351056 ms
        Execution time std-deviation : 3.511578 ms
       Execution time lower quantile : 321.947510 ms ( 2.5%)
       Execution time upper quantile : 330.375038 ms (97.5%)
                       Overhead used : 15.378363 ns
    nil
    user> (let [x (repeatedly 1000000 rand)]
            (c/quick-bench (order3 x)))
    Evaluation count : 6 in 6 samples of 1 calls.
                 Execution time mean : 339.422989 ms
        Execution time std-deviation : 19.929177 ms
       Execution time lower quantile : 318.996436 ms ( 2.5%)
       Execution time upper quantile : 366.113347 ms (97.5%)
                       Overhead used : 15.378363 ns
    nil
    user> (let [x (repeatedly 1000000 rand)]
            (c/quick-bench (order3 x)))
    Evaluation count : 6 in 6 samples of 1 calls.
                 Execution time mean : 415.171336 ms
        Execution time std-deviation : 13.624262 ms
       Execution time lower quantile : 393.242455 ms ( 2.5%)
       Execution time upper quantile : 428.881001 ms (97.5%)
                       Overhead used : 15.378363 ns
    
    Found 1 outliers in 6 samples (16.6667 %)
        low-severe   1 (16.6667 %)
     Variance from outliers : 13.8889 % Variance is moderately inflated by outliers
    nil
    user> (let [x (repeatedly 1000000 rand)]
            (c/quick-bench (order3 x)))
    Evaluation count : 6 in 6 samples of 1 calls.
                 Execution time mean : 324.547827 ms
        Execution time std-deviation : 5.196817 ms
       Execution time lower quantile : 318.541727 ms ( 2.5%)
       Execution time upper quantile : 331.878289 ms (97.5%)
                       Overhead used : 15.378363 ns
    nil
    user> (c/bench (order3 x))
    Evaluation count : 180 in 60 samples of 3 calls.
                 Execution time mean : 361.529793 ms
        Execution time std-deviation : 45.285047 ms
       Execution time lower quantile : 307.535934 ms ( 2.5%)
       Execution time upper quantile : 446.679687 ms (97.5%)
                       Overhead used : 15.378363 ns
    
    Found 1 outliers in 60 samples (1.6667 %)
        low-severe   1 (1.6667 %)
     Variance from outliers : 78.9377 % Variance is severely inflated by outliers
    nil