SPTAG/Tools/nni-auto-tune/main.py

291 строка
10 KiB
Python
Executable File

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import nni
import h5py
import time
import numpy as np
import os
from model import Sptag, BruteForceBLAS
from runner import run_individual_query
from dataset import DataReader, HDF5Reader
import argparse
import json
import shutil
import itertools
from multiprocess import Pool, Process
import multiprocess
def knn_threshold(data, k, epsilon):
return data[k - 1] + epsilon
def get_recall_from_distance(dataset_distances,
run_distances,
k,
epsilon=1e-3):
recalls = np.zeros(len(run_distances))
for i in range(len(run_distances)):
t = knn_threshold(dataset_distances[i], k, epsilon)
actual = 0
for d in run_distances[i][:k]:
if d <= t:
actual += 1
recalls[i] = actual
return (np.mean(recalls) / float(k), np.std(recalls) / float(k), recalls)
def get_recall_from_index(dataset_index, run_index, k):
recalls = np.zeros(len(run_index))
for i in range(len(run_index)):
actual = 0
for d in run_index[i][:k]:
# need to conver to string because default loaded label are strings
if str(d) in dataset_index[i][:k]:
actual += 1
recalls[i] = actual
return (np.mean(recalls) / float(k), np.std(recalls) / float(k), recalls)
def queries_per_second(attrs):
return 1.0 / attrs["best_search_time"]
def compute_metrics(groundtruth, attrs, results, k, from_index=False):
if from_index:
mean, std, recalls = get_recall_from_index(groundtruth, results, k)
else:
mean, std, recalls = get_recall_from_distance(groundtruth, results, k)
qps = queries_per_second(attrs)
print('mean: %12.3f,std: %12.3f, qps: %12.3f' % (mean, std, qps))
return mean, qps
def grid_search(params):
param_num = len(params)
params = list(params.items())
max_param_choices = max([len(p[1]) for p in params])
temp = []
for i in range(max_param_choices):
temp += [i for _ in range(param_num)]
for c in set(itertools.permutations(temp, param_num)):
res = {}
for i in range(len(c)):
if c[i] >= len(params[i][1]):
break
else:
res[params[i][0]] = params[i][1][c[i]]
if len(res) == param_num:
yield res
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--train_file',
help='the data file to load training points from, '
'could be text file, binary flie or ann-benchmark format hdf5 file',
default='glove-100-angular.hdf5')
parser.add_argument(
'--query_file',
help='the data file to load query points from, if you use '
'ann-benchmark format hdf5 file in train_file, this should be None',
default=None)
parser.add_argument(
'--label_file',
help=
'the data file to load groundtruth index from, only support text file',
default=None)
parser.add_argument('--algorithm',
help='the name of SPTAG algorithm',
default="BKT")
parser.add_argument("--k",
default=10,
type=int,
help="the number of near neighbours to search for")
parser.add_argument("--distance",
default='angular',
help="the type of distance for searching")
parser.add_argument(
"--max_build_time",
default=-1,
type=int,
help="the limit of index build time in seconds. -1 means no limit")
parser.add_argument(
"--max_memory",
default=-1,
type=int,
help=
"the limit of memory use during searching in bytes. -1 means no limit")
parser.add_argument("--dim",
default=100,
type=int,
help="the dimention of training vectors")
parser.add_argument("--input_type",
default="float32",
help="the data type of input vectors")
parser.add_argument("--data_type",
default="float32",
help="the data type for building and search in SPTAG ")
args = parser.parse_args()
if args.train_file.endswith(".hdf5"):
# ann-benchmark format hdf5 file got all we want, so args like distance are ignored
data_reader = HDF5Reader(args.train_file, args.data_type)
X_train, X_test = data_reader.readallbatches()
distance = data_reader.distance
dimension = data_reader.featuredim
label = data_reader.label
else:
X_train = DataReader(args.train_file,
args.dim,
batchsize=-1,
datatype=args.input_type,
targettype=args.data_type).readbatch()[1]
X_test = DataReader(args.query_file,
args.dim,
batchsize=-1,
datatype=args.input_type,
targettype=args.data_type).readbatch()[1]
distance = args.distance
dimension = args.dim
label = []
if args.label_file is None:
# if the groundtruth is not provided
# we calculate groundtruth distances with brute force
bf = BruteForceBLAS(distance)
bf.fit(X_train)
for i, x in enumerate(X_test):
if i % 1000 == 0:
print('%d/%d...' % (i, len(X_test)))
res = list(bf.query_with_distances(x, args.k))
res.sort(key=lambda t: t[-1])
label.append([d for _, d in res])
else:
label = []
# we assume the groundtruth index are split by space
with open(args.label_file, 'r') as f:
for line in f:
label.append(line.strip().split())
print('got a train set of size (%d * %d)' % (X_train.shape[0], dimension))
print('got %d queries' % len(X_test))
para = nni.get_next_parameter()
algo = Sptag(args.algorithm, distance)
t0 = time.time()
if args.max_build_time > 0:
pool = Pool(1)
results = pool.apply_async(algo.fit,
kwds=dict(X=X_train,
para=para,
data_type=args.data_type,
save_index=True))
try:
results.get(args.max_build_time
) # Wait timeout seconds for func to complete.
algo.load('index')
shutil.rmtree("index")
pool.close()
pool.join()
except multiprocess.TimeoutError: # kill subprocess if timeout
print("Aborting due to timeout", args.max_build_time)
pool.terminate()
nni.report_final_result({
'default': -1,
"recall": 0,
"qps": 0,
"build_time": args.max_build_time
})
return
else:
algo.fit(X=X_train, para=para, data_type=args.data_type)
build_time = time.time() - t0
print('Built index in', build_time)
search_param_choices = {
"NumberOfInitialDynamicPivots": [1, 2, 4, 8, 16, 32, 50],
"MaxCheck": [512, 3200, 5120, 8192, 12800, 16400, 19600],
"NumberOfOtherDynamicPivots": [1, 2, 4, 8, 10]
}
best_metric = -1
best_res = {}
for i, search_params in enumerate(grid_search(search_param_choices)):
algo.set_query_arguments(search_params)
try:
attrs, results = run_individual_query(algo,
X_train,
X_test,
distance,
args.k,
max_mem=args.max_memory)
except MemoryError:
print("Aborting due to exceed memory limit")
nni.report_final_result({
'default': -1,
"recall": 0,
"qps": 0,
"build_time": args.max_build_time
})
return
neighbors = [0 for _ in results]
distances = [0 for _ in results]
for idx, (t, ds) in enumerate(results):
neighbors[idx] = [n for n, d in ds] + [-1] * (args.k - len(ds))
distances[idx] = [d for n, d in ds
] + [float('inf')] * (args.k - len(ds))
if args.label_file is None:
recalls_mean, qps = compute_metrics(label, attrs, distances,
args.k)
else:
recalls_mean, qps = compute_metrics(label,
attrs,
neighbors,
args.k,
from_index=True)
combined_metric = -1 * np.log10(1 - recalls_mean) + 0.1 * np.log10(qps)
res = {
"default": combined_metric,
"recall": recalls_mean,
"qps": qps,
"build_time": build_time
}
if combined_metric > best_metric:
best_metric = combined_metric
best_res = res.copy()
res["build_params"] = para
res["search_params"] = search_params
experiment_id = nni.get_experiment_id()
result_dir = os.path.join('results', args.train_file.split('.')[0])
if not os.path.exists(result_dir):
os.makedirs(result_dir)
trial_id = nni.get_trial_id()
with open(
os.path.join(
result_dir,
"result_" + str(trial_id) + ' ' + str(i) + ".json"),
"w") as f:
json.dump(res, f)
nni.report_final_result(best_res)
if __name__ == '__main__':
main()