diff --git a/tests/common/test_postgresql_asyncpool.nim b/tests/common/test_postgresql_asyncpool.nim new file mode 100644 index 000000000..42a3af047 --- /dev/null +++ b/tests/common/test_postgresql_asyncpool.nim @@ -0,0 +1,16 @@ +{.used.} + +import + std/[strutils, os], + stew/results, + testutils/unittests, + chronos +import + ../../waku/common/postgres/asyncpool, + ../../waku/common/postgres/pg_asyncpool_opts + +suite "Async pool": + + asyncTest "Create connection pool": + ## TODO: extend unit tests + var pgOpts = PgAsyncPoolOptions.init() diff --git a/tests/v2/waku_archive/test_driver_postgres.nim b/tests/v2/waku_archive/test_driver_postgres.nim index c75549e44..1994af6b8 100644 --- a/tests/v2/waku_archive/test_driver_postgres.nim +++ b/tests/v2/waku_archive/test_driver_postgres.nim @@ -27,76 +27,61 @@ suite "Postgres driver": const storeMessageDbUrl = "postgres://postgres:test123@localhost:5432/postgres" asyncTest "Asynchronous queries": - #TODO: make the test asynchronous - return + let driverRes = PostgresDriver.new(dbUrl = storeMessageDbUrl, + maxConnections = 100) - ## When - let driverRes = PostgresDriver.new(storeMessageDbUrl) + assert driverRes.isOk(), driverRes.error - ## Then - require: - driverRes.isOk() + let driver = driverRes.value + discard await driver.reset() - let driver: ArchiveDriver = driverRes.tryGet() - require: - not driver.isNil() + var futures = newSeq[Future[ArchiveDriverResult[void]]](0) let beforeSleep = now() - for _ in 1 .. 20: - discard (PostgresDriver driver).sleep(1) + for _ in 1 .. 100: + futures.add(driver.sleep(1)) - require (now() - beforeSleep) < 20 + await allFutures(futures) + + let diff = now() - beforeSleep + # Actually, the diff randomly goes between 1 and 2 seconds. + # although in theory it should spend 1s because we establish 100 + # connections and we spawn 100 tasks that spend ~1s each. + require diff < 20 - ## Cleanup (await driver.close()).expect("driver to close") - asyncTest "init driver and database": - - ## When + asyncTest "Init database": let driverRes = PostgresDriver.new(storeMessageDbUrl) + assert driverRes.isOk(), driverRes.error - ## Then - require: - driverRes.isOk() + let driver = driverRes.value + discard await driver.reset() - let driver: ArchiveDriver = driverRes.tryGet() - require: - not driver.isNil() + let initRes = await driver.init() + assert initRes.isOk(), initRes.error - discard driverRes.get().reset() - let initRes = driverRes.get().init() - - require: - initRes.isOk() - - ## Cleanup (await driver.close()).expect("driver to close") - asyncTest "insert a message": - ## Given + asyncTest "Insert a message": const contentTopic = "test-content-topic" let driverRes = PostgresDriver.new(storeMessageDbUrl) + assert driverRes.isOk(), driverRes.error - require: - driverRes.isOk() + let driver = driverRes.get() - discard driverRes.get().reset() - discard driverRes.get().init() + discard await driver.reset() - let driver: ArchiveDriver = driverRes.tryGet() - require: - not driver.isNil() + let initRes = await driver.init() + assert initRes.isOk(), initRes.error let msg = fakeWakuMessage(contentTopic=contentTopic) let computedDigest = computeDigest(msg) - ## When - let putRes = await driver.put(DefaultPubsubTopic, msg, computedDigest, msg.timestamp) - ## Then - require: - putRes.isOk() + let putRes = await driver.put(DefaultPubsubTopic, msg, computedDigest, msg.timestamp) + assert putRes.isOk(), putRes.error let storedMsg = (await driver.getAllMessages()).tryGet() require: @@ -108,80 +93,61 @@ suite "Postgres driver": toHex(computedDigest.data) == toHex(digest) and toHex(actualMsg.payload) == toHex(msg.payload) - ## Cleanup (await driver.close()).expect("driver to close") - asyncTest "insert and query message": - ## Given + asyncTest "Insert and query message": const contentTopic1 = "test-content-topic-1" const contentTopic2 = "test-content-topic-2" const pubsubTopic1 = "pubsubtopic-1" const pubsubTopic2 = "pubsubtopic-2" let driverRes = PostgresDriver.new(storeMessageDbUrl) + assert driverRes.isOk(), driverRes.error - require: - driverRes.isOk() + let driver = driverRes.value - discard driverRes.get().reset() - discard driverRes.get().init() + discard await driver.reset() - let driver: ArchiveDriver = driverRes.tryGet() - require: - not driver.isNil() + let initRes = await driver.init() + assert initRes.isOk(), initRes.error let msg1 = fakeWakuMessage(contentTopic=contentTopic1) - ## When var putRes = await driver.put(pubsubTopic1, msg1, computeDigest(msg1), msg1.timestamp) - - ## Then - require: - putRes.isOk() + assert putRes.isOk(), putRes.error let msg2 = fakeWakuMessage(contentTopic=contentTopic2) - ## When putRes = await driver.put(pubsubTopic2, msg2, computeDigest(msg2), msg2.timestamp) - - ## Then - require: - putRes.isOk() + assert putRes.isOk(), putRes.error let countMessagesRes = await driver.getMessagesCount() - require: - countMessagesRes.isOk() and - countMessagesRes.get() == 2 + require countMessagesRes.isOk() and countMessagesRes.get() == 2 var messagesRes = await driver.getMessages(contentTopic = @[contentTopic1]) - require: - messagesRes.isOk() - - require: - messagesRes.get().len == 1 + require messagesRes.isOk() + require messagesRes.get().len == 1 # Get both content topics, check ordering messagesRes = await driver.getMessages(contentTopic = @[contentTopic1, contentTopic2]) - require: - messagesRes.isOk() + assert messagesRes.isOk(), messagesRes.error require: messagesRes.get().len == 2 and - messagesRes.get()[0][1].WakuMessage.contentTopic == contentTopic1 + messagesRes.get()[0][1].contentTopic == contentTopic1 # Descending order messagesRes = await driver.getMessages(contentTopic = @[contentTopic1, contentTopic2], ascendingOrder = false) - require: - messagesRes.isOk() + assert messagesRes.isOk(), messagesRes.error require: messagesRes.get().len == 2 and - messagesRes.get()[0][1].WakuMessage.contentTopic == contentTopic2 + messagesRes.get()[0][1].contentTopic == contentTopic2 # cursor # Get both content topics @@ -191,50 +157,39 @@ suite "Postgres driver": cursor = some( computeTestCursor(pubsubTopic1, messagesRes.get()[0][1]))) - require: - messagesRes.isOk() - - require: - messagesRes.get().len == 1 + require messagesRes.isOk() + require messagesRes.get().len == 1 # Get both content topics but one pubsub topic messagesRes = await driver.getMessages(contentTopic = @[contentTopic1, contentTopic2], pubsubTopic = some(pubsubTopic1)) - require: - messagesRes.isOk() + assert messagesRes.isOk(), messagesRes.error require: messagesRes.get().len == 1 and - messagesRes.get()[0][1].WakuMessage.contentTopic == contentTopic1 + messagesRes.get()[0][1].contentTopic == contentTopic1 # Limit messagesRes = await driver.getMessages(contentTopic = @[contentTopic1, contentTopic2], maxPageSize = 1) - require: - messagesRes.isOk() + assert messagesRes.isOk(), messagesRes.error + require messagesRes.get().len == 1 - require: - messagesRes.get().len == 1 - - ## Cleanup (await driver.close()).expect("driver to close") - asyncTest "insert true duplicated messages": + asyncTest "Insert true duplicated messages": # Validates that two completely equal messages can not be stored. - ## Given let driverRes = PostgresDriver.new(storeMessageDbUrl) + assert driverRes.isOk(), driverRes.error - require: - driverRes.isOk() + let driver = driverRes.value - discard driverRes.get().reset() - discard driverRes.get().init() + discard await driver.reset() - let driver: ArchiveDriver = driverRes.tryGet() - require: - not driver.isNil() + let initRes = await driver.init() + assert initRes.isOk(), initRes.error let now = now() @@ -243,14 +198,8 @@ suite "Postgres driver": var putRes = await driver.put(DefaultPubsubTopic, msg1, computeDigest(msg1), msg1.timestamp) - ## Then - require: - putRes.isOk() + assert putRes.isOk(), putRes.error putRes = await driver.put(DefaultPubsubTopic, msg2, computeDigest(msg2), msg2.timestamp) - ## Then - require: - not putRes.isOk() - - + require not putRes.isOk() diff --git a/waku/v2/waku_archive/driver/postgres_driver/asyncpool.nim b/waku/v2/waku_archive/driver/postgres_driver/asyncpool.nim new file mode 100644 index 000000000..2d7866f31 --- /dev/null +++ b/waku/v2/waku_archive/driver/postgres_driver/asyncpool.nim @@ -0,0 +1,208 @@ +# Simple async pool driver for postgress. +# Inspired by: https://github.com/treeform/pg/ +when (NimMajor, NimMinor) < (1, 4): + {.push raises: [Defect].} +else: + {.push raises: [].} + +import + std/sequtils, + stew/results, + chronicles, + chronos +import + ../../driver, + ./connection + +logScope: + topics = "postgres asyncpool" + +type PgAsyncPoolState {.pure.} = enum + Closed, + Live, + Closing + +type + PgDbConn = object + dbConn: DbConn + busy: bool + open: bool + insertStmt: SqlPrepared + +type + # Database connection pool + PgAsyncPool* = ref object + connString: string + maxConnections: int + + state: PgAsyncPoolState + conns: seq[PgDbConn] + +proc new*(T: type PgAsyncPool, + connString: string, + maxConnections: int): T = + + let pool = PgAsyncPool( + connString: connString, + maxConnections: maxConnections, + state: PgAsyncPoolState.Live, + conns: newSeq[PgDbConn](0) + ) + + return pool + +func isLive(pool: PgAsyncPool): bool = + pool.state == PgAsyncPoolState.Live + +func isBusy(pool: PgAsyncPool): bool = + pool.conns.mapIt(it.busy).allIt(it) + +proc close*(pool: PgAsyncPool): + Future[Result[void, string]] {.async.} = + ## Gracefully wait and close all openned connections + + if pool.state == PgAsyncPoolState.Closing: + while pool.state == PgAsyncPoolState.Closing: + await sleepAsync(0.milliseconds) # Do not block the async runtime + return ok() + + pool.state = PgAsyncPoolState.Closing + + # wait for the connections to be released and close them, without + # blocking the async runtime + if pool.conns.anyIt(it.busy): + while pool.conns.anyIt(it.busy): + await sleepAsync(0.milliseconds) + + for i in 0.. 0: + return err($message) + + return ok() + +proc open*(connString: string): + Result[DbConn, string] = + ## Opens a new connection. + var conn: DbConn = nil + try: + conn = open("","", "", connString) + except DbError: + return err("exception opening new connection: " & + getCurrentExceptionMsg()) + + if conn.status != CONNECTION_OK: + let checkRes = conn.check() + if checkRes.isErr(): + return err("failed to connect to database: " & checkRes.error) + + return err("unknown reason") + + ok(conn) + +proc rows*(db: DbConn, + query: SqlQuery, + args: seq[string]): + Future[Result[seq[Row], string]] {.async.} = + ## Runs the SQL getting results. + + if db.status != CONNECTION_OK: + let checkRes = db.check() + if checkRes.isErr(): + return err("failed to connect to database: " & checkRes.error) + + return err("unknown reason") + + var wellFormedQuery = "" + try: + wellFormedQuery = dbFormat(query, args) + except DbError: + return err("exception formatting the query: " & + getCurrentExceptionMsg()) + + let success = db.pqsendQuery(cstring(wellFormedQuery)) + if success != 1: + let checkRes = db.check() + if checkRes.isErr(): + return err("failed pqsendQuery: " & checkRes.error) + + return err("failed pqsendQuery: unknown reason") + + var ret = newSeq[Row](0) + + while true: + + let success = db.pqconsumeInput() + if success != 1: + let checkRes = db.check() + if checkRes.isErr(): + return err("failed pqconsumeInput: " & checkRes.error) + + return err("failed pqconsumeInput: unknown reason") + + if db.pqisBusy() == 1: + await sleepAsync(0.milliseconds) # Do not block the async runtime + continue + + var pqResult = db.pqgetResult() + if pqResult == nil: + # Check if its a real error or just end of results + let checkRes = db.check() + if checkRes.isErr(): + return err("error in rows: " & checkRes.error) + + return ok(ret) # reached the end of the results + + var cols = pqResult.pqnfields() + var row = cols.newRow() + for i in 0'i32 .. pqResult.pqNtuples() - 1: + pqResult.setRow(row, i, cols) # puts the value in the row + ret.add(row) + + pqclear(pqResult) diff --git a/waku/v2/waku_archive/driver/postgres_driver/postgres_driver.nim b/waku/v2/waku_archive/driver/postgres_driver/postgres_driver.nim index 71e157527..47786faf6 100644 --- a/waku/v2/waku_archive/driver/postgres_driver/postgres_driver.nim +++ b/waku/v2/waku_archive/driver/postgres_driver/postgres_driver.nim @@ -4,24 +4,23 @@ else: {.push raises: [].} import - std/db_postgres, std/strformat, std/nre, std/options, std/strutils, stew/[results,byteutils], + db_postgres, chronos - import ../../../waku_core, ../../common, - ../../driver + ../../driver, + asyncpool export postgres_driver type PostgresDriver* = ref object of ArchiveDriver - connection: DbConn - preparedInsert: SqlPrepared + connPool: PgAsyncPool proc dropTableQuery(): string = "DROP TABLE messages" @@ -39,85 +38,84 @@ proc createTableQuery(): string = ");" proc insertRow(): string = + # TODO: get the sql queries from a file """INSERT INTO messages (id, storedAt, contentTopic, payload, pubsubTopic, version, timestamp) VALUES ($1, $2, $3, $4, $5, $6, $7);""" -proc new*(T: type PostgresDriver, storeMessageDbUrl: string): ArchiveDriverResult[T] = - var host: string - var user: string - var password: string - var dbName: string - var port: string - var connectionString: string - var dbConn: DbConn +const DefaultMaxConnections = 5 + +proc new*(T: type PostgresDriver, + dbUrl: string, + maxConnections: int = DefaultMaxConnections): + ArchiveDriverResult[T] = + + var connPool: PgAsyncPool + try: let regex = re("""^postgres:\/\/([^:]+):([^@]+)@([^:]+):(\d+)\/(.+)$""") - let matches = find(storeMessageDbUrl,regex).get.captures - user = matches[0] - password = matches[1] - host = matches[2] - port = matches[3] - dbName = matches[4] - connectionString = "user={user} host={host} port={port} dbname={dbName} password={password}".fmt + let matches = find(dbUrl,regex).get.captures + let user = matches[0] + let password = matches[1] + let host = matches[2] + let port = matches[3] + let dbName = matches[4] + let connectionString = fmt"user={user} host={host} port={port} dbname={dbName} password={password}" + + connPool = PgAsyncPool.new(connectionString, maxConnections) + except KeyError,InvalidUnicodeError, RegexInternalError, ValueError, StudyError, SyntaxError: return err("could not parse postgres string") - try: - dbConn = open("","", "", connectionString) - except DbError: - return err("could not connect to postgres") + return ok(PostgresDriver(connPool: connPool)) - return ok(PostgresDriver(connection: dbConn)) +proc createMessageTable(s: PostgresDriver): + Future[ArchiveDriverResult[void]] {.async.} = -method reset*(s: PostgresDriver): ArchiveDriverResult[void] {.base.} = - try: - let res = s.connection.tryExec(sql(dropTableQuery())) - if not res: - return err("failed to reset database") - except DbError: - return err("failed to reset database") + let execRes = await s.connPool.exec(createTableQuery(), newSeq[string](0)) + if execRes.isErr(): + return err("error in createMessageTable: " & execRes.error) return ok() -method init*(s: PostgresDriver): ArchiveDriverResult[void] {.base.} = - try: - let res = s.connection.tryExec(sql(createTableQuery())) - if not res: - return err("failed to initialize") - s.preparedInsert = prepare(s.connection, "insertRow", sql(insertRow()), 7) - except DbError: - let - e = getCurrentException() - msg = getCurrentExceptionMsg() - exceptionMessage = "failed to init driver, got exception " & - repr(e) & " with message " & msg - return err(exceptionMessage) +proc deleteMessageTable*(s: PostgresDriver): + Future[ArchiveDriverResult[void]] {.async.} = + + let ret = await s.connPool.exec(dropTableQuery(), newSeq[string](0)) + return ret + +proc init*(s: PostgresDriver): Future[ArchiveDriverResult[void]] {.async.} = + + let createMsgRes = await s.createMessageTable() + if createMsgRes.isErr(): + return err("createMsgRes.isErr in init: " & createMsgRes.error) return ok() +proc reset*(s: PostgresDriver): Future[ArchiveDriverResult[void]] {.async.} = + + let ret = await s.deleteMessageTable() + return ret + method put*(s: PostgresDriver, pubsubTopic: PubsubTopic, message: WakuMessage, digest: MessageDigest, receivedTime: Timestamp): Future[ArchiveDriverResult[void]] {.async.} = - try: - let res = s.connection.tryExec(s.preparedInsert, - toHex(digest.data), - receivedTime, - message.contentTopic, - toHex(message.payload), - pubsubTopic, - int64(message.version), - message.timestamp) - if not res: - return err("failed to insert into database") - except DbError: - return err("failed to insert into database") - return ok() + let ret = await s.connPool.runStmt(insertRow(), + @[toHex(digest.data), + $receivedTime, + message.contentTopic, + toHex(message.payload), + pubsubTopic, + $message.version, + $message.timestamp]) + return ret + +proc toArchiveRow(r: Row): ArchiveDriverResult[ArchiveRow] = + # Converts a postgres row into an ArchiveRow -proc extractRow(r: Row): ArchiveDriverResult[ArchiveRow] = var wakuMessage: WakuMessage var timestamp: Timestamp var version: uint @@ -151,17 +149,18 @@ proc extractRow(r: Row): ArchiveDriverResult[ArchiveRow] = method getAllMessages*(s: PostgresDriver): Future[ArchiveDriverResult[seq[ArchiveRow]]] {.async.} = ## Retrieve all messages from the store. - var rows: seq[Row] - var results: seq[ArchiveRow] - try: - rows = s.connection.getAllRows(sql("""SELECT storedAt, contentTopic, - payload, pubsubTopic, version, timestamp, - id FROM messages ORDER BY storedAt ASC""")) - except DbError: - return err("failed to query rows") - for r in rows: - let rowRes = extractRow(r) + let rowsRes = await s.connPool.query("""SELECT storedAt, contentTopic, + payload, pubsubTopic, version, timestamp, + id FROM messages ORDER BY storedAt ASC""", + newSeq[string](0)) + + if rowsRes.isErr(): + return err("failed in query: " & rowsRes.error) + + var results: seq[ArchiveRow] + for r in rowsRes.value: + let rowRes = r.toArchiveRow() if rowRes.isErr(): return err("failed to extract row") @@ -221,17 +220,15 @@ method getMessages*(s: PostgresDriver, query &= " LIMIT ?" args.add($maxPageSize) - var rows: seq[Row] - var results: seq[ArchiveRow] - try: - rows = s.connection.getAllRows(sql(query), args) - except DbError: - return err("failed to query rows") + let rowsRes = await s.connPool.query(query, args) + if rowsRes.isErr(): + return err("failed to run query: " & rowsRes.error) - for r in rows: - let rowRes = extractRow(r) + var results: seq[ArchiveRow] + for r in rowsRes.value: + let rowRes = r.toArchiveRow() if rowRes.isErr(): - return err("failed to extract row") + return err("failed to extract row: " & rowRes.error) results.add(rowRes.get()) @@ -239,16 +236,20 @@ method getMessages*(s: PostgresDriver, method getMessagesCount*(s: PostgresDriver): Future[ArchiveDriverResult[int64]] {.async.} = - var count: int64 - try: - let row = s.connection.getRow(sql("""SELECT COUNT(1) FROM messages""")) - count = parseInt(row[0]) - except DbError: - return err("failed to query count") - except ValueError: - return err("failed to parse query count result") + let rowsRes = await s.connPool.query("SELECT COUNT(1) FROM messages") + if rowsRes.isErr(): + return err("failed to get messages count: " & rowsRes.error) + let rows = rowsRes.get() + if rows.len == 0: + return err("failed to get messages count: rows.len == 0") + + let rowFields = rows[0] + if rowFields.len == 0: + return err("failed to get messages count: rowFields.len == 0") + + let count = parseInt(rowFields[0]) return ok(count) method getOldestMessageTimestamp*(s: PostgresDriver): @@ -272,8 +273,8 @@ method deleteOldestMessagesNotWithinLimit*(s: PostgresDriver, method close*(s: PostgresDriver): Future[ArchiveDriverResult[void]] {.async.} = ## Close the database connection - s.connection.close() - return ok() + let result = await s.connPool.close() + return result proc sleep*(s: PostgresDriver, seconds: int): Future[ArchiveDriverResult[void]] {.async.} = @@ -282,9 +283,11 @@ proc sleep*(s: PostgresDriver, seconds: int): # database for the amount of seconds given as a parameter. try: let params = @[$seconds] - s.connection.exec(sql"SELECT pg_sleep(?)", params) + let sleepRes = await s.connPool.query("SELECT pg_sleep(?)", params) + if sleepRes.isErr(): + return err("error in postgres_driver sleep: " & sleepRes.error) except DbError: # This always raises an exception although the sleep works return err("exception sleeping: " & getCurrentExceptionMsg()) - return ok() \ No newline at end of file + return ok()