Search code examples
statisticsjulianormal-distribution

Correct way to combine Normal distributions in Julia with Distributions.jl


I have two multivariate normal distributions like such:

using Distributions, LinearAlgebra
g1 = MvNormal([1,2], [2 1; 1 2])
A = [3 1; 1 3]
B = [[A [0;0]]; transpose([0,0,1])]
g2 = MvNormal([1,2,3], B)

I would like to combine them into a new distribution of the concatenation of both variables, assuming they are independent. The function product_distribution seems like it should do the trick:

g3 = product_distribution(g1, g2)

But that results in an error:

ERROR: all distributions must be of the same size

Which I really don't understand, and makes me think this function's purpose is not what I thought it was, but I can't find any other that would be more appropriate.

To be clear, the desired output should be equivalent to:

m3 = vcat(mean(g1), mean(g2))
s3 = hvcat( (2,2), cov(g1), zeros(2,3), zeros(3,2), cov(g2))
g3 = MvNormal(m3, s3)

(Although perhaps a sparse matrix or some other optimised diagonal block matrix type would be more appropriate, but I really don't care in this case.)


Solution

  • Couldn't find a clean answer, but the issue has come up before. The best I can suggest so far:

    using BlockDiagonals, Distributions
    
    concat(ds::Union{MvNormal, Normal}...) =
      foldl(ds; init = MvNormal(Float64[],zeros((0,0)))) do x, D
          d = D isa Normal ? MvNormal([mean(D)], [var(D);;]) : D
          m1, c1 = mean(x), cov(x)
          m2, c2 = mean(d), cov(d)
          return MvNormal(vcat(m1, m2), BlockDiagonal([c1, c2]))
      end
    

    giving:

    julia> g1 = MvNormal([1,2], [2 1; 1 2]);
    
    julia> A = [3 1; 1 3];
    
    julia> B = BlockDiagonal([A,[1;;]]);
    
    julia> g2 = MvNormal([1,2,3], B);
    
    julia> g12 = concat(g1,g2)
    MvNormal{Float64, PDMats.PDMat{Float64, BlockDiagonal{Float64, Matrix{Float64}}}, Vector{Float64}}(
    dim: 5
    μ: [1.0, 2.0, 1.0, 2.0, 3.0]
    Σ: [2.0 1.0 … 0.0 0.0; 1.0 2.0 … 0.0 0.0; … ; 0.0 0.0 … 3.0 0.0; 0.0 0.0 … 0.0 1.0]
    )
    
    julia> rand(g12)
    5-element Vector{Float64}:
      1.0362522596298223
      2.0469956784329866
     -0.21925320262748982
     -1.2334419613775114
      3.2146164519549814