Search code examples
c++sqlitesql-injection

How to stop SQL injection vulnerability?


Was wondering how vulnerable this was to SQL injection. I heard that using prepared sql statements can circumvent this vulnerability, but I also heard that using double quotes instead of single quotes can prevent SQL injection as well. I am not a security expert, and i'm not great with sqlite either. Also, I need to initalize the database elsewhere, and probably end up using prepared statements instead of sprintf, but i'm just not exactly sure how to do either one of those things. Any help is greatly appreciated! Thank you!

bool sql_console_msgs = false;
void QServ::savestats(clientinfo *ci)
{
    if(enable_sqlite_db) {
        sqlite3 *db;
        char *zErrMsg = 0;
        int  rc;
        const char *sql;
        bool name_match;
        const char* player_database_names;
        char *p_name = ci->name;
        char *p_ip = ci->ip;
        
        rc = sqlite3_open("playerinfo.db", &db);
        if( rc ){
            fprintf(stderr, "Can't open database: %s\n", sqlite3_errmsg(db));
            exit(0);
        }else{
            if(sql_console_msgs) fprintf(stdout, "Opened database successfully\n");
        }
        if( rc != SQLITE_OK ){
            fprintf(stderr, "SQL Database Error: %s\n", zErrMsg);
            sqlite3_free(zErrMsg);
        }else{
            sqlite3_stmt *stmt;
            defformatstring(sqlstrprep)("SELECT NAME FROM PLAYERINFO");
            rc = sqlite3_prepare_v2(db, sqlstrprep, -1, &stmt, NULL);
            
            while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
                player_database_names = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
                if(!strcmp(player_database_names, p_name)) name_match = true;
                else name_match = false;
            }
        }
        
        sql = "CREATE TABLE IF NOT EXISTS PLAYERINFO("    \
        "NAME                       TEXT    NOT NULL,"    \
        "FRAGS                       INT    NOT NULL,"    \
        "DEATHS                      INT    NOT NULL,"    \
        "FLAGS                       INT    NOT NULL,"    \
        "PASSES                      INT    NOT NULL,"    \
        "IP                         TEXT    NOT NULL,"    \
        "ACCURACY          DECIMAL(4, 2)    NOT NULL,"    \
        "KPD               DECIMAL(4, 2)    NOT NULL);";
        rc = sqlite3_exec(db, sql, callback, 0, &zErrMsg);
        if( rc != SQLITE_OK ){
            fprintf(stderr, "SQLITE3 ERROR @ CREATE TABLE IF NOT EXISTS: %s\n", zErrMsg);
            sqlite3_free(zErrMsg);
        }else{
            if(sql_console_msgs) {
                if(!name_match) fprintf(stdout, "No previous record found under that name\n");
                else fprintf(stdout, "Found name and IP already, updating record instead\n");
            }
        }
        
        char sqlINSERT[256];
        char sqlUPDATE[1000];
        int p_frags = ci->state.frags;
        int p_deaths = ci->state.deaths;
        int p_flags = ci->state.flags;
        int p_passes = ci->state.passes;
        int p_acc = (ci->state.damage*100)/max(ci->state.shotdamage, 1);
        int p_kpd = (ci->state.frags)/max(ci->state.deaths, 1);
        
        //name and ip are different
        if(!name_match) {
            sprintf(sqlINSERT, "INSERT INTO PLAYERINFO( NAME,FRAGS,DEATHS,FLAGS,PASSES,IP,ACCURACY,KPD ) VALUES ('%s', %d, %d, %d, %d, '%s', %d, %d)",p_name,p_frags,p_deaths,p_flags,p_passes,p_ip,p_acc,p_kpd);
            rc = sqlite3_exec(db, sqlINSERT, callback, 0, &zErrMsg);
        }
        //client name matches db record, update db if new info is > than db info
        else if(name_match)  {
            sprintf(sqlUPDATE,
                    "UPDATE PLAYERINFO SET FRAGS = %d+(SELECT FRAGS FROM PLAYERINFO) WHERE NAME = '%s';"     \
                    "UPDATE PLAYERINFO SET DEATHS = %d+(SELECT DEATHS FROM PLAYERINFO) WHERE NAME = '%s';"   \
                    "UPDATE PLAYERINFO SET FLAGS = %d+(SELECT FLAGS FROM PLAYERINFO) WHERE NAME = '%s';"     \
                    "UPDATE PLAYERINFO SET PASSES = %d+(SELECT PASSES FROM PLAYERINFO) WHERE NAME = '%s';"   \
                    "UPDATE PLAYERINFO SET ACCURACY = %d+(SELECT PASSES FROM PLAYERINFO) WHERE NAME = '%s';" \
                    "UPDATE PLAYERINFO SET KPD = %d+(SELECT PASSES FROM PLAYERINFO) WHERE NAME = '%s';",
                    ci->state.frags, ci->name, ci->state.deaths, ci->name, ci->state.flags, ci->name, ci->state.passes, ci->name, p_acc, ci->name, p_kpd, ci->name);
            rc = sqlite3_exec(db, sqlUPDATE, callback, 0, &zErrMsg);
        }
        if( rc != SQLITE_OK ){
            fprintf(stderr, "SQLITE3 ERROR @ INSERT & UPDATE: %s\n", zErrMsg);
            sqlite3_free(zErrMsg);
        }else{
            if(sql_console_msgs) fprintf(stdout, "Playerinfo modified\n");
        }
        sqlite3_close(db);
    }
}

