[DISCO-3043] Relevancy: multi armed bandit API

adding init, select and update functions to be
used in thompson sampling
This commit is contained in:
Ben Dean-Kawamura 2024-10-23 14:50:44 -04:00 коммит произвёл Temisan Iwere
Родитель 6159184773
Коммит 506dac4e13
7 изменённых файлов: 469 добавлений и 1 удалений

Просмотреть файл

@ -1,5 +1,10 @@
# v134.0 (In progress)
## ✨ What's New ✨
### Relevancy
- Added init, select and update methods for Thompson Sampling (multi-armed bandit)
## 🦊 What's Changed 🦊
### Glean

13
Cargo.lock сгенерированный
Просмотреть файл

@ -2865,6 +2865,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
dependencies = [
"autocfg",
"libm",
]
[[package]]
@ -3486,6 +3487,16 @@ dependencies = [
"getrandom",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [
"num-traits",
"rand",
]
[[package]]
name = "rand_rccrypto"
version = "0.1.0"
@ -3626,6 +3637,8 @@ dependencies = [
"log",
"md-5",
"parking_lot",
"rand",
"rand_distr",
"remote_settings",
"rusqlite",
"serde",

Просмотреть файл

@ -15,6 +15,8 @@ sql-support = { path = "../support/sql" }
log = "0.4"
md-5 = "0.10"
parking_lot = ">=0.11,<=0.12"
rand = "0.8"
rand_distr = "0.4"
rusqlite = { workspace = true, features = ["bundled"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"

Просмотреть файл

@ -3,6 +3,7 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
use crate::Error::BanditNotFound;
use crate::{
interest::InterestVectorKind,
schema::RelevancyConnectionInitializer,
@ -172,11 +173,81 @@ impl<'a> RelevancyDao<'a> {
}
Ok(interest_vec)
}
/// Initializes a multi-armed bandit record in the database for a specific bandit and arm.
///
/// This method inserts a new record into the `multi_armed_bandit` table with default probability
/// distribution parameters (`alpha` and `beta` set to 1) and usage counters (`impressions` and
/// `clicks` set to 0) for the specified `bandit` and `arm`. If a record for this bandit-arm pair
/// already exists, the insertion is ignored, preserving the existing data.
pub fn initialize_multi_armed_bandit(&mut self, bandit: &str, arm: &str) -> Result<()> {
let mut new_statement = self.conn.prepare(
"INSERT OR IGNORE INTO multi_armed_bandit (bandit, arm, alpha, beta, impressions, clicks) VALUES (?, ?, ?, ?, ?, ?)"
)?;
new_statement.execute((bandit, arm, 1, 1, 0, 0))?;
Ok(())
}
/// Retrieves the Beta distribution parameters (`alpha` and `beta`) for a specific arm in a bandit model.
///
/// If the specified `bandit` and `arm` do not exist in the table, an error is returned indicating
/// that the record was not found.
pub fn retrieve_bandit_arm_beta_distribution(
&self,
bandit: &str,
arm: &str,
) -> Result<(usize, usize)> {
let mut stmt = self
.conn
.prepare("SELECT alpha, beta FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
let mut result = stmt.query((&bandit, &arm))?;
match result.next()? {
Some(row) => Ok((row.get(0)?, row.get(1)?)),
None => Err(BanditNotFound {
bandit: bandit.to_string(),
arm: arm.to_string(),
}),
}
}
/// Updates the Beta distribution parameters and counters for a specific arm in a bandit model based on user interaction.
///
/// This method updates the `alpha` or `beta` parameters in the `multi_armed_bandit` table for the specified
/// `bandit` and `arm` based on whether the arm was selected by the user. If `selected` is true, it increments
/// both the `alpha` (indicating success) and the `clicks` and `impressions` counters. If `selected` is false,
/// it increments `beta` (indicating failure) and only the `impressions` counter. This approach adjusts the
/// distribution parameters to reflect the arm's performance over time.
pub fn update_bandit_arm_data(&self, bandit: &str, arm: &str, selected: bool) -> Result<()> {
let mut stmt = if selected {
self
.conn
.prepare("UPDATE multi_armed_bandit SET alpha=alpha+1, clicks=clicks+1, impressions=impressions+1 WHERE bandit=? AND arm=?")?
} else {
self
.conn
.prepare("UPDATE multi_armed_bandit SET beta=beta+1, impressions=impressions+1 WHERE bandit=? AND arm=?")?
};
let result = stmt.execute((&bandit, &arm))?;
if result == 0 {
return Err(BanditNotFound {
bandit: bandit.to_string(),
arm: arm.to_string(),
});
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use rusqlite::params;
#[test]
fn test_store_frecency_user_interest_vector() {
@ -229,4 +300,222 @@ mod test {
interest_vec2,
);
}
#[test]
fn test_initialize_multi_armed_bandit() -> Result<()> {
let db = RelevancyDb::new_for_test();
let bandit = "provider".to_string();
let arm = "weather".to_string();
db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
let result = db.read(|dao| {
let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
stmt.query_row(params![&bandit, &arm], |row| {
let alpha: usize = row.get(0)?;
let beta: usize = row.get(1)?;
let impressions: usize = row.get(2)?;
let clicks: usize = row.get(3)?;
Ok((alpha, beta, impressions, clicks))
}).map_err(|e| e.into())
})?;
assert_eq!(result.0, 1); // Default alpha
assert_eq!(result.1, 1); // Default beta
assert_eq!(result.2, 0); // Default impressions
assert_eq!(result.3, 0); // Default clicks
Ok(())
}
#[test]
fn test_initialize_multi_armed_bandit_existing_data() -> Result<()> {
let db = RelevancyDb::new_for_test();
let bandit = "provider".to_string();
let arm = "weather".to_string();
db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
let result = db.read(|dao| {
let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
stmt.query_row(params![&bandit, &arm], |row| {
let alpha: usize = row.get(0)?;
let beta: usize = row.get(1)?;
let impressions: usize = row.get(2)?;
let clicks: usize = row.get(3)?;
Ok((alpha, beta, impressions, clicks))
}).map_err(|e| e.into())
})?;
assert_eq!(result.0, 1); // Default alpha
assert_eq!(result.1, 1); // Default beta
assert_eq!(result.2, 0); // Default impressions
assert_eq!(result.3, 0); // Default clicks
db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
let (alpha, beta) =
db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
assert_eq!(alpha, 2);
assert_eq!(beta, 1);
// this should be a no-op since the same bandit-arm has already been initialized
db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
let (alpha, beta) =
db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
// alpha & beta values for the bandit-arm should remain unchanged
assert_eq!(alpha, 2);
assert_eq!(beta, 1);
Ok(())
}
#[test]
fn test_retrieve_bandit_arm_beta_distribution() -> Result<()> {
let db = RelevancyDb::new_for_test();
let bandit = "provider".to_string();
let arm = "weather".to_string();
db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
let (alpha, beta) =
db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
assert_eq!(alpha, 2);
assert_eq!(beta, 3);
Ok(())
}
#[test]
fn test_retrieve_bandit_arm_beta_distribution_not_found() -> Result<()> {
let db = RelevancyDb::new_for_test();
let bandit = "provider".to_string();
let arm = "weather".to_string();
let result = db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm));
match result {
Ok((alpha, beta)) => panic!(
"Expected BanditNotFound error, but got Ok result with alpha: {} and beta: {}",
alpha, beta
),
Err(BanditNotFound { bandit: b, arm: a }) => {
assert_eq!(b, bandit);
assert_eq!(a, arm);
}
_ => {}
}
Ok(())
}
#[test]
fn test_update_bandit_arm_data_selected() -> Result<()> {
let db = RelevancyDb::new_for_test();
let bandit = "provider".to_string();
let arm = "weather".to_string();
db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
let result = db.read(|dao| {
let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
stmt.query_row(params![&bandit, &arm], |row| {
let alpha: usize = row.get(0)?;
let beta: usize = row.get(1)?;
let impressions: usize = row.get(2)?;
let clicks: usize = row.get(3)?;
Ok((alpha, beta, impressions, clicks))
}).map_err(|e| e.into())
})?;
assert_eq!(result.0, 1);
assert_eq!(result.1, 1);
assert_eq!(result.2, 0);
assert_eq!(result.3, 0);
db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, true))?;
let (alpha, beta) =
db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
assert_eq!(alpha, 2);
assert_eq!(beta, 1);
Ok(())
}
#[test]
fn test_update_bandit_arm_data_not_selected() -> Result<()> {
let db = RelevancyDb::new_for_test();
let bandit = "provider".to_string();
let arm = "weather".to_string();
db.read_write(|dao| dao.initialize_multi_armed_bandit(&bandit, &arm))?;
let result = db.read(|dao| {
let mut stmt = dao.conn.prepare("SELECT alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;
stmt.query_row(params![&bandit, &arm], |row| {
let alpha: usize = row.get(0)?;
let beta: usize = row.get(1)?;
let impressions: usize = row.get(2)?;
let clicks: usize = row.get(3)?;
Ok((alpha, beta, impressions, clicks))
}).map_err(|e| e.into())
})?;
assert_eq!(result.0, 1);
assert_eq!(result.1, 1);
assert_eq!(result.2, 0);
assert_eq!(result.3, 0);
db.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, false))?;
let (alpha, beta) =
db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
assert_eq!(alpha, 1);
assert_eq!(beta, 2);
Ok(())
}
#[test]
fn test_update_bandit_arm_data_not_found() -> Result<()> {
let db = RelevancyDb::new_for_test();
let bandit = "provider".to_string();
let arm = "weather".to_string();
let result = db.read(|dao| dao.update_bandit_arm_data(&bandit, &arm, false));
match result {
Ok(()) => panic!("Expected BanditNotFound error, but got Ok result"),
Err(BanditNotFound { bandit: b, arm: a }) => {
assert_eq!(b, bandit);
assert_eq!(a, arm);
}
_ => {}
}
Ok(())
}
}

Просмотреть файл

@ -42,6 +42,9 @@ pub enum Error {
#[error("Base64 Decode Error: {0}")]
Base64DecodeError(String),
#[error("Error retrieving bandit data for bandit {bandit} and arm {arm}")]
BanditNotFound { bandit: String, arm: String },
}
/// Result enum for the public API

Просмотреть файл

@ -18,6 +18,8 @@ mod rs;
mod schema;
pub mod url_hash;
use rand_distr::{Beta, Distribution};
pub use db::RelevancyDb;
pub use error::{ApiResult, Error, RelevancyApiError, Result};
pub use interest::{Interest, InterestVector};
@ -95,6 +97,72 @@ impl RelevancyStore {
pub fn user_interest_vector(&self) -> ApiResult<InterestVector> {
self.db.read(|dao| dao.get_frecency_user_interest_vector())
}
/// Initializes probability distributions for any uninitialized items (arms) within a bandit model.
///
/// This method takes a `bandit` identifier and a list of `arms` (items) and ensures that each arm
/// in the list has an initialized probability distribution in the database. For each arm, if the
/// probability distribution does not already exist, it will be created, using Beta(1,1) as default,
/// which represents uniform distribution.
#[handle_error(Error)]
pub fn bandit_init(&self, bandit: String, arms: &[String]) -> ApiResult<()> {
self.db.read_write(|dao| {
for arm in arms {
dao.initialize_multi_armed_bandit(&bandit, arm)?;
}
Ok(())
})?;
Ok(())
}
/// Selects the optimal item (arm) to display to the user based on a multi-armed bandit model.
///
/// This method takes in a `bandit` identifier and a list of possible `arms` (items) and uses a
/// Thompson sampling approach to select the arm with the highest probability of success.
/// For each arm, it retrieves the Beta distribution parameters (alpha and beta) from the
/// database, creates a Beta distribution, and samples from it to estimate the arm's probability
/// of success. The arm with the highest sampled probability is selected and returned.
#[handle_error(Error)]
pub fn bandit_select(&self, bandit: String, arms: &[String]) -> ApiResult<String> {
// we should cache the distribution so we don't retrieve each time
let mut best_sample = f64::MIN;
let mut selected_arm = String::new();
for arm in arms {
let (alpha, beta) = self
.db
.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, arm))?;
// this creates a Beta distribution for an alpha & beta pair
let beta_dist = Beta::new(alpha as f64, beta as f64)
.expect("computing betas dist unexpectedly failed");
// Sample from the Beta distribution
let sampled_prob = beta_dist.sample(&mut rand::thread_rng());
if sampled_prob > best_sample {
best_sample = sampled_prob;
selected_arm.clone_from(arm);
}
}
return Ok(selected_arm);
}
/// Updates the bandit model's arm data based on user interaction (selection or non-selection).
///
/// This method takes in a `bandit` identifier, an `arm` identifier, and a `selected` flag.
/// If `selected` is true, it updates the model to reflect a successful selection of the arm,
/// reinforcing its positive reward probability. If `selected` is false, it updates the
/// beta (failure) distribution of the arm, reflecting a lack of selection and reinforcing
/// its likelihood of a negative outcome.
#[handle_error(Error)]
pub fn bandit_update(&self, bandit: String, arm: String, selected: bool) -> ApiResult<()> {
self.db
.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, selected))?;
Ok(())
}
}
impl RelevancyStore {
@ -147,6 +215,8 @@ mod test {
use crate::url_hash::hash_url;
use super::*;
use rand::Rng;
use std::collections::HashMap;
fn make_fixture() -> Vec<(String, Interest)> {
vec![
@ -207,4 +277,66 @@ mod test {
expected_interest_vector()
);
}
#[test]
fn test_thompson_sampling_convergence() {
let relevancy_store = setup_store("thompson_sampling_convergence");
let arms_to_ctr_map: HashMap<String, f64> = [
("wiki".to_string(), 0.1), // 10% CTR
("geolocation".to_string(), 0.3), // 30% CTR
("weather".to_string(), 0.8), // 80% CTR
]
.into_iter()
.collect();
let arm_names: Vec<String> = arms_to_ctr_map.keys().cloned().collect();
let bandit = "provider".to_string();
// initialize bandit
relevancy_store
.bandit_init(bandit.clone(), &arm_names)
.unwrap();
let mut rng = rand::thread_rng();
// Create a HashMap to map arm names to their selection counts
let mut selection_counts: HashMap<String, usize> =
arm_names.iter().map(|name| (name.clone(), 0)).collect();
// Simulate 1000 rounds of Thompson Sampling
for _ in 0..1000 {
// Use Thompson Sampling to select an arm
let selected_arm_name = relevancy_store
.bandit_select(bandit.clone(), &arm_names)
.expect("Failed to select arm");
// increase the selection count for the selected arm
*selection_counts.get_mut(&selected_arm_name).unwrap() += 1;
// get the true CTR for the selected arm
let true_ctr = &arms_to_ctr_map[&selected_arm_name];
// simulate a click or no-click based on the true CTR
let clicked = rng.gen_bool(*true_ctr);
// update beta distribution for arm based on click/no click
relevancy_store
.bandit_update(bandit.clone(), selected_arm_name, clicked)
.expect("Failed to update beta distribution for arm");
}
//retrieve arm with maximum selection count
let most_selected_arm_name = selection_counts
.iter()
.max_by_key(|(_, count)| *count)
.unwrap()
.0;
assert_eq!(
most_selected_arm_name, "weather",
"Thompson Sampling did not favor the best-performing arm"
);
}
}

Просмотреть файл

@ -13,7 +13,7 @@ use sql_support::open_database::{self, ConnectionInitializer};
/// 1. Bump this version.
/// 2. Add a migration from the old version to the new version in
/// [`RelevancyConnectionInitializer::upgrade_from`].
pub const VERSION: u32 = 14;
pub const VERSION: u32 = 15;
/// The current database schema.
pub const SQL: &str = "
@ -30,6 +30,15 @@ pub const SQL: &str = "
count INTEGER NOT NULL,
PRIMARY KEY (kind, interest_code)
) WITHOUT ROWID;
CREATE TABLE multi_armed_bandit(
bandit TEXT NOT NULL,
arm TEXT NOT NULL,
alpha INTEGER NOT NULL,
beta INTEGER NOT NULL,
clicks INTEGER NOT NULL,
impressions INTEGER NOT NULL,
PRIMARY KEY (bandit, arm)
) WITHOUT ROWID;
";
/// Initializes an SQLite connection to the Relevancy database, performing
@ -73,6 +82,21 @@ impl ConnectionInitializer for RelevancyConnectionInitializer {
)?;
Ok(())
}
14 => {
tx.execute(
"CREATE TABLE multi_armed_bandit(
bandit TEXT NOT NULL,
arm TEXT NOT NULL,
alpha INTEGER NOT NULL,
beta INTEGER NOT NULL,
clicks INTEGER NOT NULL,
impressions INTEGER NOT NULL,
PRIMARY KEY (bandit, arm)
) WITHOUT ROWID;",
(),
)?;
Ok(())
}
_ => Err(open_database::Error::IncompatibleVersion(version)),
}
}