Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 197 additions & 24 deletions tools/sqlite3_api_wrapper/sqlite3_api_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ static char *sqlite3_strdup(const char *str);
struct sqlite3_string_buffer {
//! String data
unique_ptr<char[]> data;
//! String length
int data_len;
};

struct sqlite3_stmt {
Expand Down Expand Up @@ -249,7 +251,7 @@ int sqlite3_step(sqlite3_stmt *pStmt) {
// update total changes
auto row_changes = pStmt->current_chunk->GetValue(0, 0);
if (!row_changes.is_null && row_changes.TryCastAs(LogicalType::BIGINT)) {
pStmt->db->last_changes += row_changes.GetValue<int64_t>();
pStmt->db->last_changes = row_changes.GetValue<int64_t>();
pStmt->db->total_changes += row_changes.GetValue<int64_t>();
}
}
Expand Down Expand Up @@ -432,7 +434,8 @@ int sqlite3_column_type(sqlite3_stmt *pStmt, int iCol) {
case LogicalTypeId::BLOB:
return SQLITE_BLOB;
default:
return 0;
// TODO(wangfenjin): agg function don't have type?
return SQLITE_TEXT;
}
return 0;
}
Expand Down Expand Up @@ -501,6 +504,31 @@ const unsigned char *sqlite3_column_text(sqlite3_stmt *pStmt, int iCol) {
// not initialized yet, convert the value and initialize it
entry.data = unique_ptr<char[]>(new char[val.str_value.size() + 1]);
memcpy(entry.data.get(), val.str_value.c_str(), val.str_value.size() + 1);
entry.data_len = val.str_value.length();
}
return (const unsigned char *)entry.data.get();
} catch (...) {
// memory error!
return nullptr;
}
}

