more matrix funcs
This commit is contained in:
Родитель
7d59591edd
Коммит
7f73baad1c
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2018 Unity Technologies Oy
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
63
xgdmatrix.go
63
xgdmatrix.go
|
@ -9,6 +9,7 @@ package xgboost
|
|||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
@ -23,6 +24,26 @@ type XGDMatrix struct {
|
|||
rows int
|
||||
}
|
||||
|
||||
// NumRow get number of rows.
|
||||
func (matrix *XGDMatrix) NumRow() (uint32, error) {
|
||||
var count C.ulong
|
||||
if err := checkError(C.XGDMatrixNumRow(matrix.handle, &count)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return uint32(count), nil
|
||||
}
|
||||
|
||||
// NumCol get number of cols.
|
||||
func (matrix *XGDMatrix) NumCol() (uint32, error) {
|
||||
var count C.ulong
|
||||
if err := checkError(C.XGDMatrixNumCol(matrix.handle, &count)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return uint32(count), nil
|
||||
}
|
||||
|
||||
// SetUIntInfo set uint32 vector to a content in info
|
||||
func (matrix *XGDMatrix) SetUIntInfo(field string, values []uint32) error {
|
||||
cstr := C.CString(field)
|
||||
|
@ -49,6 +70,48 @@ func (matrix *XGDMatrix) SetFloatInfo(field string, values []float32) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetFloatInfo get float info vector from matrix
|
||||
func (matrix *XGDMatrix) GetFloatInfo(field string) ([]float32, error) {
|
||||
cstr := C.CString(field)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
|
||||
var outLen C.ulong
|
||||
var outResult *C.float
|
||||
|
||||
if err := checkError(C.XGDMatrixGetFloatInfo(matrix.handle, cstr, &outLen, &outResult)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var list []float32
|
||||
sliceHeader := (*reflect.SliceHeader)((unsafe.Pointer(&list)))
|
||||
sliceHeader.Cap = int(outLen)
|
||||
sliceHeader.Len = int(outLen)
|
||||
sliceHeader.Data = uintptr(unsafe.Pointer(outResult))
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// GetUIntInfo get uint32 info vector from matrix
|
||||
func (matrix *XGDMatrix) GetUIntInfo(field string) ([]uint32, error) {
|
||||
cstr := C.CString(field)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
|
||||
var outLen C.ulong
|
||||
var outResult *C.uint
|
||||
|
||||
if err := checkError(C.XGDMatrixGetUIntInfo(matrix.handle, cstr, &outLen, &outResult)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var list []uint32
|
||||
sliceHeader := (*reflect.SliceHeader)((unsafe.Pointer(&list)))
|
||||
sliceHeader.Cap = int(outLen)
|
||||
sliceHeader.Len = int(outLen)
|
||||
sliceHeader.Data = uintptr(unsafe.Pointer(outResult))
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func xdgMatrixFinalizer(mat *XGDMatrix) {
|
||||
C.XGDMatrixFree(mat.handle)
|
||||
}
|
||||
|
|
|
@ -17,5 +17,37 @@ func TestXGDMatrix(t *testing.T) {
|
|||
t.Error("matrix was not created")
|
||||
}
|
||||
|
||||
err = matrix.SetFloatInfo("label", []float32{1, 2})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
vals, err := matrix.GetFloatInfo("label")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if vals[0] != 1 || vals[1] != 2 {
|
||||
t.Error("Wrong values returned")
|
||||
}
|
||||
|
||||
rowCount, err := matrix.NumRow()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if rowCount != 2 {
|
||||
t.Error("Wrong row count returned")
|
||||
}
|
||||
|
||||
colCount, err := matrix.NumCol()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if colCount != 2 {
|
||||
t.Error("Wrong col count returned")
|
||||
}
|
||||
|
||||
fmt.Printf("%+v\n", data)
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче