This commit is contained in:
James Lamb 2023-10-26 12:58:21 -05:00 коммит произвёл GitHub
Родитель 2d358d5d20
Коммит fcf76bceb9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 8 добавлений и 8 удалений

Просмотреть файл

@ -755,7 +755,7 @@ def test_ranking_prediction_early_stopping():
# (in our example it is simply the ordering by some feature correlated with relevance, e.g., 34)
# and clicks on that document (new_label=1) with some probability 'pclick' depending on its true relevance;
# at each position the user may stop the traversal with some probability pstop. For the non-clicked documents,
# new_label=0. Thus the generated new labels are biased towards the baseline ranker.
# new_label=0. Thus the generated new labels are biased towards the baseline ranker.
# The positions of the documents in the ranked lists produced by the baseline, are returned.
def simulate_position_bias(file_dataset_in, file_query_in, file_dataset_out, baseline_feature):
# a mapping of a document's true relevance (defined on a 5-grade scale) into the probability of clicking it
@ -772,7 +772,7 @@ def simulate_position_bias(file_dataset_in, file_query_in, file_dataset_out, bas
return 0.9
# an instantiation of a cascade model where the user stops with probability 0.2 after observing each document
pstop = 0.2
f_dataset_in = open(file_dataset_in, 'r')
f_dataset_out = open(file_dataset_out, 'w')
random.seed(10)
@ -780,19 +780,19 @@ def simulate_position_bias(file_dataset_in, file_query_in, file_dataset_out, bas
for line in open(file_query_in):
docs_num = int (line)
lines = []
index_values = []
index_values = []
positions = [0] * docs_num
for index in range(docs_num):
features = f_dataset_in.readline().split()
lines.append(features)
val = 0.0
for feature_val in features:
feature_val_split = feature_val.split(":")
feature_val_split = feature_val.split(":")
if int(feature_val_split[0]) == baseline_feature:
val = float(feature_val_split[1])
index_values.append([index, val])
index_values.sort(key=lambda x: -x[1])
stop = False
stop = False
for pos in range(docs_num):
index = index_values[pos][0]
new_label = 0
@ -800,7 +800,7 @@ def simulate_position_bias(file_dataset_in, file_query_in, file_dataset_out, bas
label = int(lines[index][0])
pclick = get_pclick(label)
if random.random() < pclick:
new_label = 1
new_label = 1
stop = random.random() < pstop
lines[index][0] = str(new_label)
positions[index] = pos
@ -843,7 +843,7 @@ def test_ranking_with_position_information_with_file(tmp_path):
lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params)
lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))]
gbm_unbiased_with_file = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50)
# the performance of the unbiased LambdaMART should outperform the plain LambdaMART on the dataset with position bias
assert gbm_baseline.best_score['valid_0']['ndcg@3'] + 0.03 <= gbm_unbiased_with_file.best_score['valid_0']['ndcg@3']
@ -853,7 +853,7 @@ def test_ranking_with_position_information_with_file(tmp_path):
file.close()
lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params)
lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))]
with pytest.raises(lgb.basic.LightGBMError, match="Positions size \(3006\) doesn't match data size"):
with pytest.raises(lgb.basic.LightGBMError, match=r"Positions size \(3006\) doesn't match data size"):
lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50)