fix intermittent segfault under mysql error conditions

This commit is contained in:
Mike Solomon 2012-07-06 00:33:31 -07:00
Родитель 2fdd372e17
Коммит 898507dd0d
1 изменённых файлов: 118 добавлений и 21 удалений

Просмотреть файл

@ -4,8 +4,105 @@
package mysql
// #include <stdlib.h>
// #include <mysql.h>
/*
#include <stdlib.h>
#include <mysql.h>
// The vt_mysql_* functions are needed to honor the mysql client library's
// assumption that all threads using mysql functions have first called
// my_thread_init().
//
// Any goroutine may run on any thread at any time. However, a Cgo call
// is guaranteed to run on a single thread for its duration. Therefore,
// calling my_thread_init() before every actual operation is sufficient.
// Multiple calls to my_thread_init() are guarded interanlly in mysql and
// the call is cheap.
MYSQL* vt_mysql_real_connect(MYSQL *mysql, const char *host, const char *user, const char *passwd, const char *db, unsigned int port, const char *unix_socket, unsigned long client_flag) {
my_thread_init();
return mysql_real_connect(mysql, host, user, passwd, db, port, unix_socket, client_flag);
}
int vt_mysql_set_character_set(MYSQL *mysql, const char *csname) {
my_thread_init();
return mysql_set_character_set(mysql, csname);
}
int vt_mysql_real_query(MYSQL *mysql, const char *stmt_str, unsigned long length) {
my_thread_init();
return mysql_real_query(mysql, stmt_str, length);
}
MYSQL_RES *vt_mysql_store_result(MYSQL *mysql) {
my_thread_init();
return mysql_store_result(mysql);
}
unsigned int vt_mysql_field_count(MYSQL *mysql) {
my_thread_init();
return mysql_field_count(mysql);
}
my_ulonglong vt_mysql_affected_rows(MYSQL *mysql) {
my_thread_init();
return mysql_affected_rows(mysql);
}
my_ulonglong vt_mysql_insert_id(MYSQL *mysql) {
my_thread_init();
return mysql_insert_id(mysql);
}
void vt_mysql_free_result(MYSQL_RES *result) {
my_thread_init();
mysql_free_result(result);
}
unsigned long vt_mysql_thread_id(MYSQL *mysql) {
my_thread_init();
return mysql_thread_id(mysql);
}
void vt_mysql_close(MYSQL *mysql) {
my_thread_init();
mysql_close(mysql);
}
MYSQL_FIELD *vt_mysql_fetch_fields(MYSQL_RES *result) {
my_thread_init();
return mysql_fetch_fields(result);
}
unsigned int vt_mysql_num_fields(MYSQL_RES *result) {
my_thread_init();
return mysql_num_fields(result);
}
my_ulonglong vt_mysql_num_rows(MYSQL_RES *result) {
my_thread_init();
return mysql_num_rows(result);
}
MYSQL_ROW vt_mysql_fetch_row(MYSQL_RES *result) {
my_thread_init();
return mysql_fetch_row(result);
}
unsigned long *vt_mysql_fetch_lengths(MYSQL_RES *result) {
my_thread_init();
return mysql_fetch_lengths(result);
}
unsigned int vt_mysql_errno(MYSQL *mysql) {
my_thread_init();
return mysql_errno(mysql);
}
const char *vt_mysql_error(MYSQL *mysql) {
my_thread_init();
return mysql_error(mysql);
}
*/
import "C"
import (
@ -77,12 +174,12 @@ func Connect(info map[string]interface{}) (conn *Connection, err error) {
conn = &Connection{}
conn.handle = C.mysql_init(nil)
if C.mysql_real_connect(conn.handle, host, uname, pass, dbname, C.uint(port), unix_socket, 0) == nil {
if C.vt_mysql_real_connect(conn.handle, host, uname, pass, dbname, C.uint(port), unix_socket, 0) == nil {
defer conn.Close()
return nil, conn.lastError(nil)
}
if C.mysql_set_character_set(conn.handle, charset) != 0 {
if C.vt_mysql_set_character_set(conn.handle, charset) != 0 {
defer conn.Close()
return nil, conn.lastError(nil)
}
@ -93,23 +190,23 @@ func (self *Connection) ExecuteFetch(query []byte, maxrows int, wantfields bool)
defer handleError(&err)
self.validate()
if C.mysql_real_query(self.handle, (*C.char)(unsafe.Pointer(&query[0])), C.ulong(len(query))) != 0 {
if C.vt_mysql_real_query(self.handle, (*C.char)(unsafe.Pointer(&query[0])), C.ulong(len(query))) != 0 {
return nil, self.lastError(query)
}
result := C.mysql_store_result(self.handle)
result := C.vt_mysql_store_result(self.handle)
if result == nil {
if int(C.mysql_field_count(self.handle)) != 0 { // Query was supposed to return data, but it didn't
if int(C.vt_mysql_field_count(self.handle)) != 0 { // Query was supposed to return data, but it didn't
return nil, self.lastError(query)
}
qr = &QueryResult{}
qr.RowsAffected = uint64(C.mysql_affected_rows(self.handle))
qr.InsertId = uint64(C.mysql_insert_id(self.handle))
qr.RowsAffected = uint64(C.vt_mysql_affected_rows(self.handle))
qr.InsertId = uint64(C.vt_mysql_insert_id(self.handle))
return qr, nil
}
defer C.mysql_free_result(result)
defer C.vt_mysql_free_result(result)
qr = &QueryResult{}
qr.RowsAffected = uint64(C.mysql_affected_rows(self.handle))
qr.RowsAffected = uint64(C.vt_mysql_affected_rows(self.handle))
if qr.RowsAffected > uint64(maxrows) {
return nil, &SqlError{0, fmt.Sprintf("Row count exceeded %d", maxrows), string(query)}
}
@ -124,14 +221,14 @@ func (self *Connection) Id() int64 {
if self.handle == nil {
return 0
}
return int64(C.mysql_thread_id(self.handle))
return int64(C.vt_mysql_thread_id(self.handle))
}
func (self *Connection) Close() {
if self.handle == nil {
return
}
C.mysql_close(self.handle)
C.vt_mysql_close(self.handle)
self.handle = nil
}
@ -140,8 +237,8 @@ func (self *Connection) IsClosed() bool {
}
func (self *Connection) buildFields(result *C.MYSQL_RES) (fields []Field) {
nfields := int(C.mysql_num_fields(result))
cfieldsptr := C.mysql_fetch_fields(result)
nfields := int(C.vt_mysql_num_fields(result))
cfieldsptr := C.vt_mysql_fetch_fields(result)
cfields := (*[1 << 30]C.MYSQL_FIELD)(unsafe.Pointer(cfieldsptr))
arena := hack.NewStringArena(1024) // prealloc a reasonable amount of space
fields = make([]Field, nfields)
@ -155,12 +252,12 @@ func (self *Connection) buildFields(result *C.MYSQL_RES) (fields []Field) {
}
func (self *Connection) fetchAll(result *C.MYSQL_RES) (rows [][]interface{}) {
rowCount := int(C.mysql_num_rows(result))
rowCount := int(C.vt_mysql_num_rows(result))
if rowCount == 0 {
return nil
}
rows = make([][]interface{}, rowCount)
colCount := int(C.mysql_num_fields(result))
colCount := int(C.vt_mysql_num_fields(result))
for i := 0; i < rowCount; i++ {
rows[i] = self.fetchNext(result, colCount)
}
@ -168,12 +265,12 @@ func (self *Connection) fetchAll(result *C.MYSQL_RES) (rows [][]interface{}) {
}
func (self *Connection) fetchNext(result *C.MYSQL_RES, colCount int) (row []interface{}) {
rowPtr := (*[1 << 30]*[1 << 30]byte)(unsafe.Pointer(C.mysql_fetch_row(result)))
rowPtr := (*[1 << 30]*[1 << 30]byte)(unsafe.Pointer(C.vt_mysql_fetch_row(result)))
if rowPtr == nil {
panic(self.lastError(nil))
}
row = make([]interface{}, colCount)
lengths := (*[1 << 30]uint64)(unsafe.Pointer(C.mysql_fetch_lengths(result)))
lengths := (*[1 << 30]uint64)(unsafe.Pointer(C.vt_mysql_fetch_lengths(result)))
totalLength := uint64(0)
for i := 0; i < colCount; i++ {
totalLength += (*lengths)[i]
@ -191,8 +288,8 @@ func (self *Connection) fetchNext(result *C.MYSQL_RES, colCount int) (row []inte
}
func (self *Connection) lastError(query []byte) error {
if err := C.mysql_error(self.handle); *err != 0 {
return &SqlError{Num: int(C.mysql_errno(self.handle)), Message: C.GoString(err), Query: string(query)}
if err := C.vt_mysql_error(self.handle); *err != 0 {
return &SqlError{Num: int(C.vt_mysql_errno(self.handle)), Message: C.GoString(err), Query: string(query)}
}
return &SqlError{0, "Dummy", string(query)}
}