Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting for user functions, fixes #140 #448

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions src/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include "database.h"
#include "statement.h"

#ifndef SQLITE_DETERMINISTIC
#define SQLITE_DETERMINISTIC 0x800
#endif

using namespace node_sqlite3;

Persistent<FunctionTemplate> Database::constructor_template;
Expand All @@ -24,6 +28,7 @@ void Database::Init(Handle<Object> target) {
NODE_SET_PROTOTYPE_METHOD(t, "serialize", Serialize);
NODE_SET_PROTOTYPE_METHOD(t, "parallelize", Parallelize);
NODE_SET_PROTOTYPE_METHOD(t, "configure", Configure);
NODE_SET_PROTOTYPE_METHOD(t, "registerFunction", RegisterFunction);

NODE_SET_GETTER(t, "open", OpenGetter);

Expand Down Expand Up @@ -356,6 +361,159 @@ NAN_METHOD(Database::Configure) {
NanReturnValue(args.This());
}

NAN_METHOD(Database::RegisterFunction) {
NanScope();
Database* db = ObjectWrap::Unwrap<Database>(args.This());

REQUIRE_ARGUMENTS(2);
REQUIRE_ARGUMENT_STRING(0, functionName);
REQUIRE_ARGUMENT_FUNCTION(1, callback);

FunctionBaton *baton = new FunctionBaton(db, *functionName, callback);
sqlite3_create_function(
db->_handle,
*functionName,
-1, // arbitrary number of args
SQLITE_UTF8 | SQLITE_DETERMINISTIC,
baton,
FunctionEnqueue,
NULL,
NULL);

uv_mutex_init(&baton->mutex);
uv_cond_init(&baton->condition);
uv_async_init(uv_default_loop(), &baton->async, (uv_async_cb)Database::AsyncFunctionProcessQueue);

NanReturnValue(args.This());
}

void Database::FunctionEnqueue(sqlite3_context *context, int argc, sqlite3_value **argv) {
// the JS function can only be safely executed on the main thread
// (uv_default_loop), so setup an invocation w/ the relevant information,
// enqueue it and signal the main thread to process the invocation queue.
// sqlite3 requires the result to be set before this function returns, so
// wait for the invocation to be completed.
FunctionBaton *baton = (FunctionBaton *)sqlite3_user_data(context);
FunctionInvocation invocation = {};
invocation.context = context;
invocation.argc = argc;
invocation.argv = argv;

uv_async_send(&baton->async);
uv_mutex_lock(&baton->mutex);
baton->queue.push(&invocation);
while (!invocation.complete) {
uv_cond_wait(&baton->condition, &baton->mutex);
}
uv_mutex_unlock(&baton->mutex);
}

void Database::AsyncFunctionProcessQueue(uv_async_t *async) {
FunctionBaton *baton = (FunctionBaton *)async->data;

for (;;) {
FunctionInvocation *invocation = NULL;

uv_mutex_lock(&baton->mutex);
if (!baton->queue.empty()) {
invocation = baton->queue.front();
baton->queue.pop();
}
uv_mutex_unlock(&baton->mutex);

if (!invocation) { break; }

Database::FunctionExecute(baton, invocation);

uv_mutex_lock(&baton->mutex);
invocation->complete = true;
uv_cond_signal(&baton->condition); // allow paused thread to complete
uv_mutex_unlock(&baton->mutex);
}
}

void Database::FunctionExecute(FunctionBaton *baton, FunctionInvocation *invocation) {
NanScope();

Database *db = baton->db;
Local<Function> cb = NanNew(baton->callback);
sqlite3_context *context = invocation->context;
sqlite3_value **values = invocation->argv;
int argc = invocation->argc;

if (!cb.IsEmpty() && cb->IsFunction()) {

// build the argument list for the function call
typedef Local<Value> LocalValue;
std::vector<LocalValue> argv;
for (int i = 0; i < argc; i++) {
sqlite3_value *value = values[i];
int type = sqlite3_value_type(value);
Local<Value> arg;
switch(type) {
case SQLITE_INTEGER: {
arg = NanNew<Number>(sqlite3_value_int64(value));
} break;
case SQLITE_FLOAT: {
arg = NanNew<Number>(sqlite3_value_double(value));
} break;
case SQLITE_TEXT: {
const char* text = (const char*)sqlite3_value_text(value);
int length = sqlite3_value_bytes(value);
arg = NanNew<String>(text, length);
} break;
case SQLITE_BLOB: {
const void *blob = sqlite3_value_blob(value);
int length = sqlite3_value_bytes(value);
arg = NanNew(NanNewBufferHandle((char *)blob, length));
} break;
case SQLITE_NULL: {
arg = NanNew(NanNull());
} break;
}

argv.push_back(arg);
}

TryCatch trycatch;
Local<Value> result = cb->Call(NanObjectWrapHandle(db), argc, argv.data());

// process the result
if (trycatch.HasCaught()) {
String::Utf8Value message(trycatch.Message()->Get());
sqlite3_result_error(context, *message, message.length());
}
else if (result->IsString() || result->IsRegExp()) {
String::Utf8Value value(result->ToString());
sqlite3_result_text(context, *value, value.length(), SQLITE_TRANSIENT);
}
else if (result->IsInt32()) {
sqlite3_result_int(context, result->Int32Value());
}
else if (result->IsNumber() || result->IsDate()) {
sqlite3_result_double(context, result->NumberValue());
}
else if (result->IsBoolean()) {
sqlite3_result_int(context, result->BooleanValue());
}
else if (result->IsNull() || result->IsUndefined()) {
sqlite3_result_null(context);
}
else if (Buffer::HasInstance(result)) {
Local<Object> buffer = result->ToObject();
sqlite3_result_blob(context,
Buffer::Data(buffer),
Buffer::Length(buffer),
SQLITE_TRANSIENT);
}
else {
std::string message("invalid return type in user function");
message = message + " " + baton->name;
sqlite3_result_error(context, message.c_str(), message.length());
}
}
}

void Database::SetBusyTimeout(Baton* baton) {
assert(baton->db->open);
assert(baton->db->_handle);
Expand Down
31 changes: 31 additions & 0 deletions src/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,32 @@ class Database : public ObjectWrap {
Baton(db_, cb_), filename(filename_) {}
};

struct FunctionInvocation {
sqlite3_context *context;
sqlite3_value **argv;
int argc;
bool complete;
};

struct FunctionBaton {
Database* db;
std::string name;
Persistent<Function> callback;
uv_async_t async;
uv_mutex_t mutex;
uv_cond_t condition;
std::queue<FunctionInvocation*> queue;

FunctionBaton(Database* db_, const char* name_, Handle<Function> cb_) :
db(db_), name(name_) {
async.data = this;
NanAssignPersistent(callback, cb_);
}
virtual ~FunctionBaton() {
NanDisposePersistent(callback);
}
};

typedef void (*Work_Callback)(Baton* baton);

struct Call {
Expand Down Expand Up @@ -152,6 +178,11 @@ class Database : public ObjectWrap {

static NAN_METHOD(Configure);

static NAN_METHOD(RegisterFunction);
static void FunctionEnqueue(sqlite3_context *context, int argc, sqlite3_value **argv);
static void FunctionExecute(FunctionBaton *baton, FunctionInvocation *invocation);
static void AsyncFunctionProcessQueue(uv_async_t *async);

static void SetBusyTimeout(Baton* baton);

static void RegisterTraceCallback(Baton* baton);
Expand Down
107 changes: 107 additions & 0 deletions test/user_functions.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
var sqlite3 = require('..');
var assert = require('assert');

describe('user functions', function() {
var db;
before(function(done) { db = new sqlite3.Database(':memory:', done); });

it('should allow registration of user functions', function() {
db.registerFunction('MY_UPPERCASE', function(value) {
return value.toUpperCase();
});
db.registerFunction('MY_STRING_JOIN', function(value1, value2) {
return [value1, value2].join(' ');
});
db.registerFunction('MY_Add', function(value1, value2) {
return value1 + value2;
});
db.registerFunction('MY_REGEX', function(regex, value) {
return !!value.match(new RegExp(regex));
});
db.registerFunction('MY_REGEX_VALUE', function(regex, value) {
return /match things/i;
});
db.registerFunction('MY_ERROR', function(value) {
throw new Error('This function always throws');
});
db.registerFunction('MY_UNHANDLED_TYPE', function(value) {
return {};
});
db.registerFunction('MY_NOTHING', function(value) {

});
});

it('should process user functions with one arg', function(done) {
db.all('SELECT MY_UPPERCASE("hello") AS txt', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].txt, 'HELLO')
done();
});
});

it('should process user functions with two args', function(done) {
db.all('SELECT MY_STRING_JOIN("hello", "world") AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, 'hello world');
done();
});
});

it('should process user functions with number args', function(done) {
db.all('SELECT MY_ADD(1, 2) AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, 3);
done();
});
});

it('allows writing of a regex function', function(done) {
db.all('SELECT MY_REGEX("colou?r", "color") AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(Boolean(rows[0].val), true);
done();
});
});

it('converts returned regex instances to strings', function(done) {
db.all('SELECT MY_REGEX_VALUE() AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, '/match things/i');
done();
});
});

it('reports errors thrown in functions', function(done) {
db.all('SELECT MY_ERROR() AS val', function(err, rows) {
assert.equal(err.message, 'SQLITE_ERROR: Uncaught Error: This function always throws');
assert.equal(rows, undefined);
done();
});
});

it('reports errors for unhandled types', function(done) {
db.all('SELECT MY_UNHANDLED_TYPE() AS val', function(err, rows) {
assert.equal(err.message, 'SQLITE_ERROR: invalid return type in ' +
'user function MY_UNHANDLED_TYPE');
assert.equal(rows, undefined);
done();
});
});

it('allows no return value from functions', function(done) {
db.all('SELECT MY_NOTHING() AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, undefined);
done();
});
});

after(function(done) { db.close(done); });
});