зеркало из https://github.com/golang/pkgsite.git
internal/database: CopyUpsert: support dropping a column
You can't do a CopyFrom on a table with a generated column: postgres complains about the column value being null. To fix, drop the column on the temporary table. Change-Id: Ia52f59af6d026b3fcdaafe3c7865a2eb85deb179 Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/305830 Trust: Jonathan Amsterdam <jba@google.com> Run-TryBot: Jonathan Amsterdam <jba@google.com> Reviewed-by: Julie Qiu <julie@golang.org>
This commit is contained in:
Родитель
ace022fde6
Коммит
127896bc9e
|
@ -23,11 +23,13 @@ import (
|
|||
// src is the source of the rows to upsert.
|
||||
// conflictColumns are the columns that might conflict (i.e. that have a UNIQUE
|
||||
// constraint).
|
||||
// If dropColumn is non-empty, that column will be dropped from the temporary
|
||||
// table before copying. Use dropColumn for generated ID columns.
|
||||
//
|
||||
// CopyUpsert works by first creating a temporary table, populating it with
|
||||
// CopyFrom, and then running an INSERT...SELECT...ON CONFLICT to upsert its
|
||||
// rows into the original table.
|
||||
func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string) (err error) {
|
||||
func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string, dropColumn string) (err error) {
|
||||
defer derrors.Wrap(&err, "CopyUpsert(%q)", table)
|
||||
|
||||
if !db.InTransaction() {
|
||||
|
@ -46,8 +48,11 @@ func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, sr
|
|||
tempTable := fmt.Sprintf("__%s_copy", table)
|
||||
stmt := fmt.Sprintf(`
|
||||
DROP TABLE IF EXISTS %s;
|
||||
CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP
|
||||
CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP;
|
||||
`, tempTable, table)
|
||||
if dropColumn != "" {
|
||||
stmt += fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tempTable, dropColumn)
|
||||
}
|
||||
_, err = conn.Exec(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -55,12 +60,12 @@ func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, sr
|
|||
start := time.Now()
|
||||
n, err := conn.CopyFrom(ctx, []string{tempTable}, columns, src)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("CopyFrom: %w", err)
|
||||
}
|
||||
log.Debugf(ctx, "CopyUpsert(%q): copied %d rows in %s", table, n, time.Since(start))
|
||||
conflictAction := buildUpsertConflictAction(columns, conflictColumns)
|
||||
query := buildCopyUpsertQuery(table, tempTable, columns, conflictAction)
|
||||
|
||||
cols := strings.Join(columns, ", ")
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
|
||||
defer logQuery(ctx, query, nil, db.instanceID, db.IsRetryable())(&err)
|
||||
start = time.Now()
|
||||
ctag, err := conn.Exec(ctx, query)
|
||||
|
@ -72,11 +77,6 @@ func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, sr
|
|||
})
|
||||
}
|
||||
|
||||
func buildCopyUpsertQuery(table, tempTable string, columns []string, conflictAction string) string {
|
||||
cols := strings.Join(columns, ", ")
|
||||
return fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction)
|
||||
}
|
||||
|
||||
// A RowItem is a row of values or an error.
|
||||
type RowItem struct {
|
||||
Values []interface{}
|
||||
|
|
|
@ -15,18 +15,8 @@ import (
|
|||
)
|
||||
|
||||
func TestCopyUpsert(t *testing.T) {
|
||||
pgxOnly(t)
|
||||
ctx := context.Background()
|
||||
conn, err := testDB.db.Conn(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Raw(func(c interface{}) error {
|
||||
if _, ok := c.(*stdlib.Conn); !ok {
|
||||
t.Skip("skipping; DB driver not pgx")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
for _, stmt := range []string{
|
||||
`DROP TABLE IF EXISTS test_streaming_upsert`,
|
||||
`CREATE TABLE test_streaming_upsert (key INTEGER PRIMARY KEY, value TEXT)`,
|
||||
|
@ -40,8 +30,8 @@ func TestCopyUpsert(t *testing.T) {
|
|||
{3, "baz"}, // new row
|
||||
{1, "moo"}, // replace "foo" with "moo"
|
||||
}
|
||||
err = testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
|
||||
return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"})
|
||||
err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
|
||||
return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -66,3 +56,57 @@ func TestCopyUpsert(t *testing.T) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCopyUpsertGeneratedColumn(t *testing.T) {
|
||||
pgxOnly(t)
|
||||
ctx := context.Background()
|
||||
stmt := `
|
||||
DROP TABLE IF EXISTS test_copy_gen;
|
||||
CREATE TABLE test_copy_gen (id bigint PRIMARY KEY GENERATED ALWAYS AS IDENTITY, key INT, value TEXT, UNIQUE (key));
|
||||
INSERT INTO test_copy_gen (key, value) VALUES (11, 'foo'), (12, 'bar')`
|
||||
if _, err := testDB.Exec(ctx, stmt); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows := [][]interface{}{
|
||||
{13, "baz"}, // new row
|
||||
{11, "moo"}, // replace "foo" with "moo"
|
||||
}
|
||||
err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
|
||||
return tx.CopyUpsert(ctx, "test_copy_gen", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "id")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type row struct {
|
||||
ID int64
|
||||
Key int
|
||||
Value string
|
||||
}
|
||||
wantRows := []row{
|
||||
{1, 11, "moo"},
|
||||
{2, 12, "bar"},
|
||||
{3, 13, "baz"},
|
||||
}
|
||||
var gotRows []row
|
||||
if err := testDB.CollectStructs(ctx, &gotRows, `SELECT * FROM test_copy_gen ORDER BY ID`); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !cmp.Equal(gotRows, wantRows) {
|
||||
t.Errorf("got %v, want %v", gotRows, wantRows)
|
||||
}
|
||||
}
|
||||
|
||||
func pgxOnly(t *testing.T) {
|
||||
conn, err := testDB.db.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Raw(func(c interface{}) error {
|
||||
if _, ok := c.(*stdlib.Conn); !ok {
|
||||
t.Skip("skipping; DB driver not pgx")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче