commit d0d0bdeb4ff3136e7fc1136a4dc7fb0ff1bb2f78
parent 2c983e338c9a53167dbf3c63f9125e7b188ea599
Author: Joris Vink <joris@coders.se>
Date: Thu, 25 Apr 2019 23:13:13 +0200
Improve pgsql support.
- Add kore_pgsql_query_param_fields() which allows you to pass in the
arrays for values, lengths and formats yourself.
- Add kore_pgsql_column_binary() which will return 1 if the given column
index contains a binary result or 0 if it contains a text result.
- Change the query call in req.pgsql() for Python to always use the
parameterized queries.
This adds the 'params' and 'binary' keywords to the req.pgsql method.
Eg:
result = await req.pgsql("db", "INSERT INTO foo (field) VALUES($1"),
params=["this is my value"])
Diffstat:
4 files changed, 182 insertions(+), 38 deletions(-)
diff --git a/include/kore/pgsql.h b/include/kore/pgsql.h
@@ -86,6 +86,8 @@ int kore_pgsql_query_params(struct kore_pgsql *,
const char *, int, int, ...);
int kore_pgsql_v_query_params(struct kore_pgsql *,
const char *, int, int, va_list);
+int kore_pgsql_query_param_fields(struct kore_pgsql *, const char *,
+ int, int, const char **, int *, int *);
int kore_pgsql_register(const char *, const char *);
int kore_pgsql_ntuples(struct kore_pgsql *);
int kore_pgsql_nfields(struct kore_pgsql *);
@@ -93,6 +95,7 @@ void kore_pgsql_logerror(struct kore_pgsql *);
char *kore_pgsql_fieldname(struct kore_pgsql *, int);
char *kore_pgsql_getvalue(struct kore_pgsql *, int, int);
int kore_pgsql_getlength(struct kore_pgsql *, int, int);
+int kore_pgsql_column_binary(struct kore_pgsql *, int);
#if defined(__cplusplus)
}
diff --git a/include/kore/python_methods.h b/include/kore/python_methods.h
@@ -556,7 +556,7 @@ static void pyhttp_dealloc(struct pyhttp_request *);
static void pyhttp_file_dealloc(struct pyhttp_file *);
#if defined(KORE_USE_PGSQL)
-static PyObject *pyhttp_pgsql(struct pyhttp_request *, PyObject *);
+static PyObject *pyhttp_pgsql(struct pyhttp_request *, PyObject *, PyObject *);
#endif
static PyObject *pyhttp_cookie(struct pyhttp_request *, PyObject *);
static PyObject *pyhttp_response(struct pyhttp_request *, PyObject *);
@@ -574,7 +574,7 @@ static PyObject *pyhttp_websocket_handshake(struct pyhttp_request *,
static PyMethodDef pyhttp_request_methods[] = {
#if defined(KORE_USE_PGSQL)
- METHOD("pgsql", pyhttp_pgsql, METH_VARARGS),
+ METHOD("pgsql", pyhttp_pgsql, METH_VARARGS | METH_KEYWORDS),
#endif
METHOD("cookie", pyhttp_cookie, METH_VARARGS),
METHOD("response", pyhttp_response, METH_VARARGS),
@@ -739,15 +739,25 @@ static PyTypeObject pyhttp_client_op_type = {
struct pykore_pgsql {
PyObject_HEAD
int state;
+ int binary;
+ struct kore_pgsql sql;
+
char *db;
- char *query;
struct http_request *req;
+ char *query;
PyObject *result;
- struct kore_pgsql sql;
+ struct {
+ int count;
+ const char **values;
+ int *lengths;
+ int *formats;
+ PyObject **objs;
+ } param;
};
static void pykore_pgsql_dealloc(struct pykore_pgsql *);
-int pykore_pgsql_result(struct pykore_pgsql *);
+static int pykore_pgsql_result(struct pykore_pgsql *);
+static int pykore_pgsql_params(struct pykore_pgsql *, PyObject *);
static PyObject *pykore_pgsql_await(PyObject *);
static PyObject *pykore_pgsql_iternext(struct pykore_pgsql *);
diff --git a/src/pgsql.c b/src/pgsql.c
@@ -202,10 +202,10 @@ kore_pgsql_query(struct kore_pgsql *pgsql, const char *query)
int
kore_pgsql_v_query_params(struct kore_pgsql *pgsql,
- const char *query, int result, int count, va_list args)
+ const char *query, int binary, int count, va_list args)
{
- u_int8_t i;
- char **values;
+ int i;
+ const char **values;
int *lengths, *formats, ret;
if (pgsql->conn == NULL) {
@@ -229,49 +229,59 @@ kore_pgsql_v_query_params(struct kore_pgsql *pgsql,
values = NULL;
}
- ret = KORE_RESULT_ERROR;
+ ret = kore_pgsql_query_param_fields(pgsql, query, binary, count,
+ values, lengths, formats);
+
+ kore_free(values);
+ kore_free(lengths);
+ kore_free(formats);
+
+ return (ret);
+}
+
+int
+kore_pgsql_query_param_fields(struct kore_pgsql *pgsql, const char *query,
+ int binary, int count, const char **values, int *lengths, int *formats)
+{
+ if (pgsql->conn == NULL) {
+ pgsql_set_error(pgsql, "no connection was set before query");
+ return (KORE_RESULT_ERROR);
+ }
if (pgsql->flags & KORE_PGSQL_SYNC) {
pgsql->result = PQexecParams(pgsql->conn->db, query, count,
NULL, (const char * const *)values, lengths, formats,
- result);
+ binary);
if ((PQresultStatus(pgsql->result) != PGRES_TUPLES_OK) &&
(PQresultStatus(pgsql->result) != PGRES_COMMAND_OK)) {
pgsql_set_error(pgsql, PQerrorMessage(pgsql->conn->db));
- goto cleanup;
+ return (KORE_RESULT_ERROR);
}
pgsql->state = KORE_PGSQL_STATE_DONE;
} else {
if (!PQsendQueryParams(pgsql->conn->db, query, count, NULL,
- (const char * const *)values, lengths, formats, result)) {
+ (const char * const *)values, lengths, formats, binary)) {
pgsql_set_error(pgsql, PQerrorMessage(pgsql->conn->db));
- goto cleanup;
+ return (KORE_RESULT_ERROR);
}
pgsql_schedule(pgsql);
}
- ret = KORE_RESULT_OK;
-
-cleanup:
- kore_free(values);
- kore_free(lengths);
- kore_free(formats);
-
- return (ret);
+ return (KORE_RESULT_OK);
}
int
kore_pgsql_query_params(struct kore_pgsql *pgsql,
- const char *query, int result, int count, ...)
+ const char *query, int binary, int count, ...)
{
int ret;
va_list args;
va_start(args, count);
- ret = kore_pgsql_v_query_params(pgsql, query, result, count, args);
+ ret = kore_pgsql_v_query_params(pgsql, query, binary, count, args);
va_end(args);
return (ret);
@@ -432,6 +442,12 @@ kore_pgsql_getvalue(struct kore_pgsql *pgsql, int row, int col)
return (PQgetvalue(pgsql->result, row, col));
}
+int
+kore_pgsql_column_binary(struct kore_pgsql *pgsql, int col)
+{
+ return (PQfformat(pgsql->result, col));
+}
+
static struct pgsql_conn *
pgsql_conn_next(struct kore_pgsql *pgsql, struct pgsql_db *db)
{
diff --git a/src/python.c b/src/python.c
@@ -78,7 +78,8 @@ static void pygather_reap_coro(struct pygather_op *,
#if defined(KORE_USE_PGSQL)
static PyObject *pykore_pgsql_alloc(struct http_request *,
- const char *, const char *);
+ const char *, const char *, PyObject *);
+static int pykore_pgsql_params(struct pykore_pgsql *, PyObject *);
#endif
#if defined(KORE_USE_CURL)
@@ -3572,6 +3573,8 @@ pyhttp_file_get_filename(struct pyhttp_file *pyfile, void *closure)
static void
pykore_pgsql_dealloc(struct pykore_pgsql *pysql)
{
+ Py_ssize_t i;
+
kore_free(pysql->db);
kore_free(pysql->query);
kore_pgsql_cleanup(&pysql->sql);
@@ -3579,12 +3582,22 @@ pykore_pgsql_dealloc(struct pykore_pgsql *pysql)
if (pysql->result != NULL)
Py_DECREF(pysql->result);
+ for (i = 0; i < pysql->param.count; i++)
+ Py_DECREF(pysql->param.objs[i]);
+
+ kore_free(pysql->param.objs);
+ kore_free(pysql->param.values);
+ kore_free(pysql->param.lengths);
+ kore_free(pysql->param.formats);
+
PyObject_Del((PyObject *)pysql);
}
static PyObject *
-pykore_pgsql_alloc(struct http_request *req, const char *db, const char *query)
+pykore_pgsql_alloc(struct http_request *req, const char *db, const char *query,
+ PyObject *kwargs)
{
+ PyObject *obj;
struct pykore_pgsql *pysql;
pysql = PyObject_New(struct pykore_pgsql, &pykore_pgsql_type);
@@ -3597,11 +3610,105 @@ pykore_pgsql_alloc(struct http_request *req, const char *db, const char *query)
pysql->query = kore_strdup(query);
pysql->state = PYKORE_PGSQL_PREINIT;
+ pysql->binary = 0;
+ pysql->param.objs = 0;
+ pysql->param.count = 0;
+ pysql->param.values = NULL;
+ pysql->param.lengths = NULL;
+ pysql->param.formats = NULL;
+
memset(&pysql->sql, 0, sizeof(pysql->sql));
+ if (kwargs != NULL) {
+ if ((obj = PyDict_GetItemString(kwargs, "params")) != NULL) {
+ if (!pykore_pgsql_params(pysql, obj)) {
+ Py_DECREF((PyObject *)pysql);
+ return (NULL);
+ }
+ }
+
+ if ((obj = PyDict_GetItemString(kwargs, "binary")) != NULL) {
+ if (obj == Py_True) {
+ pysql->binary = 1;
+ } else if (obj == Py_False) {
+ pysql->binary = 0;
+ } else {
+ Py_DECREF((PyObject *)pysql);
+ PyErr_SetString(PyExc_RuntimeError,
+ "pgsql: binary not True or False");
+ return (NULL);
+ }
+ }
+ }
+
return ((PyObject *)pysql);
}
+static int
+pykore_pgsql_params(struct pykore_pgsql *op, PyObject *list)
+{
+ union { const char *cp; char *p; } ptr;
+ PyObject *item;
+ int format;
+ Py_ssize_t i, len, vlen;
+
+ if (!PyList_CheckExact(list)) {
+ if (list == Py_None)
+ return (KORE_RESULT_OK);
+
+ PyErr_SetString(PyExc_RuntimeError,
+ "pgsql: params keyword must be a list");
+ return (KORE_RESULT_ERROR);
+ }
+
+ len = PyList_Size(list);
+ if (len == 0)
+ return (KORE_RESULT_OK);
+
+ if (len > INT_MAX) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "pgsql: list length too large");
+ return (KORE_RESULT_ERROR);
+ }
+
+ op->param.count = len;
+ op->param.lengths = kore_calloc(len, sizeof(int));
+ op->param.formats = kore_calloc(len, sizeof(int));
+ op->param.values = kore_calloc(len, sizeof(char *));
+ op->param.objs = kore_calloc(len, sizeof(PyObject *));
+
+ for (i = 0; i < len; i++) {
+ if ((item = PyList_GetItem(list, i)) == NULL)
+ return (KORE_RESULT_ERROR);
+
+ if (PyUnicode_CheckExact(item)) {
+ format = 0;
+ ptr.cp = PyUnicode_AsUTF8AndSize(item, &vlen);
+ } else if (PyBytes_CheckExact(item)) {
+ format = 1;
+ if (PyBytes_AsStringAndSize(item, &ptr.p, &vlen) == -1)
+ ptr.p = NULL;
+ } else {
+ PyErr_Format(PyExc_RuntimeError,
+ "pgsql: item %zu is not a string or bytes", i);
+ return (KORE_RESULT_ERROR);
+ }
+
+ if (ptr.cp == NULL)
+ return (KORE_RESULT_ERROR);
+
+ op->param.lengths[i] = vlen;
+ op->param.values[i] = ptr.cp;
+ op->param.formats[i] = format;
+
+ /* Hold on to it since we are directly referencing its data. */
+ op->param.objs[i] = item;
+ Py_INCREF(item);
+ }
+
+ return (KORE_RESULT_OK);
+}
+
static PyObject *
pykore_pgsql_iternext(struct pykore_pgsql *pysql)
{
@@ -3616,15 +3723,18 @@ pykore_pgsql_iternext(struct pykore_pgsql *pysql)
KORE_PGSQL_ASYNC)) {
if (pysql->sql.state == KORE_PGSQL_STATE_INIT)
break;
- kore_pgsql_logerror(&pysql->sql);
- PyErr_SetString(PyExc_RuntimeError, "pgsql error");
+ PyErr_Format(PyExc_RuntimeError, "pgsql error: %s",
+ pysql->sql.error);
return (NULL);
}
/* fallthrough */
case PYKORE_PGSQL_QUERY:
- if (!kore_pgsql_query(&pysql->sql, pysql->query)) {
- kore_pgsql_logerror(&pysql->sql);
- PyErr_SetString(PyExc_RuntimeError, "pgsql error");
+ if (!kore_pgsql_query_param_fields(&pysql->sql,
+ pysql->query, pysql->binary,
+ pysql->param.count, pysql->param.values,
+ pysql->param.lengths, pysql->param.formats)) {
+ PyErr_Format(PyExc_RuntimeError,
+ "pgsql error: %s", pysql->sql.error);
return (NULL);
}
pysql->state = PYKORE_PGSQL_WAIT;
@@ -3645,9 +3755,8 @@ wait_again:
}
return (NULL);
case KORE_PGSQL_STATE_ERROR:
- kore_pgsql_logerror(&pysql->sql);
- PyErr_SetString(PyExc_RuntimeError,
- "failed to perform query");
+ PyErr_Format(PyExc_RuntimeError,
+ "failed to perform query: %s", pysql->sql.error);
return (NULL);
case KORE_PGSQL_STATE_RESULT:
if (!pykore_pgsql_result(pysql))
@@ -3674,13 +3783,13 @@ pykore_pgsql_await(PyObject *obj)
return (obj);
}
-int
+static int
pykore_pgsql_result(struct pykore_pgsql *pysql)
{
const char *val;
char key[64];
PyObject *list, *pyrow, *pyval;
- int rows, row, field, fields;
+ int rows, row, field, fields, len;
if ((list = PyList_New(0)) == NULL) {
PyErr_SetNone(PyExc_MemoryError);
@@ -3699,8 +3808,14 @@ pykore_pgsql_result(struct pykore_pgsql *pysql)
for (field = 0; field < fields; field++) {
val = kore_pgsql_getvalue(&pysql->sql, row, field);
+ len = kore_pgsql_getlength(&pysql->sql, row, field);
+
+ if (kore_pgsql_column_binary(&pysql->sql, field)) {
+ pyval = PyBytes_FromStringAndSize(val, len);
+ } else {
+ pyval = PyUnicode_FromString(val);
+ }
- pyval = PyUnicode_FromString(val);
if (pyval == NULL) {
Py_DECREF(pyrow);
Py_DECREF(list);
@@ -3741,7 +3856,7 @@ pykore_pgsql_result(struct pykore_pgsql *pysql)
}
static PyObject *
-pyhttp_pgsql(struct pyhttp_request *pyreq, PyObject *args)
+pyhttp_pgsql(struct pyhttp_request *pyreq, PyObject *args, PyObject *kwargs)
{
PyObject *obj;
const char *db, *query;
@@ -3749,7 +3864,7 @@ pyhttp_pgsql(struct pyhttp_request *pyreq, PyObject *args)
if (!PyArg_ParseTuple(args, "ss", &db, &query))
return (NULL);
- if ((obj = pykore_pgsql_alloc(pyreq->req, db, query)) == NULL)
+ if ((obj = pykore_pgsql_alloc(pyreq->req, db, query, kwargs)) == NULL)
return (PyErr_NoMemory());
Py_INCREF(obj);