From 31a4289cc635ab3c4aad7ec4f94fb124ca239d77 Mon Sep 17 00:00:00 2001 From: Guangcong Luo Date: Mon, 4 Aug 2025 19:51:42 -0700 Subject: [PATCH] Refactor Teams plugin to use lib/database (#11109) * Refactor Teams.save Fixes a bug where passwords were changed every time a team was updated. * Refactor Teams to use lib/database * Add unit tests (This found a bug which has also been fixed) * Test more things --------- Co-authored-by: Mia <49593536+mia-pi-git@users.noreply.github.com> --- lib/database.ts | 56 ++++++----- lib/index.ts | 1 - lib/postgres.ts | 131 ------------------------- server/chat-plugins/teams.ts | 182 ++++++++++++++++------------------- test/lib/postgres.js | 56 +++++++---- 5 files changed, 149 insertions(+), 277 deletions(-) delete mode 100644 lib/postgres.ts diff --git a/lib/database.ts b/lib/database.ts index 5f7b79e7cc..c35fb9dc7f 100644 --- a/lib/database.ts +++ b/lib/database.ts @@ -11,7 +11,8 @@ import * as pg from 'pg'; export type BasicSQLValue = string | number | null; export type SQLRow = { [k: string]: BasicSQLValue }; -export type SQLValue = BasicSQLValue | SQLStatement | PartialOrSQL | BasicSQLValue[] | undefined; +export type SQLValue = + BasicSQLValue | SQLStatement | SQLStatement[] | PartialOrSQL | BasicSQLValue[] | undefined; export function isSQL(value: any): value is SQLStatement { /** @@ -35,61 +36,66 @@ export class SQLStatement { constructor(strings: TemplateStringsArray | string[], values: SQLValue[]) { this.sql = [strings[0]]; this.values = []; - for (let i = 0; i < strings.length; i++) { - this.append(values[i], strings[i + 1]); + for (let i = 0; i < strings.length - 1; i++) { + this.append(values[i]).appendRaw(strings[i + 1]); } } - append(value: SQLValue, nextString = ''): this { + appendRaw(str: string): this { + this.sql[this.sql.length - 1] += str; + return this; + } + append(value: SQLValue): this { if (isSQL(value)) { if (!value.sql.length) return this; - const oldLength = this.sql.length; + this.appendRaw(value.sql[0]); this.sql = this.sql.concat(value.sql.slice(1)); - this.sql[oldLength - 1] += value.sql[0]; this.values = this.values.concat(value.values); - if (nextString) this.sql[this.sql.length - 1] += nextString; } else if (typeof value === 'string' || typeof value === 'number' || value === null) { this.values.push(value); - this.sql.push(nextString); + this.sql.push(''); } else if (value === undefined) { - this.sql[this.sql.length - 1] += nextString; + // do nothing } else if (Array.isArray(value)) { - if ('"`'.includes(this.sql[this.sql.length - 1].slice(-1))) { + if (!value.length || isSQL(value[0])) { + // array of SQL statements + for (const part of value) this.append(part); + } else if ('"`'.includes(this.sql[this.sql.length - 1].slice(-1))) { // "`a`, `b`" syntax const quoteChar = this.sql[this.sql.length - 1].slice(-1); for (const col of value) { - this.append(col, `${quoteChar}, ${quoteChar}`); + this.append(col).appendRaw(`${quoteChar}, ${quoteChar}`); } - this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + nextString; + this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4); } else { // "1, 2" syntax for (const val of value) { - this.append(val, `, `); + this.append(val).appendRaw(`, `); } - this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2) + nextString; + this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2); } } else if (this.sql[this.sql.length - 1].endsWith('(')) { // "(`a`, `b`) VALUES (1, 2)" syntax - this.sql[this.sql.length - 1] += `"`; + this.appendRaw(`"`); for (const col in value) { - this.append(col, `", "`); + this.append(col).appendRaw(`", "`); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `") VALUES (`; for (const col in value) { - this.append(value[col], `, `); + this.append(value[col]).appendRaw(`, `); } - this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2) + nextString; + this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2); } else if (this.sql[this.sql.length - 1].toUpperCase().endsWith(' SET ')) { // "`a` = 1, `b` = 2" syntax - this.sql[this.sql.length - 1] += `"`; + this.appendRaw(`"`); for (const col in value) { - this.append(col, `" = `); - this.append(value[col], `, "`); + this.append(col).appendRaw(`" = `); + this.append(value[col]).appendRaw(`, "`); } - this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -3) + nextString; + this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -3); } else { throw new Error( `Objects can only appear in (obj) or after SET; ` + - `unrecognized: ${this.sql[this.sql.length - 1]}[obj]${nextString}` + `unrecognized: ${this.sql[this.sql.length - 1]}[obj]` ); } return this; @@ -251,7 +257,7 @@ export class DatabaseTable { return (strings, ...rest) => this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(strings, rest)}`; } - updateOne(partialRow: PartialOrSQL): + updateOne(partialRow: PartialOrSQL | SQLStatement): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise> { return (s, ...r) => this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(s, r)} LIMIT 1`; @@ -375,7 +381,7 @@ export class MySQLDatabase extends Database { export class PGDatabase extends Database { override type = 'pg' as const; constructor(config: pg.PoolConfig) { - super(new pg.Pool(config)); + super(config ? new pg.Pool(config) : null!); } override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { let sql = query.sql[0]; diff --git a/lib/index.ts b/lib/index.ts index f4fe095cd8..f2bd72e64b 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -7,4 +7,3 @@ export * as Utils from './utils'; export { crashlogger } from './crashlogger'; export * as ProcessManager from './process-manager'; export { SQL } from './sql'; -export { PostgresDatabase } from './postgres'; diff --git a/lib/postgres.ts b/lib/postgres.ts deleted file mode 100644 index 3027375dfc..0000000000 --- a/lib/postgres.ts +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Library made to simplify accessing / connecting to postgres databases, - * and to cleanly handle when the pg module isn't installed. - * @author mia-pi-git - */ - -// eslint-disable-next-line @typescript-eslint/ban-ts-comment -// @ts-ignore in case module doesn't exist -import type * as PG from 'pg'; -import type { SQLStatement } from 'sql-template-strings'; -import * as Streams from './streams'; -import { FS } from './fs'; -import * as Utils from './utils'; - -interface MigrationOptions { - table: string; - migrationsFolder: string; - baseSchemaFile: string; -} - -export class PostgresDatabase { - private pool: PG.Pool; - constructor(config = PostgresDatabase.getConfig()) { - try { - this.pool = new (require('pg').Pool)(config); - } catch { - this.pool = null!; - } - } - destroy() { - return this.pool.end(); - } - async query(statement: string | SQLStatement, values?: any[]) { - if (!this.pool) { - throw new Error(`Attempting to use postgres without 'pg' installed`); - } - let result; - try { - result = await this.pool.query(statement, values); - } catch (e: any) { - // postgres won't give accurate stacks unless we do this - throw new Error(e.message); - } - return result?.rows || []; - } - static getConfig() { - let config: AnyObject = {}; - try { - config = require(FS.ROOT_PATH + '/config/config').usepostgres; - if (!config) throw new Error('Missing config for pg database'); - } catch {} - return config; - } - async transaction(callback: (conn: PG.PoolClient) => any, depth = 0): Promise { - const conn = await this.pool.connect(); - await conn.query(`BEGIN`); - let result; - try { - result = await callback(conn); - } catch (e: any) { - await conn.query(`ROLLBACK`); - // two concurrent transactions conflicted, try again - if (e.code === '40001' && depth <= 10) { - return this.transaction(callback, depth + 1); - // There is a bug in Postgres that causes some - // serialization failures to be reported as failed - // unique constraint checks. Only retrying once since - // it could be our fault (thanks chaos for this info / the first half of this comment) - } else if (e.code === '23505' && !depth) { - return this.transaction(callback, depth + 1); - } else { - throw e; - } - } - await conn.query(`COMMIT`); - return result; - } - stream(query: string) { - // eslint-disable-next-line @typescript-eslint/no-this-alias - const db = this; - return new Streams.ObjectReadStream({ - async read(this: Streams.ObjectReadStream) { - const result = await db.query(query) as T[]; - if (!result.length) return this.pushEnd(); - // getting one row at a time means some slower queries - // might help with performance - this.buf.push(...result); - }, - }); - } - async ensureMigrated(opts: MigrationOptions) { - let value; - try { - const stored = await this.query( - `SELECT value FROM db_info WHERE key = 'version' AND name = $1`, [opts.table] - ); - if (stored.length) { - value = stored[0].value || "0"; - } - } catch { - await this.query(`CREATE TABLE db_info (name TEXT NOT NULL, key TEXT NOT NULL, value TEXT NOT NULL)`); - } - if (!value) { // means nothing inserted - create row - value = "0"; - await this.query('INSERT INTO db_info (name, key, value) VALUES ($1, $2, $3)', [opts.table, 'version', value]); - } - value = Number(value); - const files = FS(opts.migrationsFolder) - .readdirSync() - .filter(f => f.endsWith('.sql')) - .map(f => Number(f.slice(1).split('.')[0])); - Utils.sortBy(files, f => f); - const curVer = files[files.length - 1] || 0; - if (curVer !== value) { - if (!value) { - try { - await this.query(`SELECT * FROM ${opts.table} LIMIT 1`); - } catch { - await this.query(FS(opts.baseSchemaFile).readSync()); - } - } - for (const n of files) { - if (n <= value) continue; - await this.query(FS(`${opts.migrationsFolder}/v${n}.sql`).readSync()); - await this.query( - `UPDATE db_info SET value = $1 WHERE key = 'version' AND name = $2`, [`${n}`, opts.table] - ); - } - } - } -} diff --git a/server/chat-plugins/teams.ts b/server/chat-plugins/teams.ts index 8257bda7c7..240d8caea8 100644 --- a/server/chat-plugins/teams.ts +++ b/server/chat-plugins/teams.ts @@ -4,7 +4,8 @@ * @author mia-pi-git */ -import { PostgresDatabase, FS, Utils } from '../../lib'; +import { SQL, PGDatabase } from '../../lib/database'; +import { FS, Utils } from '../../lib'; import * as crypto from 'crypto'; /** Maximum amount of teams a user can have stored at once. */ @@ -13,6 +14,11 @@ const MAX_TEAMS = 200; const MAX_SEARCH = 3000; const ALPHABET = '0123456789abcdefghijklmnopqrstuvwxyz'.split(''); +export const teamsDB = Config.usepostgres ? new PGDatabase(Config.usepostgres) : null!; +export const teamsTable = teamsDB?.getTable< + StoredTeam +>('teams', 'teamid'); + export interface StoredTeam { teamid: string; team: string; @@ -42,41 +48,40 @@ function refresh(context: Chat.PageContext) { } export const TeamsHandler = new class { - database = new PostgresDatabase(); - readyPromise: Promise | null = Config.usepostgres ? (async () => { + readyPromise: Promise | null = teamsDB ? (async () => { try { - await this.database.query('SELECT * FROM teams LIMIT 1'); + await teamsDB.query()`SELECT * FROM teams LIMIT 1`; } catch { - await this.database.query(FS(`databases/schemas/teams.sql`).readSync()); + await teamsDB.query(SQL(FS(`databases/schemas/teams.sql`).readSync())); } })() : null; destroy() { - void this.database.destroy(); + void teamsDB.close(); } async search(search: TeamSearch, user: User, count = 10, includePrivate = false) { - const args = []; const where = []; if (count > 500) { throw new Chat.ErrorMessage("Cannot search more than 500 teams."); } if (search.format) { - where.push(`format = $${args.length + 1}`); - args.push(toID(search.format)); + where.push(where.length ? SQL` AND ` : SQL`WHERE `); + where.push(SQL`format = ${toID(search.format)}`); } if (search.owner) { - where.push(`ownerid = $${args.length + 1}`); - args.push(toID(search.owner)); + where.push(where.length ? SQL` AND ` : SQL`WHERE `); + where.push(SQL`ownerid = ${toID(search.owner)}`); } if (search.gen) { - where.push(`format LIKE 'gen${search.gen}%'`); + where.push(where.length ? SQL` AND ` : SQL`WHERE `); + where.push(SQL`format LIKE ${`gen${search.gen}%`}`); + } + if (!includePrivate) { + where.push(where.length ? SQL` AND ` : SQL`WHERE `); + where.push(SQL`private IS NULL`); } - if (!includePrivate) where.push('private IS NULL'); - const result = await this.query( - `SELECT * FROM teams${where.length ? ` WHERE ${where.join(' AND ')}` : ''} ORDER BY date DESC LIMIT ${count}`, - args, - ); + const result = await teamsTable.selectAll()`${where} ORDER BY date DESC LIMIT ${count}`; return result.filter(row => { const team = Teams.unpack(row.team)!; if (row.private && row.ownerid !== user.id) { @@ -104,11 +109,6 @@ export const TeamsHandler = new class { }); } - async query(statement: string, values: any[] = []) { - if (this.readyPromise) await this.readyPromise; - return this.database.query(statement, values) as Promise; - } - isOMNickname(nickname: string) { // allow nicknames named after other mons/types/abilities/items - to support those OMs if (Dex.species.get(nickname).exists) { @@ -127,10 +127,12 @@ export const TeamsHandler = new class { async save( context: Chat.CommandContext, - formatName: string, - rawTeam: string, - teamName: string | null = null, - isPrivate?: string | null, + team: { + name?: string | null, + packedTeam: string, + format: string, + privacy?: boolean | string | null, + }, isUpdate?: number ) { const connection = context.connection; @@ -140,9 +142,9 @@ export const TeamsHandler = new class { return null; } const user = connection.user; - const format = Dex.formats.get(toID(formatName)); + const format = Dex.formats.get(toID(team.format)); if (format.effectType !== 'Format' || format.team) { - connection.popup("Invalid format:\n\n" + formatName); + connection.popup("Invalid format:\n\n" + team.format); return null; } let existing = null; @@ -158,9 +160,9 @@ export const TeamsHandler = new class { } } - const team = Teams.import(rawTeam, true); - if (!team) { - connection.popup('Invalid team:\n\n' + rawTeam); + const sets = Teams.import(team.packedTeam, true); + if (!sets) { + connection.popup('Invalid team:\n\n' + team.packedTeam); return null; } if (team.length > 50) { @@ -231,13 +233,13 @@ export const TeamsHandler = new class { return null; } } - if (teamName) { - if (teamName.length > 100) { + if (team.name) { + if (team.name.length > 100) { connection.popup("Your team's name is too long."); return null; } - const filtered = context.filter(teamName); - if (!filtered || filtered?.trim() !== teamName.trim()) { + const filtered = context.filter(team.name); + if (!filtered || filtered?.trim() !== team.name.trim()) { connection.popup(`Your team's name has a filtered word.`); return null; } @@ -247,39 +249,41 @@ export const TeamsHandler = new class { connection.popup(`You have too many teams stored. If you wish to upload this team, delete some first.`); return null; } - rawTeam = Teams.pack(team); - if (!rawTeam.trim()) { // extra sanity check + // eslint-disable-next-line require-atomic-updates + team.packedTeam = Teams.pack(sets); + if (!team.packedTeam.trim()) { // extra sanity check connection.popup("Invalid team provided."); return null; } + team.privacy ||= null; + if (team.privacy === true) team.privacy = existing?.private || TeamsHandler.generatePassword(); // the && existing doesn't really matter because we've verified it above, this is just for TS if (isUpdate && existing) { const differenceExists = ( - existing.team !== rawTeam || - (teamName && teamName !== existing.title) || + existing.team !== team.packedTeam || + (team.name && team.name !== existing.title) || format.id !== existing.format || - existing.private !== isPrivate + existing.private !== team.privacy ); if (!differenceExists) { connection.popup("Your team was not saved as no changes were made."); return null; } - await this.query( - 'UPDATE teams SET team = $1, title = $2, private = $3, format = $4 WHERE teamid = $5', - [rawTeam, teamName, isPrivate, format.id, isUpdate] - ); - return isUpdate; + await teamsTable.updateOne( + { team: team.packedTeam, title: team.name, private: team.privacy, format: format.id } + )`WHERE teamid = ${isUpdate}`; + return { teamid: isUpdate, teamName: team.name, privacy: team.privacy }; } else { - const exists = await this.query('SELECT * FROM teams WHERE ownerid = $1 AND team = $2', [user.id, rawTeam]); - if (exists.length) { + const exists = await teamsTable.selectOne()`WHERE ownerid = ${user.id} AND team = ${team.packedTeam}`; + if (exists) { connection.popup("You've already uploaded that team."); return null; } - const loaded = await this.query( - `INSERT INTO teams (ownerid, team, date, format, views, title, private) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING teamid`, - [user.id, rawTeam, new Date(), format.id, 0, teamName, isPrivate] - ); - return loaded?.[0].teamid; + const loaded = await teamsTable.queryOne()`INSERT INTO teams (${{ + ownerid: user.id, team: team.packedTeam, date: Math.round(Date.now() / 1000), format: format.id, + views: 0, title: team.name, private: team.privacy, + }}) RETURNING teamid`; + return { teamid: loaded?.teamid, teamName: team.name, privacy: team.privacy }; } } generatePassword(len = 20) { @@ -288,17 +292,11 @@ export const TeamsHandler = new class { return pw; } updateViews(teamid: string) { - return this.query(`UPDATE teams SET views = views + 1 WHERE teamid = $1`, [teamid]); + return teamsTable.updateOne(SQL`views = views + 1`)`WHERE teamid = ${teamid}`; } list(userid: ID, count: number, publicOnly = false) { - let query = `SELECT * FROM teams WHERE ownerid = $1 `; - if (publicOnly) { - query += `AND private IS NULL `; - } - query += `ORDER BY date DESC LIMIT $2`; - return this.query( - query, [userid, count] - ); + const publicOnlyQuery = publicOnly ? SQL`AND private IS NULL ` : SQL``; + return teamsTable.selectAll()`WHERE ownerid = ${userid} ${publicOnlyQuery} ORDER BY date DESC LIMIT ${count}`; } preview(teamData: StoredTeam, user?: User | null, isFull = false) { let buf = Utils.html`${teamData.title || `Untitled ${teamData.teamid}`}`; @@ -388,28 +386,24 @@ export const TeamsHandler = new class { } async count(user: string | User) { const id = toID(user); - const result = await this.query<{ count: number }>(`SELECT count(*) AS count FROM teams WHERE ownerid = $1`, [id]); - return result?.[0]?.count || 0; + const result = await teamsTable.queryOne<{ count: number }>( + )`SELECT count(*) AS count FROM teams WHERE ownerid = ${id}`; + return result?.count || 0; } async get(teamid: number | string): Promise { teamid = Number(teamid); if (isNaN(teamid)) { throw new Chat.ErrorMessage(`Invalid team ID.`); } - const rows = await this.query( - `SELECT * FROM teams WHERE teamid = $1`, [teamid], - ); - if (!rows.length) return null; - return rows[0] as StoredTeam; + const team = await teamsTable.get(teamid); + return team || null; } async delete(id: string | number) { id = Number(id); if (isNaN(id)) { throw new Chat.ErrorMessage("Invalid team ID"); } - await this.query( - `DELETE FROM teams WHERE teamid = $1`, [id], - ); + await teamsTable.delete(id); } }; @@ -426,34 +420,29 @@ export const commands: Chat.ChatCommands = { const isEdit = cmd === 'update'; const targets = Utils.splitFirst(target, ',', isEdit ? 4 : 3).map(x => x.trim()); const rawTeamID = isEdit ? targets.shift() : undefined; - let [teamName, formatid, rawPrivacy, rawTeam] = targets; + const [teamName, formatid, isPrivate, rawTeam] = targets; const teamID = isEdit ? Number(rawTeamID) : undefined; if (isEdit && (!rawTeamID?.length || isNaN(teamID!))) { connection.popup("Invalid team ID provided."); return null; } - if (rawTeam.includes('\n')) { - rawTeam = Teams.pack(Teams.import(rawTeam, true)); - } - if (!rawTeam) { - connection.popup("Invalid team."); - return null; - } - formatid = toID(formatid); - teamName = toID(teamName) ? teamName : null!; - const privacy = toID(rawPrivacy) === '1' ? TeamsHandler.generatePassword() : null; - const id = await TeamsHandler.save( - this, formatid, rawTeam, teamName, privacy, teamID + const result = await TeamsHandler.save( + this, { + name: toID(teamName) ? teamName : null, + format: toID(formatid), + packedTeam: rawTeam, + privacy: toID(isPrivate) === '1' ? true : null, + }, teamID ); if (!id) { return; // error messages were thrown to the user } const page = isEdit ? 'edit' : 'upload'; - if (id) { - connection.send(`|queryresponse|teamupload|` + JSON.stringify({ teamid: id, teamName, privacy })); + if (result) { + connection.send(`|queryresponse|teamupload|` + JSON.stringify(result)); connection.send(`>view-teams-${page}\n|deinit`); - this.parse(`/join view-teams-view-${id}-${id}`); + this.parse(`/join view-teams-view-${result.teamid}`); } else { this.parse(`/join view-teams-${page}`); } @@ -520,7 +509,7 @@ export const commands: Chat.ChatCommands = { if (team.ownerid !== user.id && !user.can('rangeban')) { return this.popupReply(`You cannot change privacy for a team you don't own.`); } - await TeamsHandler.query(`UPDATE teams SET private = $1 WHERE teamid = $2`, [privacy, teamId]); + await teamsTable.set(teamId, { private: privacy }); for (const pageid of this.connection.openPages || new Set()) { if (pageid.startsWith('teams-')) { this.refreshPage(pageid); @@ -598,17 +587,13 @@ export const pages: Chat.PageTable = { switch (type) { case 'views': this.title = `[Most Viewed Teams]`; - teams = await TeamsHandler.query( - `SELECT * FROM teams WHERE private IS NULL ORDER BY views DESC LIMIT $1`, [count] - ); + teams = await teamsTable.selectAll()`WHERE private IS NULL ORDER BY views DESC LIMIT ${count}`; title = `Most viewed teams:`; delete buttons.views; break; default: this.title = `[Latest Teams]`; - teams = await TeamsHandler.query( - `SELECT * FROM teams WHERE private IS NULL ORDER BY date DESC LIMIT $1`, [count] - ); + teams = await teamsTable.selectAll()`WHERE private IS NULL ORDER BY date DESC LIMIT ${count}`; title = `Recently uploaded teams:`; delete buttons.latest; break; @@ -790,20 +775,19 @@ export const pages: Chat.PageTable = { if (count > MAX_SEARCH) { count = MAX_SEARCH; } - let queryStr = 'SELECT * FROM teams WHERE private IS NULL'; let name = sorter; + let order; switch (sorter) { case 'views': - queryStr += ` ORDER BY views DESC `; + order = SQL` ORDER BY views DESC `; name = 'most viewed'; break; case 'latest': - queryStr += ` ORDER BY date DESC`; + order = SQL` ORDER BY date DESC`; break; default: throw new Chat.ErrorMessage(`Invalid sort term '${sorter}'. Must be either 'views' or 'latest'.`); } - queryStr += ` LIMIT ${count}`; let buf = `

Browse ${name} teams

`; buf += refresh(this); buf += `
Search`; @@ -811,7 +795,7 @@ export const pages: Chat.PageTable = { buf += ``; buf += `
`; - const results = await TeamsHandler.query(queryStr, []); + const results = await teamsTable.selectAll()`WHERE private IS NULL ${order} LIMIT ${count}`; if (!results.length) { buf += `
None found.
`; return buf; diff --git a/test/lib/postgres.js b/test/lib/postgres.js index 5704e3872b..d5956d18c1 100644 --- a/test/lib/postgres.js +++ b/test/lib/postgres.js @@ -1,22 +1,42 @@ "use strict"; const assert = require('assert').strict; -const { PostgresDatabase } = require('../../dist/lib'); +const { PGDatabase, SQL } = require('../../dist/lib/database'); -function testMod(mod) { - try { - require(mod); - } catch { - return it.skip; - } - return it; -} +const database = new PGDatabase(); +const assertSQL = (sql, rawSql, args) => assert.deepEqual( + database._resolveSQL(sql), [rawSql, args || []] +); -// only run these if you already have postgres configured -describe.skip("Postgres features", () => { - it("Should be able to connect to a database", async () => { - this.database = new PostgresDatabase(); +describe("Postgres library", () => { + it("should support template strings", async () => { + assertSQL(SQL`INSERT INTO test (col1, col2) VALUES (${'a'}, ${'b'})`, + "INSERT INTO test (col1, col2) VALUES ($1, $2)", ["a", "b"]); + assertSQL(SQL`INSERT INTO test (${{ col1: "a", col2: "b" }})`, + `INSERT INTO test ("col1", "col2") VALUES ($1, $2)`, ["a", "b"]); + assertSQL(SQL`SELECT * FROM test ${SQL`WHERE `}${SQL`a = 1`} LIMIT 1`, + `SELECT * FROM test WHERE a = 1 LIMIT 1`); + assertSQL(SQL`SELECT ${undefined}1+1`, + `SELECT 1+1`); + assertSQL(SQL`SELECT ${[]}2+2`, + `SELECT 2+2`); + + const constructed = SQL`SELECT `; + constructed.appendRaw(`3`); + constructed.append(SQL` + `); + constructed.append(3); + assertSQL(constructed, `SELECT 3 + $1`, [3]); + + assertSQL(SQL`SELECT * FROM test ${[SQL`WHERE `, SQL`a = 2`]} LIMIT 1`, + `SELECT * FROM test WHERE a = 2 LIMIT 1`); }); - it("Should be able to insert data", async () => { + + // only run these if you already have postgres configured + // TODO: update for new db + + it.skip("Should be able to connect to a database", async () => { + this.database = new PGDatabase(); + }); + it.skip("Should be able to insert data", async () => { await assert.doesNotThrowAsync(async () => { await this.database.query(`CREATE TABLE test (col TEXT, col2 TEXT)`); await this.database.query( @@ -25,13 +45,7 @@ describe.skip("Postgres features", () => { ); }); }); - testMod('sql-template-strings')('Should support sql-template-strings', async () => { - await assert.doesNotThrowAsync(async () => { - const SQL = require('sql-template-strings'); - await this.database.query(SQL`INSERT INTO test (col1, col2) VALUES (${'a'}, ${'b'})`); - }); - }); - it("Should be able to run multiple statements in transaction", async () => { + it.skip("Should be able to run multiple statements in transaction", async () => { await assert.doesNotThrowAsync(async () => { await this.database.transaction(async worker => { const tables = await worker.query(