/** * Database abstraction layer that's vaguely ORM-like. * Modern (Promises, strict types, tagged template literals), but ORMs * are a bit _too_ magical for me, so none of that magic here. * * @author Zarel */ import type * as mysql from 'mysql2'; import type * as pg from 'pg'; export type BasicSQLValue = string | number | null; export type SQLRow = { [k: string]: BasicSQLValue }; export type SQLValue = BasicSQLValue | SQLStatement | SQLStatement[] | PartialOrSQL | BasicSQLValue[] | undefined; let mysqlRuntime: typeof import('mysql2'); let pgRuntime: typeof import('pg'); // Lazy-load mysql2 try { mysqlRuntime = require('mysql2'); } catch { // Only throw if someone tries to use MySQL } // Lazy-load pg try { pgRuntime = require('pg'); } catch { // Only throw if someone tries to use Postgres } export function isSQL(value: any): value is SQLStatement { /** * This addresses a scenario where objects get out of sync due to hotpatching. * Table A is instantiated, and retains SQLStatement at that specific point in time. Consumer A is also instantiated at * the same time, and both can interact freely, since consumer A and table A share the same reference to SQLStatement. * However, when consumer A is hotpatched, consumer A imports a new instance of SQLStatement. Thus, when consumer A * provides that new SQLStatement, it does not pass the `instanceof SQLStatement` check in Table A, * since table A is still referencing he old SQLStatement (checking that the new is an instance of the old). * This does not work. Thus, we're forced to check constructor name instead. */ return value instanceof SQLStatement || ( // assorted safety checks to be sure it'll actually work (theoretically preventing certain attacks) value?.constructor.name === 'SQLStatement' && (Array.isArray(value.sql) && Array.isArray(value.values)) ); } export class SQLStatement { sql: string[]; values: BasicSQLValue[]; constructor(strings: TemplateStringsArray | string[], values: SQLValue[]) { this.sql = [strings[0]]; this.values = []; for (let i = 0; i < strings.length - 1; i++) { this.append(values[i]).appendRaw(strings[i + 1]); } } appendRaw(str: string): this { this.sql[this.sql.length - 1] += str; return this; } append(value: SQLValue): this { if (isSQL(value)) { if (!value.sql.length) return this; this.appendRaw(value.sql[0]); this.sql = this.sql.concat(value.sql.slice(1)); this.values = this.values.concat(value.values); } else if (typeof value === 'string' || typeof value === 'number' || value === null) { this.values.push(value); this.sql.push(''); } else if (value === undefined) { // do nothing } else if (Array.isArray(value)) { if (!value.length || isSQL(value[0])) { // array of SQL statements for (const part of value) this.append(part); } else if ('"`'.includes(this.sql[this.sql.length - 1].slice(-1))) { // "`a`, `b`" syntax const quoteChar = this.sql[this.sql.length - 1].slice(-1); for (const col of value) { this.append(col).appendRaw(`${quoteChar}, ${quoteChar}`); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4); } else { // "1, 2" syntax for (const val of value) { this.append(val).appendRaw(`, `); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2); } } else if (this.sql[this.sql.length - 1].endsWith('(')) { // "(`a`, `b`) VALUES (1, 2)" syntax this.appendRaw(`"`); for (const col in value) { this.append(col).appendRaw(`", "`); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `") VALUES (`; for (const col in value) { this.append(value[col]).appendRaw(`, `); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2); } else if (this.sql[this.sql.length - 1].toUpperCase().endsWith(' SET ')) { // "`a` = 1, `b` = 2" syntax this.appendRaw(`"`); for (const col in value) { this.append(col).appendRaw(`" = `); this.append(value[col]).appendRaw(`, "`); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -3); } else { throw new Error( `Objects can only appear in (obj) or after SET; ` + `unrecognized: ${this.sql[this.sql.length - 1]}[obj]` ); } return this; } } /** * Tag function for SQL, with some magic. * * * `` SQL`UPDATE table SET a = ${'hello"'}` `` * * `` `UPDATE table SET a = 'hello'` `` * * Values surrounded by `"` or `` ` `` become identifiers: * * * ``` SQL`SELECT * FROM "${'table'}"` ``` * * `` `SELECT * FROM "table"` `` * * (Make sure to use `"` for Postgres and `` ` `` for MySQL.) * * Objects preceded by SET become setters: * * * `` SQL`UPDATE table SET ${{a: 1, b: 2}}` `` * * `` `UPDATE table SET "a" = 1, "b" = 2` `` * * Objects surrounded by `()` become keys and values: * * * `` SQL`INSERT INTO table (${{a: 1, b: 2}})` `` * * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` `` * * Arrays become lists; surrounding by `"` or `` ` `` turns them into lists of names: * * * `` SQL`INSERT INTO table ("${['a', 'b']}") VALUES (${[1, 2]})` `` * * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` `` * * SQL statements can be nested: * * * `` SQL`SELECT * FR${SQL`OM table`})` `` * * `` `SELECT * FROM table` `` * * Raw unescaped strings can be put inside SQL() but I can't actually think of a * use case, so probably don't ever do this: * * * `` secondPart = SQL('OM table'); SQL`SELECT * FR${secondPart})` `` * * `` `SELECT * FROM table` `` */ export function SQL(strings: TemplateStringsArray | string[] | string, ...values: SQLValue[]) { if (typeof strings === 'string') strings = [strings]; return new SQLStatement(strings, values); } export interface ResultRow { [k: string]: BasicSQLValue } export const connectedDatabases: Database[] = []; export abstract class Database { connection: Pool; prefix: string; type = ''; constructor(connection: Pool, prefix = '') { this.prefix = prefix; this.connection = connection; connectedDatabases.push(this); } abstract _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]]; abstract _query(sql: string, values: BasicSQLValue[]): Promise; abstract _queryExec(sql: string, values: BasicSQLValue[]): Promise; abstract escapeId(param: string): string; query(sql: SQLStatement): Promise; query(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; query(sql?: SQLStatement) { if (!sql) return (strings: any, ...rest: any) => this.query(new SQLStatement(strings, rest)); const [query, values] = this._resolveSQL(sql); return this._query(query, values); } queryOne(sql: SQLStatement): Promise; queryOne(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; queryOne(sql?: SQLStatement) { if (!sql) return (strings: any, ...rest: any) => this.queryOne(new SQLStatement(strings, rest)); return this.query(sql).then(res => Array.isArray(res) ? res[0] : res); } queryExec(sql: SQLStatement): Promise; queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; queryExec(sql?: SQLStatement) { if (!sql) return (strings: any, ...rest: any) => this.queryExec(new SQLStatement(strings, rest)); const [query, values] = this._resolveSQL(sql); return this._queryExec(query, values); } getTable(name: string, primaryKeyName: keyof Row & string | null = null): DatabaseTable { return new DatabaseTable(this, name, primaryKeyName); } close() { void this.connection.end(); } } type PartialOrSQL = { [P in keyof T]?: T[P] | SQLStatement; }; type OkPacketOf = DB extends Database ? T : never; // Row extends SQLRow but TS doesn't support closed types so we can't express this export class DatabaseTable { db: DB; name: string; primaryKeyName: keyof Row & string | null; constructor( db: DB, name: string, primaryKeyName: keyof Row & string | null = null ) { this.db = db; this.name = db.prefix + name; this.primaryKeyName = primaryKeyName; } escapeId(param: string) { return this.db.escapeId(param); } // raw query(sql: SQLStatement): Promise; query(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; query(sql?: SQLStatement) { return this.db.query(sql as any) as any; } queryOne(sql: SQLStatement): Promise; queryOne(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; queryOne(sql?: SQLStatement) { return this.db.queryOne(sql as any) as any; } queryExec(sql: SQLStatement): Promise>; queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise>; queryExec(sql?: SQLStatement) { return this.db.queryExec(sql as any) as any; } // low-level selectAll(entries?: (keyof Row & string)[] | SQLStatement): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { if (!entries) entries = SQL`*`; if (Array.isArray(entries)) entries = SQL`"${entries}"`; return (strings, ...rest) => this.query()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)}`; } selectOne(entries?: (keyof Row & string)[] | SQLStatement): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { if (!entries) entries = SQL`*`; if (Array.isArray(entries)) entries = SQL`"${entries}"`; return (strings, ...rest) => this.queryOne()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`; } updateAll(partialRow: PartialOrSQL): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise> { return (strings, ...rest) => this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(strings, rest)}`; } updateOne(partialRow: PartialOrSQL | SQLStatement): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise> { return (s, ...r) => this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(s, r)}`; } deleteAll(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise> { return (strings, ...rest) => this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)}`; } deleteOne(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise> { return (strings, ...rest) => this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`; } eval(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { return (strings, ...rest) => this.queryOne<{ result: T }>( )`SELECT ${new SQLStatement(strings, rest)} AS result FROM "${this.name}" LIMIT 1` .then(row => row?.result); } // high-level insert(partialRow: PartialOrSQL, where?: SQLStatement) { return this.queryExec()`INSERT INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; } insertIgnore(partialRow: PartialOrSQL, where?: SQLStatement) { return this.queryExec()`INSERT IGNORE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; } async tryInsert(partialRow: PartialOrSQL, where?: SQLStatement) { try { return await this.insert(partialRow, where); } catch (err: any) { if (err.code === 'ER_DUP_ENTRY') { return undefined; } throw err; } } upsert(partialRow: PartialOrSQL, partialUpdate = partialRow, where?: SQLStatement) { if (this.db.type === 'pg') { return this.queryExec( )`INSERT INTO "${this.name}" (${partialRow as any}) ON CONFLICT (${this.primaryKeyName }) DO UPDATE SET ${partialUpdate as any} ${where}`; } return this.queryExec( )`INSERT INTO "${this.name}" (${partialRow as any}) ON DUPLICATE KEY UPDATE ${partialUpdate as any} ${where}`; } replace(partialRow: PartialOrSQL, where?: SQLStatement) { if (this.db.type === 'pg') { if (!this.primaryKeyName) throw new Error(`Cannot replace() without a single-column primary key`); return this.queryExec( )`INSERT INTO "${this.name}" (${partialRow as any}) ON CONFLICT ("${this.primaryKeyName }") DO UPDATE SET ${partialRow as any} ${where}`; } return this.queryExec()`REPLACE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; } get(primaryKey: BasicSQLValue, entries?: (keyof Row & string)[] | SQLStatement) { if (!this.primaryKeyName) throw new Error(`Cannot get() without a single-column primary key`); return this.selectOne(entries)`WHERE "${this.primaryKeyName}" = ${primaryKey}`; } delete(primaryKey: BasicSQLValue) { if (!this.primaryKeyName) throw new Error(`Cannot delete() without a single-column primary key`); return this.deleteAll()`WHERE "${this.primaryKeyName}" = ${primaryKey}`; } update(primaryKey: BasicSQLValue, data: PartialOrSQL) { if (!this.primaryKeyName) throw new Error(`Cannot update() without a single-column primary key`); return this.updateAll(data)`WHERE "${this.primaryKeyName}" = ${primaryKey}`; } } export class MySQLDatabase extends Database { override type = 'mysql' as const; constructor(config: mysql.PoolOptions & { prefix?: string }) { if (!mysqlRuntime) throw new Error(`Install the 'mysql2' module to use a MySQL database`); const prefix = config.prefix || ""; if (config.prefix) { config = { ...config }; delete config.prefix; } super(mysqlRuntime.createPool(config), prefix); } override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { let sql = query.sql[0]; const values = []; for (let i = 0; i < query.values.length; i++) { const value = query.values[i]; if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) { sql = sql.slice(0, -1) + this.escapeId(`${value}`) + query.sql[i + 1].slice(1); } else { sql += '?' + query.sql[i + 1]; values.push(value); } } return [sql, values]; } override _query(query: string, values: BasicSQLValue[]): Promise { return new Promise((resolve, reject) => { this.connection.query(query, values, (e, results: any) => { if (e) { return reject(new Error(`${e.message} (${query}) (${values}) [${e.code}]`)); } if (Array.isArray(results)) { for (const row of results) { for (const col in row) { if (Buffer.isBuffer(row[col])) row[col] = row[col].toString(); } } } return resolve(results); }); }); } override _queryExec(sql: string, values: BasicSQLValue[]): Promise { return this._query(sql, values); } override escapeId(id: string) { if (!mysqlRuntime) throw new Error(`Install the 'mysql2' module to use a MySQL database`); return mysqlRuntime.escapeId(id); } } export class PGDatabase extends Database { override type = 'pg' as const; constructor(config: pg.PoolConfig) { if (!pgRuntime) throw new Error(`Install the 'pg' module to use a Postgres database`); super(config ? new pgRuntime.Pool(config) : null!); } override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { let sql = query.sql[0]; const values = []; let paramCount = 0; for (let i = 0; i < query.values.length; i++) { const value = query.values[i]; if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) { sql = sql.slice(0, -1) + this.escapeId(`${value}`) + query.sql[i + 1].slice(1); } else { paramCount++; sql += `$${paramCount}` + query.sql[i + 1]; values.push(value); } } return [sql, values]; } override _query(query: string, values: BasicSQLValue[]) { return this.connection.query(query, values).then(res => res.rows); } override _queryExec(query: string, values: BasicSQLValue[]) { return this.connection.query(query, values).then(res => ({ affectedRows: res.rowCount })); } override escapeId(id: string) { if (!pgRuntime) throw new Error(`Install the 'pg' module to use a Postgres database`); // @ts-expect-error @types/pg really needs to be updated return pgRuntime.escapeIdentifier(id); } }