diff --git a/go/vt/vitessdriver/driver.go b/go/vt/vitessdriver/driver.go index cf92f8575e..fb8272cd12 100644 --- a/go/vt/vitessdriver/driver.go +++ b/go/vt/vitessdriver/driver.go @@ -80,11 +80,11 @@ func OpenShardForStreaming(address, keyspace, shard, tabletType string, timeout // It allows to pass in a Configuration struct to control all possible // settings of the Vitess Go SQL driver. func OpenWithConfiguration(c Configuration) (*sql.DB, error) { - jsonBytes, err := json.Marshal(c) + json, err := c.toJSON() if err != nil { return nil, err } - return sql.Open("vitess", string(jsonBytes)) + return sql.Open("vitess", json) } type drv struct { @@ -172,11 +172,31 @@ type Configuration struct { } func newDefaultConfiguration() Configuration { - return Configuration{ - Protocol: "grpc", - TabletType: "master", - Streaming: false, + c := Configuration{} + c.setDefaults() + return c +} + +// toJSON converts Configuration to the JSON string which is required by the +// Vitess driver. Default values for empty fields will be set. +func (c Configuration) toJSON() (string, error) { + c.setDefaults() + jsonBytes, err := json.Marshal(c) + if err != nil { + return "", err } + return string(jsonBytes), nil +} + +// setDefaults sets the default values for empty fields. +func (c *Configuration) setDefaults() { + if c.Protocol == "" { + c.Protocol = "grpc" + } + if c.TabletType == "" { + c.TabletType = "master" + } + // c.Streaming = false is enforced by Go's zero value. } type conn struct { diff --git a/go/vt/vitessdriver/driver_test.go b/go/vt/vitessdriver/driver_test.go index 9b4a3ea2c8..32feaa9730 100644 --- a/go/vt/vitessdriver/driver_test.go +++ b/go/vt/vitessdriver/driver_test.go @@ -218,6 +218,46 @@ func TestExec(t *testing.T) { } } +func TestConfigurationToJSON(t *testing.T) { + var testcases = []struct { + desc string + config Configuration + json string + }{ + { + desc: "all fields set", + config: Configuration{ + Protocol: "some-invalid-protocol", + Keyspace: "ks2", + Shard: "-80", + TabletType: "replica", + Streaming: true, + Timeout: 1 * time.Second, + }, + json: `{"Protocol":"some-invalid-protocol","Address":"","Keyspace":"ks2","Shard":"-80","tablet_type":"replica","Streaming":true,"Timeout":1000000000}`, + }, + { + desc: "default fields are empty", + config: Configuration{ + Keyspace: "ks2", + Shard: "-80", + Timeout: 1 * time.Second, + }, + json: `{"Protocol":"grpc","Address":"","Keyspace":"ks2","Shard":"-80","tablet_type":"master","Streaming":false,"Timeout":1000000000}`, + }, + } + + for _, tc := range testcases { + json, err := tc.config.toJSON() + if err != nil { + t.Errorf("%v: JSON conversion should have succeeded but did not: %v", tc.desc, err) + } + if json != tc.json { + t.Errorf("%v: Configuration.JSON(): got: %v want: %v Configuration: %v", tc.desc, json, tc.json, tc.config) + } + } +} + func TestExecStreamingNotAllowed(t *testing.T) { db, err := OpenForStreaming(testAddress, "rdonly", 30*time.Second) if err != nil {