2020-07-14 20:42:03 +03:00
|
|
|
// Copyright 2020 The Go Authors. All rights reserved.
|
|
|
|
// Use of this source code is governed by a BSD-style
|
|
|
|
// license that can be found in the LICENSE file.
|
|
|
|
|
2021-09-02 18:38:42 +03:00
|
|
|
package relui
|
2020-07-14 20:42:03 +03:00
|
|
|
|
|
|
|
import (
|
2020-12-03 03:25:34 +03:00
|
|
|
"context"
|
2021-09-10 22:42:39 +03:00
|
|
|
"database/sql"
|
2021-08-25 22:48:02 +03:00
|
|
|
"errors"
|
|
|
|
"fmt"
|
2020-07-14 20:42:03 +03:00
|
|
|
|
2021-09-10 22:42:39 +03:00
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
|
|
dbpgx "github.com/golang-migrate/migrate/v4/database/pgx"
|
|
|
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
2021-08-25 22:48:02 +03:00
|
|
|
"github.com/jackc/pgx/v4"
|
2020-07-14 20:42:03 +03:00
|
|
|
)
|
|
|
|
|
2021-08-25 22:48:02 +03:00
|
|
|
var errDBNotExist = errors.New("database does not exist")
|
|
|
|
|
2021-09-10 22:42:39 +03:00
|
|
|
// InitDB creates and applies all migrations to the database specified
|
|
|
|
// in conn.
|
|
|
|
//
|
|
|
|
// If the database does not exist, one will be created using the
|
|
|
|
// credentials provided.
|
|
|
|
//
|
|
|
|
// Any key/value or URI string compatible with libpq is valid.
|
|
|
|
func InitDB(ctx context.Context, conn string) error {
|
|
|
|
cfg, err := pgx.ParseConfig(conn)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("pgx.ParseConfig() = %w", err)
|
|
|
|
}
|
|
|
|
if err := CreateDBIfNotExists(ctx, cfg); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2021-09-30 21:27:54 +03:00
|
|
|
if err := MigrateDB(conn, false); err != nil {
|
2021-09-10 22:42:39 +03:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// MigrateDB applies all migrations to the database specified in conn.
|
|
|
|
//
|
2021-09-30 21:27:54 +03:00
|
|
|
// Any key/value or URI string compatible with libpq is a valid conn.
|
|
|
|
// If downUp is true, all migrations will be run, then the down and up
|
|
|
|
// migrations of the final migration are run.
|
|
|
|
func MigrateDB(conn string, downUp bool) error {
|
2021-09-10 22:42:39 +03:00
|
|
|
cfg, err := pgx.ParseConfig(conn)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("pgx.ParseConfig() = %w", err)
|
|
|
|
}
|
|
|
|
db, err := sql.Open("pgx", conn)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("sql.Open(%q, _) = %v, %w", "pgx", db, err)
|
|
|
|
}
|
|
|
|
mcfg := &dbpgx.Config{
|
|
|
|
MigrationsTable: "migrations",
|
|
|
|
DatabaseName: cfg.Database,
|
|
|
|
}
|
|
|
|
mdb, err := dbpgx.WithInstance(db, mcfg)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("dbpgx.WithInstance(_, %v) = %v, %w", mcfg, mdb, err)
|
|
|
|
}
|
|
|
|
mfs, err := iofs.New(migrations, "migrations")
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("iofs.New(%v, %q) = %v, %w", migrations, "migrations", mfs, err)
|
|
|
|
}
|
|
|
|
m, err := migrate.NewWithInstance("iofs", mfs, "pgx", mdb)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("migrate.NewWithInstance(%q, %v, %q, %v) = %v, %w", "iofs", migrations, "pgx", mdb, m, err)
|
|
|
|
}
|
|
|
|
if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
|
|
|
return fmt.Errorf("m.Up() = %w", err)
|
|
|
|
}
|
2021-09-30 21:27:54 +03:00
|
|
|
if downUp {
|
|
|
|
if err := m.Steps(-1); err != nil {
|
|
|
|
return fmt.Errorf("m.Steps(%d) = %w", -1, err)
|
|
|
|
}
|
|
|
|
if err := m.Up(); err != nil {
|
|
|
|
return fmt.Errorf("m.Up() = %w", err)
|
|
|
|
}
|
|
|
|
}
|
2021-09-10 22:42:39 +03:00
|
|
|
db.Close()
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2021-08-25 22:48:02 +03:00
|
|
|
// ConnectMaintenanceDB connects to the maintenance database using the
|
|
|
|
// credentials from cfg. If maintDB is an empty string, the database
|
|
|
|
// with the name cfg.User will be used.
|
|
|
|
func ConnectMaintenanceDB(ctx context.Context, cfg *pgx.ConnConfig, maintDB string) (*pgx.Conn, error) {
|
|
|
|
cfg = cfg.Copy()
|
2021-11-15 22:23:47 +03:00
|
|
|
if maintDB == "" {
|
|
|
|
maintDB = "postgres"
|
|
|
|
}
|
2021-08-25 22:48:02 +03:00
|
|
|
cfg.Database = maintDB
|
|
|
|
return pgx.ConnectConfig(ctx, cfg)
|
|
|
|
}
|
|
|
|
|
|
|
|
// CreateDBIfNotExists checks whether the given dbName is an existing
|
|
|
|
// database, and creates one if not.
|
|
|
|
func CreateDBIfNotExists(ctx context.Context, cfg *pgx.ConnConfig) error {
|
|
|
|
exists, err := checkIfDBExists(ctx, cfg)
|
|
|
|
if err != nil || exists {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
conn, err := ConnectMaintenanceDB(ctx, cfg, "")
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("ConnectMaintenanceDB = %w", err)
|
|
|
|
}
|
|
|
|
createSQL := fmt.Sprintf("CREATE DATABASE %s", pgx.Identifier{cfg.Database}.Sanitize())
|
|
|
|
if _, err := conn.Exec(ctx, createSQL); err != nil {
|
|
|
|
return fmt.Errorf("conn.Exec(%q) = %w", createSQL, err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// DropDB drops the database specified in cfg. An error returned if
|
|
|
|
// the database does not exist.
|
|
|
|
func DropDB(ctx context.Context, cfg *pgx.ConnConfig) error {
|
|
|
|
exists, err := checkIfDBExists(ctx, cfg)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("p.checkIfDBExists() = %w", err)
|
|
|
|
}
|
|
|
|
if !exists {
|
|
|
|
return errDBNotExist
|
|
|
|
}
|
|
|
|
conn, err := ConnectMaintenanceDB(ctx, cfg, "")
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("ConnectMaintenanceDB = %w", err)
|
|
|
|
}
|
|
|
|
dropSQL := fmt.Sprintf("DROP DATABASE %s", pgx.Identifier{cfg.Database}.Sanitize())
|
|
|
|
if _, err := conn.Exec(ctx, dropSQL); err != nil {
|
|
|
|
return fmt.Errorf("conn.Exec(%q) = %w", dropSQL, err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func checkIfDBExists(ctx context.Context, cfg *pgx.ConnConfig) (bool, error) {
|
|
|
|
conn, err := ConnectMaintenanceDB(ctx, cfg, "")
|
|
|
|
if err != nil {
|
|
|
|
return false, fmt.Errorf("ConnectMaintenanceDB = %w", err)
|
|
|
|
}
|
|
|
|
row := conn.QueryRow(ctx, "SELECT 1 from pg_database WHERE datname=$1 LIMIT 1", cfg.Database)
|
|
|
|
var exists int
|
|
|
|
if err := row.Scan(&exists); err != nil && err != pgx.ErrNoRows {
|
|
|
|
return false, fmt.Errorf("row.Scan() = %w", err)
|
|
|
|
}
|
|
|
|
return exists == 1, nil
|
|
|
|
}
|