зеркало из https://github.com/github/vitess-gh.git
fix intermittent segfault under mysql error conditions
This commit is contained in:
Родитель
2fdd372e17
Коммит
898507dd0d
|
@ -4,8 +4,105 @@
|
|||
|
||||
package mysql
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include <mysql.h>
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <mysql.h>
|
||||
|
||||
// The vt_mysql_* functions are needed to honor the mysql client library's
|
||||
// assumption that all threads using mysql functions have first called
|
||||
// my_thread_init().
|
||||
//
|
||||
// Any goroutine may run on any thread at any time. However, a Cgo call
|
||||
// is guaranteed to run on a single thread for its duration. Therefore,
|
||||
// calling my_thread_init() before every actual operation is sufficient.
|
||||
// Multiple calls to my_thread_init() are guarded interanlly in mysql and
|
||||
// the call is cheap.
|
||||
|
||||
MYSQL* vt_mysql_real_connect(MYSQL *mysql, const char *host, const char *user, const char *passwd, const char *db, unsigned int port, const char *unix_socket, unsigned long client_flag) {
|
||||
my_thread_init();
|
||||
return mysql_real_connect(mysql, host, user, passwd, db, port, unix_socket, client_flag);
|
||||
}
|
||||
|
||||
int vt_mysql_set_character_set(MYSQL *mysql, const char *csname) {
|
||||
my_thread_init();
|
||||
return mysql_set_character_set(mysql, csname);
|
||||
}
|
||||
|
||||
int vt_mysql_real_query(MYSQL *mysql, const char *stmt_str, unsigned long length) {
|
||||
my_thread_init();
|
||||
return mysql_real_query(mysql, stmt_str, length);
|
||||
}
|
||||
|
||||
MYSQL_RES *vt_mysql_store_result(MYSQL *mysql) {
|
||||
my_thread_init();
|
||||
return mysql_store_result(mysql);
|
||||
}
|
||||
|
||||
unsigned int vt_mysql_field_count(MYSQL *mysql) {
|
||||
my_thread_init();
|
||||
return mysql_field_count(mysql);
|
||||
}
|
||||
|
||||
my_ulonglong vt_mysql_affected_rows(MYSQL *mysql) {
|
||||
my_thread_init();
|
||||
return mysql_affected_rows(mysql);
|
||||
}
|
||||
|
||||
my_ulonglong vt_mysql_insert_id(MYSQL *mysql) {
|
||||
my_thread_init();
|
||||
return mysql_insert_id(mysql);
|
||||
}
|
||||
|
||||
void vt_mysql_free_result(MYSQL_RES *result) {
|
||||
my_thread_init();
|
||||
mysql_free_result(result);
|
||||
}
|
||||
|
||||
unsigned long vt_mysql_thread_id(MYSQL *mysql) {
|
||||
my_thread_init();
|
||||
return mysql_thread_id(mysql);
|
||||
}
|
||||
|
||||
void vt_mysql_close(MYSQL *mysql) {
|
||||
my_thread_init();
|
||||
mysql_close(mysql);
|
||||
}
|
||||
|
||||
MYSQL_FIELD *vt_mysql_fetch_fields(MYSQL_RES *result) {
|
||||
my_thread_init();
|
||||
return mysql_fetch_fields(result);
|
||||
}
|
||||
|
||||
unsigned int vt_mysql_num_fields(MYSQL_RES *result) {
|
||||
my_thread_init();
|
||||
return mysql_num_fields(result);
|
||||
}
|
||||
|
||||
my_ulonglong vt_mysql_num_rows(MYSQL_RES *result) {
|
||||
my_thread_init();
|
||||
return mysql_num_rows(result);
|
||||
}
|
||||
|
||||
MYSQL_ROW vt_mysql_fetch_row(MYSQL_RES *result) {
|
||||
my_thread_init();
|
||||
return mysql_fetch_row(result);
|
||||
}
|
||||
|
||||
unsigned long *vt_mysql_fetch_lengths(MYSQL_RES *result) {
|
||||
my_thread_init();
|
||||
return mysql_fetch_lengths(result);
|
||||
}
|
||||
|
||||
unsigned int vt_mysql_errno(MYSQL *mysql) {
|
||||
my_thread_init();
|
||||
return mysql_errno(mysql);
|
||||
}
|
||||
|
||||
const char *vt_mysql_error(MYSQL *mysql) {
|
||||
my_thread_init();
|
||||
return mysql_error(mysql);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
|
@ -77,12 +174,12 @@ func Connect(info map[string]interface{}) (conn *Connection, err error) {
|
|||
|
||||
conn = &Connection{}
|
||||
conn.handle = C.mysql_init(nil)
|
||||
if C.mysql_real_connect(conn.handle, host, uname, pass, dbname, C.uint(port), unix_socket, 0) == nil {
|
||||
if C.vt_mysql_real_connect(conn.handle, host, uname, pass, dbname, C.uint(port), unix_socket, 0) == nil {
|
||||
defer conn.Close()
|
||||
return nil, conn.lastError(nil)
|
||||
}
|
||||
|
||||
if C.mysql_set_character_set(conn.handle, charset) != 0 {
|
||||
if C.vt_mysql_set_character_set(conn.handle, charset) != 0 {
|
||||
defer conn.Close()
|
||||
return nil, conn.lastError(nil)
|
||||
}
|
||||
|
@ -93,23 +190,23 @@ func (self *Connection) ExecuteFetch(query []byte, maxrows int, wantfields bool)
|
|||
defer handleError(&err)
|
||||
self.validate()
|
||||
|
||||
if C.mysql_real_query(self.handle, (*C.char)(unsafe.Pointer(&query[0])), C.ulong(len(query))) != 0 {
|
||||
if C.vt_mysql_real_query(self.handle, (*C.char)(unsafe.Pointer(&query[0])), C.ulong(len(query))) != 0 {
|
||||
return nil, self.lastError(query)
|
||||
}
|
||||
|
||||
result := C.mysql_store_result(self.handle)
|
||||
result := C.vt_mysql_store_result(self.handle)
|
||||
if result == nil {
|
||||
if int(C.mysql_field_count(self.handle)) != 0 { // Query was supposed to return data, but it didn't
|
||||
if int(C.vt_mysql_field_count(self.handle)) != 0 { // Query was supposed to return data, but it didn't
|
||||
return nil, self.lastError(query)
|
||||
}
|
||||
qr = &QueryResult{}
|
||||
qr.RowsAffected = uint64(C.mysql_affected_rows(self.handle))
|
||||
qr.InsertId = uint64(C.mysql_insert_id(self.handle))
|
||||
qr.RowsAffected = uint64(C.vt_mysql_affected_rows(self.handle))
|
||||
qr.InsertId = uint64(C.vt_mysql_insert_id(self.handle))
|
||||
return qr, nil
|
||||
}
|
||||
defer C.mysql_free_result(result)
|
||||
defer C.vt_mysql_free_result(result)
|
||||
qr = &QueryResult{}
|
||||
qr.RowsAffected = uint64(C.mysql_affected_rows(self.handle))
|
||||
qr.RowsAffected = uint64(C.vt_mysql_affected_rows(self.handle))
|
||||
if qr.RowsAffected > uint64(maxrows) {
|
||||
return nil, &SqlError{0, fmt.Sprintf("Row count exceeded %d", maxrows), string(query)}
|
||||
}
|
||||
|
@ -124,14 +221,14 @@ func (self *Connection) Id() int64 {
|
|||
if self.handle == nil {
|
||||
return 0
|
||||
}
|
||||
return int64(C.mysql_thread_id(self.handle))
|
||||
return int64(C.vt_mysql_thread_id(self.handle))
|
||||
}
|
||||
|
||||
func (self *Connection) Close() {
|
||||
if self.handle == nil {
|
||||
return
|
||||
}
|
||||
C.mysql_close(self.handle)
|
||||
C.vt_mysql_close(self.handle)
|
||||
self.handle = nil
|
||||
}
|
||||
|
||||
|
@ -140,8 +237,8 @@ func (self *Connection) IsClosed() bool {
|
|||
}
|
||||
|
||||
func (self *Connection) buildFields(result *C.MYSQL_RES) (fields []Field) {
|
||||
nfields := int(C.mysql_num_fields(result))
|
||||
cfieldsptr := C.mysql_fetch_fields(result)
|
||||
nfields := int(C.vt_mysql_num_fields(result))
|
||||
cfieldsptr := C.vt_mysql_fetch_fields(result)
|
||||
cfields := (*[1 << 30]C.MYSQL_FIELD)(unsafe.Pointer(cfieldsptr))
|
||||
arena := hack.NewStringArena(1024) // prealloc a reasonable amount of space
|
||||
fields = make([]Field, nfields)
|
||||
|
@ -155,12 +252,12 @@ func (self *Connection) buildFields(result *C.MYSQL_RES) (fields []Field) {
|
|||
}
|
||||
|
||||
func (self *Connection) fetchAll(result *C.MYSQL_RES) (rows [][]interface{}) {
|
||||
rowCount := int(C.mysql_num_rows(result))
|
||||
rowCount := int(C.vt_mysql_num_rows(result))
|
||||
if rowCount == 0 {
|
||||
return nil
|
||||
}
|
||||
rows = make([][]interface{}, rowCount)
|
||||
colCount := int(C.mysql_num_fields(result))
|
||||
colCount := int(C.vt_mysql_num_fields(result))
|
||||
for i := 0; i < rowCount; i++ {
|
||||
rows[i] = self.fetchNext(result, colCount)
|
||||
}
|
||||
|
@ -168,12 +265,12 @@ func (self *Connection) fetchAll(result *C.MYSQL_RES) (rows [][]interface{}) {
|
|||
}
|
||||
|
||||
func (self *Connection) fetchNext(result *C.MYSQL_RES, colCount int) (row []interface{}) {
|
||||
rowPtr := (*[1 << 30]*[1 << 30]byte)(unsafe.Pointer(C.mysql_fetch_row(result)))
|
||||
rowPtr := (*[1 << 30]*[1 << 30]byte)(unsafe.Pointer(C.vt_mysql_fetch_row(result)))
|
||||
if rowPtr == nil {
|
||||
panic(self.lastError(nil))
|
||||
}
|
||||
row = make([]interface{}, colCount)
|
||||
lengths := (*[1 << 30]uint64)(unsafe.Pointer(C.mysql_fetch_lengths(result)))
|
||||
lengths := (*[1 << 30]uint64)(unsafe.Pointer(C.vt_mysql_fetch_lengths(result)))
|
||||
totalLength := uint64(0)
|
||||
for i := 0; i < colCount; i++ {
|
||||
totalLength += (*lengths)[i]
|
||||
|
@ -191,8 +288,8 @@ func (self *Connection) fetchNext(result *C.MYSQL_RES, colCount int) (row []inte
|
|||
}
|
||||
|
||||
func (self *Connection) lastError(query []byte) error {
|
||||
if err := C.mysql_error(self.handle); *err != 0 {
|
||||
return &SqlError{Num: int(C.mysql_errno(self.handle)), Message: C.GoString(err), Query: string(query)}
|
||||
if err := C.vt_mysql_error(self.handle); *err != 0 {
|
||||
return &SqlError{Num: int(C.vt_mysql_errno(self.handle)), Message: C.GoString(err), Query: string(query)}
|
||||
}
|
||||
return &SqlError{0, "Dummy", string(query)}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче