Search code examples
typescripttypescript-types

Extending single part of discriminated union


I have a discriminated union that looks like so:

type Union = {
    // This is important! The value to discriminate over can itself be a union!
    type: "foo" | "bar"
} | {
    type: "baz"
} | {
    type: "quux"
}

I would like to extend the type that extends type: "baz" with an additional property, for example value: string so that the final result is similar to this:

type ExtendedUnion = Extend<{
    "baz": { value: string }
}> /* {
    type: "foo" | "bar"
} | {
    type: "bar"
} | {
    type: "baz"
    value: string
} */

The exact signature of Extend is not important at all, that is just an example, but if it can extend several types in the union at the same time that would be very useful.

In the case where the value is a union, and the parameter is a value of that union (for example, Extend<{ "foo": { value: string } }>) the result should be { type: "foo" | "bar"; value: string }. If the extended type includes both types in the union, e.g. Extend<{ "foo": { value: string }, "bar": { other: number } }> the result should be { type: "foo" | "bar"; value: string; other: number }.

I have created a utility type that does this, except it does not handle the case type: "foo" | "bar", it will simply omit this entire type from the result. My type looks like so:

type Extend<TTypeOverride extends { [key in Union["type"]]?: unknown }> = {
    [key in Union["type"]]:
        // This Extract is the failing point, if "key" is "foo" for example, it will not extract
        // { type: "foo" | "bar" } but I can not figure out an alternative
        Extract<Union, { type: key }> &
        (TTypeOverride[key] extends undefined ? unknown : TTypeOverride[key])
}[Union["type"]]

// Will not contain { type: "foo" | "bar" }
type ExtendedUnion = Extend<{
    quux: { value: string }
}> /* {
    type: "baz"
} | {
    type: "quux"
    value: string
} */

As I have noted, the Extract is the reason why, it does not handle the "foo" | "bar" case. Any pointers?


