Search code examples
javarandompriority-queueweighted

Weighted Randomized Ordering


The problem:

I have items that have weights. The higher the weight, the greater chance they have the item will go first. I need to have a clean, simple way of doing this that is based on core Java (no third party libraries, jars, etc.).

I've done this for 2 items, by summing the weights then randomly pick a number using Math.random() within that range. Very simple. But for items greater than 2, I can either do more samples in the same range chancing misses, or I can recompute the sum of the weights of the remaining items and select again (recursive approach). I think that there might be something out there can can do this faster/cleaner. This code will be used over and over, so I'm looking for an effective solution.

In essence, its like randomized weight permutations.

Some Examples:

  1. A has weight of 1, B has weight of 99. If I ran the simulation with this, I would expect to get BA 99% of the time and AB 1% of the time.

  2. A has the weight of 10, B has the weight of 10, and C has the weight of 80. If I ran simulations with this, I would expect C to be the first item in the ordering 80% of the time, in those cases, A and B would have an equal chance of being the next character.

Extra Details:

For my particular problem, there is a small number of items with potentially large weights. Say 20 to 50 items with weights that are stored in the form of a long, where the minimum weight is at least a 1000. The number of items may increase quite a bit too, so if we can find a solution that doesn't require the items to be small, that would be preferred.


Solution

  • This seems to work fine:

    // Can do a weighted sort on weighted items.
    public interface Weighted {
        int getWeight();
    }
    
    /**
     * Weighted sort of an array - orders them at random but the weight of each
     * item makes it more likely to be earlier.
     *
     * @param values
     */
    public static void weightedSort(Weighted[] values) {
        // Build a list containing as many of each item to make up the full weight.
        List<Weighted> full = new ArrayList<>();
        for (Weighted v : values) {
            // Add a v weight times.
            for (int i = 0; i < v.getWeight(); i++) {
                full.add(v);
            }
        }
        // Shuffle it.
        Collections.shuffle(full);
        // Roll them out in the order required.
        int i = 0;
        do {
            // Get the first one in the shuffled list.
            Weighted next = full.get(0);
            // Put it back into the array.
            values[i++] = next;
            // Remove all occurrences of that one from the list.
            full.remove(next);
        } while (!full.isEmpty());
    }
    
    // A bunch of weighted items.
    enum Heavies implements Weighted {
    
        Rare(1),
        Few(3),
        Common(6);
        final int weight;
    
        Heavies(int weight) {
            this.weight = weight;
        }
    
        @Override
        public int getWeight() {
            return weight;
        }
    }
    
    public void test() {
        Weighted[] w = Heavies.values();
        for (int i = 0; i < 10; i++) {
            // Sort it weighted.
            weightedSort(w);
            // What did we get.
            System.out.println(Arrays.toString(w));
        }
    }
    

    Essentially for each item to be sorted I add it as many times as needed to a new list. I then shuffle the list and pull the top one out and clar all occurrences of it from the remaining.

    The last test run produced:

    [Rare, Common, Few]
    [Common, Rare, Few]
    [Few, Common, Rare]
    [Common, Few, Rare]
    [Common, Rare, Few]
    [Few, Rare, Common]
    

    which seems to be about right.

    NB - this algorithm will fail under at least the following conditions:

    1. The original array has the same object in it more than once.
    2. The weights of the items are insanely huge.
    3. Zero or negative weights will almost certainly mess with the results.

    Added

    This implements Rossum's idea - please be sure to give him the credit for the algorithm.

    public static void weightedSort2(Weighted[] values) {
        // Calculate the total weight.
        int total = 0;
        for (Weighted v : values) {
            total += v.getWeight();
        }
        // Start with all of them.
        List<Weighted> remaining = new ArrayList(Arrays.asList(values));
        // Take each at random - weighted by it's weight.
        int which = 0;
        do {
            // Pick a random point.
            int random = (int) (Math.random() * total);
            // Pick one from the list.
            Weighted picked = null;
            int pos = 0;
            for (Weighted v : remaining) {
                // Pick this ne?
                if (pos + v.getWeight() > random) {
                    picked = v;
                    break;
                }
                // Move forward by that much.
                pos += v.getWeight();
            }
            // Removed picked from the remaining.
            remaining.remove(picked);
            // Reduce total.
            total -= picked.getWeight();
            // Record picked.
            values[which++] = picked;
        } while (!remaining.isEmpty());
    }