Search code examples
postgresqlgosqlx

Postgres array of Golang structs


I have the following Go struct:

type Bar struct {
    Stuff string `db:"stuff"`
    Other string `db:"other"`
}

type Foo struct {
    ID    int    `db:"id"`
    Bars  []*Bar `db:"bars"`
}

So Foo contains a slice of Bar pointers. I also have the following tables in Postgres:

CREATE TABLE foo (
    id  INT
)

CREATE TABLE bar (
    id      INT,
    stuff   VARCHAR,
    other   VARCHAR,
    trash   VARCHAR
)

I want to LEFT JOIN on table bar and aggregate it as an array to be stored in the struct Foo. I've tried:

SELECT f.*,
ARRAY_AGG(b.stuff, b.other) AS bars
FROM foo f
LEFT JOIN bar b
ON f.id = b.id
WHERE f.id = $1
GROUP BY f.id

But it looks like the ARRAY_AGG function signature is incorrect (function array_agg(character varying, character varying) does not exist). Is there a way to do this without making a separate query to bar?


Solution

  • As you already know array_agg takes a single argument and returns an array of the type of the argument. So, if you want all of a row's columns to be included in the array's elements you can just pass in the row reference directly, e.g.:

    SELECT array_agg(b) FROM b
    

    If, however, you only want to include specific columns in the array's elements you can use the ROW constructor, e.g.:

    SELECT array_agg(ROW(b.stuff, b.other)) FROM b
    

    Go's standard library provides out-of-the-box support for scanning only scalar values. For scanning more complex values like arbitrary objects and arrays one has to either look for 3rd party solutions, or implement their own sql.Scanner.

    To be able to implement your own sql.Scanner and properly parse a postgres array of rows you first need to know what format postgres uses to output the value, you can find this out by using psql and some queries directly:

    -- simple values
    SELECT ARRAY[ROW(123,'foo'),ROW(456,'bar')];
    -- output: {"(123,foo)","(456,bar)"}
    
    -- not so simple values 
    SELECT ARRAY[ROW(1,'a b'),ROW(2,'a,b'),ROW(3,'a",b'),ROW(4,'(a,b)'),ROW(5,'"','""')];
    -- output: {"(1,\"a b\")","(2,\"a,b\")","(3,\"a\"\",b\")","(4,\"(a,b)\")","(5,\"\"\"\",\"\"\"\"\"\")"}
    

    As you can see this can get pretty hairy but nevertheless it's parseable, the syntax looks to be something like this:

    {"(column_value[, ...])"[, ...]}
    

    where column_value is either an unquoted value, or a quoted value with escaped double quotes, and such a quoted value itself can contain escaped double quotes but only in twos, i.e. a single escaped double quote will not occur inside the column_value. So a rough and incomplete implementation of the parser might look something like this:

    NOTE: there may be other syntax rules, that I do not know of, that need to be taken into consideration during parsing. In addition to that the code below doesn't handle NULLs properly.

    func parseRowArray(a []byte) (out [][]string) {
        a = a[1 : len(a)-1] // drop surrounding curlies
    
        for i := 0; i < len(a); i++ {
            if a[i] == '"' { // start of row element
                row := []string{}
    
                i += 2 // skip over current '"' and the following '('
                for j := i; j < len(a); j++ {
                    if a[j] == '\\' && a[j+1] == '"' { // start of quoted column value
                        var col string // column value
    
                        j += 2 // skip over current '\' and following '"'
                        for k := j; k < len(a); k++ {
                            if a[k] == '\\' && a[k+1] == '"' { // end of quoted column, maybe
                                if a[k+2] == '\\' && a[k+3] == '"' { // nope, just escaped quote
                                    col += string(a[j:k]) + `"`
                                    k += 3    // skip over `\"\` (the k++ in the for statement will skip over the `"`)
                                    j = k + 1 // skip over `\"\"`
                                    continue  // go to k loop
                                } else { // yes, end of quoted column
                                    col += string(a[j:k])
                                    row = append(row, col)
                                    j = k + 2 // skip over `\"`
                                    break     // go back to j loop
                                }
                            }
    
                        }
    
                        if a[j] == ')' { // row end
                            out = append(out, row)
                            i = j + 1 // advance i to j's position and skip the potential ','
                            break     // go to back i loop
                        }
                    } else { // assume non quoted column value
                        for k := j; k < len(a); k++ {
                            if a[k] == ',' || a[k] == ')' { // column value end
                                col := string(a[j:k])
                                row = append(row, col)
                                j = k // advance j to k's position
                                break // go back to j loop
                            }
                        }
    
                        if a[j] == ')' { // row end
                            out = append(out, row)
                            i = j + 1 // advance i to j's position and skip the potential ','
                            break     // go to back i loop
                        }
                    }
                }
            }
        }
        return out
    }
    

    Try it on playground.

    With something like that you can then implement an sql.Scanner for your Go slice of bars.

    type BarList []*Bar
    
    func (ls *BarList) Scan(src interface{}) error {
        switch data := src.(type) {
        case []byte:
            a := praseRowArray(data)
            res := make(BarList, len(a))
            for i := 0; i < len(a); i++ {
                bar := new(Bar)
                // Here i'm assuming the parser produced a slice of at least two
                // strings, if there are cases where this may not be the true you
                // should add proper length checks to avoid unnecessary panics.
                bar.Stuff = a[i][0]
                bar.Other = a[i][1]
                res[i] = bar
            }
            *ls = res
        }
        return nil
    }
    

    Now if you change the type of the Bars field in the Foo type from []*Bar to BarList you'll be able to directly pass in a pointer of the field to a (*sql.Row|*sql.Rows).Scan call:

    rows.Scan(&f.Bars)
    

    If you don't want to change the field's type you can still make it work by converting the pointer just when it's being passed to the Scan method:

    rows.Scan((*BarList)(&f.Bars))
    

    JSON

    An sql.Scanner implementation for the json solution suggested by Henry Woody would look something like this:

    type BarList []*Bar
    
    func (ls *BarList) Scan(src interface{}) error {
        if b, ok := src.([]byte); ok {
            return json.Unmarshal(b, ls)
        }
        return nil
    }