I have the following situation:
A variable moves +1 with prob=0.5 and -1 with prob=-0.5 ... if this Markov Chain starts at position=5
I tried to do this as follows (I created a tracking variable to see where the simulation gets stuck):
# the simulations can take a long time to run, I interrupted them
library(ggplot2)
library(gridExtra)
n_sims <- 100
times_to_end_0 <- numeric(n_sims)
times_to_end_0_after_10 <- numeric(n_sims)
paths_0 <- vector("list", n_sims)
paths_0_after_10 <- vector("list", n_sims)
for (i in 1:n_sims) {
print(paste("Running simulation", i, "for situation 1..."))
y <- 5
time <- 0
path_0 <- c(y)
while(y > 0) {
step <- sample(c(-1, 1), 1, prob = c(0.5, 0.5))
y <- y + step
path_0 <- c(path_0, y)
time <- time + 1
if (y == 0) {
times_to_end_0[i] <- time
paths_0[[i]] <- data.frame(time = 1:length(path_0), y = path_0, sim = i)
break
}
}
print(paste("Running simulation", i, "for situation 2..."))
y <- 5
time <- 0
reached_10 <- FALSE
path_0_after_10 <- c(y)
while(y > 0 || !reached_10) {
step <- sample(c(-1, 1), 1, prob = c(0.5, 0.5))
y <- y + step
path_0_after_10 <- c(path_0_after_10, y)
time <- time + 1
if (y == 10) {
reached_10 <- TRUE
}
if (y == 0 && reached_10) {
times_to_end_0_after_10[i] <- time
paths_0_after_10[[i]] <- data.frame(time = 1:length(path_0_after_10), y = path_0_after_10, sim = i)
break
}
}
}
df1 <- data.frame(time = times_to_end_0)
df2 <- data.frame(time = times_to_end_0_after_10[times_to_end_0_after_10 > 0])
mean1 <- mean(log(df1$time))
mean2 <- mean(log(df2$time))
p1 <- ggplot(df1, aes(x = log(time))) +
geom_density() +
geom_vline(aes(xintercept = exp(mean1)), color = "red", linetype = "dotted") +
labs(title = paste("Density of Times to Reach 0 - Mean Time:", round(exp(mean1), 2)), x = "Time", y = "Density") + theme_bw()
p2 <- ggplot(df2, aes(x = log(time))) +
geom_density() +
geom_vline(aes(xintercept = exp(mean2)), color = "red", linetype = "dotted") +
labs(title = paste("Density of Times to Reach 0 After Reaching 10 - Mean Time:", round(exp(mean2), 2)), x = "Time", y = "Density") + theme_bw()
plot_df_0 <- do.call(rbind, paths_0)
p3 <- ggplot(plot_df_0, aes(x = log(time), y = y, group = sim)) +
geom_line() +
labs(title = "Paths of First Simulation", x = "Time", y = "Y") +
theme_bw()
plot_df_0_after_10 <- do.call(rbind, paths_0_after_10)
p4 <- ggplot(plot_df_0_after_10, aes(x = log(time), y = y, group = sim)) +
geom_line() +
labs(title = "Paths of Second Simulation", x = "Time", y = "Y") +
theme_bw()
grid.arrange(p1, p2, p3, p4, ncol = 2)
My Question: Is there anything I can do to improve the efficiency of this simulation?
Thanks!
Calling sample()
in each cycle of your for/while loop creates a lot of overhead
and causes your simulations to run slow.
The solution below runs sample()
once per simulation with a size of 1e6
(1,000,000).
You could also go bigger, 1e7
seems reasonably fast on my machine and all cycles
have hit 0
for several runs with 1e7
. At 1e8
things get slow and
at 1e9
is likely too much for most machines. The smaller you set the sample size
the more runs that don’t hit 0
you’ll get.
This function relies on vectorization to speed things up.
sim <- function(y = 5,
size = 1e6) {
steps <- sample(c(-1, 1),
size,
prob = c(0.5, 0.5),
replace = TRUE)
y_steps <- c(y, steps)
y_steps_cum <- cumsum(y_steps)
times <- which(y_steps_cum == 0)
times_to_0 <- times[1]
path <- if (!is.na(times_to_0))
y_steps_cum[1:times_to_0]
else
NA
touched_10 <- which(y_steps_cum == 10)[1]
times_to_0_after_10 <-
if (!is.na(touched_10))
times[times > touched_10][1]
else
NA
path_to_0_after_10 <-
if (!is.na(times_to_0_after_10))
y_steps_cum[1:times_to_0_after_10]
else
NA
list(
times_to_0 = times_to_0,
path = list(path),
touched_10 = touched_10,
times_to_0_after_10 = times_to_0_after_10,
path_to_0_after_10 = list(path_to_0_after_10)
)
}
A single simulation:
sim()
#> $times_to_0
#> [1] 314
#>
#> $path
#> $path[[1]]
#> [1] 5 4 3 2 3 4 3 4 3 2 3 4 5 6 7 8 9 8 9 10 9 10 9 8 7
#> [26] 6 7 6 5 4 5 6 7 8 9 10 9 10 11 12 11 12 11 12 13 12 11 12 13 12
#> [51] 13 12 11 12 11 10 11 10 11 10 11 12 13 12 13 14 13 12 11 10 9 8 7 6 7
#> [76] 6 7 8 9 8 7 8 9 8 7 8 7 6 7 8 9 8 9 8 7 6 5 6 5 4
#> [101] 3 4 5 6 5 6 5 6 7 8 7 8 7 6 7 6 7 8 9 8 7 8 9 8 9
#> [126] 8 7 6 5 6 7 8 9 8 7 8 7 8 9 10 9 10 9 8 9 8 9 8 9 10
#> [151] 11 12 11 12 11 10 11 10 11 12 11 12 13 12 13 14 15 16 15 14 15 14 13 12 13
#> [176] 14 15 14 15 16 17 18 19 18 19 20 19 18 17 16 15 14 13 12 11 10 11 10 9 10
#> [201] 9 10 11 12 11 10 11 10 9 10 11 10 11 12 13 12 13 14 15 14 15 16 17 16 15
#> [226] 14 15 14 13 12 11 10 9 8 9 8 7 8 9 10 11 10 11 12 11 10 11 10 9 10
#> [251] 9 10 11 12 11 12 11 10 9 10 11 12 11 12 13 12 11 12 13 12 13 14 13 12 13
#> [276] 12 11 12 11 10 9 8 7 6 7 8 7 6 7 6 5 6 5 4 5 4 3 2 3 4
#> [301] 5 4 5 4 5 4 5 4 3 2 1 2 1 0
#>
#>
#> $touched_10
#> [1] 20
#>
#> $times_to_0_after_10
#> [1] 314
#>
#> $path_to_0_after_10
#> $path_to_0_after_10[[1]]
#> [1] 5 4 3 2 3 4 3 4 3 2 3 4 5 6 7 8 9 8 9 10 9 10 9 8 7
#> [26] 6 7 6 5 4 5 6 7 8 9 10 9 10 11 12 11 12 11 12 13 12 11 12 13 12
#> [51] 13 12 11 12 11 10 11 10 11 10 11 12 13 12 13 14 13 12 11 10 9 8 7 6 7
#> [76] 6 7 8 9 8 7 8 9 8 7 8 7 6 7 8 9 8 9 8 7 6 5 6 5 4
#> [101] 3 4 5 6 5 6 5 6 7 8 7 8 7 6 7 6 7 8 9 8 7 8 9 8 9
#> [126] 8 7 6 5 6 7 8 9 8 7 8 7 8 9 10 9 10 9 8 9 8 9 8 9 10
#> [151] 11 12 11 12 11 10 11 10 11 12 11 12 13 12 13 14 15 16 15 14 15 14 13 12 13
#> [176] 14 15 14 15 16 17 18 19 18 19 20 19 18 17 16 15 14 13 12 11 10 11 10 9 10
#> [201] 9 10 11 12 11 10 11 10 9 10 11 10 11 12 13 12 13 14 15 14 15 16 17 16 15
#> [226] 14 15 14 13 12 11 10 9 8 9 8 7 8 9 10 11 10 11 12 11 10 11 10 9 10
#> [251] 9 10 11 12 11 12 11 10 9 10 11 12 11 12 13 12 11 12 13 12 13 14 13 12 13
#> [276] 12 11 12 11 10 9 8 7 6 7 8 7 6 7 6 5 6 5 4 5 4 3 2 3 4
#> [301] 5 4 5 4 5 4 5 4 3 2 1 2 1 0
100 simulations:
library(tidyverse)
tictoc::tic() # tic() and toc() are for benchmarking
res <-
lapply(rep(5, 100),
sim) |>
bind_rows(.id = "sim")
tictoc::toc()
#> 2.596 sec elapsed
res
#> # A tibble: 100 × 6
#> sim times_to_0 path touched_10 times_to_0_after_10 path_to_0_after_10
#> <chr> <int> <list> <int> <int> <list>
#> 1 1 3202 <dbl> 26 3202 <dbl [3,202]>
#> 2 2 1690 <dbl> 16 1690 <dbl [1,690]>
#> 3 3 32 <dbl [32]> 11586 12012 <dbl [12,012]>
#> 4 4 386 <dbl> 54 386 <dbl [386]>
#> 5 5 280 <dbl> 20 280 <dbl [280]>
#> 6 6 20 <dbl [20]> 272 2890 <dbl [2,890]>
#> 7 7 3520 <dbl> 20 3520 <dbl [3,520]>
#> 8 8 160 <dbl> 8 160 <dbl [160]>
#> 9 9 141814 <dbl> 8 141814 <dbl [141,814]>
#> 10 10 1474 <dbl> 10 1474 <dbl [1,474]>
#> # ℹ 90 more rows
library(ggplot2)
library(gridExtra)
#>
#> Attaching package: 'gridExtra'
#> The following object is masked from 'package:dplyr':
#>
#> combine
mean1 <- mean(log(res$times_to_0), na.rm = TRUE)
mean2 <- mean(log(res$times_to_0_after_10), na.rm = TRUE)
p1 <-
ggplot(res, aes(x = log(times_to_0))) +
geom_density() +
geom_vline(aes(xintercept = exp(mean1)),
color = "red",
linetype = "dotted") +
labs(
title = paste("Density of Times to Reach 0 - Mean Time:", round(exp(mean1), 2)),
x = "Time",
y = "Density"
) + theme_bw()
p2 <-
ggplot(res, aes(x = log(times_to_0_after_10))) +
geom_density() +
geom_vline(aes(xintercept = exp(mean2)),
color = "red",
linetype = "dotted") +
labs(
title = paste(
"Density of Times to Reach 0 After Reaching 10 - Mean Time:",
round(exp(mean2), 2)
),
x = "Time",
y = "Density"
) + theme_bw()
p3 <-
res |>
select(sim, path) |>
unnest(path) |>
mutate(time = 1:n(), .by = sim) |>
ggplot(aes(
x = log(time),
y = path,
group = sim
)) +
geom_line(linewidth = .05) +
labs(title = "Paths of First Simulation", x = "Time", y = "Y") +
theme_bw()
p4 <-
res |>
select(sim, path_to_0_after_10) |>
unnest(path_to_0_after_10) |>
mutate(time = 1:n(), .by = sim) |>
ggplot(aes(
x = log(time),
y = path_to_0_after_10,
group = sim
)) +
geom_line(linewidth = .05) +
labs(title = "Paths of Second Simulation", x = "Time", y = "Y") +
theme_bw()
grid.arrange(p1, p2, p3, p4, ncol = 2)
#> Warning: Removed 1 rows containing non-finite values (`stat_density()`).
#> Warning: Removed 3 rows containing non-finite values (`stat_density()`).
#> Warning: Removed 1 row containing missing values (`geom_line()`).
#> Warning: Removed 3 rows containing missing values (`geom_line()`).