Search code examples
rprobability

Enumerating a subset of paths in a sequential probability tree in R


To illustrate the problem, let us define the following matrix (where NA indicates that the option is unavailable in period t)

set.seed(1)
x <- matrix(NA, 4, 4, dimnames = list(paste0("t=", seq_len(4)), LETTERS[seq_len(4)]))
x[lower.tri(x, diag = TRUE)] <- rnorm(10)

Which gives a matrix that looks like this:

              A           B          C         D
t=1  0.91897737          NA         NA        NA
t=2  0.78213630  0.61982575         NA        NA
t=3  0.07456498 -0.05612874 -1.4707524        NA
t=4 -1.98935170 -0.15579551 -0.4781501 0.4179416

The goal is to calculate the probability that each value is the highest in each time period $t$, however, the values are conditional on the values in the previous periods. For example, in moving from period t=2 to t=3 and the assumption that A is the highest, A is only compared to C and not B because in t=2 it is assumed to be higher. We can structure the problem as a tree like this:

enter image description here

So for t=1 the probability is 1, for t=2 we calculate 2 probabilities from 1 grouping, in t=3 we calculate 4 probabilities from 2 groupings (note how one option is eliminated from the comparison because of the sequential dependence and inherent assumption that it was not the highest in t-1) and in t=4, we calculate 8 probabilities from 4 groupings. The final probabilities then are product over the probabilities in each t making up the 8 paths. In the real problem, t gets large and manually identifying these groupings becomes infeasible.

I've been trying to come up with a clever way of identifying these paths and calculate the probabilities. One idea was to use a set of "masking matrices" for each possible pattern. That way I could simply multiply the masking matrix and perform row operations. However, I could not find a robust way to populate the different masking matrices as the the number of levels increased.

For example, assume the pattern of choosing A in all periods leading up to the final period can be described by the following masking matrix:

mask <- matrix(c(
1, NA, NA, NA,
1, 1,  NA, NA,
1, NA, 1,  NA,
1, NA, NA, 1
), ncol = 4, byrow = TRUE, dimnames = list(paste0("t=", seq_len(4)), LETTERS[seq_len(4)]))

which looks like this (1 of the 4 possible comparisons in this case):

    A  B  C  D
t=1 1 NA NA NA
t=2 1  1 NA NA
t=3 1 NA  1 NA
t=4 1 NA NA  1

And we can calculate the probabilities in each period like this (all rows sum to one as they should):

exp_x <- exp(x * mask)
sum_exp_x <- rowSums(exp_x, na.rm = TRUE)
pr_x <- exp_x / sum_exp_x
             A         B         C         D
t=1 1.00000000        NA        NA        NA
t=2 0.54048879 0.4595112        NA        NA
t=3 0.82423638        NA 0.1757636        NA
t=4 0.08261824        NA        NA 0.9173818

Is there a clever way of doing this for all possible paths as tgrows? Or a good way of populating a set of masking matrices to loop over? I'm trying to avoid the problem growing out of hand. Is it possible that complete path enumeration and elimination is a better option, i.e. faster and more robust? Any help, ideas and pointers are helpful.


Solution

  • Is this what you want?

    find_path <- function(nperiods, opts = LETTERS[seq_len(period)]) {
      stopifnot(length(opts) == nperiods)
      out <- matrix(nrow = 2 ^ (nperiods - 1L), ncol = nperiods)
      r <- 1L
      recur_ <- function(period, branch, outcome) {
        if (period > length(branch)) {
          out[r, ] <<- opts[branch]
          r <<- r + 1L
          return(NULL)
        }
        for (i in c(outcome, period)) {
          branch[[period]] <- i
          recur_(period + 1L, branch, i)
        }
      }
      recur_(1L, integer(nperiods), NULL)
      out
    }
    
    calc_prob <- function(mat) {
      ps <- dimnames(mat)[[1L]]; if (is.null(ps)) ps <- seq_len(nrow(mat))
      ops <- dimnames(mat)[[2L]]; if (is.null(ops)) ops <- seq_len(ncol(mat))
      paths <- find_path(nrow(mat), ops)
      out <- vapply(seq_len(ncol(paths))[-1L], function(i) {
        comp <- ops[[i]]
        comp <- ifelse(paths[, i] == comp, paths[, i - 1L], comp)
        x <- exp(mat[i, paths[, i]])
        y <- exp(mat[i, comp])
        x / (x + y)
      }, numeric(nrow(paths)))
      dimnames(out) <- NULL; out <- cbind(1, out)
      dimnames(out)[[2L]] <- dimnames(paths)[[2L]] <- ps
      list(paths = paths, probs = out)
    }
    

    Output

    > calc_prob(x) # x is the same lower-triangular matrix as shown in your example.
    
    $paths
         t=1 t=2 t=3 t=4
    [1,] "A" "A" "A" "A"
    [2,] "A" "A" "A" "D"
    [3,] "A" "A" "C" "C"
    [4,] "A" "A" "C" "D"
    [5,] "A" "B" "B" "B"
    [6,] "A" "B" "B" "D"
    [7,] "A" "B" "C" "C"
    [8,] "A" "B" "C" "D"
    
    $probs
         t=1       t=2       t=3        t=4
    [1,]   1 0.5404888 0.8242364 0.08261823
    [2,]   1 0.5404888 0.8242364 0.91738177
    [3,]   1 0.5404888 0.1757636 0.28985432
    [4,]   1 0.5404888 0.1757636 0.71014568
    [5,]   1 0.4595112 0.8044942 0.36037495
    [6,]   1 0.4595112 0.8044942 0.63962505
    [7,]   1 0.4595112 0.1955058 0.28985432
    [8,]   1 0.4595112 0.1955058 0.71014568
    

    The variable paths gives you all the possible outcomes for each period t; probs tells you the probability of a corresponding outcome. However, note that such a probability tree grows exponentially as the number of periods increases. The equation is

    enter image description here

    where N is the number of all possible paths at period t. For just 20 periods, you will have 524288 different paths. If the number of periods goes to 30, you will have 536870912 different paths, and R just cannot handle that amount of computations. I do suggest you reconsider your expected outputs. Are you running a simulation with some other constraints than just the time dependence so that we can further trim off some unnecessary paths? Or maybe you only need some summary statistics like the expected value so that we don't have to generate all possible paths? There must be a better way than just using a brute-force approach like this.