зеркало из https://github.com/golang/pkgsite.git
162 строки
4.3 KiB
Go
162 строки
4.3 KiB
Go
// Copyright 2019 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"reflect"
|
|
"unicode"
|
|
"unicode/utf8"
|
|
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
// StructScanner returns a function that, when called on a
|
|
// struct pointer of its argument type, returns a slice of arguments suitable for
|
|
// Row.Scan or Rows.Scan. The call to either Scan will populate the exported
|
|
// fields of the struct in the order they appear in the type definition.
|
|
//
|
|
// StructScanner panics if p is not a struct or a pointer to a struct.
|
|
// The function it returns will panic if its argument is not a pointer
|
|
// to a struct.
|
|
//
|
|
// Example:
|
|
//
|
|
// type Player struct { Name string; Score int }
|
|
// playerScanArgs := database.StructScanner(Player{})
|
|
// err := db.RunQuery(ctx, "SELECT name, score FROM players", func(rows *sql.Rows) error {
|
|
// var p Player
|
|
// if err := rows.Scan(playerScanArgs(&p)...); err != nil {
|
|
// return err
|
|
// }
|
|
// // use p
|
|
// return nil
|
|
// })
|
|
func StructScanner[T any]() func(p *T) []any {
|
|
return structScannerForType[T]()
|
|
}
|
|
|
|
type fieldInfo struct {
|
|
num int // to pass to v.Field
|
|
kind reflect.Kind
|
|
}
|
|
|
|
func structScannerForType[T any]() func(p *T) []any {
|
|
var x T
|
|
t := reflect.TypeOf(x)
|
|
if t.Kind() != reflect.Struct {
|
|
panic(fmt.Sprintf("%s is not a struct", t))
|
|
}
|
|
|
|
// Collect the numbers of the exported fields.
|
|
var fieldInfos []fieldInfo
|
|
for i := 0; i < t.NumField(); i++ {
|
|
r, _ := utf8.DecodeRuneInString(t.Field(i).Name)
|
|
if unicode.IsUpper(r) {
|
|
fieldInfos = append(fieldInfos, fieldInfo{i, t.Field(i).Type.Kind()})
|
|
}
|
|
}
|
|
// Return a function that gets pointers to the exported fields.
|
|
return func(p *T) []any {
|
|
v := reflect.ValueOf(p).Elem()
|
|
var ps []any
|
|
for _, info := range fieldInfos {
|
|
p := v.Field(info.num).Addr().Interface()
|
|
switch info.kind {
|
|
case reflect.Slice:
|
|
if _, ok := p.(*[]byte); !ok {
|
|
p = pq.Array(p)
|
|
}
|
|
case reflect.Ptr:
|
|
p = NullPtr(p)
|
|
default:
|
|
}
|
|
ps = append(ps, p)
|
|
}
|
|
return ps
|
|
}
|
|
}
|
|
|
|
// NullPtr is for scanning nullable database columns into pointer variables or
|
|
// fields. When given a pointer to a pointer to some type T, it returns a
|
|
// value that can be passed to a Scan function. If the corresponding column is
|
|
// nil, the variable will be set to nil. Otherwise, it will be set to a newly
|
|
// allocated pointer to the column value.
|
|
func NullPtr(p any) nullPtr {
|
|
v := reflect.ValueOf(p)
|
|
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Ptr {
|
|
panic("NullPtr arg must be pointer to pointer")
|
|
}
|
|
return nullPtr{v}
|
|
}
|
|
|
|
type nullPtr struct {
|
|
// ptr is a pointer to a pointer to something: **T
|
|
ptr reflect.Value
|
|
}
|
|
|
|
func (n nullPtr) Scan(value any) error {
|
|
// n.ptr is like a variable v of type **T
|
|
ntype := n.ptr.Elem().Type() // T
|
|
if value == nil {
|
|
n.ptr.Elem().Set(reflect.Zero(ntype)) // *v = nil
|
|
} else {
|
|
p := reflect.New(ntype.Elem()) // p := new(T)
|
|
p.Elem().Set(reflect.ValueOf(value)) // *p = value
|
|
n.ptr.Elem().Set(p) // *v = p
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (n nullPtr) Value() (driver.Value, error) {
|
|
if n.ptr.Elem().IsNil() {
|
|
return nil, nil
|
|
}
|
|
return n.ptr.Elem().Elem().Interface(), nil
|
|
}
|
|
|
|
// CollectStructs scans the rows from the query into structs and returns a slice of them.
|
|
// Example:
|
|
//
|
|
// type Player struct { Name string; Score int }
|
|
// var players []Player
|
|
// err := db.CollectStructs(ctx, &players, "SELECT name, score FROM players")
|
|
func CollectStructs[T any](ctx context.Context, db *DB, query string, args ...any) ([]T, error) {
|
|
scanner := structScannerForType[T]()
|
|
var ts []T
|
|
err := db.RunQuery(ctx, query, func(rows *sql.Rows) error {
|
|
var s T
|
|
if err := rows.Scan(scanner(&s)...); err != nil {
|
|
return err
|
|
}
|
|
ts = append(ts, s)
|
|
return nil
|
|
}, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ts, nil
|
|
}
|
|
|
|
func CollectStructPtrs[T any](ctx context.Context, db *DB, query string, args ...any) ([]*T, error) {
|
|
scanner := structScannerForType[T]()
|
|
var ts []*T
|
|
err := db.RunQuery(ctx, query, func(rows *sql.Rows) error {
|
|
var s T
|
|
if err := rows.Scan(scanner(&s)...); err != nil {
|
|
return err
|
|
}
|
|
ts = append(ts, &s)
|
|
return nil
|
|
}, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ts, nil
|
|
}
|