Search code examples
structjuliaoverloading

Overloading methods for parametrized struct


Short Version

I have a struct like

@kwdef mutable struct Params{B, C}
    a::Float64 = 42
end

where B and C are symbols that will determine which function to use out of many functions that are overloaded for different combinations of symbols. The valid symbols would be for example :constant and :exponential. I want to write a function that will be invoked if one of the symbols matches the corresponding symbol in the parametrized struct, and it should also be invoked regardless of the other symbol. So as an example,

p = Params{:constant, :gaussian}()
foo(p::Params{:constant, <:Symbol}) = 42
foo(p) # should return 42, but throws an error

How can I do this?

Context

I am implementing a dynamic rule (to be used with DynamicalSystems.jl) for which one term of the rule is computed using the PDF of the gaussian distribution, a reciprocal function, or is constant (see supplementary information as well as figure 3 of "Quantitative modeling of the terminal differentiation of B cells and mechanisms of lymphomagenesis"). Using a simplified model, see the below code as an example of what I want to do:

using UnPack
using DynamicalSystems
using Distributions

@kwdef mutable struct Params{B, C}
    mu_p    = 10e-6
    sigma_p = 9
    mu_b    = 2
    sigma_b = 100
    mu_r    = 0.1
    sigma_r = 2.6
    
    bcr_0   = 0.05
    cd_0    = 0.025
end

function germinal_center_regulation_rule(u, params, t)
    @unpack mu_p, sigma_p, mu_b, sigma_b, mu_r, sigma_r, 
    @unpack bcr0, cd0 = params
    p, b, r = u
    #######IMPORTANT PART############
    bcr = compute_bcr(;u, params, t)
    cd40 = compute_cd40(;u, params, t)
    #################################
    pdot = mu_p + sigma_p/b
    bdot = mu_p + sigma_p/b - bcr*b
    rdot = mu_r * sigma_r/r
    return SVector(pdot, bdot, rdot)
    
# Different methods for different parameterized struct
compute_bcr(;u, params::Params{:constant, <:Symbol}, t) = 15
compute_bcr(;u, params::Params{:gaussian, <:Symbol}, t) = pdf(Normal(), t)
compute_cd40(;u, params::Params{<:Symbol, :reciprocal}, t) = params[:bcr0]/u[2]
compute_cd40(;u, params::Params{<:Symbol, :gaussian}, t) = pdf(Normal(), t)

# Example usage
p_constant_bcr_gaussian_cd40 = Params{:constant, :gaussian}()
u0 = [0.2, 5.0, 0.2]
mixed_ds = CoupledODEs(
  germinal_center_regulation_rule, u0, p_constant_bcr_gaussian_cd40)
total_time = 200
X, t = trajectory(ds, total_time)

p_gaussian_bcr_gaussian_cd40 = Params{:gaussian, :gaussian}()
gaussian_ds = CoupledODEs(
  germinal_center_regulation_rule, u0, p_gaussian_bcr_gaussian_cd40)
total_time = 200
X, t = trajectory(ds, total_time)

Solution

  • As pointed out by BallpointPen in their answer, you can make this work by not using the subtype operator <: for value types in the type parameters. I am not sure whether DynamicalSystems.jl mandates the use of symbols in your case, but in case it doesn't, IMO it is a better idea to just let the type system work for you.

    Create a new abstract type and derive your parameter types from it.

    abstract type AbstractParamType end
    
    struct Constant <: AbstractParamType end
    struct Exponential <: AbstractParamType end
    struct Reciprocal <: AbstractParamType end
    struct Gaussian <: AbstractParamType end
    

    Then use this abstract type in the definition of Params

    @kwdef struct Params{P1 <: AbstractParamType, P2 <: AbstractParamType}
        mu_p    = 10e-6
        sigma_p = 9
        mu_b    = 2
        sigma_b = 100
        mu_r    = 0.1
        sigma_r = 2.6
        
        bcr_0   = 0.05
        cd_0    = 0.025
    end 
    

    so that you can define methods on the parameters this way.

    compute_bcr(;u, params::Params{Constant, <:AbstractParamType}, t) = 15
    compute_bcr(;u, params::Params{Gaussian, <:AbstractParamType}, t) = pdf(Normal(), t)
    compute_cd40(;u, params::Params{<:AbstractParamType, Reciprocal}, t) = params[:bcr0]/u[2]
    compute_cd40(;u, params::Params{<:AbstractParamType, Gaussian}, t) = pdf(Normal(), t)
    

    This has a few advantages over using values as type parameters:

    • The manual check of the type parameters in the Params constructor is no longer required in the constructor.
    • Invalid states are not representable -- the original Params constructor will accept any symbols as type parameters and restricting them artificially would require another manual check. The checks are now handled entirely by the type system.
    • Adding new kinds of type parameters is still as easy as before.