internal/database: CopyUpsert: support dropping a column

You can't do a CopyFrom on a table with a generated column: postgres
complains about the column value being null. To fix, drop the column
on the temporary table.

Change-Id: Ia52f59af6d026b3fcdaafe3c7865a2eb85deb179
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/305830
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
This commit is contained in:
Jonathan Amsterdam 2021-03-30 10:25:18 -04:00
Родитель ace022fde6
Коммит 127896bc9e
2 изменённых файлов: 67 добавлений и 23 удалений

Просмотреть файл

@ -23,11 +23,13 @@ import (
// src is the source of the rows to upsert.
// conflictColumns are the columns that might conflict (i.e. that have a UNIQUE
// constraint).
// If dropColumn is non-empty, that column will be dropped from the temporary
// table before copying. Use dropColumn for generated ID columns.
//
// CopyUpsert works by first creating a temporary table, populating it with
// CopyFrom, and then running an INSERT...SELECT...ON CONFLICT to upsert its
// rows into the original table.
func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string) (err error) {
func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string, dropColumn string) (err error) {
defer derrors.Wrap(&err, "CopyUpsert(%q)", table)
if !db.InTransaction() {
@ -46,8 +48,11 @@ func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, sr
tempTable := fmt.Sprintf("__%s_copy", table)
stmt := fmt.Sprintf(`
DROP TABLE IF EXISTS %s;
CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP
CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP;
`, tempTable, table)
if dropColumn != "" {
stmt += fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tempTable, dropColumn)
}
_, err = conn.Exec(ctx, stmt)
if err != nil {
return err
@ -55,12 +60,12 @@ func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, sr
start := time.Now()
n, err := conn.CopyFrom(ctx, []string{tempTable}, columns, src)
if err != nil {
return err
return fmt.Errorf("CopyFrom: %w", err)
}
log.Debugf(ctx, "CopyUpsert(%q): copied %d rows in %s", table, n, time.Since(start))
conflictAction := buildUpsertConflictAction(columns, conflictColumns)
query := buildCopyUpsertQuery(table, tempTable, columns, conflictAction)
cols := strings.Join(columns, ", ")
query := fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
defer logQuery(ctx, query, nil, db.instanceID, db.IsRetryable())(&err)
start = time.Now()
ctag, err := conn.Exec(ctx, query)
@ -72,11 +77,6 @@ func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, sr
})
}
func buildCopyUpsertQuery(table, tempTable string, columns []string, conflictAction string) string {
cols := strings.Join(columns, ", ")
return fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
}
// A RowItem is a row of values or an error.
type RowItem struct {
Values []interface{}

Просмотреть файл

@ -15,18 +15,8 @@ import (
)
func TestCopyUpsert(t *testing.T) {
pgxOnly(t)
ctx := context.Background()
conn, err := testDB.db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
conn.Raw(func(c interface{}) error {
if _, ok := c.(*stdlib.Conn); !ok {
t.Skip("skipping; DB driver not pgx")
}
return nil
})
for _, stmt := range []string{
`DROP TABLE IF EXISTS test_streaming_upsert`,
`CREATE TABLE test_streaming_upsert (key INTEGER PRIMARY KEY, value TEXT)`,
@ -40,8 +30,8 @@ func TestCopyUpsert(t *testing.T) {
{3, "baz"}, // new row
{1, "moo"}, // replace "foo" with "moo"
}
err = testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"})
err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "")
})
if err != nil {
t.Fatal(err)
@ -66,3 +56,57 @@ func TestCopyUpsert(t *testing.T) {
}
}
func TestCopyUpsertGeneratedColumn(t *testing.T) {
pgxOnly(t)
ctx := context.Background()
stmt := `
DROP TABLE IF EXISTS test_copy_gen;
CREATE TABLE test_copy_gen (id bigint PRIMARY KEY GENERATED ALWAYS AS IDENTITY, key INT, value TEXT, UNIQUE (key));
INSERT INTO test_copy_gen (key, value) VALUES (11, 'foo'), (12, 'bar')`
if _, err := testDB.Exec(ctx, stmt); err != nil {
t.Fatal(err)
}
rows := [][]interface{}{
{13, "baz"}, // new row
{11, "moo"}, // replace "foo" with "moo"
}
err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
return tx.CopyUpsert(ctx, "test_copy_gen", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "id")
})
if err != nil {
t.Fatal(err)
}
type row struct {
ID int64
Key int
Value string
}
wantRows := []row{
{1, 11, "moo"},
{2, 12, "bar"},
{3, 13, "baz"},
}
var gotRows []row
if err := testDB.CollectStructs(ctx, &gotRows, `SELECT * FROM test_copy_gen ORDER BY ID`); err != nil {
t.Fatal(err)
}
if !cmp.Equal(gotRows, wantRows) {
t.Errorf("got %v, want %v", gotRows, wantRows)
}
}
func pgxOnly(t *testing.T) {
conn, err := testDB.db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
conn.Raw(func(c interface{}) error {
if _, ok := c.(*stdlib.Conn); !ok {
t.Skip("skipping; DB driver not pgx")
}
return nil
})
}