Search code examples
postgresqlgorefactoringbulkinsertpq

Create a general func from particular function (refactoring)


I'm using the createUsers func to populate my fake DB, just for tests.

I'm using the bulk imports feature of pq (https://godoc.org/github.com/lib/pq#hdr-Bulk_imports).

func createUsers() {

    users := []models.User{}

    for i := 0; i < 10; i++ {
        users = append(users, models.User{Username: "username"+i, Age: i})
    }

    connStr := "user=postgres password=postgres dbname=dbname sslmode=disable"
    DB, err = sql.Open("postgres", connStr)
    checkErr(err)

    txn, err := DB.Begin()
    checkErr(err)

    stmt, err := txn.Prepare(pq.CopyIn("users", "username", "age"))
    checkErr(err)

    for _, user := range users {
        _, err = stmt.Exec(user.Username, user.Age)
        checkErr(err)
    }

    _, err = stmt.Exec()
    checkErr(err)

    err = stmt.Close()
    checkErr(err)

    err = txn.Commit()
    checkErr(err)
}

Everything in this code is working good.

THE NEED:

What I need now is to make it "general", not just for User model.

I think I need something like:

DBBulkInsert(users, "users", "username", "age")

with func DBBulkInsert like:

func DBBulkInsert(rows []interface{}, tableName string, tableColumns ...string) {
    // DB var from connection func

    txn, err := DB.Begin()
    checkErr(err)

    stmt, err := txn.Prepare(pq.CopyIn(tableName, tableColumns...))
    checkErr(err)

    for _, row := range rows {
        _, err = stmt.Exec(row[0], row[1]) //THIS IS TOTALLY WRONG! WHAT TO DO HERE?
        checkErr(err)
    }

    _, err = stmt.Exec()
    checkErr(err)

    err = stmt.Close()
    checkErr(err)

    err = txn.Commit()
    checkErr(err)
}

THE PROBLEM:

Obviously _, err = stmt.Exec(row[0], row[1]) is totally wrong. I don't understand how to call DBBulkInsert with my users array.

STILL BETTER:

Maybe I can remove also the parameters "users", "username", "age" in DBBulkInsert(users, "users", "username", "age"), but how? Reflection?


Solution

  • Your rows type needs to be [][]interface{}, i.e. a list of rows where each row is a list of column values. Then with that type each single row can be "unpacked" into the Exec call using ....

    That is:

    for _, row := range rows {
        _, err = stmt.Exec(row...)
    }
    

    To get from []model.User or []model.Whatever to [][]interface{} you'll need to use reflection. And if you want, you can also use reflection to get the column names and the table name as well.

    Say you have a model type like:

    type User struct {
        _        struct{} `rel:"users"`
        Username string   `col:"username"`
        Age      int      `col:"age"`
    }
    

    now you could use reflection to get table name and the list of columns from the fields' struct tags. (Note that use of the _ (blank) field is just one option of how to specify the table name, it has its downsides and upsides so it's up to you to choose, here I'm just trying to demonstrate how the reflect package can be leveraged).

    The following is a more complete example of how to collect the "meta" data from tags and how to aggregate the column values from struct fields.

    func DBBulkInsert(source interface{}) {
        slice := reflect.ValueOf(source)
        if slice.Kind() != reflect.Slice {
            panic("not a slice")
        }
    
        elem := slice.Type().Elem()
        if elem.Kind() == reflect.Ptr {
            elem = elem.Elem()
        }
        if elem.Kind() != reflect.Struct {
            panic("slice elem not a struct, nor a pointer to a struct")
        }
    
        // get table and column names
        var tableName string
        var cols []string
        for i := 0; i < elem.NumField(); i++ {
            f := elem.Field(i)
            if rel := f.Tag.Get("rel"); len(rel) > 0 {
                tableName = rel
            }
            if col := f.Tag.Get("col"); len(col) > 0 {
                cols = append(cols, col)
            }
        }
    
        // aggregate rows
        rows := [][]interface{}{}
        for i := 0; i < slice.Len(); i++ {
            m := slice.Index(i)
            if m.Kind() == reflect.Ptr {
                m = m.Elem()
            }
    
            vals := []interface{}{}
            for j := 0; j < m.NumField(); j++ {
                ft := m.Type().Field(j)
                if col := ft.Tag.Get("col"); len(col) > 0 {
                    f := m.Field(j)
                    vals = append(vals, f.Interface())
                }
            }
    
            rows = append(rows, vals)
        }
    
        // ...
    }
    

    Run it on playground