Signed-off-by: Jacques Grove <aquarapid@gmail.com>
This commit is contained in:
Jacques Grove 2021-09-17 12:50:26 -07:00
Родитель 101563b99d
Коммит 2d024326fa
6 изменённых файлов: 104 добавлений и 90 удалений

Просмотреть файл

@ -18,6 +18,7 @@ limitations under the License.
package netutil
import (
"bytes"
"fmt"
"math/rand"
"net"
@ -189,3 +190,60 @@ func ResolveIPv4Addrs(addr string) ([]string, error) {
}
return result, nil
}
func dnsLookup(host string) ([]net.IP, error) {
addrs, err := net.LookupHost(host)
if err != nil {
return nil, fmt.Errorf("Error looking up dns name [%v]: (%v)", host, err)
}
naddr := make([]net.IP, len(addrs))
for i, a := range addrs {
naddr[i] = net.ParseIP(a)
}
sort.Slice(naddr, func(i, j int) bool {
return bytes.Compare(naddr[i], naddr[j]) < 0
})
return naddr, nil
}
// DNSTracker is a closure that persists state for
// tracking changes in the DNS resolution of a target dns name
// returns true if the DNS name resolution has changed
// If there is a lookup problem, we pretend nothing has changed
func DNSTracker(host string) func() (bool, error) {
dnsName := host
var addrs []net.IP
if dnsName != "" {
addrs, _ = dnsLookup(dnsName)
}
return func() (bool, error) {
if dnsName == "" {
return false, nil
}
newaddrs, err := dnsLookup(dnsName)
if err != nil {
return false, err
}
if len(newaddrs) == 0 { // Should not happen, but just in case
return false, fmt.Errorf("Connection DNS for %s reporting as empty, ignoring", dnsName)
}
if !addrEqual(addrs, newaddrs) {
addrs = newaddrs
return true, fmt.Errorf("Connection DNS for %s has changed; old: [%v] new: [%v]", dnsName, addrs, newaddrs)
}
return false, nil
}
}
func addrEqual(a, b []net.IP) bool {
if len(a) != len(b) {
return false
}
for idx, v := range a {
if !net.IP.Equal(v, b[idx]) {
return false
}
}
return true
}

Просмотреть файл

