Search code examples
sqlpostgresqlgoutf-8emoji

One printable latin1 char in valid full utf8 sequence


I'm trying to insert an email address into Postgres using the libpq driver for Golang. The error message is

Unable to execute the query. pq: character with byte sequence 0xf0 0x9f 0x9a 0x80 in encoding "UTF8" has no equivalent in encoding "LATIN1"

In fact, only 0xf0 is a printable latin1 char, I don't think those are latin1 chars. The full sequence is the :rocket: UTF8 char.

This is the email address.

author := "🚀 Steven Duncan 🌌 <steven.duncan@gmail.com>"

I am trying to use a SanitizeSQL function to catch this error but it is not catching the

sqlStatement := `INSERT INTO logs (author) VALUES ($1) RETURNING id`

sql_with_args, err := SanitizeSQL(sqlStatement, author)

This is the code (full demo on google play)

func SanitizeSQL(sql string, args ...any) (string, error) {
    query, err := NewQuery(sql)
    if err != nil {
        return "", err
    }
    return query.Sanitize(args...)
}

type Part any

type Query struct {
    Parts []Part
}

// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
// character. utf8.RuneError is not an error if it is also width 3.
//
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3

func (q *Query) Sanitize(args ...any) (string, error) {
    argUse := make([]bool, len(args))
    buf := &bytes.Buffer{}

    for _, part := range q.Parts {
        var str string
        switch part := part.(type) {
        case string:
            str = part
        case int:
            argIdx := part - 1

            if argIdx < 0 {
                return "", fmt.Errorf("first sql argument must be > 0")
            }

            if argIdx >= len(args) {
                return "", fmt.Errorf("insufficient arguments")
            }
            arg := args[argIdx]
            switch arg := arg.(type) {
            case nil:
                str = "null"
            case int64:
                str = strconv.FormatInt(arg, 10)
            case float64:
                str = strconv.FormatFloat(arg, 'f', -1, 64)
            case bool:
                str = strconv.FormatBool(arg)
            case []byte:
                str = QuoteBytes(arg)
            case string:
                str = QuoteString(arg)
            case time.Time:
                str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
            default:
                return "", fmt.Errorf("invalid arg type: %T", arg)
            }
            argUse[argIdx] = true

            // Prevent SQL injection via Line Comment Creation
            // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
            str = " " + str + " "
        default:
            return "", fmt.Errorf("invalid Part type: %T", part)
        }
        buf.WriteString(str)
    }

    for i, used := range argUse {
        if !used {
            return "", fmt.Errorf("unused argument: %d", i)
        }
    }
    return buf.String(), nil
}

func NewQuery(sql string) (*Query, error) {
    l := &sqlLexer{
        src:     sql,
        stateFn: rawState,
    }

    for l.stateFn != nil {
        l.stateFn = l.stateFn(l)
    }

    query := &Query{Parts: l.parts}

    return query, nil
}

func QuoteString(str string) string {
    return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}

func QuoteBytes(buf []byte) string {
    return `'\x` + hex.EncodeToString(buf) + "'"
}

type sqlLexer struct {
    src     string
    start   int
    pos     int
    nested  int // multiline comment nesting level.
    stateFn stateFn
    parts   []Part
}

type stateFn func(*sqlLexer) stateFn

func rawState(l *sqlLexer) stateFn {
    for {
        r, width := utf8.DecodeRuneInString(l.src[l.pos:])
        l.pos += width

        switch r {
        case 'e', 'E':
            nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
            if nextRune == '\'' {
                l.pos += width
                return escapeStringState
            }
        case '\'':
            return singleQuoteState
        case '"':
            return doubleQuoteState
        case '$':
            nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
            if '0' <= nextRune && nextRune <= '9' {
                if l.pos-l.start > 0 {
                    l.parts = append(l.parts, l.src[l.start:l.pos-width])
                }
                l.start = l.pos
                return placeholderState
            }
        case '-':
            nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
            if nextRune == '-' {
                l.pos += width
                return oneLineCommentState
            }
        case '/':
            nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
            if nextRune == '*' {
                l.pos += width
                return multilineCommentState
            }
        case utf8.RuneError:
            if width != replacementcharacterwidth {
                if l.pos-l.start > 0 {
                    l.parts = append(l.parts, l.src[l.start:l.pos])
                    l.start = l.pos
                }
                return nil
            }
        }
    }
}