void QServ::getstats(clientinfo *ci)
{
    if(enable_sqlite_db) {
        sqlite3 *db;
        char *zErrMsg = 0;
        int rc;
        char *sql;
        const char* data = "Callback function called";
        
        rc = sqlite3_open("playerinfo.db", &db);
        if( rc ){
            fprintf(stderr, "Can't open database: %s\n", sqlite3_errmsg(db));
            exit(0);
        }
        
        if( rc != SQLITE_OK ){
            fprintf(stderr, "SQL Database Error: %s\n", zErrMsg);
            sqlite3_free(zErrMsg);
        }else{
            sqlite3_stmt *stmt;
            defformatstring(sqlstrprep)("SELECT NAME,FRAGS,ACCURACY,KPD FROM PLAYERINFO WHERE NAME == '%s';", ci->name);
            rc = sqlite3_prepare_v2(db, sqlstrprep, -1, &stmt, NULL);
            
            bool necho = false;
            while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
                const char* name = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
                const char* allfrags = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
                const char* avgacc = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
                const char* avgkpd = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 3));
                if(!necho) {
                    if(avgacc == NULL) out(ECHO_SERV, "Name: \f0%s\f7, Total Frags: \f3%s\f7, Average KPD: \f6%s", name, allfrags, avgkpd);
                    else if(avgkpd == NULL) out(ECHO_SERV, "Name: \f0%s\f7, Total Frags: \f3%s\f7, Average Accuracy: \f2%s%%", name, allfrags, avgacc);
                    else out(ECHO_SERV, "Name: \f0%s\f7, Total Frags: \f3%s\f7, Average Accuracy: \f2%s%%\f7, Average KPD: \f6%s", name,allfrags,avgacc,avgkpd);
                    necho = true;
                }
            }
        }
        sqlite3_close(db);
    }
}

