Fix bug in vtexplain around JOINs (#12383)

Signed-off-by: Andres Taylor <andres@planetscale.com>
Co-authored-by: Andres Taylor <andres@planetscale.com>
This commit is contained in:
vitess-bot[bot] 2023-03-20 10:35:10 +02:00 коммит произвёл GitHub
Родитель e3d889be05
Коммит 3b0ccd0a02
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 220 добавлений и 176 удалений

Просмотреть файл

@ -45,3 +45,8 @@ select ID from t1
1 ks_unsharded/-: select ID from t1 limit 10001
----------------------------------------------------------------------
select t1.id, t2.c2 from t1 join t2 on t1.id = t2.t1_id where t2.c2 in (1)
1 ks_unsharded/-: select t1.id, t2.c2 from t1 join t2 on t1.id = t2.t1_id where t2.c2 in (1) limit 10001
----------------------------------------------------------------------

6
go/vt/vtexplain/testdata/test-schema.sql поставляемый
Просмотреть файл

@ -4,6 +4,12 @@ create table t1 (
floatval float not null default 0,
primary key (id)
);
create table t2 (
id bigint(20) unsigned not null,
t1_id bigint(20) unsigned not null default 0,
c2 bigint(20) null,
primary key (id)
);
create table user (
id bigint,

1
go/vt/vtexplain/testdata/test-vschema.json поставляемый
Просмотреть файл

@ -3,6 +3,7 @@
"sharded": false,
"tables": {
"t1": {},
"t2": {},
"table_not_in_schema": {}
}
},

Просмотреть файл

@ -6,3 +6,4 @@ update t1 set floatval = 9.99;
delete from t1 where id = 100;
insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update intval=3, floatval=3.14;
select ID from t1;
select t1.id, t2.c2 from t1 join t2 on t1.id = t2.t1_id where t2.c2 in (1);

Просмотреть файл

@ -504,185 +504,11 @@ func (t *explainTablet) HandleQuery(c *mysql.Conn, query string, callback func(*
}
switch sqlparser.Preview(query) {
case sqlparser.StmtSelect:
// Parse the select statement to figure out the table and columns
// that were referenced so that the synthetic response has the
// expected field names and types.
stmt, err := sqlparser.Parse(query)
var err error
result, err = t.handleSelect(query)
if err != nil {
return err
}
var selStmt *sqlparser.Select
switch stmt := stmt.(type) {
case *sqlparser.Select:
selStmt = stmt
case *sqlparser.Union:
selStmt = sqlparser.GetFirstSelect(stmt)
default:
return fmt.Errorf("vtexplain: unsupported statement type +%v", reflect.TypeOf(stmt))
}
// Gen4 supports more complex queries so we now need to
// handle multiple FROM clauses
tables := make([]*sqlparser.AliasedTableExpr, len(selStmt.From))
for _, from := range selStmt.From {
tables = append(tables, getTables(from)...)
}
tableColumnMap := map[sqlparser.IdentifierCS]map[string]querypb.Type{}
for _, table := range tables {
if table == nil {
continue
}
tableName := sqlparser.String(sqlparser.GetTableName(table.Expr))
columns, exists := t.vte.getGlobalTabletEnv().tableColumns[tableName]
if !exists && tableName != "" && tableName != "dual" {
return fmt.Errorf("unable to resolve table name %s", tableName)
}
colTypeMap := map[string]querypb.Type{}
if table.As.IsEmpty() {
tableColumnMap[sqlparser.GetTableName(table.Expr)] = colTypeMap
} else {
tableColumnMap[table.As] = colTypeMap
}
for k, v := range columns {
if colType, exists := colTypeMap[k]; exists {
if colType != v {
return fmt.Errorf("column type mismatch for column : %s, types: %d vs %d", k, colType, v)
}
continue
}
colTypeMap[k] = v
}
}
colNames := make([]string, 0, 4)
colTypes := make([]querypb.Type, 0, 4)
for _, node := range selStmt.SelectExprs {
switch node := node.(type) {
case *sqlparser.AliasedExpr:
colNames, colTypes = inferColTypeFromExpr(node.Expr, tableColumnMap, colNames, colTypes)
case *sqlparser.StarExpr:
if node.TableName.Name.IsEmpty() {
// SELECT *
for _, colTypeMap := range tableColumnMap {
for col, colType := range colTypeMap {
colNames = append(colNames, col)
colTypes = append(colTypes, colType)
}
}
} else {
// SELECT tableName.*
colTypeMap := tableColumnMap[node.TableName.Name]
for col, colType := range colTypeMap {
colNames = append(colNames, col)
colTypes = append(colTypes, colType)
}
}
}
}
// the query against lookup table is in-query, handle it specifically
var inColName string
inVal := make([]sqltypes.Value, 0, 10)
rowCount := 1
if selStmt.Where != nil {
switch v := selStmt.Where.Expr.(type) {
case *sqlparser.ComparisonExpr:
if v.Operator == sqlparser.InOp {
switch c := v.Left.(type) {
case *sqlparser.ColName:
colName := strings.ToLower(c.Name.String())
colType := tableColumnMap[sqlparser.GetTableName(selStmt.From[0].(*sqlparser.AliasedTableExpr).Expr)][colName]
switch values := v.Right.(type) {
case sqlparser.ValTuple:
for _, val := range values {
switch v := val.(type) {
case *sqlparser.Literal:
value, err := evalengine.LiteralToValue(v)
if err != nil {
return err
}
// Cast the value in the tuple to the expected value of the column
castedValue, err := evalengine.Cast(value, colType)
if err != nil {
return err
}
// Check if we have a duplicate value
isNewValue := true
for _, v := range inVal {
result, err := evalengine.NullsafeCompare(v, value, collations.Default())
if err != nil {
return err
}
if result == 0 {
isNewValue = false
break
}
}
if isNewValue {
inVal = append(inVal, castedValue)
}
}
}
rowCount = len(inVal)
}
inColName = strings.ToLower(c.Name.String())
}
}
}
}
fields := make([]*querypb.Field, len(colNames))
rows := make([][]sqltypes.Value, 0, rowCount)
for i, col := range colNames {
colType := colTypes[i]
fields[i] = &querypb.Field{
Name: col,
Type: colType,
}
}
for j := 0; j < rowCount; j++ {
values := make([]sqltypes.Value, len(colNames))
for i, col := range colNames {
// Generate a fake value for the given column. For the column in the IN clause,
// use the provided values in the query, For numeric types,
// use the column index. For all other types, just shortcut to using
// a string type that encodes the column name + index.
colType := colTypes[i]
if len(inVal) > j && col == inColName {
values[i], _ = sqltypes.NewValue(querypb.Type_VARBINARY, inVal[j].Raw())
} else if sqltypes.IsIntegral(colType) {
values[i] = sqltypes.NewInt32(int32(i + 1))
} else if sqltypes.IsFloat(colType) {
values[i] = sqltypes.NewFloat64(1.0 + float64(i))
} else {
values[i] = sqltypes.NewVarChar(fmt.Sprintf("%s_val_%d", col, i+1))
}
}
rows = append(rows, values)
}
result = &sqltypes.Result{
Fields: fields,
InsertID: 0,
Rows: rows,
}
resultJSON, _ := json.MarshalIndent(result, "", " ")
log.V(100).Infof("query %s result %s\n", query, string(resultJSON))
case sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtSet,
sqlparser.StmtSavepoint, sqlparser.StmtSRollback, sqlparser.StmtRelease:
result = &sqltypes.Result{}
@ -699,6 +525,211 @@ func (t *explainTablet) HandleQuery(c *mysql.Conn, query string, callback func(*
return callback(result)
}
func (t *explainTablet) handleSelect(query string) (*sqltypes.Result, error) {
// Parse the select statement to figure out the table and columns
// that were referenced so that the synthetic response has the
// expected field names and types.
stmt, err := sqlparser.Parse(query)
if err != nil {
return nil, err
}
var selStmt *sqlparser.Select
switch stmt := stmt.(type) {
case *sqlparser.Select:
selStmt = stmt
case *sqlparser.Union:
selStmt = sqlparser.GetFirstSelect(stmt)
default:
return nil, fmt.Errorf("vtexplain: unsupported statement type +%v", reflect.TypeOf(stmt))
}
// Gen4 supports more complex queries so we now need to
// handle multiple FROM clauses
tables := make([]*sqlparser.AliasedTableExpr, len(selStmt.From))
for _, from := range selStmt.From {
tables = append(tables, getTables(from)...)
}
tableColumnMap := map[sqlparser.IdentifierCS]map[string]querypb.Type{}
for _, table := range tables {
if table == nil {
continue
}
tableName := sqlparser.String(sqlparser.GetTableName(table.Expr))
columns, exists := t.vte.getGlobalTabletEnv().tableColumns[tableName]
if !exists && tableName != "" && tableName != "dual" {
return nil, fmt.Errorf("unable to resolve table name %s", tableName)
}
colTypeMap := map[string]querypb.Type{}
if table.As.IsEmpty() {
tableColumnMap[sqlparser.GetTableName(table.Expr)] = colTypeMap
} else {
tableColumnMap[table.As] = colTypeMap
}
for k, v := range columns {
if colType, exists := colTypeMap[k]; exists {
if colType != v {
return nil, fmt.Errorf("column type mismatch for column : %s, types: %d vs %d", k, colType, v)
}
continue
}
colTypeMap[k] = v
}
}
colNames, colTypes := t.analyzeExpressions(selStmt, tableColumnMap)
inColName, inVal, rowCount, s, err := t.analyzeWhere(selStmt, tableColumnMap)
if err != nil {
return s, err
}
fields := make([]*querypb.Field, len(colNames))
rows := make([][]sqltypes.Value, 0, rowCount)
for i, col := range colNames {
colType := colTypes[i]
fields[i] = &querypb.Field{
Name: col,
Type: colType,
}
}
for j := 0; j < rowCount; j++ {
values := make([]sqltypes.Value, len(colNames))
for i, col := range colNames {
// Generate a fake value for the given column. For the column in the IN clause,
// use the provided values in the query, For numeric types,
// use the column index. For all other types, just shortcut to using
// a string type that encodes the column name + index.
colType := colTypes[i]
if len(inVal) > j && col == inColName {
values[i], _ = sqltypes.NewValue(querypb.Type_VARBINARY, inVal[j].Raw())
} else if sqltypes.IsIntegral(colType) {
values[i] = sqltypes.NewInt32(int32(i + 1))
} else if sqltypes.IsFloat(colType) {
values[i] = sqltypes.NewFloat64(1.0 + float64(i))
} else {
values[i] = sqltypes.NewVarChar(fmt.Sprintf("%s_val_%d", col, i+1))
}
}
rows = append(rows, values)
}
result := &sqltypes.Result{
Fields: fields,
InsertID: 0,
Rows: rows,
}
resultJSON, _ := json.MarshalIndent(result, "", " ")
log.V(100).Infof("query %s result %s\n", query, string(resultJSON))
return result, nil
}
func (t *explainTablet) analyzeWhere(selStmt *sqlparser.Select, tableColumnMap map[sqlparser.IdentifierCS]map[string]querypb.Type) (inColName string, inVal []sqltypes.Value, rowCount int, result *sqltypes.Result, err error) {
// the query against lookup table is in-query, handle it specifically
rowCount = 1
if selStmt.Where == nil {
return
}
v, ok := selStmt.Where.Expr.(*sqlparser.ComparisonExpr)
if !ok || v.Operator != sqlparser.InOp {
return
}
c, ok := v.Left.(*sqlparser.ColName)
if !ok {
return
}
colName := strings.ToLower(c.Name.String())
colType := querypb.Type_VARCHAR
tableExpr := selStmt.From[0]
expr, ok := tableExpr.(*sqlparser.AliasedTableExpr)
if ok {
m := tableColumnMap[sqlparser.GetTableName(expr.Expr)]
if m != nil {
t, found := m[colName]
if found {
colType = t
}
}
}
values, ok := v.Right.(sqlparser.ValTuple)
if !ok {
return
}
for _, val := range values {
lit, ok := val.(*sqlparser.Literal)
if !ok {
continue
}
value, err := evalengine.LiteralToValue(lit)
if err != nil {
return "", nil, 0, nil, err
}
// Cast the value in the tuple to the expected value of the column
castedValue, err := evalengine.Cast(value, colType)
if err != nil {
return "", nil, 0, nil, err
}
// Check if we have a duplicate value
isNewValue := true
for _, v := range inVal {
result, err := evalengine.NullsafeCompare(v, value, collations.Default())
if err != nil {
return "", nil, 0, nil, err
}
if result == 0 {
isNewValue = false
break
}
}
if isNewValue {
inVal = append(inVal, castedValue)
}
}
inColName = strings.ToLower(c.Name.String())
return inColName, inVal, rowCount, nil, nil
}
func (t *explainTablet) analyzeExpressions(selStmt *sqlparser.Select, tableColumnMap map[sqlparser.IdentifierCS]map[string]querypb.Type) ([]string, []querypb.Type) {
colNames := make([]string, 0, 4)
colTypes := make([]querypb.Type, 0, 4)
for _, node := range selStmt.SelectExprs {
switch node := node.(type) {
case *sqlparser.AliasedExpr:
colNames, colTypes = inferColTypeFromExpr(node.Expr, tableColumnMap, colNames, colTypes)
case *sqlparser.StarExpr:
if node.TableName.Name.IsEmpty() {
// SELECT *
for _, colTypeMap := range tableColumnMap {
for col, colType := range colTypeMap {
colNames = append(colNames, col)
colTypes = append(colTypes, colType)
}
}
} else {
// SELECT tableName.*
colTypeMap := tableColumnMap[node.TableName.Name]
for col, colType := range colTypeMap {
colNames = append(colNames, col)
colTypes = append(colTypes, colType)
}
}
}
}
return colNames, colTypes
}
func getTables(node sqlparser.SQLNode) []*sqlparser.AliasedTableExpr {
var tables []*sqlparser.AliasedTableExpr
switch expr := node.(type) {