func singleQuoteState(l *sqlLexer) stateFn {
    for {
        r, width := utf8.DecodeRuneInString(l.src[l.pos:])
        l.pos += width

        switch r {
        case '\'':
            nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
            if nextRune != '\'' {
                return rawState
            }
            l.pos += width
        case utf8.RuneError:
            if width != replacementcharacterwidth {
                if l.pos-l.start > 0 {
                    l.parts = append(l.parts, l.src[l.start:l.pos])
                    l.start = l.pos
                }
                return nil
            }
        }
    }
}

func doubleQuoteState(l *sqlLexer) stateFn {
    for {
        r, width := utf8.DecodeRuneInString(l.src[l.pos:])
        l.pos += width

        switch r {
        case '"':
            nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
            if nextRune != '"' {
                return rawState
            }
            l.pos += width
        case utf8.RuneError:
            if width != replacementcharacterwidth {
                if l.pos-l.start > 0 {
                    l.parts = append(l.parts, l.src[l.start:l.pos])
                    l.start = l.pos
                }
                return nil
            }
        }
    }
}

// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
    num := 0

    for {
        r, width := utf8.DecodeRuneInString(l.src[l.pos:])
        l.pos += width

        if '0' <= r && r <= '9' {
            num *= 10
            num += int(r - '0')
        } else {
            l.parts = append(l.parts, num)
            l.pos -= width
            l.start = l.pos
            return rawState
        }
    }
}

func escapeStringState(l *sqlLexer) stateFn {
    for {
        r, width := utf8.DecodeRuneInString(l.src[l.pos:])
        l.pos += width

        switch r {
        case '\\':
            _, width = utf8.DecodeRuneInString(l.src[l.pos:])
            l.pos += width
        case '\'':
            nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
            if nextRune != '\'' {
                return rawState
            }
            l.pos += width
        case utf8.RuneError:
            if width != replacementcharacterwidth {
                if l.pos-l.start > 0 {
                    l.parts = append(l.parts, l.src[l.start:l.pos])
                    l.start = l.pos
                }
                return nil
            }
        }
    }
}

func oneLineCommentState(l *sqlLexer) stateFn {
    for {
        r, width := utf8.DecodeRuneInString(l.src[l.pos:])
        l.pos += width

        switch r {
        case '\\':
            _, width = utf8.DecodeRuneInString(l.src[l.pos:])
            l.pos += width
        case '\n', '\r':
            return rawState
        case utf8.RuneError:
            if width != replacementcharacterwidth {
                if l.pos-l.start > 0 {
                    l.parts = append(l.parts, l.src[l.start:l.pos])
                    l.start = l.pos
                }
                return nil
            }
        }
    }
}

func multilineCommentState(l *sqlLexer) stateFn {
    for {
        r, width := utf8.DecodeRuneInString(l.src[l.pos:])
        l.pos += width

        switch r {
        case '/':
            nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
            if nextRune == '*' {
                l.pos += width
                l.nested++
            }
        case '*':
            nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
            if nextRune != '/' {
                continue
            }

            l.pos += width
            if l.nested == 0 {
                return rawState
            }
            l.nested--

        case utf8.RuneError:
            if width != replacementcharacterwidth {
                if l.pos-l.start > 0 {
                    l.parts = append(l.parts, l.src[l.start:l.pos])
                    l.start = l.pos
                }
                return nil
            }
        }
    }
}

Solution

  • You cannot store a character that cannot be encoded in LATIN1 in a database with encoding LATIN1. There is really only one good solution: create a new database with UTF8 encoding, dump your database and restore it into the UTF8 database.