go-xgboost/booster.go

116 строки
2.6 KiB
Go

package xgboost
import (
"context"
"errors"
"runtime"
"github.com/Applifier/go-xgboost/core"
)
// Matrix interface for 2D matrix
type Matrix interface {
Data() (data []float32, rowCount, columnCount int)
}
// FloatSliceVector float32 slice backed Matrix implementation
type FloatSliceVector []float32
// Data returns float32 slice as (1, len(data)) matrix
func (fsm FloatSliceVector) Data() (data []float32, rowCount, columnCount int) {
return fsm, 1, len(fsm)
}
// Predictor interface for xgboost predictors
type Predictor interface {
Predict(input Matrix) ([]float32, error)
Close(ctx context.Context) error
}
// NewPredictor returns a new predictor based on given model path, worker count, option mask, ntree_limit and missing value indicator
func NewPredictor(xboostSavedModelPath string, workerCount int, optionMask int, nTreeLimit uint, missingValue float32) (Predictor, error) {
if workerCount <= 0 {
return nil, errors.New("worker count needs to be larger than zero")
}
requestChan := make(chan multiBoosterRequest)
initErrors := make(chan error)
defer close(initErrors)
for i := 0; i < workerCount; i++ {
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
booster, err := core.XGBoosterCreate(nil)
if err != nil {
initErrors <- err
return
}
err = booster.LoadModel(xboostSavedModelPath)
if err != nil {
initErrors <- err
return
}
// No errors occured during init
initErrors <- nil
for req := range requestChan {
data, rowCount, columnCount := req.matrix.Data()
matrix, err := core.XGDMatrixCreateFromMat(data, rowCount, columnCount, missingValue)
if err != nil {
req.resultChan <- multiBoosterResponse{
err: err,
}
continue
}
res, err := booster.Predict(matrix, optionMask, nTreeLimit)
req.resultChan <- multiBoosterResponse{
err: err,
result: res,
}
}
}()
err := <-initErrors
if err != nil {
return nil, err
}
}
return &multiBooster{reqChan: requestChan}, nil
}
type multiBoosterRequest struct {
matrix Matrix
resultChan chan multiBoosterResponse
}
type multiBoosterResponse struct {
err error
result []float32
}
type multiBooster struct {
reqChan chan multiBoosterRequest
}
func (mb *multiBooster) Predict(input Matrix) ([]float32, error) {
resChan := make(chan multiBoosterResponse)
mb.reqChan <- multiBoosterRequest{
matrix: input,
resultChan: resChan,
}
result := <-resChan
return result.result, result.err
}
func (mb *multiBooster) Close(ctx context.Context) error {
close(mb.reqChan)
return nil
}