const void *sqlite3_column_blob(sqlite3_stmt *pStmt, int iCol) {
Value val;
if (!sqlite3_column_has_value(pStmt, iCol, LogicalType::BLOB, val)) {
return nullptr;
}
try {
if (!pStmt->current_text) {
pStmt->current_text =
unique_ptr<sqlite3_string_buffer[]>(new sqlite3_string_buffer[pStmt->result->types.size()]);
}
auto &entry = pStmt->current_text[iCol];
if (!entry.data) {
// not initialized yet, convert the value and initialize it
entry.data = unique_ptr<char[]>(new char[val.str_value.size() + 1]);
memcpy(entry.data.get(), val.str_value.c_str(), val.str_value.size() + 1);
entry.data_len = val.str_value.length();
}
return (const unsigned char *)entry.data.get();
} catch (...) {
Expand Down Expand Up @@ -585,6 +613,7 @@ int sqlite3_bind_text(sqlite3_stmt *stmt, int idx, const char *val, int length,
}
if (free_func && ((ptrdiff_t)free_func) != -1) {
free_func((void *)val);
val = nullptr;
}
try {
return sqlite3_internal_bind_value(stmt, idx, Value(value));
Expand All @@ -593,6 +622,32 @@ int sqlite3_bind_text(sqlite3_stmt *stmt, int idx, const char *val, int length,
}
}

int sqlite3_bind_blob(sqlite3_stmt *stmt, int idx, const void *val, int length, void (*free_func)(void *)) {
if (!val) {
return SQLITE_MISUSE;
}
Value blob;
if (length < 0) {
blob = Value::BLOB(string((const char *)val));
} else {
blob = Value::BLOB((const_data_ptr_t)val, length);
}
if (free_func && ((ptrdiff_t)free_func) != -1) {
free_func((void *)val);
val = nullptr;
}
try {
return sqlite3_internal_bind_value(stmt, idx, blob);
} catch (std::exception &ex) {
return SQLITE_ERROR;
}
}

SQLITE_API int sqlite3_bind_zeroblob(sqlite3_stmt *stmt, int idx, int length) {
fprintf(stderr, "sqlite3_bind_zeroblob: unsupported.\n");
return SQLITE_ERROR;
}

int sqlite3_clear_bindings(sqlite3_stmt *stmt) {
if (!stmt) {
return SQLITE_MISUSE;
Expand Down Expand Up @@ -769,6 +824,10 @@ int sqlite3_total_changes(sqlite3 *db) {
return db->total_changes;
}

SQLITE_API sqlite3_int64 sqlite3_last_insert_rowid(sqlite3 *db) {
return SQLITE_ERROR;
}

// some code borrowed from sqlite
// its probably best to match its behavior

Expand Down Expand Up @@ -946,20 +1005,11 @@ int sqlite3_complete_old(const char *sql) {
return -1;
}

int sqlite3_bind_blob(sqlite3_stmt *, int, const void *, int n, void (*)(void *)) {
fprintf(stderr, "sqlite3_bind_blob: unsupported.\n");
return -1;
}

const void *sqlite3_column_blob(sqlite3_stmt *, int iCol) {
fprintf(stderr, "sqlite3_column_blob: unsupported.\n");
return nullptr;
}

// length of varchar or blob value
int sqlite3_column_bytes(sqlite3_stmt *, int iCol) {
fprintf(stderr, "sqlite3_column_bytes: unsupported.\n");
return -1;
int sqlite3_column_bytes(sqlite3_stmt *pStmt, int iCol) {
// fprintf(stderr, "sqlite3_column_bytes: unsupported.\n");
return pStmt->current_text[iCol].data_len;
// return -1;
}

sqlite3_value *sqlite3_column_value(sqlite3_stmt *, int iCol) {
Expand All @@ -973,10 +1023,7 @@ int sqlite3_db_config(sqlite3 *, int op, ...) {
}

int sqlite3_get_autocommit(sqlite3 *db) {
return 1;
// TODO fix this
// return db->con->context->transaction.IsAutoCommit();
// fprintf(stderr, "sqlite3_get_autocommit: unsupported.\n");
return db->con->context->transaction.IsAutoCommit();
}

int sqlite3_limit(sqlite3 *, int id, int newVal) {
Expand Down Expand Up @@ -1134,6 +1181,13 @@ int sqlite3_create_function(sqlite3 *db, const char *zFunctionName, int nArg, in
return -1;
}

int sqlite3_create_function_v2(sqlite3 *db, const char *zFunctionName, int nArg, int eTextRep, void *pApp,
void (*xFunc)(sqlite3_context *, int, sqlite3_value **),
void (*xStep)(sqlite3_context *, int, sqlite3_value **),
void (*xFinal)(sqlite3_context *), void (*xDestroy)(void *)) {
return -1;
}

int sqlite3_set_authorizer(sqlite3 *, int (*xAuth)(void *, int, const char *, const char *, const char *, const char *),
void *pUserData) {
fprintf(stderr, "sqlite3_set_authorizer: unsupported.\n");
Expand Down Expand Up @@ -1530,11 +1584,6 @@ SQLITE_API void sqlite3_free_table(char **result) {
fprintf(stderr, "sqlite3_free_table: unsupported.\n");
}

SQLITE_API sqlite3_int64 sqlite3_last_insert_rowid(sqlite3 *) {
fprintf(stderr, "sqlite3_last_insert_rowid: unsupported.\n");
return SQLITE_ERROR;
}

SQLITE_API int sqlite3_prepare(sqlite3 *db, /* Database handle */
const char *zSql, /* SQL statement, UTF-8 encoded */
int nByte, /* Maximum length of zSql in bytes. */
Expand All @@ -1553,3 +1602,127 @@ SQLITE_API void *sqlite3_profile(sqlite3 *, void (*xProfile)(void *, const char
fprintf(stderr, "sqlite3_profile: unsupported.\n");
return nullptr;
}

SQLITE_API int sqlite3_libversion_number(void) {
return SQLITE_VERSION_NUMBER;
}

SQLITE_API int sqlite3_threadsafe(void) {
return SQLITE_OK;
}

SQLITE_API sqlite3_mutex *sqlite3_mutex_alloc(int) {
fprintf(stderr, "sqlite3_mutex_alloc: unsupported.\n");
return nullptr;
}

SQLITE_API void sqlite3_mutex_free(sqlite3_mutex *) {
fprintf(stderr, "sqlite3_mutex_free: unsupported.\n");
}

SQLITE_API int sqlite3_extended_result_codes(sqlite3 *db, int onoff) {
fprintf(stderr, "sqlite3_extended_result_codes: unsupported.\n");
return SQLITE_ERROR;
}

SQLITE_API void *sqlite3_update_hook(sqlite3 *db, /* Attach the hook to this database */
void (*xCallback)(void *, int, char const *, char const *, sqlite_int64),
void *pArg /* Argument to the function */
) {
fprintf(stderr, "sqlite3_update_hook: unsupported.\n");
return nullptr;
}

SQLITE_API void sqlite3_log(int iErrCode, const char *zFormat, ...) {
fprintf(stderr, "sqlite3_log: unsupported.\n");
}

SQLITE_API int sqlite3_unlock_notify(sqlite3 *db, void (*xNotify)(void **, int), void *pArg) {
fprintf(stderr, "sqlite3_unlock_notify: unsupported.\n");
return SQLITE_ERROR;
}

SQLITE_API void *sqlite3_get_auxdata(sqlite3_context *pCtx, int iArg) {
fprintf(stderr, "sqlite3_get_auxdata: unsupported.\n");
return nullptr;
}

SQLITE_API void *sqlite3_rollback_hook(sqlite3 *db, /* Attach the hook to this database */
void (*xCallback)(void *), /* Callback function */
void *pArg /* Argument to the function */
) {
fprintf(stderr, "sqlite3_rollback_hook: unsupported.\n");
return nullptr;
}

SQLITE_API void *sqlite3_commit_hook(sqlite3 *db, /* Attach the hook to this database */
int (*xCallback)(void *), /* Function to invoke on each commit */
void *pArg /* Argument to the function */
) {
fprintf(stderr, "sqlite3_commit_hook: unsupported.\n");
return nullptr;
}

SQLITE_API int sqlite3_blob_open(sqlite3 *db, /* The database connection */
const char *zDb, /* The attached database containing the blob */
const char *zTable, /* The table containing the blob */
const char *zColumn, /* The column containing the blob */
sqlite_int64 iRow, /* The row containing the glob */
int wrFlag, /* True -> read/write access, false -> read-only */
sqlite3_blob **ppBlob /* Handle for accessing the blob returned here */
) {
fprintf(stderr, "sqlite3_blob_open: unsupported.\n");
return SQLITE_ERROR;
}

SQLITE_API const char *sqlite3_db_filename(sqlite3 *db, const char *zDbName) {
fprintf(stderr, "sqlite3_db_filename: unsupported.\n");
return nullptr;
}

SQLITE_API int sqlite3_stmt_busy(sqlite3_stmt *) {
fprintf(stderr, "sqlite3_stmt_busy: unsupported.\n");
return false;
}

SQLITE_API int sqlite3_bind_pointer(sqlite3_stmt *pStmt, int i, void *pPtr, const char *zPTtype,
void (*xDestructor)(void *)) {
fprintf(stderr, "sqlite3_bind_pointer: unsupported.\n");
return SQLITE_ERROR;
}

SQLITE_API int sqlite3_create_module_v2(sqlite3 *db, /* Database in which module is registered */
const char *zName, /* Name assigned to this module */
const sqlite3_module *pModule, /* The definition of the module */
void *pAux, /* Context pointer for xCreate/xConnect */
void (*xDestroy)(void *) /* Module destructor function */
) {
fprintf(stderr, "sqlite3_create_module_v2: unsupported.\n");
return SQLITE_ERROR;
}

SQLITE_API int sqlite3_blob_write(sqlite3_blob *, const void *z, int n, int iOffset) {
fprintf(stderr, "sqlite3_blob_write: unsupported.\n");
return SQLITE_ERROR;
}

SQLITE_API void sqlite3_set_auxdata(sqlite3_context *, int N, void *, void (*)(void *)) {
fprintf(stderr, "sqlite3_set_auxdata: unsupported.\n");
}

SQLITE_API sqlite3_stmt *sqlite3_next_stmt(sqlite3 *pDb, sqlite3_stmt *pStmt) {
fprintf(stderr, "sqlite3_next_stmt: unsupported.\n");
return nullptr;
}

SQLITE_API int sqlite3_collation_needed(sqlite3 *, void *, void (*)(void *, sqlite3 *, int eTextRep, const char *)) {
fprintf(stderr, "sqlite3_collation_needed: unsupported.\n");
return SQLITE_ERROR;
}

SQLITE_API int sqlite3_create_collation_v2(sqlite3 *, const char *zName, int eTextRep, void *pArg,
int (*xCompare)(void *, int, const void *, int, const void *),
void (*xDestroy)(void *)) {
fprintf(stderr, "sqlite3_create_collation_v2: unsupported.\n");
return SQLITE_ERROR;
}
30 changes: 22 additions & 8 deletions tools/sqlite3_api_wrapper/test_sqlite3_api_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,22 +150,22 @@ TEST_CASE("Basic prepared statement usage", "[sqlite3wrapper]") {

// open an in-memory db
REQUIRE(db.Open(":memory:"));
REQUIRE(db.Execute("CREATE TABLE test(i INTEGER, j BIGINT, k DATE, l VARCHAR)"));
REQUIRE(db.Execute("CREATE TABLE test(i INTEGER, j BIGINT, k DATE, l VARCHAR, b BLOB)"));
#ifndef SQLITE_TEST
// sqlite3_prepare_v2 errors
// nullptr for db/stmt, note: normal sqlite segfaults here
REQUIRE(sqlite3_prepare_v2(nullptr, "INSERT INTO test VALUES ($1, $2, $3, $4)", -1, nullptr, nullptr) ==
REQUIRE(sqlite3_prepare_v2(nullptr, "INSERT INTO test VALUES ($1, $2, $3, $4, $5)", -1, nullptr, nullptr) ==
SQLITE_MISUSE);
REQUIRE(sqlite3_prepare_v2(db.db, "INSERT INTO test VALUES ($1, $2, $3, $4)", -1, nullptr, nullptr) ==
REQUIRE(sqlite3_prepare_v2(db.db, "INSERT INTO test VALUES ($1, $2, $3, $4, $5)", -1, nullptr, nullptr) ==
SQLITE_MISUSE);
#endif
// prepared statement
REQUIRE(stmt.Prepare(db.db, "INSERT INTO test VALUES ($1, $2, $3, $4)", -1, nullptr) == SQLITE_OK);
REQUIRE(stmt.Prepare(db.db, "INSERT INTO test VALUES ($1, $2, $3, $4, $5)", -1, nullptr) == SQLITE_OK);

// test for parameter count, names and indexes
REQUIRE(sqlite3_bind_parameter_count(nullptr) == 0);
REQUIRE(sqlite3_bind_parameter_count(stmt.stmt) == 4);
for (int i = 1; i < 5; i++) {
REQUIRE(sqlite3_bind_parameter_count(stmt.stmt) == 5);
for (int i = 1; i < 6; i++) {
REQUIRE(sqlite3_bind_parameter_name(nullptr, i) == nullptr);
REQUIRE(sqlite3_bind_parameter_index(nullptr, nullptr) == 0);
REQUIRE(sqlite3_bind_parameter_index(stmt.stmt, nullptr) == 0);
Expand All @@ -174,7 +174,7 @@ TEST_CASE("Basic prepared statement usage", "[sqlite3wrapper]") {
REQUIRE(sqlite3_bind_parameter_index(stmt.stmt, sqlite3_bind_parameter_name(stmt.stmt, i)) == i);
}
REQUIRE(sqlite3_bind_parameter_name(stmt.stmt, 0) == nullptr);
REQUIRE(sqlite3_bind_parameter_name(stmt.stmt, 5) == nullptr);
REQUIRE(sqlite3_bind_parameter_name(stmt.stmt, 6) == nullptr);

#ifndef SQLITE_TEST
// this segfaults in SQLITE
Expand All @@ -186,7 +186,7 @@ TEST_CASE("Basic prepared statement usage", "[sqlite3wrapper]") {
// incorrect bindings: nullptr as statement, wrong type and out of range binding
REQUIRE(sqlite3_bind_int(nullptr, 1, 1) == SQLITE_MISUSE);
REQUIRE(sqlite3_bind_int(stmt.stmt, 0, 1) == SQLITE_RANGE);
REQUIRE(sqlite3_bind_int(stmt.stmt, 5, 1) == SQLITE_RANGE);
REQUIRE(sqlite3_bind_int(stmt.stmt, 6, 1) == SQLITE_RANGE);

// we can bind the incorrect type just fine
// error will only be thrown on execution
Expand All @@ -196,7 +196,18 @@ TEST_CASE("Basic prepared statement usage", "[sqlite3wrapper]") {
REQUIRE(sqlite3_bind_int(stmt.stmt, 1, 2) == SQLITE_OK);
REQUIRE(sqlite3_bind_int64(stmt.stmt, 2, 1000) == SQLITE_OK);
REQUIRE(sqlite3_bind_text(stmt.stmt, 3, "1992-01-01", -1, nullptr) == SQLITE_OK);
REQUIRE(sqlite3_bind_text(stmt.stmt, 4, nullptr, -1, &free) == SQLITE_MISUSE);
char *buffer = (char *)malloc(12);
strcpy(buffer, "hello world");
REQUIRE(sqlite3_bind_text(stmt.stmt, 4, buffer, -1, &free) == SQLITE_OK);
REQUIRE(sqlite3_bind_text(stmt.stmt, 4, "hello world", -1, nullptr) == SQLITE_OK);
// test for bind blob
REQUIRE(sqlite3_bind_blob(stmt.stmt, 5, "hello world", -1, nullptr) == SQLITE_OK);
REQUIRE(sqlite3_bind_blob(stmt.stmt, 5, "hello world", 11, nullptr) == SQLITE_OK);
REQUIRE(sqlite3_bind_blob(stmt.stmt, 5, NULL, 10, &free) == SQLITE_MISUSE);
buffer = (char *)malloc(6);
strcpy(buffer, "hello");
REQUIRE(sqlite3_bind_blob(stmt.stmt, 5, buffer, 5, &free) == SQLITE_OK);

REQUIRE(sqlite3_step(nullptr) == SQLITE_MISUSE);
REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
Expand All @@ -211,6 +222,7 @@ TEST_CASE("Basic prepared statement usage", "[sqlite3wrapper]") {
REQUIRE(sqlite3_bind_null(stmt.stmt, 2) == SQLITE_OK);
REQUIRE(sqlite3_bind_null(stmt.stmt, 3) == SQLITE_OK);
REQUIRE(sqlite3_bind_null(stmt.stmt, 4) == SQLITE_OK);
REQUIRE(sqlite3_bind_null(stmt.stmt, 5) == SQLITE_OK);

// we can step multiple times
REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
Expand All @@ -227,6 +239,7 @@ TEST_CASE("Basic prepared statement usage", "[sqlite3wrapper]") {
REQUIRE(db.CheckColumn(1, {"", "", "", "", "1000"}));
REQUIRE(db.CheckColumn(2, {"", "", "", "", "1992-01-01"}));
REQUIRE(db.CheckColumn(3, {"", "", "", "", "hello world"}));
REQUIRE(db.CheckColumn(4, {"", "", "", "", "hello"}));

REQUIRE(sqlite3_finalize(nullptr) == SQLITE_OK);

Expand Down Expand Up @@ -264,6 +277,7 @@ TEST_CASE("Basic prepared statement usage", "[sqlite3wrapper]") {
REQUIRE(string((char *)sqlite3_column_text(stmt.stmt, 1)) == string("1000"));
REQUIRE(string((char *)sqlite3_column_text(stmt.stmt, 2)) == string("1992-01-01"));
REQUIRE(string((char *)sqlite3_column_text(stmt.stmt, 3)) == string("hello world"));
REQUIRE(string((char *)sqlite3_column_blob(stmt.stmt, 4)) == string("hello"));
REQUIRE(sqlite3_column_int(stmt.stmt, 3) == 0);
REQUIRE(sqlite3_column_int64(stmt.stmt, 3) == 0);
REQUIRE(sqlite3_column_double(stmt.stmt, 3) == 0);
Expand Down