void QServ::getnames(clientinfo *ci) {
    if(enable_sqlite_db) {
        sqlite3_stmt *stmt3;
        sqlite3 *db;
        int rc;
        rc = sqlite3_open("playerinfo.db", &db);
        defformatstring(sqlstrprep3)("SELECT group_concat(NAME, ', ') FROM PLAYERINFO WHERE IP == '%s';", ci->ip);
        rc = sqlite3_prepare_v2(db, sqlstrprep3, -1, &stmt3, NULL);
        while ((rc = sqlite3_step(stmt3)) == SQLITE_ROW) {
            std::string names(reinterpret_cast<const char*>(sqlite3_column_text(stmt3, 0)));
            defformatstring(nmsg)("Names from IP \f2%s\f7: %s", ci->ip, names.c_str());
            out(ECHO_SERV, nmsg);
        }
        sqlite3_close(db);
    }
}

Solution

  • When formatting an SQL statement by hand (which you really SHOULD NOT do when input parameters are involved!), it is not enough to merely wrap parameter values in quotes. You need to also escape reserved characters inside of the parameter data, otherwise you are still susceptible to injection attacks. The attacker could simply place a matching quote inside the data, closing off your opening quote, and then the rest of the parameter data can contain malicious instructions.

    For example:

    const char *p_name = "'); DROP TABLE MyTable; --";
    sprintf(sql, "INSERT INTO MyTable(NAME) VALUES ('%s')", p_name);
    

    Or:

    const char *p_name = "\"); DROP TABLE MyTable; --";
    sprintf(sql, "INSERT INTO MyTable(NAME) VALUES (\"%s\")", p_name);
    

    These would create the following SQL statements:

    INSERT INTO MyTable(NAME) VALUES (''); DELETE TABLE MyTable; --')
    
    INSERT INTO MyTable(NAME) VALUES (""); DELETE TABLE MyTable; --")
    

    Say "bye bye" to your table when the SQL is executed! (assuming the user executing the SQL has DELETE access to the table - that is a whole other security concern of its own).

    In this case, you would need to double up any single-quote characters, or slash-escape any double-quote characters, that are in the parameter data, eg:

    const char *p_name = "'); DROP TABLE MyTable; --";
    char *p_escaped_name = sqlEscape(p_name); // <-- you have to implement this yourself!
    sprintf(sql, "INSERT INTO MyTable(NAME) VALUES ('%s')", p_escaped_name);
    // or:
    // sprintf(sql, "INSERT INTO MyTable(NAME) VALUES (\"%s\")", p_escaped_name);
    free(p_escaped_name);
    

    Thus, the resulting SQL statements would look like these instead:

    INSERT INTO MyTable(NAME) VALUES ('''); DELETE TABLE MyTable; --')
    
    INSERT INTO MyTable(NAME) VALUES ("\"); DELETE TABLE MyTable; --")
    

    Thus, the name inserted into the table would be '); DELETE TABLE MyTable; -- (or "); DELETE TABLE MyTable; --). Not pretty, but the table would be saved.

    Some DB frameworks offer functions to do this escaping for you, but I don't see one in sqlite, so you will have to implement it manually in your own code, eg:

    char* sqlEscape(const char *str)
    {
        int len = strlen(str);
        int newlen = len;
    
        for (int i = 0; i < len; ++i) {
            switch (str[i]) {
                case '\'':
                case '"':
                    ++newlen;
                    break;
            }
        }
    
        if (newlen == len)
            return strdup(str);
    
        char *newstr = (char*) malloc(newlen + 1);
        if (!newstr)
            return NULL;
    
        newlen = 0;
        for (int i = 0; i < len; ++i) {
            switch (str[i]) {
                case '\'':
                    newstr[newlen++] = '\'';
                    break;
                case '"':
                    newstr[newlen++] = '\\';
                    break;
            }
            newstr[newlen++] = str[i];
        }
    
        newstr[newlen] = '\0';
    
        return newstr;
    }
    

    A prepared statement avoids the need to do this escaping manually, by letting the DB engine handle these details for you when executing the prepared statement.


    Also, your code's use of sprintf() is susceptible to buffer overflows, which is even worse, because a carefully crafted buffer overflow can let an attacker execute arbitrary machine code inside your app, not just in the database. Use snprintf() instead to avoid that.