@ -24,6 +24,8 @@ import (
"sync"
"time"
"vitess.io/vitess/go/vt/log"
"context"
"vitess.io/vitess/go/sync2"
@ -52,7 +54,7 @@ type Factory func(context.Context) (Resource, error)
// RefreshCheck is a function used to determine if a resource pool should be
// refreshed (i.e. closed and reopened)
type RefreshCheck func() bool
type RefreshCheck func() (bool, error)
// Resource defines the interface that every resource must provide.
// Thread synchronization between Close() and IsClosed()
@ -81,7 +83,7 @@ type ResourcePool struct {
logWait func(time.Time)
refreshCheck RefreshCheck
refreshFrequency time.Duration
refreshInterval time.Duration
refreshStop chan struct{}
refreshTicker *time.Ticker
refreshWg sync.WaitGroup
@ -103,8 +105,9 @@ type resourceWrapper struct {
// An idleTimeout of 0 means that there is no timeout.
// A non-zero value of prefillParallelism causes the pool to be pre-filled.
// The value specifies how many resources can be opened in parallel.
// TODO: document refreshCheck and refreshFrequency
func NewResourcePool(factory Factory, refreshCheck RefreshCheck, refreshFrequency time.Duration, capacity, maxCap int, idleTimeout time.Duration, prefillParallelism int, logWait func(time.Time)) *ResourcePool {
// refreshCheck is a function we consult at refreshInterval
// intervals to determine if the pool should be drained and reopened
func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Duration, prefillParallelism int, logWait func(time.Time), refreshCheck RefreshCheck, refreshInterval time.Duration) *ResourcePool {
if capacity <= 0 || maxCap <= 0 || capacity > maxCap {
panic(errors.New("invalid/out of range capacity"))
}
@ -154,8 +157,8 @@ func NewResourcePool(factory Factory, refreshCheck RefreshCheck, refreshFrequenc
rp.idleTimer.Start(rp.closeIdleResources)
}
if refreshCheck != nil && refreshFrequency > 0 {
rp.refreshFrequency = refreshFrequency
if refreshCheck != nil && refreshInterval > 0 {
rp.refreshInterval = refreshInterval
rp.refreshCheck = refreshCheck
rp.startRefreshTicker()
}
@ -164,7 +167,7 @@ func NewResourcePool(factory Factory, refreshCheck RefreshCheck, refreshFrequenc
}
func (rp *ResourcePool) startRefreshTicker() {
rp.refreshTicker = time.NewTicker(rp.refreshFrequency)
rp.refreshTicker = time.NewTicker(rp.refreshInterval)
rp.refreshStop = make(chan struct{})
rp.refreshWg.Add(1)
go func() {
@ -172,7 +175,11 @@ func (rp *ResourcePool) startRefreshTicker() {
for {
select {
case <-rp.refreshTicker.C:
if rp.refreshCheck() {
val, err := rp.refreshCheck()
if err != nil {
log.Info(err)
}
if val {
go rp.reopen()
return
}

Просмотреть файл

@ -65,7 +65,7 @@ func TestOpen(t *testing.T) {
count.Set(0)
waitStarts = waitStarts[:0]
p := NewResourcePool(PoolFactory, nil, 0, 6, 6, time.Second, 0, logWait)
p := NewResourcePool(PoolFactory, 6, 6, time.Second, 0, logWait, nil, 0)
p.SetCapacity(5)
var resources [10]Resource
@ -218,12 +218,12 @@ func TestOpen(t *testing.T) {
func TestPrefill(t *testing.T) {
lastID.Set(0)
count.Set(0)
p := NewResourcePool(PoolFactory, nil, 0, 5, 5, time.Second, 1, logWait)
p := NewResourcePool(PoolFactory, 5, 5, time.Second, 1, logWait, nil, 0)
defer p.Close()
if p.Active() != 5 {
t.Errorf("p.Active(): %d, want 5", p.Active())
}
p = NewResourcePool(FailFactory, nil, 0, 5, 5, time.Second, 1, logWait)
p = NewResourcePool(FailFactory, 5, 5, time.Second, 1, logWait, nil, 0)
defer p.Close()
if p.Active() != 0 {
t.Errorf("p.Active(): %d, want 0", p.Active())
@ -238,7 +238,7 @@ func TestPrefillTimeout(t *testing.T) {
defer func() { prefillTimeout = saveTimeout }()
start := time.Now()
p := NewResourcePool(SlowFailFactory, nil, 0, 5, 5, time.Second, 1, logWait)
p := NewResourcePool(SlowFailFactory, 5, 5, time.Second, 1, logWait, nil, 0)
defer p.Close()
if elapsed := time.Since(start); elapsed > 20*time.Millisecond {
t.Errorf("elapsed: %v, should be around 10ms", elapsed)
@ -254,7 +254,7 @@ func TestShrinking(t *testing.T) {
count.Set(0)
waitStarts = waitStarts[:0]
p := NewResourcePool(PoolFactory, nil, 0, 5, 5, time.Second, 0, logWait)
p := NewResourcePool(PoolFactory, 5, 5, time.Second, 0, logWait, nil, 0)
var resources [10]Resource
// Leave one empty slot in the pool
for i := 0; i < 4; i++ {
@ -396,7 +396,7 @@ func TestClosing(t *testing.T) {
ctx := context.Background()
lastID.Set(0)
count.Set(0)
p := NewResourcePool(PoolFactory, nil, 0, 5, 5, time.Second, 0, logWait)
p := NewResourcePool(PoolFactory, 5, 5, time.Second, 0, logWait, nil, 0)
var resources [10]Resource
for i := 0; i < 5; i++ {
r, err := p.Get(ctx)
@ -444,10 +444,10 @@ func TestReopen(t *testing.T) {
ctx := context.Background()
lastID.Set(0)
count.Set(0)
refreshCheck := func() bool {
return true
refreshCheck := func() (bool, error) {
return true, nil
}
p := NewResourcePool(PoolFactory, refreshCheck, 500*time.Millisecond, 5, 5, time.Second, 0, logWait)
p := NewResourcePool(PoolFactory, 5, 5, time.Second, 0, logWait, refreshCheck, 500*time.Millisecond)
var resources [10]Resource
for i := 0; i < 5; i++ {
r, err := p.Get(ctx)
@ -487,7 +487,7 @@ func TestIdleTimeout(t *testing.T) {
ctx := context.Background()
lastID.Set(0)
count.Set(0)
p := NewResourcePool(PoolFactory, nil, 0, 1, 1, 10*time.Millisecond, 0, logWait)
p := NewResourcePool(PoolFactory, 1, 1, 10*time.Millisecond, 0, logWait, nil, 0)
defer p.Close()
r, err := p.Get(ctx)
@ -598,7 +598,7 @@ func TestIdleTimeoutCreateFail(t *testing.T) {
ctx := context.Background()
lastID.Set(0)
count.Set(0)
p := NewResourcePool(PoolFactory, nil, 0, 1, 1, 10*time.Millisecond, 0, logWait)
p := NewResourcePool(PoolFactory, 1, 1, 10*time.Millisecond, 0, logWait, nil, 0)
defer p.Close()
r, err := p.Get(ctx)
if err != nil {
@ -619,7 +619,7 @@ func TestCreateFail(t *testing.T) {
ctx := context.Background()
lastID.Set(0)
count.Set(0)
p := NewResourcePool(FailFactory, nil, 0, 5, 5, time.Second, 0, logWait)
p := NewResourcePool(FailFactory, 5, 5, time.Second, 0, logWait, nil, 0)
defer p.Close()
if _, err := p.Get(ctx); err.Error() != "Failed" {
t.Errorf("Expecting Failed, received %v", err)
@ -635,7 +635,7 @@ func TestCreateFailOnPut(t *testing.T) {
ctx := context.Background()
lastID.Set(0)
count.Set(0)
p := NewResourcePool(PoolFactory, nil, 0, 5, 5, time.Second, 0, logWait)
p := NewResourcePool(PoolFactory, 5, 5, time.Second, 0, logWait, nil, 0)
defer p.Close()
_, err := p.Get(ctx)
if err != nil {
@ -652,7 +652,7 @@ func TestSlowCreateFail(t *testing.T) {
ctx := context.Background()
lastID.Set(0)
count.Set(0)
p := NewResourcePool(SlowFailFactory, nil, 0, 2, 2, time.Second, 0, logWait)
p := NewResourcePool(SlowFailFactory, 2, 2, time.Second, 0, logWait, nil, 0)
defer p.Close()
ch := make(chan bool)
// The third Get should not wait indefinitely
@ -674,7 +674,7 @@ func TestTimeout(t *testing.T) {
ctx := context.Background()
lastID.Set(0)
count.Set(0)
p := NewResourcePool(PoolFactory, nil, 0, 1, 1, time.Second, 0, logWait)
p := NewResourcePool(PoolFactory, 1, 1, time.Second, 0, logWait, nil, 0)
defer p.Close()
r, err := p.Get(ctx)
if err != nil {
@ -693,7 +693,7 @@ func TestTimeout(t *testing.T) {
func TestExpired(t *testing.T) {
lastID.Set(0)
count.Set(0)
p := NewResourcePool(PoolFactory, nil, 0, 1, 1, time.Second, 0, logWait)
p := NewResourcePool(PoolFactory, 1, 1, time.Second, 0, logWait, nil, 0)
defer p.Close()
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second))
r, err := p.Get(ctx)

Просмотреть файл

@ -44,7 +44,7 @@ type RPCPool struct {
// will not be called).
func NewRPCPool(size int, waitTimeout time.Duration, logWait func(time.Time)) *RPCPool {
return &RPCPool{
rp: NewResourcePool(rpcResourceFactory, nil, 0, size, size, 0, size, logWait),
rp: NewResourcePool(rpcResourceFactory, size, size, 0, size, logWait, nil, 0),
waitTimeout: waitTimeout,
}
}

Просмотреть файл

@ -22,19 +22,18 @@ object to pool these DBConnections.
package dbconnpool
import (
"bytes"
"errors"
"net"
"sort"
"sync"
"time"
"vitess.io/vitess/go/netutil"
"context"
"vitess.io/vitess/go/pools"
"vitess.io/vitess/go/stats"
"vitess.io/vitess/go/vt/dbconfigs"
"vitess.io/vitess/go/vt/log"
)
var (
@ -90,57 +89,6 @@ func (cp *ConnectionPool) pool() (p *pools.ResourcePool) {
return p
}
func lookup(host string) []net.IP {
addrs, err := net.LookupHost(host)
if err != nil {
log.Errorf("Error looking up dns name [%v]: (%v)\n", host, err)
return nil
}
naddr := make([]net.IP, len(addrs))
for i, a := range addrs {
naddr[i] = net.ParseIP(a)
}
sort.Slice(naddr, func(i, j int) bool {
return bytes.Compare(naddr[i], naddr[j]) < 0
})
return naddr
}
// DNSTracker is a closure that persists state for
// tracking changes in the DNS resolution of a target dns name
func DNSTracker(host string) func() bool {
dnsName := host
var addrs []net.IP
if dnsName != "" {
addrs = lookup(dnsName)
}
return func() bool {
if dnsName == "" {
return false
}
newaddrs := lookup(dnsName)
if !addrEqual(addrs, newaddrs) {
log.Infof("Connection DNS has changed; old: [%v] new: [%v]\n", addrs, newaddrs)
addrs = newaddrs
return true
}
return false
}
}
func addrEqual(a, b []net.IP) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if !net.IP.Equal(v, b[i]) {
return false
}
}
return true
}
// Open must be called before starting to use the pool.
//
// For instance:
@ -149,18 +97,17 @@ func addrEqual(a, b []net.IP) bool {
// ...
// conn, err := pool.Get()
// ...
// TODO: fix comment
func (cp *ConnectionPool) Open(info dbconfigs.Connector) {
var f pools.RefreshCheck
var refreshCheck pools.RefreshCheck
if net.ParseIP(info.Host()) == nil {
f = DNSTracker(info.Host())
refreshCheck = netutil.DNSTracker(info.Host())
} else {
f = nil
refreshCheck = nil
}
cp.mu.Lock()
defer cp.mu.Unlock()
cp.info = info
cp.connections = pools.NewResourcePool(cp.connect, f, cp.resolutionFrequency, cp.capacity, cp.capacity, cp.idleTimeout, 0, nil)
cp.connections = pools.NewResourcePool(cp.connect, cp.capacity, cp.capacity, cp.idleTimeout, 0, nil, refreshCheck, cp.resolutionFrequency)
}
// connect is used by the resource pool to create a new Resource.

Просмотреть файл

@ -21,6 +21,8 @@ import (
"sync"
"time"
"vitess.io/vitess/go/netutil"
"context"
"vitess.io/vitess/go/pools"
@ -114,12 +116,12 @@ func (cp *Pool) Open(appParams, dbaParams, appDebugParams dbconfigs.Connector) {
var refreshCheck pools.RefreshCheck
if net.ParseIP(appParams.Host()) == nil {
refreshCheck = dbconnpool.DNSTracker(appParams.Host())
refreshCheck = netutil.DNSTracker(appParams.Host())
} else {
refreshCheck = nil
}
cp.connections = pools.NewResourcePool(f, refreshCheck, *mysqlctl.PoolDynamicHostnameResolution, cp.capacity, cp.capacity, cp.idleTimeout, cp.prefillParallelism, cp.getLogWaitCallback())
cp.connections = pools.NewResourcePool(f, cp.capacity, cp.capacity, cp.idleTimeout, cp.prefillParallelism, cp.getLogWaitCallback(), refreshCheck, *mysqlctl.PoolDynamicHostnameResolution)
cp.appDebugParams = appDebugParams
cp.dbaPool.Open(dbaParams)