- Completed Addition function in arithmetic.go to match behavior similar to mySQL

- Added possible tests for Addition function within arithmetic_test.go

Signed-off-by: Rasika Kale <rasika@planetscale.com>
This commit is contained in:
Rasika Kale 2019-08-06 13:17:21 -07:00
Родитель c7bd871ffa
Коммит 0307a2e372
2 изменённых файлов: 187 добавлений и 31 удалений

Просмотреть файл

@ -19,6 +19,8 @@ package sqltypes
import (
"bytes"
"fmt"
"math"
"strconv"
querypb "vitess.io/vitess/go/vt/proto/query"
@ -40,36 +42,37 @@ type numeric struct {
var zeroBytes = []byte("0")
//const maxUintVal = 18446744073709551615
//const maxIntVal = 9223372036854775807
//Addition adds two values together
//if v1 or v2 is null, then it returns null
/*
func Addition(v1, v2 Value) Value {
func Addition(v1, v2 Value) (Value, error) {
if v1.IsNull() {
return NULL
return NULL, nil
}
if v2.IsNull() {
return NULL
return NULL, nil
}
lv1, err := newNumeric(v1)
if err != nil {
return NULL
return NULL, err
}
lv2, err := newNumeric(v2)
if err != nil {
return NULL
return NULL, err
}
lresult, err := addNumeric(lv1, lv2)
lresult, err := addNumericWithError(lv1, lv2)
if err != nil {
return NULL
return NULL, err
}
return castFromNumeric(lresult, lresult.typ)
return castFromNumeric(lresult, lresult.typ), nil
}
function to make
*/
// NullsafeAdd adds two Values in a null-safe manner. A null value
// is treated as 0. If both values are null, then a null is returned.
@ -98,10 +101,7 @@ func NullsafeAdd(v1, v2 Value, resultType querypb.Type) Value {
if err != nil {
return NULL //, err
}
lresult, err := addNumeric(lv1, lv2)
if err != nil {
return NULL //, err
}
lresult := addNumeric(lv1, lv2)
//fmt.Printf("resultType = %v, lresult = %v\n", lresult.typ, lresult)
return castFromNumeric(lresult, resultType)
}
@ -355,22 +355,43 @@ func newIntegralNumeric(v Value) (numeric, error) {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str)
}
func addNumeric(v1, v2 numeric) (numeric, error) {
func addNumeric(v1, v2 numeric) numeric {
v1, v2 = prioritize(v1, v2)
switch v1.typ {
case Int64:
return intPlusInt(v1.ival, v2.ival), nil
return intPlusInt(v1.ival, v2.ival)
case Uint64:
switch v2.typ {
case Int64:
return uintPlusInt(v1.uval, v2.ival)
case Uint64:
return uintPlusUint(v1.uval, v2.uval), nil
return uintPlusUint(v1.uval, v2.uval)
}
case Float64:
return floatPlusAny(v1.fval, v2)
}
panic("unreachable")
}
func addNumericWithError(v1, v2 numeric) (numeric, error) {
v1, v2 = prioritize(v1, v2)
//fmt.Printf("v1 = %v\n", v1.uval)
//fmt.Printf("v2 = %v\n", v2.uval)
switch v1.typ {
case Int64:
return intPlusIntWithError(v1.ival, v2.ival)
case Uint64:
switch v2.typ {
case Int64:
return uintPlusIntWithError(v1.uval, v2.ival)
case Uint64:
return uintPlusUintWithError(v1.uval, v2.uval)
}
case Float64:
return floatPlusAny(v1.fval, v2), nil
}
panic("unreachable")
}
// prioritize reorders the input parameters
@ -385,6 +406,7 @@ func prioritize(v1, v2 numeric) (altv1, altv2 numeric) {
if v2.typ == Float64 {
return v2, v1
}
}
return v1, v2
}
@ -403,8 +425,6 @@ overflow:
return numeric{typ: Float64, fval: float64(v1) + float64(v2)}
}
/*
function to make
func intPlusIntWithError(v1, v2 int64) (numeric, error) {
result := v1 + v2
if v1 > 0 && v2 > 0 && result < 0 {
@ -416,15 +436,35 @@ func intPlusIntWithError(v1, v2 int64) (numeric, error) {
return numeric{typ: Int64, ival: result}, nil
overflow:
return numeric{}, vterrors.Errorf(vtrpcpb.Code_)
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2)
}
*/
func uintPlusInt(v1 uint64, v2 int64) (numeric, error) {
func uintPlusInt(v1 uint64, v2 int64) numeric {
//if v2 < 0 {
// return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "cannot add a negative number to an unsigned integer: %d, %d", v1, v2)
// }
return uintPlusUint(v1, uint64(v2)), nil
return uintPlusUint(v1, uint64(v2))
}
func uintPlusIntWithError(v1 uint64, v2 int64) (numeric, error) {
if v2 >= math.MaxInt64 && v1 > 0 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2)
}
if v1 >= math.MaxUint64 && v2 > 0 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2)
}
//result := int64(v1) + v2
//return numeric{typ: Int64, ival: result}, nil
// readon to convert to int -> uint is because for numeric operators (such as + or -)
//where one of the operands is an unsigned integer, the result is unsigned by default.
return uintPlusUintWithError(v1, uint64(v2))
}
func uintPlusUint(v1, v2 uint64) numeric {
@ -436,6 +476,14 @@ func uintPlusUint(v1, v2 uint64) numeric {
return numeric{typ: Uint64, uval: result}
}
func uintPlusUintWithError(v1, v2 uint64) (numeric, error) {
result := v1 + v2
if result < v2 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2)
}
return numeric{typ: Uint64, uval: result}, nil
}
func floatPlusAny(v1 float64, v2 numeric) numeric {
switch v2.typ {
case Int64:

Просмотреть файл

@ -14,6 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
//-9,223,372,036,854,775,808
//-18,446,744,073,709,551,615
package sqltypes
import (
@ -28,6 +31,109 @@ import (
"vitess.io/vitess/go/vt/vterrors"
)
func TestAddition(t *testing.T) {
tcases := []struct {
v1, v2 Value
out Value
err error
}{{
//All Nulls
v1: NULL,
v2: NULL,
out: NULL,
}, {
// First value null.
v1: NewInt32(1),
v2: NULL,
out: NULL,
}, {
// Second value null.
v1: NULL,
v2: NewInt32(1),
out: NULL,
}, {
// case with negatives
v1: NewInt64(-1),
v2: NewInt64(-2),
out: NewInt64(-3),
}, {
// testing for overflow int64
v1: NewInt64(9223372036854775807),
v2: NewUint64(2),
err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in 2 + 9223372036854775807"),
}, {
v1: NewInt64(-2),
v2: NewUint64(1),
//out: NewInt64(-1),
out: NewUint64(18446744073709551615),
}, {
v1: NewInt64(9223372036854775807),
v2: NewInt64(-2),
out: NewInt64(9223372036854775805),
}, {
//Normal case
v1: NewUint64(1),
v2: NewUint64(2),
out: NewUint64(3),
}, {
//testing for overflow uint64
v1: NewUint64(18446744073709551615),
v2: NewUint64(2),
err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"),
}, {
//int64 underflow
v1: NewInt64(-9223372036854775807),
v2: NewInt64(-2),
err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775807 + -2"),
}, {
//checking int64 max value can be returned
v1: NewInt64(9223372036854775807),
v2: NewUint64(0),
out: NewUint64(9223372036854775807),
}, {
// testing whether uint64 max value can be returned
v1: NewUint64(18446744073709551615),
v2: NewInt64(0),
out: NewUint64(18446744073709551615),
}, {
v1: NewInt64(-3),
v2: NewUint64(1),
out: NewUint64(18446744073709551614),
}, {
//how is this okay? Because v1 is greater than max int64 value
v1: NewUint64(9223372036854775808),
v2: NewInt64(1),
out: NewUint64(9223372036854775809),
}}
for _, tcase := range tcases {
got, err := Addition(tcase.v1, tcase.v2)
if !vterrors.Equals(err, tcase.err) {
t.Errorf("Addition(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err))
}
if tcase.err != nil {
continue
}
if !reflect.DeepEqual(got, tcase.out) {
t.Errorf("Addition(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out))
}
}
}
func TestAdd(t *testing.T) {
tcases := []struct {
v1, v2 Value
@ -637,13 +743,15 @@ func TestAddNumeric(t *testing.T) {
out: numeric{typ: Float64, fval: 18446744073709551617},
}}
for _, tcase := range tcases {
got, err := addNumeric(tcase.v1, tcase.v2)
if !vterrors.Equals(err, tcase.err) {
t.Errorf("addNumeric(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err))
}
if tcase.err != nil {
continue
}
got := addNumeric(tcase.v1, tcase.v2)
/*
if !vterrors.Equals(err, tcase.err) {
t.Errorf("addNumeric(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err))
}
if tcase.err != nil {
continue
}
*/
if got != tcase.out {
t.Errorf("addNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out)