Solution

  • My approach would be to write a more general utility type, let's call it ExtendDiscriminatedUnion<T, K, M>, where T is a discriminated union type, K is the discriminant property key, and M is a mapping from discriminant value to the piece you want to add. Then your Extend<M> would be written as:

    type Extend<M extends Partial<Record<Union["type"], object>>> =
      ExtendDiscriminatedUnion<Union, "type", M>;
    

    We want ExtendDiscriminatedUnion<T, K, M> to act on each union member of T independently and then unite the results back together. So, for example, ExtendDiscriminatedUnion<A | B | C, K, M> should be equivalent to ExtendDiscriminatedUnion<A, K, M> | ExtendDiscriminatedUnion<B, K, M> | ExtendDiscriminatedUnion<C, K, M>. That means ExtendDiscriminatedUnion<T, K, M> should be distributive over unions in T. The easiest way to do that is to make it a distributive conditional type of the form:

    type ExtendDiscriminatedUnion<
      T extends Record<K, PropertyKey>, K extends keyof T,
      M extends Partial<Record<T[K], object>>
    > = T extends unknown ? ⋯T⋯ : never;
    

    That looks like it doesn't do anything, but when T is a generic type parameter, then T extends unknown ? ⋯T⋯ : never will distribute over unions in T automatically, so that the operation in ⋯T⋯ acts on union members of T and not the whole union at once.


    If we didn't have to worry about the complication where multiple keys of M apply to the same member of T, then we could write it like this:

    type ExtendDiscriminatedUnion<
      T extends Record<K, PropertyKey>,
      K extends keyof T,
      M extends Partial<Record<T[K], object>>
    > = T extends unknown ?
      T & (M[T[K]] extends object ? M[T[K]] : unknown)
      : never;
    

    For each member T of the union, the output will be T intersected with M[T[K]] extends object ? M[T[K]] : unknown. T[K] is the discriminant property value for T, which we look up in M to get M[T[K]]. If that's an object type then we want to keep it to get T & that type; otherwise (if it's undefined or not present or something) then we want to replace it with unknown so that the intersection T & unknown is just T.

    This works for your main example:

    type ExtendedUnion1 = Extend<{
      "baz": { value: string }
    }>;
    /* type ExtendedUnion1 = 
      { type: "foo" | "bar"; } | 
      {  type: "quux"; } | 
      ({ type: "baz"; } & { value: string; })
    */
    

    but unfortunately it gives you a union when multiple keys of M correspond to the same member of T:

    type ExtendedUnion2 = Extend<{
      foo: { a: 0 },
      bar: { b: 1 },
      baz: { c: 2 }, quux: { d: 3 }
    }>
    /* type ExtendedUnion2 = 
         ({ type: "foo" | "bar"; } & ({ a: 0; } | { b: 1; })) | 
         ({ type: "baz"; } & { c: 2; }) | 
         ({ type: "quux"; } & { d: 3; })
    */
    

    So we need to make it more complicated.


    When looking up T[K] in M, we want to get the intersection of the results, not the union. In order to do this we need to play some TypeScript type system variance tricks (see Difference between Variance, Covariance, Contravariance and Bivariance in TypeScript ) as done in the UnionToIntersection<T> type described in Transform union type to intersection type :

    type ExtendDiscriminatedUnion<
      T extends Record<K, PropertyKey>,
      K extends keyof T,
      M extends Partial<Record<T[K], object>>
    > = T extends unknown ? (
      { [P in T[K]]: (x: M[P] extends object ? M[P] : unknown) => void }[T[K]]
    ) extends (x: infer I) => void ? T & I : never : never;
    

    Here we use a mapped type walk through each member P of T[K] and look up M[P]. Then we put the result (either M[P] or unknown depending on whether it's an object, like before) in a contravariant position, and then use conditional type inference via infer to ask the compiler to infer a single type I from each of those types. Inferring from multiple contravariant-position types results in their intersection (as documented in the conditional type inference section of the TS2.8 release notes). And then we return T & I.

    Let's try that:

    type ExtendedUnion1 = Extend<{
      "baz": { value: string }
    }>;
    /* type ExtendedUnion1 = 
      { type: "foo" | "bar"; } | 
      {  type: "quux"; } | 
      ({ type: "baz"; } & { value: string; })
    */
    
    type ExtendedUnion2 = Extend<{
      foo: { a: 0 },
      bar: { b: 1 },
      baz: { c: 2 }, quux: { d: 3 }
    }>
    /* type ExtendedUnion2 = 
         ({ type: "foo" | "bar"; } & { a: 0; } & { b: 1; }) | 
         ({ type: "baz"; } & { c: 2; }) | 
         ({ type: "quux"; } & { d: 3; })
    */
    

    So that's exactly what you wanted, I think.


    Still, it's kind of ugly. Intersections of object types are equivalent to single object types (e.g., {a: 0} & {b: 1} and {a: 0; b: 1} are equivalent) so we can use a trick described in How can I see the full expanded contract of a Typescript type? to just map over the keys of T & I with an identity mapping:

    type ExtendDiscriminatedUnion<
      T extends Record<K, PropertyKey>,
      K extends keyof T,
      M extends Partial<Record<T[K], object>>
    > = T extends unknown ? (
      { [P in T[K]]: (x: M[P] extends object ? M[P] : unknown) => void }[T[K]]
    ) extends (x: infer I) => void ? {
      [P in keyof (T & I)]: (T & I)[P]
    } : never : never;
    

    And now we get:

    type ExtendedUnion1 = Extend<{ "baz": { value: string } }>;
    /* type ExtendedUnion1 = 
        { type: "foo" | "bar"; } | 
        { type: "baz"; value: string; } | 
        { type: "quux"; } 
    */
    
    type ExtendedUnion2 = Extend<{
      foo: { a: 0 },
      bar: { b: 1 },
      baz: { c: 2 }, quux: { d: 3 }
    }>
    /* type ExtendedUnion2 = 
       { type: "foo" | "bar"; a: 0; b: 1; } | 
       { type: "baz"; c: 2; } | 
       { type: "quux"; d: 3; }
    */
    

    which is about the best I can imagine doing here.

    Playground link to code