diff --git a/Cargo.lock b/Cargo.lock index 4ff840d1..ae288c0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -685,7 +685,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "libflate" -version = "0.1.16" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "adler32 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1310,7 +1310,7 @@ dependencies = [ "futures 0.1.23 (registry+https://github.com/rust-lang/crates.io-index)", "hyper 0.11.27 (registry+https://github.com/rust-lang/crates.io-index)", "hyper-tls 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", - "libflate 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)", + "libflate 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)", "mime_guess 2.0.0-alpha.6 (registry+https://github.com/rust-lang/crates.io-index)", "native-tls 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1361,7 +1361,7 @@ dependencies = [ "term 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)", "time 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", - "tiny_http 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", + "tiny_http 0.6.0 (git+https://github.com/aidanhs/tiny-http-sccache.git?rev=a14fa0a)", "url 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -1477,6 +1477,7 @@ dependencies = [ "toml 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", "url 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)", "uuid 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)", + "void 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", "walkdir 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", "which 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "winapi 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1797,13 +1798,14 @@ dependencies = [ [[package]] name = "tiny_http" version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" +source = "git+https://github.com/aidanhs/tiny-http-sccache.git?rev=a14fa0a#a14fa0ab963be252c0c608e2516ef30252d6a7e2" dependencies = [ "ascii 0.8.7 (registry+https://github.com/rust-lang/crates.io-index)", "chrono 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", "chunked_transfer 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "encoding 0.2.33 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)", + "openssl 0.10.11 (registry+https://github.com/rust-lang/crates.io-index)", "url 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -2408,7 +2410,7 @@ dependencies = [ "checksum lazy_static 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca488b89a5657b0a2ecd45b95609b3e848cf1755da332a0da46e2b2b1cb371a7" "checksum lazycell 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a6f08839bc70ef4a3fe1d566d5350f519c5912ea86be0df1740a7d247c7fc0ef" "checksum libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)" = "76e3a3ef172f1a0b9a9ff0dd1491ae5e6c948b94479a3021819ba7d860c8645d" -"checksum libflate 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "7d4b4c7aff5bac19b956f693d0ea0eade8066deb092186ae954fa6ba14daab98" +"checksum libflate 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)" = "21138fc6669f438ed7ae3559d5789a5f0ba32f28c1f0608d1e452b0bb06ee936" "checksum libmount 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)" = "d9d45f88f32c57ebf3688ada41414dc700aab97ad58e26cbcda6af50da53559a" "checksum linked-hash-map 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "bda158e0dabeb97ee8a401f4d17e479d6b891a14de0bba79d5cc2d4d325b5e48" "checksum local-encoding 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e1ceb20f39ff7ae42f3ff9795f3986b1daad821caaa1e1732a0944103a5a1a66" @@ -2524,7 +2526,7 @@ dependencies = [ "checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b" "checksum threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e2f0c90a5f3459330ac8bc0d2f879c693bb7a2f59689c1083fc4ef83834da865" "checksum time 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "d825be0eb33fda1a7e68012d51e9c7f451dc1a69391e7fdc197060bb8c56667b" -"checksum tiny_http 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a442681f9f72e440be192700eeb2861e4174b9983f16f4877c93a134cb5e5f63" +"checksum tiny_http 0.6.0 (git+https://github.com/aidanhs/tiny-http-sccache.git?rev=a14fa0a)" = "" "checksum tokio 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "fbb6a6e9db2702097bfdfddcb09841211ad423b86c75b5ddaca1d62842ac492c" "checksum tokio-codec 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "881e9645b81c2ce95fcb799ded2c29ffb9f25ef5bef909089a420e5961dd8ccb" "checksum tokio-core 0.1.17 (registry+https://github.com/rust-lang/crates.io-index)" = "aeeffbbb94209023feaef3c196a41cbcdafa06b4a6f893f68779bb5e53796f71" diff --git a/Cargo.toml b/Cargo.toml index 7769802a..196d584c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ rand = "0.5" redis = { version = "0.9.0", optional = true } regex = "1" # Exact dependency since we use the unstable async API +# If updating this, make sure to update dev-dependencies reqwest = { version = "=0.8.8", features = ["unstable"], optional = true } retry = "0.4.0" ring = "0.13.2" @@ -83,7 +84,12 @@ arraydeque = { version = "0.4", optional = true } crossbeam-utils = { version = "0.5", optional = true } libmount = { version = "0.1.10", optional = true } nix = { version = "0.11.0", optional = true } -rouille = { version = "2.2", optional = true, default-features = false } +rouille = { version = "2.2", optional = true, default-features = false, features = ["ssl"] } +void = { version = "1", optional = true } + +[patch.crates-io] +# Waiting for https://github.com/tiny-http/tiny-http/pull/151 +tiny_http = { git = "https://github.com/aidanhs/tiny-http-sccache.git", rev = "a14fa0a" } [dev-dependencies] assert_cmd = "0.9" @@ -93,6 +99,8 @@ escargot = "0.3" itertools = "0.7" predicates = "0.9.0" selenium-rs = "0.1" +# Must match the version of request in dependencies +reqwest = { version = "=0.8.8" } [target.'cfg(unix)'.dependencies] daemonize = "0.3" @@ -118,7 +126,7 @@ unstable = [] # Enables distributed support in the sccache client dist-client = ["ar", "flate2", "hyper", "reqwest", "rust-crypto", "url"] # Enables the sccache-dist binary -dist-server = ["arraydeque", "crossbeam-utils", "jsonwebtoken", "flate2", "libmount", "nix", "openssl", "reqwest", "rouille"] +dist-server = ["arraydeque", "crossbeam-utils", "jsonwebtoken", "flate2", "libmount", "nix", "openssl", "reqwest", "rouille", "void"] # Enables dist tests with external requirements dist-tests = [] diff --git a/scripts/extratest.sh b/scripts/extratest.sh new file mode 100755 index 00000000..a2306625 --- /dev/null +++ b/scripts/extratest.sh @@ -0,0 +1,97 @@ +#!/bin/bash +set -o errexit +set -o pipefail +set -o nounset +set -o xtrace + +#CARGO="cargo --color=always" +CARGO="cargo" + +gnutarget=x86_64-unknown-linux-gnu +wintarget=x86_64-pc-windows-gnu + +gnutarget() { + unset OPENSSL_DIR + export OPENSSL_STATIC=1 + target=$gnutarget +} +wintarget() { + export OPENSSL_DIR=$(pwd)/openssl-win + export OPENSSL_STATIC=1 + target=$wintarget +} + +# all-windows doesn't work as redis-rs build.rs has issues (checks for cfg!(unix)) + +if [ "$1" = checkall ]; then + $CARGO check --target $target --all-targets --features 'all dist-client dist-server dist-tests' + $CARGO check --target $target --all-targets --features 'all dist-client dist-server' + $CARGO check --target $target --all-targets --features 'all dist-client dist-tests' + $CARGO check --target $target --all-targets --features 'all dist-server dist-tests' + $CARGO check --target $target --all-targets --features 'all dist-client' + $CARGO check --target $target --all-targets --features 'all dist-server' + $CARGO check --target $target --all-targets --features 'all dist-tests' + $CARGO check --target $target --all-targets --features 'all' + $CARGO check --target $target --all-targets --features 'dist-client dist-server dist-tests' + $CARGO check --target $target --all-targets --features 'dist-client dist-server' + $CARGO check --target $target --all-targets --features 'dist-client dist-tests' + $CARGO check --target $target --all-targets --features 'dist-server dist-tests' + $CARGO check --target $target --all-targets --features 'dist-client' + $CARGO check --target $target --all-targets --features 'dist-server' + $CARGO check --target $target --all-targets --features 'dist-tests' + $CARGO check --target $target --all-targets --features '' + $CARGO check --target $target --all-targets --no-default-features --features 'all dist-client dist-server dist-tests' + $CARGO check --target $target --all-targets --no-default-features --features 'all dist-client dist-server' + $CARGO check --target $target --all-targets --no-default-features --features 'all dist-client dist-tests' + $CARGO check --target $target --all-targets --no-default-features --features 'all dist-server dist-tests' + $CARGO check --target $target --all-targets --no-default-features --features 'all dist-client' + $CARGO check --target $target --all-targets --no-default-features --features 'all dist-server' + $CARGO check --target $target --all-targets --no-default-features --features 'all dist-tests' + $CARGO check --target $target --all-targets --no-default-features --features 'all' + $CARGO check --target $target --all-targets --no-default-features --features 'dist-client dist-server dist-tests' + $CARGO check --target $target --all-targets --no-default-features --features 'dist-client dist-server' + $CARGO check --target $target --all-targets --no-default-features --features 'dist-client dist-tests' + $CARGO check --target $target --all-targets --no-default-features --features 'dist-server dist-tests' + $CARGO check --target $target --all-targets --no-default-features --features 'dist-client' + $CARGO check --target $target --all-targets --no-default-features --features 'dist-server' + $CARGO check --target $target --all-targets --no-default-features --features 'dist-tests' + $CARGO check --target $target --all-targets --no-default-features --features '' + wintarget + $CARGO check --target $target --all-targets --features 'dist-client' + #$CARGO check --target $target --all-targets --features 'all-windows dist-client' + #$CARGO check --target $target --all-targets --features 'all-windows' + $CARGO check --target $target --all-targets --features '' + + +elif [ "$1" = test ]; then + # Musl tests segfault due to https://github.com/mozilla/sccache/issues/256#issuecomment-399254715 + gnutarget + VERBOSE= + NOCAPTURE= + NORUN= + TESTTHREADS= + #VERBOSE="--verbose" + #NORUN=--no-run + #NOCAPTURE=--nocapture + TESTTHREADS="--test-threads 1" + + # Since integration tests start up the sccache server they must be run sequentially. This only matters + # if you have multiple test functions in one file. + + set +x + if ! which docker; then + echo -e "WARNING: =====\n\ndocker not present, some tests will fail\n\n=====\n\n\n\n\n" + sleep 5 + fi + if ! which icecc-create-env; then + echo -e "WARNING: =====\n\nicecc-create-env not present, some tests will fail\n\n=====\n\n\n\n\n" + sleep 5 + fi + set -x + + RUST_BACKTRACE=1 $CARGO test $NORUN --target $target --features 'all dist-client dist-server dist-tests' $VERBOSE -- $NOCAPTURE $TESTTHREADS test_dist_nobuilder + +else + echo invalid command + exit 1 +fi diff --git a/src/bin/sccache-dist/build.rs b/src/bin/sccache-dist/build.rs index 3a15a361..de40db56 100644 --- a/src/bin/sccache-dist/build.rs +++ b/src/bin/sccache-dist/build.rs @@ -18,26 +18,54 @@ use libmount::Overlay; use lru_disk_cache::Error as LruError; use nix; use sccache::dist::{ - BuildResult, CompileCommand, InputsReader, OutputData, TcCache, Toolchain, + BuildResult, CompileCommand, InputsReader, OutputData, ProcessOutput, TcCache, Toolchain, BuilderIncoming, }; -use std::collections::HashMap; +use std::collections::{HashMap, hash_map}; use std::fs; use std::io; use std::iter; use std::path::{self, Path, PathBuf}; -use std::process::{Command, Output, Stdio}; +use std::process::{ChildStdin, Command, Output, Stdio}; use std::sync::{Mutex}; use tar; use errors::*; -fn check_output(output: &Output) { - if !output.status.success() { - error!("===========\n{}\n==========\n\n\n\n=========\n{}\n===============\n\n\n", - String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stderr)); - panic!() +trait CommandExt { + fn check_stdout_trim(&mut self) -> Result; + fn check_piped(&mut self, pipe: &mut FnMut(&mut ChildStdin) -> Result<()>) -> Result<()>; + fn check_run(&mut self) -> Result<()>; +} + +impl CommandExt for Command { + fn check_stdout_trim(&mut self) -> Result { + let output = self.output().chain_err(|| "Failed to start command")?; + check_output(&output)?; + let stdout = String::from_utf8(output.stdout).chain_err(|| "Output from listing containers not UTF8")?; + Ok(stdout.trim().to_owned()) } + // Should really take a FnOnce/FnBox + fn check_piped(&mut self, pipe: &mut FnMut(&mut ChildStdin) -> Result<()>) -> Result<()> { + let mut process = self.stdin(Stdio::piped()).spawn().chain_err(|| "Failed to start command")?; + let mut stdin = process.stdin.take().expect("Requested piped stdin but not present"); + pipe(&mut stdin).chain_err(|| "Failed to pipe input to process")?; + let output = process.wait_with_output().chain_err(|| "Failed to wait for process to return")?; + check_output(&output) + } + fn check_run(&mut self) -> Result<()> { + let output = self.output().chain_err(|| "Failed to start command")?; + check_output(&output) + } +} + +fn check_output(output: &Output) -> Result<()> { + if !output.status.success() { + warn!("===========\n{}\n==========\n\n\n\n=========\n{}\n===============\n\n\n", + String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stderr)); + bail!("Command failed with status {}", output.status) + } + Ok(()) } fn join_suffix>(path: &Path, suffix: P) -> PathBuf { @@ -76,24 +104,30 @@ impl OverlayBuilder { dir, toolchain_dir_map: Mutex::new(HashMap::new()), }; - ret.cleanup(); - fs::create_dir(&ret.dir).unwrap(); - fs::create_dir(ret.dir.join("builds")).unwrap(); - fs::create_dir(ret.dir.join("toolchains")).unwrap(); + ret.cleanup()?; + fs::create_dir(&ret.dir).chain_err(|| "Failed to create base directory for builder")?; + fs::create_dir(ret.dir.join("builds")).chain_err(|| "Failed to create builder builds directory")?; + fs::create_dir(ret.dir.join("toolchains")).chain_err(|| "Failed to create builder toolchains directory")?; Ok(ret) } - fn cleanup(&self) { + fn cleanup(&self) -> Result<()> { if self.dir.exists() { - fs::remove_dir_all(&self.dir).unwrap() + fs::remove_dir_all(&self.dir).chain_err(|| "Failed to clean up builder directory")? } + Ok(()) } fn prepare_overlay_dirs(&self, tc: &Toolchain, tccache: &Mutex) -> Result { let (toolchain_dir, id) = { let mut toolchain_dir_map = self.toolchain_dir_map.lock().unwrap(); // Create the toolchain dir (if necessary) while we have an exclusive lock - if !toolchain_dir_map.contains_key(tc) { + if toolchain_dir_map.contains_key(tc) { + // TODO: use if let when sccache can use NLL + let entry = toolchain_dir_map.get_mut(tc).expect("Key missing after checking"); + entry.1 += 1; + entry.clone() + } else { trace!("Creating toolchain directory for {}", tc.archive_id); let toolchain_dir = self.dir.join("toolchains").join(&tc.archive_id); fs::create_dir(&toolchain_dir)?; @@ -102,14 +136,14 @@ impl OverlayBuilder { let toolchain_rdr = match tccache.get(tc) { Ok(rdr) => rdr, Err(LruError::FileNotInCache) => bail!("expected toolchain {}, but not available", tc.archive_id), - Err(e) => return Err(Error::with_chain(e, "failed to get toolchain from cache")), + Err(e) => return Err(Error::from(e).chain_err(|| "failed to get toolchain from cache")), }; tar::Archive::new(GzDecoder::new(toolchain_rdr)).unpack(&toolchain_dir)?; - assert!(toolchain_dir_map.insert(tc.clone(), (toolchain_dir, 0)).is_none()) + + let entry = (toolchain_dir, 1); + assert!(toolchain_dir_map.insert(tc.clone(), entry.clone()).is_none()); + entry } - let entry = toolchain_dir_map.get_mut(tc).unwrap(); - entry.1 += 1; - entry.clone() }; trace!("Creating build directory for {}-{}", tc.archive_id, id); @@ -118,7 +152,7 @@ impl OverlayBuilder { Ok(OverlaySpec { build_dir, toolchain_dir }) } - fn perform_build(bubblewrap: &Path, compile_command: CompileCommand, inputs_rdr: InputsReader, output_paths: Vec, overlay: &OverlaySpec) -> BuildResult { + fn perform_build(bubblewrap: &Path, compile_command: CompileCommand, inputs_rdr: InputsReader, output_paths: Vec, overlay: &OverlaySpec) -> Result { trace!("Compile environment: {:?}", compile_command.env_vars); trace!("Compile command: {:?} {:?}", compile_command.executable, compile_command.arguments); @@ -126,40 +160,47 @@ impl OverlayBuilder { // Now mounted filesystems will be automatically unmounted when this thread dies // (and tmpfs filesystems will be completely destroyed) - nix::sched::unshare(nix::sched::CloneFlags::CLONE_NEWNS).unwrap(); + nix::sched::unshare(nix::sched::CloneFlags::CLONE_NEWNS) + .chain_err(|| "Failed to enter a new Linux namespace")?; // Make sure that all future mount changes are private to this namespace // TODO: shouldn't need to add these annotations let source: Option<&str> = None; let fstype: Option<&str> = None; let data: Option<&str> = None; - nix::mount::mount(source, "/", fstype, nix::mount::MsFlags::MS_REC | nix::mount::MsFlags::MS_PRIVATE, data).unwrap(); + // Turn / into a 'slave', so it receives mounts from real root, but doesn't propogate back + nix::mount::mount(source, "/", fstype, nix::mount::MsFlags::MS_REC | nix::mount::MsFlags::MS_PRIVATE, data) + .chain_err(|| "Failed to turn / into a slave")?; let work_dir = overlay.build_dir.join("work"); let upper_dir = overlay.build_dir.join("upper"); let target_dir = overlay.build_dir.join("target"); - fs::create_dir(&work_dir).unwrap(); - fs::create_dir(&upper_dir).unwrap(); - fs::create_dir(&target_dir).unwrap(); + fs::create_dir(&work_dir).chain_err(|| "Failed to create overlay work directory")?; + fs::create_dir(&upper_dir).chain_err(|| "Failed to create overlay upper directory")?; + fs::create_dir(&target_dir).chain_err(|| "Failed to create overlay target directory")?; let () = Overlay::writable( iter::once(overlay.toolchain_dir.as_path()), upper_dir, work_dir, &target_dir, - ).mount().unwrap(); + // This error is unfortunately not Send + ).mount().map_err(|e| Error::from(e.to_string())).chain_err(|| "Failed to mount overlay FS")?; trace!("copying in inputs"); // Note that we don't unpack directly into the upperdir since there overlayfs has some // special marker files that we don't want to create by accident (or malicious intent) - tar::Archive::new(inputs_rdr).unpack(&target_dir).unwrap(); + tar::Archive::new(inputs_rdr).unpack(&target_dir).chain_err(|| "Failed to unpack inputs to overlay")?; let CompileCommand { executable, arguments, env_vars, cwd } = compile_command; let cwd = Path::new(&cwd); trace!("creating output directories"); - fs::create_dir_all(join_suffix(&target_dir, cwd)).unwrap(); + fs::create_dir_all(join_suffix(&target_dir, cwd)).chain_err(|| "Failed to create cwd")?; for path in output_paths.iter() { - fs::create_dir_all(join_suffix(&target_dir, cwd.join(Path::new(path).parent().unwrap()))).unwrap(); + // If it doesn't have a parent, nothing needs creating + let output_parent = if let Some(p) = Path::new(path).parent() { p } else { continue }; + fs::create_dir_all(join_suffix(&target_dir, cwd.join(output_parent))) + .chain_err(|| "Failed to create an output directory")?; } trace!("performing compile"); @@ -199,7 +240,7 @@ impl OverlayBuilder { cmd.arg("--"); cmd.arg(executable); cmd.args(arguments); - let compile_output = cmd.output().unwrap(); + let compile_output = cmd.output().chain_err(|| "Failed to retrieve output from compile")?; trace!("compile_output: {:?}", compile_output); let mut outputs = vec![]; @@ -208,28 +249,36 @@ impl OverlayBuilder { let abspath = join_suffix(&target_dir, cwd.join(&path)); // Resolve in case it's relative since we copy it from the root level match fs::File::open(abspath) { Ok(mut file) => { - let output = OutputData::from_reader(file); + let output = OutputData::try_from_reader(file) + .chain_err(|| "Failed to read output file")?; outputs.push((path, output)) }, Err(e) => { if e.kind() == io::ErrorKind::NotFound { debug!("Missing output path {:?}", path) } else { - panic!(e) + return Err(Error::from(e).chain_err(|| "Failed to open output file")) } }, } } - BuildResult { output: compile_output.into(), outputs } + let compile_output = ProcessOutput::try_from(compile_output) + .chain_err(|| "Failed to convert compilation exit status")?; + Ok(BuildResult { output: compile_output, outputs }) - }).join().unwrap() }) + // Bizarrely there's no way to actually get any information from a thread::Result::Err + }).join().unwrap_or_else(|_e| Err(Error::from("Build thread exited unsuccessfully"))) }) } + // Failing during cleanup is pretty unexpected, but we can still return the successful compile + // TODO: if too many of these fail, we should mark this builder as faulty fn finish_overlay(&self, _tc: &Toolchain, overlay: OverlaySpec) { // TODO: collect toolchain directories let OverlaySpec { build_dir, toolchain_dir: _ } = overlay; - fs::remove_dir_all(build_dir).unwrap(); + if let Err(e) = fs::remove_dir_all(&build_dir) { + error!("Failed to remove build directory {}: {}", build_dir.display(), e); + } } } @@ -243,11 +292,27 @@ impl BuilderIncoming for OverlayBuilder { debug!("Finishing with overlay"); self.finish_overlay(&tc, overlay); debug!("Returning result"); - Ok(res) + res.chain_err(|| "Compilation execution failed") } } const BASE_DOCKER_IMAGE: &str = "aidanhs/busybox"; +// Make sure sh doesn't exec the final command, since we need it to do +// init duties (reaping zombies). Also, because we kill -9 -1, that kills +// the sleep (it's not a builtin) so it needs to be a loop. +const DOCKER_SHELL_INIT: &str = "while true; do /busybox sleep 365d && /busybox true; done"; + +// Check the diff and clean up the FS +fn docker_diff(cid: &str) -> Result { + Command::new("docker").args(&["diff", cid]).check_stdout_trim() + .chain_err(|| "Failed to Docker diff container") +} + +// Force remove the container +fn docker_rm(cid: &str) -> Result<()> { + Command::new("docker").args(&["rm", "-f", &cid]).check_run() + .chain_err(|| "Failed to force delete container") +} pub struct DockerBuilder { image_map: Mutex>, @@ -258,128 +323,114 @@ impl DockerBuilder { // TODO: this should accept a unique string, e.g. inode of the tccache directory // having locked a pidfile, or at minimum should loudly detect other running // instances - pidfile in /tmp - pub fn new() -> Self { + pub fn new() -> Result { info!("Creating docker builder"); let ret = Self { image_map: Mutex::new(HashMap::new()), container_lists: Mutex::new(HashMap::new()), }; - ret.cleanup(); - ret + ret.cleanup()?; + Ok(ret) } // TODO: this should really reclaim, and should check in the image map and container lists, so // that when things are removed from there it becomes a form of GC - fn cleanup(&self) { + fn cleanup(&self) -> Result<()> { info!("Performing initial Docker cleanup"); - let containers = { - let output = Command::new("docker").args(&["ps", "-a", "--format", "{{.ID}} {{.Image}}"]).output().unwrap(); - check_output(&output); - let stdout = String::from_utf8(output.stdout).unwrap(); - stdout.trim().to_owned() - }; + let containers = Command::new("docker").args(&["ps", "-a", "--format", "{{.ID}} {{.Image}}"]).check_stdout_trim() + .chain_err(|| "Unable to list all Docker containers")?; if containers != "" { let mut containers_to_rm = vec![]; for line in containers.split(|c| c == '\n') { let mut iter = line.splitn(2, ' '); - let container_id = iter.next().unwrap(); - let image_name = iter.next().unwrap(); - if iter.next() != None { panic!() } + let container_id = iter.next().ok_or_else(|| Error::from("Malformed container listing - no container ID"))?; + let image_name = iter.next().ok_or_else(|| Error::from("Malformed container listing - no image name"))?; + if iter.next() != None { bail!("Malformed container listing - third field on row") } if image_name.starts_with("sccache-builder-") { containers_to_rm.push(container_id) } } if !containers_to_rm.is_empty() { - let output = Command::new("docker").args(&["rm", "-f"]).args(containers_to_rm).output().unwrap(); - check_output(&output) + Command::new("docker").args(&["rm", "-f"]).args(containers_to_rm).check_run() + .chain_err(|| "Failed to start command to remove old containers")?; } } - let images = { - let output = Command::new("docker").args(&["images", "--format", "{{.ID}} {{.Repository}}"]).output().unwrap(); - check_output(&output); - let stdout = String::from_utf8(output.stdout).unwrap(); - stdout.trim().to_owned() - }; + let images = Command::new("docker").args(&["images", "--format", "{{.ID}} {{.Repository}}"]).check_stdout_trim() + .chain_err(|| "Failed to list all docker images")?; if images != "" { let mut images_to_rm = vec![]; for line in images.split(|c| c == '\n') { let mut iter = line.splitn(2, ' '); - let image_id = iter.next().unwrap(); - let image_name = iter.next().unwrap(); - if iter.next() != None { panic!() } + let image_id = iter.next().ok_or_else(|| Error::from("Malformed image listing - no image ID"))?; + let image_name = iter.next().ok_or_else(|| Error::from("Malformed image listing - no image name"))?; + if iter.next() != None { bail!("Malformed image listing - third field on row") } if image_name.starts_with("sccache-builder-") { images_to_rm.push(image_id) } } if !images_to_rm.is_empty() { - let output = Command::new("docker").args(&["rmi"]).args(images_to_rm).output().unwrap(); - check_output(&output) + Command::new("docker").args(&["rmi"]).args(images_to_rm).check_run() + .chain_err(|| "Failed to remove image")? } } info!("Completed initial Docker cleanup"); + Ok(()) } // If we have a spare running container, claim it and remove it from the available list, // otherwise try and create a new container (possibly creating the Docker image along // the way) - fn get_container(&self, tc: &Toolchain, tccache: &Mutex) -> String { + fn get_container(&self, tc: &Toolchain, tccache: &Mutex) -> Result { let container = { let mut map = self.container_lists.lock().unwrap(); map.entry(tc.clone()).or_insert_with(Vec::new).pop() }; match container { - Some(cid) => cid, + Some(cid) => Ok(cid), None => { // TODO: can improve parallelism (of creating multiple images at a time) by using another // (more fine-grained) mutex around the entry value and checking if its empty a second time let image = { let mut map = self.image_map.lock().unwrap(); - map.entry(tc.clone()).or_insert_with(|| { - info!("Creating Docker image for {:?} (may block requests)", tc); - Self::make_image(tc, tccache) - }).clone() + match map.entry(tc.clone()) { + hash_map::Entry::Occupied(e) => e.get().clone(), + hash_map::Entry::Vacant(e) => { + info!("Creating Docker image for {:?} (may block requests)", tc); + let image = Self::make_image(tc, tccache)?; + e.insert(image.clone()); + image + }, + } }; Self::start_container(&image) }, } } - fn finish_container(&self, tc: &Toolchain, cid: String) { - // TODO: collect images - + fn clean_container(&self, cid: &str) -> Result<()> { // Clean up any running processes - let output = Command::new("docker").args(&["exec", &cid, "/busybox", "kill", "-9", "-1"]).output().unwrap(); - check_output(&output); + Command::new("docker").args(&["exec", &cid, "/busybox", "kill", "-9", "-1"]).check_run() + .chain_err(|| "Failed to run kill on all processes in container")?; - // Check the diff and clean up the FS - fn dodiff(cid: &str) -> String { - let output = Command::new("docker").args(&["diff", cid]).output().unwrap(); - check_output(&output); - let stdout = String::from_utf8(output.stdout).unwrap(); - stdout.trim().to_owned() - } - let diff = dodiff(&cid); + let diff = docker_diff(&cid)?; if diff != "" { - let mut shoulddelete = false; let mut lastpath = None; for line in diff.split(|c| c == '\n') { let mut iter = line.splitn(2, ' '); - let changetype = iter.next().unwrap(); - let changepath = iter.next().unwrap(); - if iter.next() != None { panic!() } + let changetype = iter.next().ok_or_else(|| Error::from("Malformed container diff - no change type"))?; + let changepath = iter.next().ok_or_else(|| Error::from("Malformed container diff - no change path"))?; + if iter.next() != None { bail!("Malformed container diff - third field on row") } // TODO: If files are created in this dir, it gets marked as modified. // A similar thing applies to /root or /build etc if changepath == "/tmp" { continue } if changetype != "A" { - warn!("Deleting container {}: path {} had a non-A changetype of {}", &cid, changepath, changetype); - shoulddelete = true; - break + bail!("Path {} had a non-A changetype of {}", changepath, changetype); } // Docker diff paths are in alphabetical order and we do `rm -rf`, so we might be able to skip // calling Docker more than necessary (since it's slow) @@ -389,80 +440,88 @@ impl DockerBuilder { } } lastpath = Some(changepath.clone()); - let output = Command::new("docker").args(&["exec", &cid, "/busybox", "rm", "-rf", changepath]).output().unwrap(); - check_output(&output); + if let Err(e) = Command::new("docker").args(&["exec", &cid, "/busybox", "rm", "-rf", changepath]).check_run() { + // We do a final check anyway, so just continue + warn!("Failed to remove added path in a container: {}", e) + } } - let newdiff = dodiff(&cid); + let newdiff = docker_diff(&cid)?; // See note about changepath == "/tmp" above - if !shoulddelete && newdiff != "" && newdiff != "C /tmp" { - warn!("Deleted files, but container still has a diff: {:?}", newdiff); - shoulddelete = true - } - - if shoulddelete { - let output = Command::new("docker").args(&["rm", "-f", &cid]).output().unwrap(); - check_output(&output); - return + if newdiff != "" && newdiff != "C /tmp" { + bail!("Attempted to delete files, but container still has a diff: {:?}", newdiff); } } - // Good as new, add it back to the container list - trace!("Reclaimed container"); - self.container_lists.lock().unwrap().get_mut(tc).unwrap().push(cid); + Ok(()) } - fn make_image(tc: &Toolchain, tccache: &Mutex) -> String { - let cid = { - let output = Command::new("docker").args(&["create", BASE_DOCKER_IMAGE, "/busybox", "true"]).output().unwrap(); - check_output(&output); - let stdout = String::from_utf8(output.stdout).unwrap(); - stdout.trim().to_owned() - }; + // Failing during cleanup is pretty unexpected, but we can still return the successful compile + // TODO: if too many of these fail, we should mark this builder as faulty + fn finish_container(&self, tc: &Toolchain, cid: String) { + // TODO: collect images + + if let Err(e) = self.clean_container(&cid) { + info!("Failed to clean container {}: {}", cid, e); + if let Err(e) = docker_rm(&cid) { + warn!("Failed to remove container {} after failed clean: {}", cid, e); + } + return + } + + // Good as new, add it back to the container list + if let Some(entry) = self.container_lists.lock().unwrap().get_mut(tc) { + debug!("Reclaimed container {}", cid); + entry.push(cid) + } else { + warn!("Was ready to reclaim container {} but toolchain went missing", cid); + if let Err(e) = docker_rm(&cid) { + warn!("Failed to remove container {}: {}", cid, e); + } + } + } + + fn make_image(tc: &Toolchain, tccache: &Mutex) -> Result { + let cid = Command::new("docker").args(&["create", BASE_DOCKER_IMAGE, "/busybox", "true"]).check_stdout_trim() + .chain_err(|| "Failed to create docker container")?; let mut tccache = tccache.lock().unwrap(); - let toolchain_rdr = match tccache.get(tc) { + let mut toolchain_rdr = match tccache.get(tc) { Ok(rdr) => rdr, - Err(LruError::FileNotInCache) => panic!("expected toolchain, but not available"), - Err(e) => panic!("{}", e), + Err(LruError::FileNotInCache) => bail!("Expected to find toolchain {}, but not available", tc.archive_id), + Err(e) => return Err(Error::from(e).chain_err(|| format!("Failed to use toolchain {}", tc.archive_id))), }; trace!("Copying in toolchain"); - let mut process = Command::new("docker").args(&["cp", "-", &format!("{}:/", cid)]).stdin(Stdio::piped()).spawn().unwrap(); - io::copy(&mut {toolchain_rdr}, &mut process.stdin.take().unwrap()).unwrap(); - let output = process.wait_with_output().unwrap(); - check_output(&output); + Command::new("docker").args(&["cp", "-", &format!("{}:/", cid)]) + .check_piped(&mut |stdin| { io::copy(&mut toolchain_rdr, stdin)?; Ok(()) }) + .chain_err(|| "Failed to copy toolchain tar into container")?; + drop(toolchain_rdr); let imagename = format!("sccache-builder-{}", &tc.archive_id); - let output = Command::new("docker").args(&["commit", &cid, &imagename]).output().unwrap(); - check_output(&output); + Command::new("docker").args(&["commit", &cid, &imagename]).check_run() + .chain_err(|| "Failed to commit container after build")?; - let output = Command::new("docker").args(&["rm", "-f", &cid]).output().unwrap(); - check_output(&output); + Command::new("docker").args(&["rm", "-f", &cid]).check_run() + .chain_err(|| "Failed to remove temporary build container")?; - imagename + Ok(imagename) } - fn start_container(image: &str) -> String { - // Make sure sh doesn't exec the final command, since we need it to do - // init duties (reaping zombies). Also, because we kill -9 -1, that kills - // the sleep (it's not a builtin) so it needs to be a loop. - let output = Command::new("docker") - .args(&["run", "-d", image, "/busybox", "sh", "-c", "while true; do /busybox sleep 365d && /busybox true; done"]).output().unwrap(); - check_output(&output); - let stdout = String::from_utf8(output.stdout).unwrap(); - stdout.trim().to_owned() + fn start_container(image: &str) -> Result { + Command::new("docker").args(&["run", "-d", image, "/busybox", "sh", "-c", DOCKER_SHELL_INIT]).check_stdout_trim() + .chain_err(|| "Failed to run container") } - fn perform_build(compile_command: CompileCommand, inputs_rdr: InputsReader, output_paths: Vec, cid: &str) -> BuildResult { + fn perform_build(compile_command: CompileCommand, mut inputs_rdr: InputsReader, output_paths: Vec, cid: &str) -> Result { trace!("Compile environment: {:?}", compile_command.env_vars); trace!("Compile command: {:?} {:?}", compile_command.executable, compile_command.arguments); trace!("copying in inputs"); - let mut process = Command::new("docker").args(&["cp", "-", &format!("{}:/", cid)]).stdin(Stdio::piped()).spawn().unwrap(); - io::copy(&mut {inputs_rdr}, &mut process.stdin.take().unwrap()).unwrap(); - let output = process.wait_with_output().unwrap(); - check_output(&output); + Command::new("docker").args(&["cp", "-", &format!("{}:/", cid)]) + .check_piped(&mut |stdin| { io::copy(&mut inputs_rdr, stdin)?; Ok(()) }) + .chain_err(|| "Failed to copy inputs tar into container")?; + drop(inputs_rdr); let CompileCommand { executable, arguments, env_vars, cwd } = compile_command; let cwd = Path::new(&cwd); @@ -472,10 +531,12 @@ impl DockerBuilder { let mut cmd = Command::new("docker"); cmd.args(&["exec", cid, "/busybox", "mkdir", "-p"]).arg(cwd); for path in output_paths.iter() { - cmd.arg(cwd.join(Path::new(path).parent().unwrap())); + // If it doesn't have a parent, nothing needs creating + let output_parent = if let Some(p) = Path::new(path).parent() { p } else { continue }; + cmd.arg(cwd.join(output_parent)); } - let output = cmd.output().unwrap(); - check_output(&output); + cmd.check_run() + .chain_err(|| "Failed to create directories required for compile in container")?; trace!("performing compile"); // TODO: likely shouldn't perform the compile as root in the container @@ -497,7 +558,7 @@ impl DockerBuilder { cmd.arg(cwd); cmd.arg(executable); cmd.args(arguments); - let compile_output = cmd.output().unwrap(); + let compile_output = cmd.output().chain_err(|| "Failed to start executing compile")?; trace!("compile_output: {:?}", compile_output); let mut outputs = vec![]; @@ -505,15 +566,20 @@ impl DockerBuilder { for path in output_paths { let abspath = cwd.join(&path); // Resolve in case it's relative since we copy it from the root level // TODO: this isn't great, but cp gives it out as a tar - let output = Command::new("docker").args(&["exec", cid, "/busybox", "cat"]).arg(abspath).output().unwrap(); + let output = Command::new("docker").args(&["exec", cid, "/busybox", "cat"]).arg(abspath).output() + .chain_err(|| "Failed to start command to retrieve output file")?; if output.status.success() { - outputs.push((path, OutputData::from_reader(&*output.stdout))) + let output = OutputData::try_from_reader(&*output.stdout) + .expect("Failed to read compress output stdout"); + outputs.push((path, output)) } else { debug!("Missing output path {:?}", path) } } - BuildResult { output: compile_output.into(), outputs } + let compile_output = ProcessOutput::try_from(compile_output) + .chain_err(|| "Failed to convert compilation exit status")?; + Ok(BuildResult { output: compile_output, outputs }) } } @@ -522,9 +588,11 @@ impl BuilderIncoming for DockerBuilder { // From Server fn run_build(&self, tc: Toolchain, command: CompileCommand, outputs: Vec, inputs_rdr: InputsReader, tccache: &Mutex) -> Result { debug!("Finding container"); - let cid = self.get_container(&tc, tccache); + let cid = self.get_container(&tc, tccache) + .chain_err(|| "Failed to get a container for build")?; debug!("Performing build with container {}", cid); - let res = Self::perform_build(command, inputs_rdr, outputs, &cid); + let res = Self::perform_build(command, inputs_rdr, outputs, &cid) + .chain_err(|| "Failed to perform build")?; debug!("Finishing with container {}", cid); self.finish_container(&tc, cid); debug!("Returning result"); diff --git a/src/bin/sccache-dist/main.rs b/src/bin/sccache-dist/main.rs index 7ede449a..6662c49b 100644 --- a/src/bin/sccache-dist/main.rs +++ b/src/bin/sccache-dist/main.rs @@ -21,6 +21,7 @@ extern crate sccache; extern crate serde_derive; extern crate serde_json; extern crate tar; +extern crate void; use arraydeque::ArrayDeque; use clap::{App, Arg, SubCommand}; @@ -32,9 +33,9 @@ use sccache::config::{ }; use sccache::dist::{ self, - CompileCommand, InputsReader, JobId, JobAlloc, JobState, JobComplete, ServerId, Toolchain, ToolchainReader, - AllocJobResult, AssignJobResult, HeartbeatServerResult, RunJobResult, StatusResult, SubmitToolchainResult, UpdateJobStateResult, - BuilderIncoming, SchedulerIncoming, SchedulerOutgoing, ServerIncoming, ServerOutgoing, + CompileCommand, InputsReader, JobId, JobAlloc, JobState, JobComplete, ServerId, ServerNonce, Toolchain, ToolchainReader, + AllocJobResult, AssignJobResult, HeartbeatServerResult, RunJobResult, SchedulerStatusResult, SubmitToolchainResult, UpdateJobStateResult, + BuilderIncoming, JobAuthorizer, SchedulerIncoming, SchedulerOutgoing, ServerIncoming, ServerOutgoing, TcCache, }; use std::collections::{btree_map, BTreeMap, HashMap}; @@ -44,7 +45,7 @@ use std::net::SocketAddr; use std::path::Path; use std::sync::Mutex; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::time::{UNIX_EPOCH, Duration, Instant, SystemTime}; +use std::time::Instant; use errors::*; @@ -56,6 +57,7 @@ mod errors { use base64; use jwt; use lru_disk_cache; + use openssl; use sccache; error_chain! { @@ -64,6 +66,7 @@ mod errors { Io(io::Error); Jwt(jwt::errors::Error); Lru(lru_disk_cache::Error); + Openssl(openssl::error::ErrorStack); } links { @@ -73,43 +76,10 @@ mod errors { } mod build; +mod token_check; pub const INSECURE_DIST_SERVER_TOKEN: &str = "dangerously_insecure_server"; -// https://auth0.com/docs/jwks -#[derive(Debug)] -#[derive(Serialize, Deserialize)] -struct Jwks { - keys: Vec, -} - -#[derive(Debug)] -#[derive(Serialize, Deserialize)] -struct Jwk { - kid: String, - kty: String, - n: String, - e: String, -} - -impl Jwk { - // https://github.com/lawliet89/biscuit/issues/96#issuecomment-399149872 - fn to_der_pkcs1(&self) -> Result> { - if self.kty != "RSA" { - bail!("Cannot handle non-RSA JWK") - } - - // JWK is big-endian, openssl bignum from_slice is big-endian - let n = base64::decode_config(&self.n, base64::URL_SAFE).unwrap(); - let e = base64::decode_config(&self.e, base64::URL_SAFE).unwrap(); - let n_bn = openssl::bn::BigNum::from_slice(&n).unwrap(); - let e_bn = openssl::bn::BigNum::from_slice(&e).unwrap(); - let pubkey = openssl::rsa::Rsa::from_public_components(n_bn, e_bn).unwrap(); - let der: Vec = pubkey.public_key_to_der_pkcs1().unwrap(); - Ok(der) - } -} - enum Command { Auth(AuthSubcommand), Scheduler(scheduler_config::Config), @@ -121,8 +91,6 @@ enum AuthSubcommand { JwtHS256ServerToken { secret_key: String, server_id: ServerId }, } -enum Void {} - // Only supported on x86_64 Linux machines #[cfg(all(target_os = "linux", target_arch = "x86_64"))] fn main() { @@ -144,6 +112,9 @@ fn main() { } Err(e) => { println!("sccache: {}", e); + for e in e.iter().skip(1) { + println!("caused by: {}", e); + } get_app().print_help().unwrap(); println!(""); 1 @@ -183,7 +154,7 @@ fn parse() -> Result { AuthSubcommand::Base64 { num_bytes: 256 / 8 } }, ("generate-jwt-hs256-server-token", Some(matches)) => { - let server_id = ServerId(value_t_or_exit!(matches, "server", SocketAddr)); + let server_id = ServerId::new(value_t_or_exit!(matches, "server", SocketAddr)); let secret_key = if let Some(config_path) = matches.value_of("config").map(Path::new) { if let Some(config) = scheduler_config::from_path(config_path)? { match config.server_auth { @@ -195,7 +166,7 @@ fn parse() -> Result { bail!("Could not load config") } } else { - matches.value_of("secret-key").unwrap().to_owned() + matches.value_of("secret-key").expect("missing secret-key in parsed subcommand").to_owned() }; AuthSubcommand::JwtHS256ServerToken { secret_key, server_id } }, @@ -210,7 +181,7 @@ fn parse() -> Result { }) } ("scheduler", Some(matches)) => { - let config_path = Path::new(matches.value_of("config").unwrap()); + let config_path = Path::new(matches.value_of("config").expect("missing config in parsed subcommand")); if let Some(config) = scheduler_config::from_path(config_path)? { Command::Scheduler(config) } else { @@ -218,7 +189,7 @@ fn parse() -> Result { } }, ("server", Some(matches)) => { - let config_path = Path::new(matches.value_of("config").unwrap()); + let config_path = Path::new(matches.value_of("config").expect("missing config in parsed subcommand")); if let Some(config) = server_config::from_path(config_path)? { Command::Server(config) } else { @@ -229,178 +200,6 @@ fn parse() -> Result { }) } -// Check a JWT is valid -fn check_jwt_validity(audience: &str, issuer: &str, kid_to_pkcs1: &HashMap>, token: &str) -> Result<()> { - let header = jwt::decode_header(token).chain_err(|| "Could not decode jwt header")?; - trace!("Validating JWT in scheduler"); - // Prepare validation - let kid = header.kid.chain_err(|| "No kid found")?; - let pkcs1 = kid_to_pkcs1.get(&kid).chain_err(|| "kid not found in jwks")?; - let mut validation = jwt::Validation::new(header.alg); - validation.set_audience(&audience); - validation.iss = Some(issuer.to_owned()); - #[derive(Deserialize)] - struct Claims {} - // Decode the JWT, discarding any claims - we just care about validity - let _tokendata = jwt::decode::(token, pkcs1, &validation) - .chain_err(|| "Unable to validate and decode jwt")?; - Ok(()) -} - -// https://infosec.mozilla.org/guidelines/iam/openid_connect#session-handling -const MOZ_SESSION_TIMEOUT: Duration = Duration::from_secs(60 * 15); -const MOZ_USERINFO_ENDPOINT: &str = "https://auth.mozilla.auth0.com/userinfo"; - -// Mozilla-specific check by forwarding the token onto the auth0 userinfo endpoint -fn check_mozilla(auth_cache: &Mutex>, client: &reqwest::Client, required_groups: &[String], token: &str) -> Result<()> { - // azp == client_id - // { - // "iss": "https://auth.mozilla.auth0.com/", - // "sub": "ad|Mozilla-LDAP|asayers", - // "aud": [ - // "sccache", - // "https://auth.mozilla.auth0.com/userinfo" - // ], - // "iat": 1541103283, - // "exp": 1541708083, - // "azp": "F1VVD6nRTckSVrviMRaOdLBWIk1AvHYo", - // "scope": "openid" - // } - #[derive(Deserialize)] - struct MozillaToken { - exp: u64, - sub: String, - } - // We don't really do any validation here (just forwarding on) so it's ok to unsafely decode - let unsafe_token = jwt::dangerous_unsafe_decode::(token).chain_err(|| "Unable to decode jwt")?; - let user = unsafe_token.claims.sub; - trace!("Validating token for user {} with mozilla", user); - if UNIX_EPOCH + Duration::from_secs(unsafe_token.claims.exp) < SystemTime::now() { - bail!("JWT expired") - } - // If the token is cached and not expired, return it - { - let mut auth_cache = auth_cache.lock().unwrap(); - if let Some(cached_at) = auth_cache.get(token) { - if cached_at.elapsed() < MOZ_SESSION_TIMEOUT { - return Ok(()) - } - } - auth_cache.remove(token); - } - - debug!("User {} not in cache, validating via auth0 endpoint", user); - // Retrieve the groups from the auth0 /userinfo endpoint, which Mozilla rules populate with groups - // https://github.com/mozilla-iam/auth0-deploy/blob/6889f1dde12b84af50bb4b2e2f00d5e80d5be33f/rules/CIS-Claims-fixups.js#L158-L168 - let url = reqwest::Url::parse(MOZ_USERINFO_ENDPOINT).unwrap(); - let header = reqwest::header::Authorization(reqwest::header::Bearer { token: token.to_owned() }); - let mut res = client.get(url.clone()).header(header).send().unwrap(); - let res_text = res.text().unwrap(); - if !res.status().is_success() { - bail!("JWT forwarded to {} returned {}: {}", url, res.status(), res_text) - } - - // The API didn't return a HTTP error code, let's check the response - let () = check_mozilla_profile(&user, required_groups, &res_text) - .chain_err(|| format!("Validation of the user profile failed for {}", user))?; - - // Validation success, cache the token - debug!("Validation for user {} succeeded, caching", user); - { - let mut auth_cache = auth_cache.lock().unwrap(); - auth_cache.insert(token.to_owned(), Instant::now()); - } - Ok(()) -} - -fn check_mozilla_profile(user: &str, required_groups: &[String], profile: &str) -> Result<()> { - #[derive(Deserialize)] - struct UserInfo { - sub: String, - #[serde(rename = "https://sso.mozilla.com/claim/groups")] - groups: Vec, - } - let profile: UserInfo = serde_json::from_str(profile) - .chain_err(|| format!("Could not parse profile: {}", profile))?; - if user != profile.sub { - bail!("User {} retrieved in profile is different to desired user {}", profile.sub, user) - } - for group in required_groups.iter() { - if !profile.groups.contains(group) { - bail!("User {} is not a member of required group {}", user, group) - } - } - Ok(()) -} - -#[test] -fn test_auth_verify_check_mozilla_profile() { - // A successful response - let profile = r#"{ - "sub": "ad|Mozilla-LDAP|asayers", - "https://sso.mozilla.com/claim/groups": [ - "everyone", - "hris_dept_firefox", - "hris_individual_contributor", - "hris_nonmanagers", - "hris_is_staff", - "hris_workertype_contractor" - ], - "https://sso.mozilla.com/claim/README_FIRST": "Please refer to https://github.com/mozilla-iam/person-api in order to query Mozilla IAM CIS user profile data" - }"#; - - // If the user has been deactivated since the token was issued. Note this may be partnered with an error code - // response so may never reach validation - let profile_fail = r#"{ - "error": "unauthorized", - "error_description": "user is blocked" - }"#; - - assert!(check_mozilla_profile("ad|Mozilla-LDAP|asayers", &["hris_dept_firefox".to_owned()], profile).is_ok()); - assert!(check_mozilla_profile("ad|Mozilla-LDAP|asayers", &[], profile).is_ok()); - assert!(check_mozilla_profile("ad|Mozilla-LDAP|asayers", &["hris_the_ceo".to_owned()], profile).is_err()); - - assert!(check_mozilla_profile("ad|Mozilla-LDAP|asayers", &[], profile_fail).is_err()); -} - -// Don't check a token is valid (it may not even be a JWT) just forward it to -// an API and check for success -fn check_token_forwarding(url: &str, maybe_auth_cache: &Option, Duration)>>, client: &reqwest::Client, token: &str) -> Result<()> { - #[derive(Deserialize)] - struct Token { - exp: u64, - } - let unsafe_token = jwt::dangerous_unsafe_decode::(token).chain_err(|| "Unable to decode jwt")?; - trace!("Validating token by forwarding to {}", url); - if UNIX_EPOCH + Duration::from_secs(unsafe_token.claims.exp) < SystemTime::now() { - bail!("JWT expired") - } - // If the token is cached and not cache has not expired, return it - if let Some(ref auth_cache) = maybe_auth_cache { - let mut auth_cache = auth_cache.lock().unwrap(); - let (ref mut auth_cache, cache_duration) = *auth_cache; - if let Some(cached_at) = auth_cache.get(token) { - if cached_at.elapsed() < cache_duration { - return Ok(()) - } - } - auth_cache.remove(token); - } - // Make a request to another API, which as a side effect should actually check the token - let header = reqwest::header::Authorization(reqwest::header::Bearer { token: token.to_owned() }); - let res = client.get(url).header(header).send().unwrap(); - if !res.status().is_success() { - bail!("JWT forwarded to {} returned {}", url, res.status()); - } - // Cache the token - if let Some(ref auth_cache) = maybe_auth_cache { - let mut auth_cache = auth_cache.lock().unwrap(); - let (ref mut auth_cache, _) = *auth_cache; - auth_cache.insert(token.to_owned(), Instant::now()); - } - Ok(()) -} - fn create_server_token(server_id: ServerId, auth_token: &str) -> String { format!("{} {}", server_id.addr(), auth_token) } @@ -408,7 +207,7 @@ fn check_server_token(server_token: &str, auth_token: &str) -> Option let mut split = server_token.splitn(2, |c| c == ' '); let server_addr = split.next().and_then(|addr| addr.parse().ok())?; match split.next() { - Some(t) if t == auth_token => Some(ServerId(server_addr)), + Some(t) if t == auth_token => Some(ServerId::new(server_addr)), Some(_) | None => None, } @@ -419,8 +218,8 @@ fn check_server_token(server_token: &str, auth_token: &str) -> Option struct ServerJwt { server_id: ServerId, } -fn create_jwt_server_token(server_id: ServerId, header: &jwt::Header, key: &[u8]) -> String { - jwt::encode(&header, &ServerJwt { server_id }, key).unwrap() +fn create_jwt_server_token(server_id: ServerId, header: &jwt::Header, key: &[u8]) -> Result { + jwt::encode(&header, &ServerJwt { server_id }, key).map_err(Into::into) } fn dangerous_unsafe_extract_jwt_server_token(server_token: &str) -> Option { jwt::dangerous_unsafe_decode::(&server_token) @@ -437,7 +236,7 @@ fn run(command: Command) -> Result { match command { Command::Auth(AuthSubcommand::Base64 { num_bytes }) => { let mut bytes = vec![0; num_bytes]; - let mut rng = rand::OsRng::new().unwrap(); + let mut rng = rand::OsRng::new().chain_err(|| "Failed to initialise a random number generator")?; rng.fill_bytes(&mut bytes); // As long as it can be copied, it doesn't matter if this is base64 or hex etc println!("{}", base64::encode_config(&bytes, base64::URL_SAFE_NO_PAD)); @@ -446,60 +245,25 @@ fn run(command: Command) -> Result { Command::Auth(AuthSubcommand::JwtHS256ServerToken { secret_key, server_id }) => { let header = jwt::Header::new(jwt::Algorithm::HS256); let secret_key = base64::decode_config(&secret_key, base64::URL_SAFE_NO_PAD)?; - println!("{}", create_jwt_server_token(server_id, &header, &secret_key)); + let token = create_jwt_server_token(server_id, &header, &secret_key) + .chain_err(|| "Failed to create server token")?; + println!("{}", token); Ok(0) }, - Command::Scheduler(scheduler_config::Config { client_auth, server_auth }) => { - let check_client_auth: dist::http::ClientAuthCheck = match client_auth { - scheduler_config::ClientAuth::Insecure => Box::new(move |s| s == INSECURE_DIST_CLIENT_TOKEN), - scheduler_config::ClientAuth::Token { token } => Box::new(move |s| s == token), - scheduler_config::ClientAuth::JwtValidate { audience, issuer, jwks_url } => { - let mut res = reqwest::get(&jwks_url).unwrap(); - if !res.status().is_success() { - bail!("Could not retrieve JWKs, HTTP error: {}", res.status()) - } - let jwks: Jwks = res.json().unwrap(); - let kid_to_pkcs1 = jwks.keys.into_iter() - .map(|k| k.to_der_pkcs1().map(|pkcs1| (k.kid, pkcs1)).unwrap()) - .collect(); - Box::new(move |s| { - match check_jwt_validity(&audience, &issuer, &kid_to_pkcs1, s) { - Ok(()) => true, - Err(e) => { - warn!("JWT validation failed: {}", e); - false - }, - } - }) - }, - scheduler_config::ClientAuth::Mozilla { required_groups } => { - let auth_cache: Mutex> = Mutex::new(HashMap::new()); - let client = reqwest::Client::new(); - Box::new(move |s| { - match check_mozilla(&auth_cache, &client, &required_groups, s) { - Ok(()) => true, - Err(e) => { - warn!("JWT validation failed: {}", e); - false - }, - } - }) - }, - scheduler_config::ClientAuth::ProxyToken { url, cache_secs } => { - let maybe_auth_cache: Option, Duration)>> = - cache_secs.map(|secs| Mutex::new((HashMap::new(), Duration::from_secs(secs)))); - let client = reqwest::Client::new(); - Box::new(move |s| { - match check_token_forwarding(&url, &maybe_auth_cache, &client, s) { - Ok(()) => true, - Err(e) => { - warn!("JWT validation failed: {}", e); - false - }, - } - }) - }, + Command::Scheduler(scheduler_config::Config { public_addr, client_auth, server_auth }) => { + let check_client_auth: Box = match client_auth { + scheduler_config::ClientAuth::Insecure => + Box::new(token_check::EqCheck::new(INSECURE_DIST_CLIENT_TOKEN.to_owned())), + scheduler_config::ClientAuth::Token { token } => + Box::new(token_check::EqCheck::new(token)), + scheduler_config::ClientAuth::JwtValidate { audience, issuer, jwks_url } => + Box::new(token_check::ValidJWTCheck::new(audience.to_owned(), issuer.to_owned(), &jwks_url) + .chain_err(|| "Failed to create a checker for valid JWTs")?), + scheduler_config::ClientAuth::Mozilla { required_groups } => + Box::new(token_check::MozillaCheck::new(required_groups)), + scheduler_config::ClientAuth::ProxyToken { url, cache_secs } => + Box::new(token_check::ProxyTokenCheck::new(url, cache_secs)), }; let check_server_auth: dist::http::ServerAuthCheck = match server_auth { @@ -522,18 +286,19 @@ fn run(command: Command) -> Result { }; let scheduler = Scheduler::new(); - let http_scheduler = dist::http::Scheduler::new(scheduler, check_client_auth, check_server_auth); - let _: Void = http_scheduler.start(); + let http_scheduler = dist::http::Scheduler::new(public_addr, scheduler, check_client_auth, check_server_auth); + void::unreachable(http_scheduler.start()?); }, - Command::Server(server_config::Config { builder, cache_dir, public_addr, scheduler_addr, scheduler_auth, toolchain_cache_size }) => { + Command::Server(server_config::Config { builder, cache_dir, public_addr, scheduler_url, scheduler_auth, toolchain_cache_size }) => { let builder: Box> = match builder { - server_config::BuilderType::Docker => Box::new(build::DockerBuilder::new()), + server_config::BuilderType::Docker => + Box::new(build::DockerBuilder::new().chain_err(|| "Docker builder failed to start")?), server_config::BuilderType::Overlay { bwrap_path, build_dir } => Box::new(build::OverlayBuilder::new(bwrap_path, build_dir).chain_err(|| "Overlay builder failed to start")?) }; - let server_id = ServerId(public_addr); + let server_id = ServerId::new(public_addr); let scheduler_auth = match scheduler_auth { server_config::SchedulerAuth::Insecure => { warn!("Server starting with DANGEROUSLY_INSECURE scheduler authentication"); @@ -551,9 +316,11 @@ fn run(command: Command) -> Result { } }; - let server = Server::new(builder, &cache_dir, toolchain_cache_size); - let http_server = dist::http::Server::new(scheduler_addr, scheduler_auth, server); - let _: Void = http_server.start(); + let server = Server::new(builder, &cache_dir, toolchain_cache_size) + .chain_err(|| "Failed to create sccache server instance")?; + let http_server = dist::http::Server::new(public_addr, scheduler_url.to_url(), scheduler_auth, server) + .chain_err(|| "Failed to create sccache HTTP server instance")?; + void::unreachable(http_server.start()?) }, } } @@ -593,7 +360,8 @@ struct ServerDetails { jobs_assigned: usize, last_seen: Instant, num_cpus: usize, - generate_job_auth: Box String + Send>, + server_nonce: ServerNonce, + job_authorizer: Box, } impl Scheduler { @@ -636,25 +404,22 @@ impl SchedulerIncoming for Scheduler { info!("Job {} created and assigned to server {:?}", job_id, server_id); assert!(jobs.insert(job_id, JobDetail { server_id, state: JobState::Pending }).is_none()); - let auth = (server_details.generate_job_auth)(job_id); + let auth = server_details.job_authorizer.generate_token(job_id) + .map_err(Error::from) + .chain_err(|| "Could not create an auth token for this job")?; (job_id, server_id, auth) } else { let msg = format!("Insufficient capacity across {} available servers", num_servers); return Ok(AllocJobResult::Fail { msg }) } }; - let AssignJobResult { need_toolchain } = requester.do_assign_job(server_id, job_id, tc, auth.clone()).chain_err(|| "assign job failed")?; - if !need_toolchain { - // LOCKS - let mut jobs = self.jobs.lock().unwrap(); - - jobs.get_mut(&job_id).unwrap().state = JobState::Ready - } + let AssignJobResult { need_toolchain } = requester.do_assign_job(server_id, job_id, tc, auth.clone()) + .chain_err(|| "assign job failed")?; let job_alloc = JobAlloc { auth, job_id, server_id }; Ok(AllocJobResult::Success { job_alloc, need_toolchain }) } - fn handle_heartbeat_server(&self, server_id: ServerId, num_cpus: usize, generate_job_auth: Box String + Send>) -> Result { + fn handle_heartbeat_server(&self, server_id: ServerId, server_nonce: ServerNonce, num_cpus: usize, job_authorizer: Box) -> Result { if num_cpus == 0 { bail!("Invalid number of CPUs (0) specified in heartbeat") } @@ -662,15 +427,22 @@ impl SchedulerIncoming for Scheduler { // LOCKS let mut servers = self.servers.lock().unwrap(); - let mut is_new = false; - servers.entry(server_id) - .and_modify(|details| details.last_seen = Instant::now()) - .or_insert_with(|| { - info!("Registered new server {:?}", server_id); - is_new = true; - ServerDetails { jobs_assigned: 0, num_cpus, generate_job_auth, last_seen: Instant::now() } - }); - Ok(HeartbeatServerResult { is_new }) + match servers.get_mut(&server_id) { + Some(ref mut details) if details.server_nonce == server_nonce => { + details.last_seen = Instant::now(); + return Ok(HeartbeatServerResult { is_new: false }) + }, + _ => (), + } + info!("Registered new server {:?}", server_id); + servers.insert(server_id, ServerDetails { + last_seen: Instant::now(), + jobs_assigned: 0, + num_cpus, + server_nonce, + job_authorizer, + }); + Ok(HeartbeatServerResult { is_new: true }) } fn handle_update_job_state(&self, job_id: JobId, server_id: ServerId, job_state: JobState) -> Result { @@ -693,7 +465,11 @@ impl SchedulerIncoming for Scheduler { (JobState::Started, JobState::Complete) => { let (job_id, job_entry) = entry.remove_entry(); finished_jobs.push_back((job_id, job_entry)); - servers.get_mut(&server_id).unwrap().jobs_assigned -= 1 + if let Some(entry) = servers.get_mut(&server_id) { + entry.jobs_assigned -= 1 + } else { + bail!("Job was marked as finished, but server is not known to scheduler") + } }, (from, to) => { bail!("Invalid job state transition from {} to {}", from, to) @@ -706,10 +482,10 @@ impl SchedulerIncoming for Scheduler { Ok(UpdateJobStateResult::Success) } - fn handle_status(&self) -> Result { + fn handle_status(&self) -> Result { let servers = self.servers.lock().unwrap(); - Ok(StatusResult { + Ok(SchedulerStatusResult { num_servers: servers.len(), }) } @@ -722,22 +498,25 @@ pub struct Server { } impl Server { - pub fn new(builder: Box>, cache_dir: &Path, toolchain_cache_size: u64) -> Server { - Server { + pub fn new(builder: Box>, cache_dir: &Path, toolchain_cache_size: u64) -> Result { + let cache = TcCache::new(&cache_dir.join("tc"), toolchain_cache_size) + .chain_err(|| "Failed to create toolchain cache")?; + Ok(Server { builder, - cache: Mutex::new(TcCache::new(&cache_dir.join("tc"), toolchain_cache_size).unwrap()), + cache: Mutex::new(cache), job_toolchains: Mutex::new(HashMap::new()), - } + }) } } impl ServerIncoming for Server { type Error = Error; - fn handle_assign_job(&self, job_id: JobId, tc: Toolchain) -> Result { + fn handle_assign_job(&self, requester: &ServerOutgoing, job_id: JobId, tc: Toolchain) -> Result { let need_toolchain = !self.cache.lock().unwrap().contains_toolchain(&tc); assert!(self.job_toolchains.lock().unwrap().insert(job_id, tc).is_none()); if !need_toolchain { - // TODO: can start prepping the container now + requester.do_update_job_state(job_id, JobState::Ready).chain_err(|| "Updating job state failed")?; + // TODO: can start prepping the build environment now } Ok(AssignJobResult { need_toolchain }) } diff --git a/src/bin/sccache-dist/token_check.rs b/src/bin/sccache-dist/token_check.rs new file mode 100644 index 00000000..d6ee9bd1 --- /dev/null +++ b/src/bin/sccache-dist/token_check.rs @@ -0,0 +1,336 @@ +use base64; +use jwt; +use openssl; +use reqwest; +use sccache::dist::http::{ClientAuthCheck, ClientVisibleMsg}; +use serde_json; +use std::collections::HashMap; +use std::result::Result as StdResult; +use std::sync::Mutex; +use std::time::{UNIX_EPOCH, Duration, Instant, SystemTime}; + +use errors::*; + +// https://auth0.com/docs/jwks +#[derive(Debug)] +#[derive(Serialize, Deserialize)] +pub struct Jwks { + pub keys: Vec, +} + +#[derive(Debug)] +#[derive(Serialize, Deserialize)] +pub struct Jwk { + pub kid: String, + kty: String, + n: String, + e: String, +} + +impl Jwk { + // https://github.com/lawliet89/biscuit/issues/96#issuecomment-399149872 + pub fn to_der_pkcs1(&self) -> Result> { + if self.kty != "RSA" { + bail!("Cannot handle non-RSA JWK") + } + + // JWK is big-endian, openssl bignum from_slice is big-endian + let n = base64::decode_config(&self.n, base64::URL_SAFE).chain_err(|| "Failed to base64 decode n")?; + let e = base64::decode_config(&self.e, base64::URL_SAFE).chain_err(|| "Failed to base64 decode e")?; + let n_bn = openssl::bn::BigNum::from_slice(&n).chain_err(|| "Failed to create openssl bignum from n")?; + let e_bn = openssl::bn::BigNum::from_slice(&e).chain_err(|| "Failed to create openssl bignum from e")?; + let pubkey = openssl::rsa::Rsa::from_public_components(n_bn, e_bn) + .chain_err(|| "Failed to create pubkey from n and e")?; + let der: Vec = pubkey.public_key_to_der_pkcs1() + .chain_err(|| "Failed to convert public key to der pkcs1")?; + Ok(der) + } +} + +// Check a token is equal to a fixed string +pub struct EqCheck { + s: String, +} + +impl ClientAuthCheck for EqCheck { + fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> { + if self.s == token { + Ok(()) + } else { + warn!("User token {} != expected token {}", token, self.s); + Err(ClientVisibleMsg::from_nonsensitive("Fixed token mismatch".to_owned())) + } + } +} + +impl EqCheck { + pub fn new(s: String) -> Self { + Self { s } + } +} + +// https://infosec.mozilla.org/guidelines/iam/openid_connect#session-handling +const MOZ_SESSION_TIMEOUT: Duration = Duration::from_secs(60 * 15); +const MOZ_USERINFO_ENDPOINT: &str = "https://auth.mozilla.auth0.com/userinfo"; + +// Mozilla-specific check by forwarding the token onto the auth0 userinfo endpoint +pub struct MozillaCheck { + auth_cache: Mutex>, // token, token_expiry + client: reqwest::Client, + required_groups: Vec, +} + +impl ClientAuthCheck for MozillaCheck { + fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> { + self.check_mozilla(token) + .map_err(|e| { + warn!("Mozilla token validation failed: {}", e); + ClientVisibleMsg::from_nonsensitive("Failed to validate Mozilla OAuth token".to_owned()) + }) + } +} + +impl MozillaCheck { + pub fn new(required_groups: Vec) -> Self { + Self { + auth_cache: Mutex::new(HashMap::new()), + client: reqwest::Client::new(), + required_groups, + } + } + + fn check_mozilla(&self, token: &str) -> Result<()> { + // azp == client_id + // { + // "iss": "https://auth.mozilla.auth0.com/", + // "sub": "ad|Mozilla-LDAP|asayers", + // "aud": [ + // "sccache", + // "https://auth.mozilla.auth0.com/userinfo" + // ], + // "iat": 1541103283, + // "exp": 1541708083, + // "azp": "F1VVD6nRTckSVrviMRaOdLBWIk1AvHYo", + // "scope": "openid" + // } + #[derive(Deserialize)] + struct MozillaToken { + exp: u64, + sub: String, + } + // We don't really do any validation here (just forwarding on) so it's ok to unsafely decode + let unsafe_token = jwt::dangerous_unsafe_decode::(token).chain_err(|| "Unable to decode jwt")?; + let user = unsafe_token.claims.sub; + trace!("Validating token for user {} with mozilla", user); + if UNIX_EPOCH + Duration::from_secs(unsafe_token.claims.exp) < SystemTime::now() { + bail!("JWT expired") + } + // If the token is cached and not expired, return it + { + let mut auth_cache = self.auth_cache.lock().unwrap(); + if let Some(cached_at) = auth_cache.get(token) { + if cached_at.elapsed() < MOZ_SESSION_TIMEOUT { + return Ok(()) + } + } + auth_cache.remove(token); + } + + debug!("User {} not in cache, validating via auth0 endpoint", user); + // Retrieve the groups from the auth0 /userinfo endpoint, which Mozilla rules populate with groups + // https://github.com/mozilla-iam/auth0-deploy/blob/6889f1dde12b84af50bb4b2e2f00d5e80d5be33f/rules/CIS-Claims-fixups.js#L158-L168 + let url = reqwest::Url::parse(MOZ_USERINFO_ENDPOINT).expect("Failed to parse MOZ_USERINFO_ENDPOINT"); + let header = reqwest::header::Authorization(reqwest::header::Bearer { token: token.to_owned() }); + let mut res = self.client.get(url.clone()).header(header).send() + .chain_err(|| "Failed to make request to mozilla userinfo")?; + let res_text = res.text() + .chain_err(|| "Failed to interpret response from mozilla userinfo as string")?; + if !res.status().is_success() { + bail!("JWT forwarded to {} returned {}: {}", url, res.status(), res_text) + } + + // The API didn't return a HTTP error code, let's check the response + let () = check_mozilla_profile(&user, &self.required_groups, &res_text) + .chain_err(|| format!("Validation of the user profile failed for {}", user))?; + + // Validation success, cache the token + debug!("Validation for user {} succeeded, caching", user); + { + let mut auth_cache = self.auth_cache.lock().unwrap(); + auth_cache.insert(token.to_owned(), Instant::now()); + } + Ok(()) + } +} + +fn check_mozilla_profile(user: &str, required_groups: &[String], profile: &str) -> Result<()> { + #[derive(Deserialize)] + struct UserInfo { + sub: String, + #[serde(rename = "https://sso.mozilla.com/claim/groups")] + groups: Vec, + } + let profile: UserInfo = serde_json::from_str(profile) + .chain_err(|| format!("Could not parse profile: {}", profile))?; + if user != profile.sub { + bail!("User {} retrieved in profile is different to desired user {}", profile.sub, user) + } + for group in required_groups.iter() { + if !profile.groups.contains(group) { + bail!("User {} is not a member of required group {}", user, group) + } + } + Ok(()) +} + +#[test] +fn test_auth_verify_check_mozilla_profile() { + // A successful response + let profile = r#"{ + "sub": "ad|Mozilla-LDAP|asayers", + "https://sso.mozilla.com/claim/groups": [ + "everyone", + "hris_dept_firefox", + "hris_individual_contributor", + "hris_nonmanagers", + "hris_is_staff", + "hris_workertype_contractor" + ], + "https://sso.mozilla.com/claim/README_FIRST": "Please refer to https://github.com/mozilla-iam/person-api in order to query Mozilla IAM CIS user profile data" + }"#; + + // If the user has been deactivated since the token was issued. Note this may be partnered with an error code + // response so may never reach validation + let profile_fail = r#"{ + "error": "unauthorized", + "error_description": "user is blocked" + }"#; + + assert!(check_mozilla_profile("ad|Mozilla-LDAP|asayers", &["hris_dept_firefox".to_owned()], profile).is_ok()); + assert!(check_mozilla_profile("ad|Mozilla-LDAP|asayers", &[], profile).is_ok()); + assert!(check_mozilla_profile("ad|Mozilla-LDAP|asayers", &["hris_the_ceo".to_owned()], profile).is_err()); + + assert!(check_mozilla_profile("ad|Mozilla-LDAP|asayers", &[], profile_fail).is_err()); +} + +// Don't check a token is valid (it may not even be a JWT) just forward it to +// an API and check for success +pub struct ProxyTokenCheck { + client: reqwest::Client, + maybe_auth_cache: Option, Duration)>>, + url: String, +} + +impl ClientAuthCheck for ProxyTokenCheck { + fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> { + match self.check_token_with_forwarding(token) { + Ok(()) => Ok(()), + Err(e) => { + warn!("Proxying token validation failed: {}", e); + Err(ClientVisibleMsg::from_nonsensitive("Validation with token forwarding failed".to_owned())) + }, + } + } +} + +impl ProxyTokenCheck { + pub fn new(url: String, cache_secs: Option) -> Self { + let maybe_auth_cache: Option, Duration)>> = + cache_secs.map(|secs| Mutex::new((HashMap::new(), Duration::from_secs(secs)))); + Self { + client: reqwest::Client::new(), + maybe_auth_cache, + url, + } + } + + fn check_token_with_forwarding(&self, token: &str) -> Result<()> { + #[derive(Deserialize)] + struct Token { + exp: u64, + } + let unsafe_token = jwt::dangerous_unsafe_decode::(token).chain_err(|| "Unable to decode jwt")?; + trace!("Validating token by forwarding to {}", self.url); + if UNIX_EPOCH + Duration::from_secs(unsafe_token.claims.exp) < SystemTime::now() { + bail!("JWT expired") + } + // If the token is cached and not cache has not expired, return it + if let Some(ref auth_cache) = self.maybe_auth_cache { + let mut auth_cache = auth_cache.lock().unwrap(); + let (ref mut auth_cache, cache_duration) = *auth_cache; + if let Some(cached_at) = auth_cache.get(token) { + if cached_at.elapsed() < cache_duration { + return Ok(()) + } + } + auth_cache.remove(token); + } + // Make a request to another API, which as a side effect should actually check the token + let header = reqwest::header::Authorization(reqwest::header::Bearer { token: token.to_owned() }); + let res = self.client.get(&self.url).header(header).send() + .chain_err(|| "Failed to make request to proxying url")?; + if !res.status().is_success() { + bail!("JWT forwarded to {} returned {}", self.url, res.status()); + } + // Cache the token + if let Some(ref auth_cache) = self.maybe_auth_cache { + let mut auth_cache = auth_cache.lock().unwrap(); + let (ref mut auth_cache, _) = *auth_cache; + auth_cache.insert(token.to_owned(), Instant::now()); + } + Ok(()) + } +} + +// Check a JWT is valid +pub struct ValidJWTCheck { + audience: String, + issuer: String, + kid_to_pkcs1: HashMap>, +} + +impl ClientAuthCheck for ValidJWTCheck { + fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> { + match self.check_jwt_validity(token) { + Ok(()) => Ok(()), + Err(e) => { + warn!("JWT validation failed: {}", e); + Err(ClientVisibleMsg::from_nonsensitive("JWT could not be validated".to_owned())) + }, + } + } +} + +impl ValidJWTCheck { + pub fn new(audience: String, issuer: String, jwks_url: &str) -> Result { + let mut res = reqwest::get(jwks_url) + .chain_err(|| "Failed to make request to JWKs url")?; + if !res.status().is_success() { + bail!("Could not retrieve JWKs, HTTP error: {}", res.status()) + } + let jwks: Jwks = res.json() + .chain_err(|| "Failed to parse JWKs json")?; + let kid_to_pkcs1 = jwks.keys.into_iter() + .map(|k| k.to_der_pkcs1().map(|pkcs1| (k.kid, pkcs1))) + .collect::>() + .chain_err(|| "Failed to convert JWKs into pkcs1")?; + Ok(Self { audience, issuer, kid_to_pkcs1}) + } + + fn check_jwt_validity(&self, token: &str) -> Result<()> { + let header = jwt::decode_header(token).chain_err(|| "Could not decode jwt header")?; + trace!("Validating JWT in scheduler"); + // Prepare validation + let kid = header.kid.chain_err(|| "No kid found")?; + let pkcs1 = self.kid_to_pkcs1.get(&kid).chain_err(|| "kid not found in jwks")?; + let mut validation = jwt::Validation::new(header.alg); + validation.set_audience(&self.audience); + validation.iss = Some(self.issuer.clone()); + #[derive(Deserialize)] + struct Claims {} + // Decode the JWT, discarding any claims - we just care about validity + let _tokendata = jwt::decode::(token, pkcs1, &validation) + .chain_err(|| "Unable to validate and decode jwt")?; + Ok(()) + } +} diff --git a/src/cmdline.rs b/src/cmdline.rs index e423e465..b5e373fb 100644 --- a/src/cmdline.rs +++ b/src/cmdline.rs @@ -172,9 +172,12 @@ pub fn parse() -> Result { } else if dist_auth { Ok(Command::DistAuth) } else if package_toolchain { - let mut values = matches.values_of_os("package-toolchain").unwrap(); + let mut values = matches.values_of_os("package-toolchain").expect("Parsed package-toolchain but no values"); assert!(values.len() == 2); - let (executable, out) = (values.next().unwrap(), values.next().unwrap()); + let (executable, out) = ( + values.next().expect("package-toolchain missing value 1"), + values.next().expect("package-toolchain missing value 2") + ); Ok(Command::PackageToolchain(executable.into(), out.into())) } else if let Some(mut args) = cmd { if let Some(exe) = args.next() { diff --git a/src/compiler/c.rs b/src/compiler/c.rs index 26e0b266..639edc84 100644 --- a/src/compiler/c.rs +++ b/src/compiler/c.rs @@ -348,7 +348,8 @@ impl pkg::InputsPackager for CInputsPackager { let CInputsPackager { input_path, mut path_transformer, preprocessed_input } = *{self}; let input_path = pkg::simplify_path(&input_path)?; - let dist_input_path = path_transformer.to_dist(&input_path).unwrap(); + let dist_input_path = path_transformer.to_dist(&input_path) + .chain_err(|| format!("unable to transform input path {}", input_path.display()))?; let mut builder = tar::Builder::new(wtr); diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 651a018a..8b51ce46 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -38,7 +38,7 @@ use std::borrow::Cow; use std::collections::HashMap; use std::ffi::OsString; use std::fmt; -#[cfg(unix)] +#[cfg(any(feature = "dist-client", unix))] use std::fs; use std::fs::File; use std::io::prelude::*; @@ -129,7 +129,7 @@ pub trait CompilerHasher: fmt::Debug + Send + 'static /// Look up a cached compile result in `storage`. If not found, run the /// compile and store the result. fn get_cached_or_compile(self: Box, - dist_client: Arc, + dist_client: Option>, creator: T, storage: Arc, arguments: Vec, @@ -143,7 +143,7 @@ pub trait CompilerHasher: fmt::Debug + Send + 'static let out_pretty = self.output_pretty().into_owned(); debug!("[{}]: get_cached_or_compile: {:?}", out_pretty, arguments); let start = Instant::now(); - let result = self.generate_hash_key(&creator, cwd.clone(), env_vars, dist_client.may_dist(), &pool); + let result = self.generate_hash_key(&creator, cwd.clone(), env_vars, dist_client.is_some(), &pool); Box::new(result.then(move |res| -> SFuture<_> { debug!("[{}]: generate_hash_key took {}", out_pretty, fmt_duration_as_secs(&start.elapsed())); let (key, compilation, weak_toolchain_key) = match res { @@ -244,7 +244,7 @@ pub trait CompilerHasher: fmt::Debug + Send + 'static let start = Instant::now(); let compile = dist_or_local_compile(dist_client, creator, cwd, compilation, weak_toolchain_key, out_pretty.clone()); - Box::new(compile.and_then(move |(cacheable, compiler_result)| { + Box::new(compile.and_then(move |(cacheable, dist_type, compiler_result)| { let duration = start.elapsed(); if !compiler_result.status.success() { debug!("[{}]: Compiled but failed, not storing in cache", @@ -296,7 +296,7 @@ pub trait CompilerHasher: fmt::Debug + Send + 'static }) }); let future = Box::new(future); - Ok((CompileResult::CacheMiss(miss_type, duration, future), compiler_result)) + Ok((CompileResult::CacheMiss(miss_type, dist_type, duration, future), compiler_result)) }).chain_err(move || { format!("failed to store `{}` to cache", o) })) @@ -315,116 +315,176 @@ pub trait CompilerHasher: fmt::Debug + Send + 'static } #[cfg(not(feature = "dist-client"))] -fn dist_or_local_compile(_dist_client: Arc, +fn dist_or_local_compile(_dist_client: Option>, creator: T, _cwd: PathBuf, compilation: Box, _weak_toolchain_key: String, out_pretty: String) - -> SFuture<(Cacheable, process::Output)> + -> SFuture<(Cacheable, DistType, process::Output)> where T: CommandCreatorSync { - debug!("[{}]: Compiling locally", out_pretty); let mut path_transformer = dist::PathTransformer::new(); - let (compile_cmd, _dist_compile_cmd, cacheable) = compilation.generate_compile_commands(&mut path_transformer).unwrap(); + let compile_commands = compilation.generate_compile_commands(&mut path_transformer) + .chain_err(|| "Failed to generate compile commands"); + let (compile_cmd, _dist_compile_cmd, cacheable) = match compile_commands { + Ok(cmds) => cmds, + Err(e) => return f_err(e), + }; + + debug!("[{}]: Compiling locally", out_pretty); Box::new(compile_cmd.execute(&creator) - .map(move |o| (cacheable, o))) + .map(move |o| (cacheable, DistType::NoDist, o))) } #[cfg(feature = "dist-client")] -fn dist_or_local_compile(dist_client: Arc, +fn dist_or_local_compile(dist_client: Option>, creator: T, cwd: PathBuf, compilation: Box, weak_toolchain_key: String, out_pretty: String) - -> SFuture<(Cacheable, process::Output)> + -> SFuture<(Cacheable, DistType, process::Output)> where T: CommandCreatorSync { use futures::future; - use std::error::Error as StdError; use std::io; + let mut path_transformer = dist::PathTransformer::new(); + let compile_commands = compilation.generate_compile_commands(&mut path_transformer) + .chain_err(|| "Failed to generate compile commands"); + let (compile_cmd, dist_compile_cmd, cacheable) = match compile_commands { + Ok(cmds) => cmds, + Err(e) => return f_err(e), + }; + + let dist_client = match dist_client { + Some(dc) => dc, + None => { + debug!("[{}]: Compiling locally", out_pretty); + return Box::new(compile_cmd.execute(&creator) + .map(move |o| (cacheable, DistType::NoDist, o))) + } + }; + debug!("[{}]: Attempting distributed compilation", out_pretty); let compile_out_pretty = out_pretty.clone(); let compile_out_pretty2 = out_pretty.clone(); let compile_out_pretty3 = out_pretty.clone(); - let mut path_transformer = dist::PathTransformer::new(); - let (compile_cmd, dist_compile_cmd, cacheable) = compilation.generate_compile_commands(&mut path_transformer).unwrap(); + let compile_out_pretty4 = out_pretty.clone(); let local_executable = compile_cmd.executable.clone(); // TODO: the number of map_errs is subideal, but there's no futures-based carrier trait AFAIK Box::new(future::result(dist_compile_cmd.ok_or_else(|| "Could not create distributed compile command".into())) .and_then(move |dist_compile_cmd| { debug!("[{}]: Creating distributed compile request", compile_out_pretty); let dist_output_paths = compilation.outputs() - .map(|(_key, path)| path_transformer.to_dist_assert_abs(&cwd.join(path))) + .map(|(_key, path)| path_transformer.to_dist_abs(&cwd.join(path))) .collect::>() - .unwrap(); + .ok_or_else(|| Error::from("Failed to adapt an output path for distributed compile"))?; compilation.into_dist_packagers(path_transformer) .map(|packagers| (dist_compile_cmd, packagers, dist_output_paths)) }) .and_then(move |(mut dist_compile_cmd, (inputs_packager, toolchain_packager, outputs_rewriter), dist_output_paths)| { debug!("[{}]: Identifying dist toolchain for {:?}", compile_out_pretty2, local_executable); - // TODO: put on a thread - let (dist_toolchain, maybe_dist_compile_executable) = - ftry!(dist_client.put_toolchain(&local_executable, &weak_toolchain_key, toolchain_packager)); - if let Some(dist_compile_executable) = maybe_dist_compile_executable { - dist_compile_cmd.executable = dist_compile_executable; - } - - debug!("[{}]: Requesting allocation", compile_out_pretty2); - Box::new(dist_client.do_alloc_job(dist_toolchain.clone()).map_err(Into::into) + dist_client.put_toolchain(&local_executable, &weak_toolchain_key, toolchain_packager) + .and_then(|(dist_toolchain, maybe_dist_compile_executable)| { + if let Some(dist_compile_executable) = maybe_dist_compile_executable { + dist_compile_cmd.executable = dist_compile_executable; + } + Ok((dist_client, dist_compile_cmd, dist_toolchain, inputs_packager, outputs_rewriter, dist_output_paths)) + }) + }) + .and_then(move |(dist_client, dist_compile_cmd, dist_toolchain, inputs_packager, outputs_rewriter, dist_output_paths)| { + debug!("[{}]: Requesting allocation", compile_out_pretty3); + dist_client.do_alloc_job(dist_toolchain.clone()).chain_err(|| "failed to allocate job") .and_then(move |jares| { let alloc = match jares { dist::AllocJobResult::Success { job_alloc, need_toolchain: true } => { - debug!("[{}]: Sending toolchain", compile_out_pretty2); + debug!("[{}]: Sending toolchain {} for job {}", + compile_out_pretty3, dist_toolchain.archive_id, job_alloc.job_id); Box::new(dist_client.do_submit_toolchain(job_alloc.clone(), dist_toolchain) - .map(move |res| { + .and_then(move |res| { match res { - dist::SubmitToolchainResult::Success => job_alloc, - dist::SubmitToolchainResult::JobNotFound | - dist::SubmitToolchainResult::CannotCache => panic!(), + dist::SubmitToolchainResult::Success => Ok(job_alloc), + dist::SubmitToolchainResult::JobNotFound => + bail!("Job {} not found on server", job_alloc.job_id), + dist::SubmitToolchainResult::CannotCache => + bail!("Toolchain for job {} could not be cached by server", job_alloc.job_id), } - }).chain_err(|| "Could not submit toolchain")) + }) + .chain_err(|| "Could not submit toolchain")) }, dist::AllocJobResult::Success { job_alloc, need_toolchain: false } => f_ok(job_alloc), dist::AllocJobResult::Fail { msg } => - f_err(Error::with_chain(Error::from("Failed to allocate job"), msg)), + f_err(Error::from("Failed to allocate job").chain_err(|| msg)), }; alloc .and_then(move |job_alloc| { - debug!("[{}]: Running job", compile_out_pretty2); + let job_id = job_alloc.job_id; + debug!("[{}]: Running job", compile_out_pretty3); dist_client.do_run_job(job_alloc, dist_compile_cmd, dist_output_paths, inputs_packager) - .map_err(Into::into) + .map(move |res| (job_id, res)) + .chain_err(|| "could not run distributed compilation job") }) }) - .map(move |(jres, path_transformer)| { + .and_then(move |(job_id, (jres, path_transformer))| { let jc = match jres { dist::RunJobResult::Complete(jc) => jc, - dist::RunJobResult::JobNotFound => panic!(), + dist::RunJobResult::JobNotFound => bail!("Job {} not found on server", job_id), }; info!("fetched {:?}", jc.outputs.iter().map(|&(ref p, ref bs)| (p, bs.lens().to_string())).collect::>()); - let mut output_paths = vec![]; + let mut output_paths: Vec = vec![]; + macro_rules! try_or_cleanup { + ($v:expr) => {{ + match $v { + Ok(v) => v, + Err(e) => { + // Do our best to clear up. We may end up deleting a file that we just wrote over + // the top of, but it's better to clear up too much than too little + for local_path in output_paths.iter() { + if let Err(e) = fs::remove_file(local_path) { + if e.kind() != io::ErrorKind::NotFound { + warn!("{} while attempting to clear up {}", e, local_path.display()) + } + } + } + return Err(e) + }, + } + }}; + } + for (path, output_data) in jc.outputs { let len = output_data.lens().actual; - let local_path = path_transformer.to_local(&path); - let mut file = File::create(&local_path).unwrap(); - let count = io::copy(&mut output_data.into_reader(), &mut file).unwrap(); + let local_path = try_or_cleanup!(path_transformer.to_local(&path) + .chain_err(|| format!("unable to transform output path {}", path))); + output_paths.push(local_path); + // Do this first so cleanup works correctly + let local_path = output_paths.last().expect("nothing in vec after push"); + + let mut file = try_or_cleanup!(File::create(&local_path) + .chain_err(|| format!("Failed to create output file {}", local_path.display()))); + let count = try_or_cleanup!(io::copy(&mut output_data.into_reader(), &mut file) + .chain_err(|| format!("Failed to write output to {}", local_path.display()))); + assert!(count == len); - output_paths.push((path, local_path)) } - outputs_rewriter.handle_outputs(&path_transformer, output_paths).unwrap(); - jc.output.into() + try_or_cleanup!(outputs_rewriter.handle_outputs(&path_transformer, &output_paths) + .chain_err(|| "failed to rewrite outputs from compile")); + Ok((DistType::Ok, jc.output.into())) }) - ) }) // Something failed, do a local compilation .or_else(move |e| { - let cause = e.cause().map(|c| format!(": {}", c)).unwrap_or_else(String::new); - info!("[{}]: Could not perform distributed compile, falling back to local: {}{}", compile_out_pretty3, e, cause); - compile_cmd.execute(&creator) + let mut errmsg = e.to_string(); + for cause in e.iter() { + errmsg.push_str(": "); + errmsg.push_str(&cause.to_string()); + } + warn!("[{}]: Could not perform distributed compile, falling back to local: {}", compile_out_pretty4, errmsg); + compile_cmd.execute(&creator).map(|o| (DistType::Error, o)) }) - .map(move |o| (cacheable, o)) + .map(move |(dt, o)| (cacheable, dt, o)) ) } @@ -457,16 +517,15 @@ pub trait Compilation { #[cfg(feature = "dist-client")] pub trait OutputsRewriter { - fn handle_outputs(self: Box, path_transformer: &dist::PathTransformer, output_paths: Vec<(String, PathBuf)>) - -> Result<()>; + /// Perform any post-compilation handling of outputs, given a Vec of the dist_path and local_path + fn handle_outputs(self: Box, path_transformer: &dist::PathTransformer, output_paths: &[PathBuf]) -> Result<()>; } #[cfg(feature = "dist-client")] pub struct NoopOutputsRewriter; #[cfg(feature = "dist-client")] impl OutputsRewriter for NoopOutputsRewriter { - fn handle_outputs(self: Box, _path_transformer: &dist::PathTransformer, _output_paths: Vec<(String, PathBuf)>) - -> Result<()> { + fn handle_outputs(self: Box, _path_transformer: &dist::PathTransformer, _output_paths: &[PathBuf]) -> Result<()> { Ok(()) } } @@ -513,6 +572,17 @@ macro_rules! try_or_cannot_cache { }}; } +/// Specifics about distributed compilation. +#[derive(Debug, PartialEq)] +pub enum DistType { + /// Distribution was not enabled. + NoDist, + /// Distributed compile success. + Ok, + /// Distributed compile failed. + Error, +} + /// Specifics about cache misses. #[derive(Debug, PartialEq)] pub enum MissType { @@ -542,7 +612,7 @@ pub enum CompileResult { /// /// The `CacheWriteFuture` will resolve when the result is finished /// being stored in the cache. - CacheMiss(MissType, Duration, SFuture), + CacheMiss(MissType, DistType, Duration, SFuture), /// Not in cache, but the compilation result was determined to be not cacheable. NotCacheable, /// Not in cache, but compilation failed. @@ -568,7 +638,7 @@ impl fmt::Debug for CompileResult { match self { &CompileResult::Error => write!(f, "CompileResult::Error"), &CompileResult::CacheHit(ref d) => write!(f, "CompileResult::CacheHit({:?})", d), - &CompileResult::CacheMiss(ref m, ref d, _) => write!(f, "CompileResult::CacheMiss({:?}, {:?}, _)", d, m), + &CompileResult::CacheMiss(ref m, ref dt, ref d, _) => write!(f, "CompileResult::CacheMiss({:?}, {:?}, {:?}, _)", d, m, dt), &CompileResult::NotCacheable => write!(f, "CompileResult::NotCacheable"), &CompileResult::CompileFailed => write!(f, "CompileResult::CompileFailed"), } @@ -581,7 +651,7 @@ impl PartialEq for CompileResult { match (self, other) { (&CompileResult::Error, &CompileResult::Error) => true, (&CompileResult::CacheHit(_), &CompileResult::CacheHit(_)) => true, - (&CompileResult::CacheMiss(ref m, _, _), &CompileResult::CacheMiss(ref n, _, _)) => m == n, + (&CompileResult::CacheMiss(ref m, ref dt, _, _), &CompileResult::CacheMiss(ref n, ref dt2, _, _)) => m == n && dt == dt2, (&CompileResult::NotCacheable, &CompileResult::NotCacheable) => true, (&CompileResult::CompileFailed, &CompileResult::CompileFailed) => true, _ => false, @@ -814,7 +884,6 @@ mod test { use super::*; use cache::Storage; use cache::disk::DiskCache; - use dist; use futures::Future; use futures_cpupool::CpuPool; use mock_command::*; @@ -926,7 +995,7 @@ LLVM version: 6.0", ""))); } #[test] - fn test_compiler_get_cached_or_compile_uncached() { + fn test_compiler_get_cached_or_compile() { use env_logger; drop(env_logger::try_init()); let creator = new_creator(); @@ -934,7 +1003,7 @@ LLVM version: 6.0", ""))); let pool = CpuPool::new(1); let core = Core::new().unwrap(); let handle = core.handle(); - let dist_client = Arc::new(dist::NoopClient); + let dist_client = None; let storage = DiskCache::new(&f.tempdir.path().join("cache"), u64::MAX, &pool); @@ -977,7 +1046,7 @@ LLVM version: 6.0", ""))); // Ensure that the object file was created. assert_eq!(true, fs::metadata(&obj).and_then(|m| Ok(m.len() > 0)).unwrap()); match cached { - CompileResult::CacheMiss(MissType::Normal, _, f) => { + CompileResult::CacheMiss(MissType::Normal, DistType::NoDist, _, f) => { // wait on cache write future so we don't race with it! f.wait().unwrap(); } @@ -1009,7 +1078,8 @@ LLVM version: 6.0", ""))); } #[test] - fn test_compiler_get_cached_or_compile_cached() { + #[cfg(feature = "dist-client")] + fn test_compiler_get_cached_or_compile_dist() { use env_logger; drop(env_logger::try_init()); let creator = new_creator(); @@ -1017,7 +1087,6 @@ LLVM version: 6.0", ""))); let pool = CpuPool::new(1); let core = Core::new().unwrap(); let handle = core.handle(); - let dist_client = Arc::new(dist::NoopClient); let storage = DiskCache::new(&f.tempdir.path().join("cache"), u64::MAX, &pool); @@ -1034,13 +1103,8 @@ LLVM version: 6.0", ""))); const COMPILER_STDOUT : &'static [u8] = b"compiler stdout"; const COMPILER_STDERR : &'static [u8] = b"compiler stderr"; let obj = f.tempdir.path().join("foo.o"); - let o = obj.clone(); - next_command_calls(&creator, move |_| { - // Pretend to compile something. - let mut f = File::create(&o)?; - f.write_all(b"file contents")?; - Ok(MockChild::new(exit_status(0), COMPILER_STDOUT, COMPILER_STDERR)) - }); + // Dist client will do the compilation + let dist_client = Some(test_dist::OneshotClient::new(0, COMPILER_STDOUT.to_owned(), COMPILER_STDERR.to_owned())); let cwd = f.tempdir.path(); let arguments = ovec!["-c", "foo.c", "-o", "foo.o"]; let hasher = match c.parse_arguments(&arguments, ".".as_ref()) { @@ -1060,13 +1124,12 @@ LLVM version: 6.0", ""))); // Ensure that the object file was created. assert_eq!(true, fs::metadata(&obj).and_then(|m| Ok(m.len() > 0)).unwrap()); match cached { - CompileResult::CacheMiss(MissType::Normal, _, f) => { + CompileResult::CacheMiss(MissType::Normal, DistType::Ok, _, f) => { // wait on cache write future so we don't race with it! f.wait().unwrap(); } _ => assert!(false, "Unexpected compile result: {:?}", cached), } - assert_eq!(exit_status(0), res.status); assert_eq!(COMPILER_STDOUT, res.stdout.as_slice()); assert_eq!(COMPILER_STDERR, res.stderr.as_slice()); @@ -1103,7 +1166,7 @@ LLVM version: 6.0", ""))); let pool = CpuPool::new(1); let core = Core::new().unwrap(); let handle = core.handle(); - let dist_client = Arc::new(dist::NoopClient); + let dist_client = None; let storage = MockStorage::new(); let storage: Arc = Arc::new(storage); // Pretend to be GCC. @@ -1145,7 +1208,7 @@ LLVM version: 6.0", ""))); // Ensure that the object file was created. assert_eq!(true, fs::metadata(&obj).and_then(|m| Ok(m.len() > 0)).unwrap()); match cached { - CompileResult::CacheMiss(MissType::CacheReadError, _, f) => { + CompileResult::CacheMiss(MissType::CacheReadError, DistType::NoDist, _, f) => { // wait on cache write future so we don't race with it! f.wait().unwrap(); } @@ -1166,7 +1229,7 @@ LLVM version: 6.0", ""))); let pool = CpuPool::new(1); let core = Core::new().unwrap(); let handle = core.handle(); - let dist_client = Arc::new(dist::NoopClient); + let dist_client = None; let storage = DiskCache::new(&f.tempdir.path().join("cache"), u64::MAX, &pool); @@ -1213,7 +1276,7 @@ LLVM version: 6.0", ""))); // Ensure that the object file was created. assert_eq!(true, fs::metadata(&obj).and_then(|m| Ok(m.len() > 0)).unwrap()); match cached { - CompileResult::CacheMiss(MissType::Normal, _, f) => { + CompileResult::CacheMiss(MissType::Normal, DistType::NoDist, _, f) => { // wait on cache write future so we don't race with it! f.wait().unwrap(); } @@ -1236,7 +1299,7 @@ LLVM version: 6.0", ""))); // Ensure that the object file was created. assert_eq!(true, fs::metadata(&obj).and_then(|m| Ok(m.len() > 0)).unwrap()); match cached { - CompileResult::CacheMiss(MissType::ForcedRecache, _, f) => { + CompileResult::CacheMiss(MissType::ForcedRecache, DistType::NoDist, _, f) => { // wait on cache write future so we don't race with it! f.wait().unwrap(); } @@ -1256,7 +1319,7 @@ LLVM version: 6.0", ""))); let pool = CpuPool::new(1); let core = Core::new().unwrap(); let handle = core.handle(); - let dist_client = Arc::new(dist::NoopClient); + let dist_client = None; let storage = DiskCache::new(&f.tempdir.path().join("cache"), u64::MAX, &pool); @@ -1291,4 +1354,282 @@ LLVM version: 6.0", ""))); assert_eq!(b"", res.stdout.as_slice()); assert_eq!(PREPROCESSOR_STDERR, res.stderr.as_slice()); } + + #[test] + #[cfg(feature = "dist-client")] + fn test_compiler_get_cached_or_compile_dist_error() { + use env_logger; + drop(env_logger::try_init()); + let creator = new_creator(); + let f = TestFixture::new(); + let pool = CpuPool::new(1); + let core = Core::new().unwrap(); + let handle = core.handle(); + let dist_clients = vec![ + test_dist::ErrorPutToolchainClient::new(), + test_dist::ErrorAllocJobClient::new(), + test_dist::ErrorSubmitToolchainClient::new(), + test_dist::ErrorRunJobClient::new(), + ]; + let storage = DiskCache::new(&f.tempdir.path().join("cache"), + u64::MAX, + &pool); + let storage: Arc = Arc::new(storage); + // Pretend to be GCC. + next_command(&creator, Ok(MockChild::new(exit_status(0), "gcc", ""))); + let c = get_compiler_info(&creator, + &f.bins[0], + &[], + &pool).wait().unwrap(); + const COMPILER_STDOUT: &'static [u8] = b"compiler stdout"; + const COMPILER_STDERR: &'static [u8] = b"compiler stderr"; + // The compiler should be invoked twice, since we're forcing + // recaching. + let obj = f.tempdir.path().join("foo.o"); + for _ in dist_clients.iter() { + // The preprocessor invocation. + next_command(&creator, Ok(MockChild::new(exit_status(0), "preprocessor output", ""))); + // The compiler invocation. + let o = obj.clone(); + next_command_calls(&creator, move |_| { + // Pretend to compile something. + let mut f = File::create(&o)?; + f.write_all(b"file contents")?; + Ok(MockChild::new(exit_status(0), COMPILER_STDOUT, COMPILER_STDERR)) + }); + } + let cwd = f.tempdir.path(); + let arguments = ovec!["-c", "foo.c", "-o", "foo.o"]; + let hasher = match c.parse_arguments(&arguments, ".".as_ref()) { + CompilerArguments::Ok(h) => h, + o @ _ => panic!("Bad result from parse_arguments: {:?}", o), + }; + // All these dist clients will fail, but should still result in successful compiles + for dist_client in dist_clients { + if obj.is_file() { + fs::remove_file(&obj).unwrap(); + } + let hasher = hasher.clone(); + let (cached, res) = hasher.get_cached_or_compile(Some(dist_client.clone()), + creator.clone(), + storage.clone(), + arguments.clone(), + cwd.to_path_buf(), + vec![], + CacheControl::ForceRecache, + pool.clone(), + handle.clone()).wait().unwrap(); + // Ensure that the object file was created. + assert_eq!(true, fs::metadata(&obj).and_then(|m| Ok(m.len() > 0)).unwrap()); + match cached { + CompileResult::CacheMiss(MissType::ForcedRecache, DistType::Error, _, f) => { + // wait on cache write future so we don't race with it! + f.wait().unwrap(); + } + _ => assert!(false, "Unexpected compile result: {:?}", cached), + } + assert_eq!(exit_status(0), res.status); + assert_eq!(COMPILER_STDOUT, res.stdout.as_slice()); + assert_eq!(COMPILER_STDERR, res.stderr.as_slice()); + } + } +} + +#[cfg(test)] +#[cfg(feature = "dist-client")] +mod test_dist { + use dist::pkg; + use dist::{ + self, + + CompileCommand, + PathTransformer, + + JobId, ServerId, + JobAlloc, Toolchain, OutputData, ProcessOutput, + + AllocJobResult, RunJobResult, SubmitToolchainResult, JobComplete, + }; + use std::cell::Cell; + use std::path::Path; + use std::sync::Arc; + + use errors::*; + + pub struct ErrorPutToolchainClient; + impl ErrorPutToolchainClient { + pub fn new() -> Arc { + Arc::new(ErrorPutToolchainClient) + } + } + impl dist::Client for ErrorPutToolchainClient { + fn do_alloc_job(&self, _: Toolchain) -> SFuture { + unreachable!() + } + fn do_submit_toolchain(&self, _: JobAlloc, _: Toolchain) -> SFuture { + unreachable!() + } + fn do_run_job(&self, _: JobAlloc, _: CompileCommand, _: Vec, _: Box) -> SFuture<(RunJobResult, PathTransformer)> { + unreachable!() + } + fn put_toolchain(&self, _: &Path, _: &str, _: Box) -> SFuture<(Toolchain, Option)> { + f_err("put toolchain failure") + } + } + + pub struct ErrorAllocJobClient { + tc: Toolchain, + } + impl ErrorAllocJobClient { + pub fn new() -> Arc { + Arc::new(Self { + tc: Toolchain { archive_id: "somearchiveid".to_owned() }, + }) + } + } + impl dist::Client for ErrorAllocJobClient { + fn do_alloc_job(&self, tc: Toolchain) -> SFuture { + assert_eq!(self.tc, tc); + f_err("alloc job failure") + } + fn do_submit_toolchain(&self, _: JobAlloc, _: Toolchain) -> SFuture { + unreachable!() + } + fn do_run_job(&self, _: JobAlloc, _: CompileCommand, _: Vec, _: Box) -> SFuture<(RunJobResult, PathTransformer)> { + unreachable!() + } + fn put_toolchain(&self, _: &Path, _: &str, _: Box) -> SFuture<(Toolchain, Option)> { + f_ok((self.tc.clone(), None)) + } + } + + pub struct ErrorSubmitToolchainClient { + has_started: Cell, + tc: Toolchain, + } + impl ErrorSubmitToolchainClient { + pub fn new() -> Arc { + Arc::new(Self { + has_started: Cell::new(false), + tc: Toolchain { archive_id: "somearchiveid".to_owned() }, + }) + } + } + impl dist::Client for ErrorSubmitToolchainClient { + fn do_alloc_job(&self, tc: Toolchain) -> SFuture { + assert!(!self.has_started.replace(true)); + assert_eq!(self.tc, tc); + f_ok(AllocJobResult::Success { + job_alloc: JobAlloc { auth: "abcd".to_owned(), job_id: JobId(0), server_id: ServerId::new(([0, 0, 0, 0], 1).into()) }, + need_toolchain: true, + }) + } + fn do_submit_toolchain(&self, job_alloc: JobAlloc, tc: Toolchain) -> SFuture { + assert_eq!(job_alloc.job_id, JobId(0)); + assert_eq!(self.tc, tc); + f_err("submit toolchain failure") + } + fn do_run_job(&self, _: JobAlloc, _: CompileCommand, _: Vec, _: Box) -> SFuture<(RunJobResult, PathTransformer)> { + unreachable!() + } + fn put_toolchain(&self, _: &Path, _: &str, _: Box) -> SFuture<(Toolchain, Option)> { + f_ok((self.tc.clone(), None)) + } + } + + pub struct ErrorRunJobClient { + has_started: Cell, + tc: Toolchain, + } + impl ErrorRunJobClient { + pub fn new() -> Arc { + Arc::new(Self { + has_started: Cell::new(false), + tc: Toolchain { archive_id: "somearchiveid".to_owned() }, + }) + } + } + impl dist::Client for ErrorRunJobClient { + fn do_alloc_job(&self, tc: Toolchain) -> SFuture { + assert!(!self.has_started.replace(true)); + assert_eq!(self.tc, tc); + f_ok(AllocJobResult::Success { + job_alloc: JobAlloc { auth: "abcd".to_owned(), job_id: JobId(0), server_id: ServerId::new(([0, 0, 0, 0], 1).into()) }, + need_toolchain: true, + }) + } + fn do_submit_toolchain(&self, job_alloc: JobAlloc, tc: Toolchain) -> SFuture { + assert_eq!(job_alloc.job_id, JobId(0)); + assert_eq!(self.tc, tc); + f_ok(SubmitToolchainResult::Success) + } + fn do_run_job(&self, job_alloc: JobAlloc, command: CompileCommand, _: Vec, _: Box) -> SFuture<(RunJobResult, PathTransformer)> { + assert_eq!(job_alloc.job_id, JobId(0)); + assert_eq!(command.executable, "/overridden/compiler"); + f_err("run job failure") + } + fn put_toolchain(&self, _: &Path, _: &str, _: Box) -> SFuture<(Toolchain, Option)> { + f_ok((self.tc.clone(), Some("/overridden/compiler".to_owned()))) + } + } + + pub struct OneshotClient { + has_started: Cell, + tc: Toolchain, + output: ProcessOutput, + } + + impl OneshotClient { + pub fn new(code: i32, stdout: Vec, stderr: Vec) -> Arc { + Arc::new(Self { + has_started: Cell::new(false), + tc: Toolchain { archive_id: "somearchiveid".to_owned() }, + output: ProcessOutput::fake_output(code, stdout, stderr), + }) + } + } + + impl dist::Client for OneshotClient { + fn do_alloc_job(&self, tc: Toolchain) -> SFuture { + assert!(!self.has_started.replace(true)); + assert_eq!(self.tc, tc); + + f_ok(AllocJobResult::Success { + job_alloc: JobAlloc { + auth: "abcd".to_owned(), + job_id: JobId(0), + server_id: ServerId::new(([0, 0, 0, 0], 1).into()), + }, + need_toolchain: true, + }) + } + fn do_submit_toolchain(&self, job_alloc: JobAlloc, tc: Toolchain) -> SFuture { + assert_eq!(job_alloc.job_id, JobId(0)); + assert_eq!(self.tc, tc); + + f_ok(SubmitToolchainResult::Success) + } + fn do_run_job(&self, job_alloc: JobAlloc, command: CompileCommand, outputs: Vec, inputs_packager: Box) -> SFuture<(RunJobResult, PathTransformer)> { + assert_eq!(job_alloc.job_id, JobId(0)); + assert_eq!(command.executable, "/overridden/compiler"); + + let mut inputs = vec![]; + let path_transformer = inputs_packager.write_inputs(&mut inputs).unwrap(); + let outputs = outputs.into_iter() + .map(|name| { + let data = format!("some data in {}", name); + let data = OutputData::try_from_reader(data.as_bytes()).unwrap(); + (name, data) + }) + .collect(); + let result = RunJobResult::Complete(JobComplete { + output: self.output.clone(), + outputs, + }); + f_ok((result, path_transformer)) + } + fn put_toolchain(&self, _: &Path, _: &str, _: Box) -> SFuture<(Toolchain, Option)> { + f_ok((self.tc.clone(), Some("/overridden/compiler".to_owned()))) + } + } } diff --git a/src/compiler/gcc.rs b/src/compiler/gcc.rs index 86083cb8..b13823f7 100644 --- a/src/compiler/gcc.rs +++ b/src/compiler/gcc.rs @@ -524,7 +524,7 @@ pub fn generate_compile_commands(path_transformer: &mut dist::PathTransformer, executable: path_transformer.to_dist(&executable)?, arguments: arguments, env_vars: dist::osstring_tuples_to_strings(env_vars)?, - cwd: path_transformer.to_dist_assert_abs(cwd)?, + cwd: path_transformer.to_dist_abs(cwd)?, }) })(); @@ -1044,11 +1044,12 @@ mod test { // Compiler invocation. next_command(&creator, Ok(MockChild::new(exit_status(0), "", ""))); let mut path_transformer = dist::PathTransformer::new(); - let (command, _, cacheable) = generate_compile_commands(&mut path_transformer, - &compiler, - &parsed_args, - f.tempdir.path(), - &[]).unwrap(); + let (command, dist_command, cacheable) = generate_compile_commands(&mut path_transformer, + &compiler, + &parsed_args, + f.tempdir.path(), + &[]).unwrap(); + assert!(dist_command.is_some()); let _ = command.execute(&creator).wait(); assert_eq!(Cacheable::Yes, cacheable); // Ensure that we ran all processes. diff --git a/src/compiler/msvc.rs b/src/compiler/msvc.rs index 7fce6154..95047135 100644 --- a/src/compiler/msvc.rs +++ b/src/compiler/msvc.rs @@ -885,11 +885,12 @@ mod test { // Compiler invocation. next_command(&creator, Ok(MockChild::new(exit_status(0), "", ""))); let mut path_transformer = dist::PathTransformer::new(); - let (command, _, cacheable) = generate_compile_commands(&mut path_transformer, - &compiler, - &parsed_args, - f.tempdir.path(), - &[]).unwrap(); + let (command, dist_command, cacheable) = generate_compile_commands(&mut path_transformer, + &compiler, + &parsed_args, + f.tempdir.path(), + &[]).unwrap(); + assert!(dist_command.is_some()); let _ = command.execute(&creator).wait(); assert_eq!(Cacheable::Yes, cacheable); // Ensure that we ran all processes. @@ -917,11 +918,12 @@ mod test { // Compiler invocation. next_command(&creator, Ok(MockChild::new(exit_status(0), "", ""))); let mut path_transformer = dist::PathTransformer::new(); - let (command, _, cacheable) = generate_compile_commands(&mut path_transformer, - &compiler, - &parsed_args, - f.tempdir.path(), - &[]).unwrap(); + let (command, dist_command, cacheable) = generate_compile_commands(&mut path_transformer, + &compiler, + &parsed_args, + f.tempdir.path(), + &[]).unwrap(); + assert!(dist_command.is_some()); let _ = command.execute(&creator).wait(); assert_eq!(Cacheable::No, cacheable); // Ensure that we ran all processes. diff --git a/src/compiler/rust.rs b/src/compiler/rust.rs index a6fdfafa..8fc15285 100644 --- a/src/compiler/rust.rs +++ b/src/compiler/rust.rs @@ -1200,7 +1200,7 @@ impl Compilation for RustCompilation { executable: path_transformer.to_dist(&sysroot_executable)?, arguments: dist_arguments, env_vars, - cwd: path_transformer.to_dist_assert_abs(cwd)?, + cwd: path_transformer.to_dist_abs(cwd)?, }) })(); @@ -1277,7 +1277,8 @@ impl pkg::InputsPackager for RustInputsPackager { } } } - let dist_input_path = path_transformer.to_dist(&input_path).unwrap(); + let dist_input_path = path_transformer.to_dist(&input_path) + .chain_err(|| format!("unable to transform input path {}", input_path.display()))?; tar_inputs.push((input_path, dist_input_path)) } @@ -1296,12 +1297,12 @@ impl pkg::InputsPackager for RustInputsPackager { let dir_entries = match fs::read_dir(crate_link_path) { Ok(iter) => iter, Err(ref e) if e.kind() == io::ErrorKind::NotFound => continue, - Err(e) => return Err(Error::with_chain(e, "Failed to read dir entries in crate link path")), + Err(e) => return Err(Error::from(e).chain_err(|| "Failed to read dir entries in crate link path")), }; for entry in dir_entries { let entry = match entry { Ok(entry) => entry, - Err(e) => return Err(Error::with_chain(e, "Error during iteration over crate link path")), + Err(e) => return Err(Error::from(e).chain_err(|| "Error during iteration over crate link path")), }; let path = entry.path(); @@ -1342,7 +1343,8 @@ impl pkg::InputsPackager for RustInputsPackager { } // This is a lib that may be of interest during compilation - let dist_path = path_transformer.to_dist(&path).unwrap(); + let dist_path = path_transformer.to_dist(&path) + .chain_err(|| format!("unable to transform lib path {}", path.display()))?; tar_crate_libs.push((path, dist_path)) } } @@ -1455,34 +1457,34 @@ struct RustOutputsRewriter { #[cfg(feature = "dist-client")] impl OutputsRewriter for RustOutputsRewriter { - fn handle_outputs(self: Box, path_transformer: &dist::PathTransformer, output_paths: Vec<(String, PathBuf)>) + fn handle_outputs(self: Box, path_transformer: &dist::PathTransformer, output_paths: &[PathBuf]) -> Result<()> { - use std::io::{Seek, Write}; + use std::io::Write; - // Outputs in dep files (the files at the beginning of lines) are untransformed - remap-path-prefix is documented - // to only apply to 'inputs'. + // Outputs in dep files (the files at the beginning of lines) are untransformed at this point - + // remap-path-prefix is documented to only apply to 'inputs'. trace!("Pondering on rewriting dep file {:?}", self.dep_info); if let Some(dep_info) = self.dep_info { - for (_dep_info_dist_path, dep_info_local_path) in output_paths { + for dep_info_local_path in output_paths { trace!("Comparing with {}", dep_info_local_path.display()); - if dep_info == dep_info_local_path { - error!("Replacing using the transformer {:?}", path_transformer); - // Found the dep info file - let mut f = fs::OpenOptions::new() - .read(true) - .write(true) - .open(dep_info)?; + if dep_info == *dep_info_local_path { + info!("Replacing using the transformer {:?}", path_transformer); + // Found the dep info file, read it in + let mut f = fs::File::open(&dep_info).chain_err(|| "Failed to open dep info file")?; let mut deps = String::new(); - f.read_to_string(&mut deps)?; + {f}.read_to_string(&mut deps)?; + // Replace all the output paths, at the beginning of lines for (local_path, dist_path) in get_path_mappings(path_transformer) { let re_str = format!("(?m)^{}", regex::escape(&dist_path)); - error!("RE replacing {} with {} in {}", re_str, local_path.to_str().unwrap(), deps); + let local_path_str = local_path.to_str() + .chain_err(|| format!("could not convert {} to string for RE replacement", local_path.display()))?; + error!("RE replacing {} with {} in {}", re_str, local_path_str, deps); let re = regex::Regex::new(&re_str).expect("Invalid regex"); - deps = re.replace_all(&deps, local_path.to_str().unwrap()).into_owned(); + deps = re.replace_all(&deps, local_path_str).into_owned(); } - f.seek(io::SeekFrom::Start(0))?; - f.write_all(deps.as_bytes())?; - f.set_len(deps.len() as u64)?; + // Write the depinfo file + let mut f = fs::File::create(&dep_info).chain_err(|| "Failed to recreate dep info file")?; + {f}.write_all(deps.as_bytes())?; return Ok(()) } } @@ -1493,6 +1495,53 @@ impl OutputsRewriter for RustOutputsRewriter { } } +#[test] +#[cfg(all(feature = "dist-client", target_os = "windows"))] +fn test_rust_outputs_rewriter() { + use compiler::compiler::OutputsRewriter; + use std::io::Write; + use test::utils::create_file; + + let mut pt = dist::PathTransformer::new(); + pt.to_dist(Path::new("c:\\")).unwrap(); + let mappings: Vec<_> = pt.disk_mappings().collect(); + assert!(mappings.len() == 1); + let linux_prefix = &mappings[0].1; + + let depinfo_data = format!("{prefix}/sccache/target/x86_64-unknown-linux-gnu/debug/deps/sccache_dist-c6f3229b9ef0a5c3.rmeta: src/bin/sccache-dist/main.rs src/bin/sccache-dist/build.rs src/bin/sccache-dist/token_check.rs + +{prefix}/sccache/target/x86_64-unknown-linux-gnu/debug/deps/sccache_dist-c6f3229b9ef0a5c3.d: src/bin/sccache-dist/main.rs src/bin/sccache-dist/build.rs src/bin/sccache-dist/token_check.rs + +src/bin/sccache-dist/main.rs: +src/bin/sccache-dist/build.rs: +src/bin/sccache-dist/token_check.rs: +", prefix=linux_prefix); + + let depinfo_resulting_data = format!("{prefix}/sccache/target/x86_64-unknown-linux-gnu/debug/deps/sccache_dist-c6f3229b9ef0a5c3.rmeta: src/bin/sccache-dist/main.rs src/bin/sccache-dist/build.rs src/bin/sccache-dist/token_check.rs + +{prefix}/sccache/target/x86_64-unknown-linux-gnu/debug/deps/sccache_dist-c6f3229b9ef0a5c3.d: src/bin/sccache-dist/main.rs src/bin/sccache-dist/build.rs src/bin/sccache-dist/token_check.rs + +src/bin/sccache-dist/main.rs: +src/bin/sccache-dist/build.rs: +src/bin/sccache-dist/token_check.rs: +", prefix="c:"); + + let tempdir = TempDir::new("sccache_test").unwrap(); + let tempdir = tempdir.path(); + let depinfo_file = create_file(tempdir, "depinfo.d", |mut f| { + f.write_all(depinfo_data.as_bytes()) + }).unwrap(); + + let ror = Box::new(RustOutputsRewriter { + dep_info: Some(depinfo_file.clone()), + }); + let () = ror.handle_outputs(&pt, &[depinfo_file.clone()]).unwrap(); + + let mut s = String::new(); + fs::File::open(depinfo_file).unwrap().read_to_string(&mut s).unwrap(); + assert_eq!(s, depinfo_resulting_data) +} + #[cfg(feature = "dist-client")] #[derive(Debug)] @@ -1663,6 +1712,7 @@ fn parse_rustc_z_ls(stdout: &str) -> Result> { let mut libstring_splits = libstring.rsplitn(2, '-'); // Rustc prints strict hash value (rather than extra filename as it likely should be) + // https://github.com/rust-lang/rust/pull/55555 let _svh = libstring_splits.next().ok_or_else(|| "No hash in lib string from rustc -Z ls")?; let libname = libstring_splits.next().expect("Zero strings from libstring split"); assert!(libstring_splits.next().is_none()); diff --git a/src/config.rs b/src/config.rs index 25d82230..d21b5bff 100644 --- a/src/config.rs +++ b/src/config.rs @@ -14,13 +14,16 @@ use directories::ProjectDirs; use regex::Regex; +#[cfg(any(feature = "dist-client", feature = "dist-server"))] +use reqwest; +#[cfg(any(feature = "dist-client", feature = "dist-server"))] +use serde::ser::{Serialize, Serializer}; use serde::de::{Deserialize, DeserializeOwned, Deserializer}; use serde_json; use std::collections::HashMap; use std::env; use std::io::{Read, Write}; use std::fs::{self, File}; -use std::net::IpAddr; use std::path::{Path, PathBuf}; use std::result::Result as StdResult; use std::str::FromStr; @@ -66,7 +69,7 @@ fn default_disk_cache_size() -> u64 { TEN_GIGS } fn default_toolchain_cache_size() -> u64 { TEN_GIGS } pub fn parse_size(val: &str) -> Option { - let re = Regex::new(r"^(\d+)([KMGT])$").unwrap(); + let re = Regex::new(r"^(\d+)([KMGT])$").expect("Fixed regex parse failure"); re.captures(val) .and_then(|caps| { caps.get(1) @@ -84,6 +87,48 @@ pub fn parse_size(val: &str) -> Option { }) } +#[cfg(any(feature = "dist-client", feature = "dist-server"))] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct HTTPUrl(reqwest::Url); +#[cfg(any(feature = "dist-client", feature = "dist-server"))] +impl Serialize for HTTPUrl { + fn serialize(&self, serializer: S) -> StdResult where S: Serializer { + serializer.serialize_str(self.0.as_str()) + } +} +#[cfg(any(feature = "dist-client", feature = "dist-server"))] +impl<'a> Deserialize<'a> for HTTPUrl { + fn deserialize(deserializer: D) -> StdResult where D: Deserializer<'a> { + use serde::de::Error; + let helper: String = Deserialize::deserialize(deserializer)?; + let url = parse_http_url(&helper).map_err(D::Error::custom)?; + Ok(HTTPUrl(url)) + } +} +#[cfg(any(feature = "dist-client", feature = "dist-server"))] +fn parse_http_url(url: &str) -> Result { + use std::net::SocketAddr; + let url = if let Ok(sa) = url.parse::() { + warn!("Url {} has no scheme, assuming http", url); + reqwest::Url::parse(&format!("http://{}", sa)) + } else { + reqwest::Url::parse(url) + }.map_err(|e| format!("{}", e))?; + if url.scheme() != "http" && url.scheme() != "https" { + bail!("url not http or https") + } + // TODO: relative url handling just hasn't been implemented and tested + if url.path() != "/" { + bail!("url has a relative path (currently unsupported)") + } + Ok(url) +} +#[cfg(any(feature = "dist-client", feature = "dist-server"))] +impl HTTPUrl { + pub fn from_url(u: reqwest::Url) -> Self { HTTPUrl(u) } + pub fn to_url(&self) -> reqwest::Url { self.0.clone() } +} + #[derive(Debug, PartialEq, Eq)] #[derive(Serialize, Deserialize)] #[serde(deny_unknown_fields)] @@ -162,12 +207,12 @@ pub enum CacheType { #[derive(Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct CacheConfigs { - azure: Option, - disk: Option, - gcs: Option, - memcached: Option, - redis: Option, - s3: Option, + pub azure: Option, + pub disk: Option, + pub gcs: Option, + pub memcached: Option, + pub redis: Option, + pub s3: Option, } impl CacheConfigs { @@ -204,7 +249,7 @@ impl CacheConfigs { } } -#[derive(Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Serialize, Deserialize)] #[serde(deny_unknown_fields)] #[serde(tag = "type")] @@ -221,12 +266,15 @@ pub enum DistToolchainConfig { }, } -#[derive(Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Serialize)] #[serde(tag = "type")] pub enum DistAuth { + #[serde(rename = "token")] Token { token: String }, + #[serde(rename = "oauth2_code_grant_pkce")] Oauth2CodeGrantPKCE { client_id: String, auth_url: String, token_url: String }, + #[serde(rename = "oauth2_implicit")] Oauth2Implicit { client_id: String, auth_url: String }, } @@ -279,7 +327,10 @@ impl Default for DistAuth { #[serde(deny_unknown_fields)] pub struct DistConfig { pub auth: DistAuth, - pub scheduler_addr: Option, + #[cfg(any(feature = "dist-client", feature = "dist-server"))] + pub scheduler_url: Option, + #[cfg(not(any(feature = "dist-client", feature = "dist-server")))] + pub scheduler_url: Option, pub cache_dir: PathBuf, pub toolchains: Vec, pub toolchain_cache_size: u64, @@ -289,7 +340,7 @@ impl Default for DistConfig { fn default() -> Self { Self { auth: Default::default(), - scheduler_addr: Default::default(), + scheduler_url: Default::default(), cache_dir: default_dist_cache_dir(), toolchains: Default::default(), toolchain_cache_size: default_toolchain_cache_size(), @@ -430,7 +481,9 @@ impl Config { .expect("Unable to get config directory"); dirs.config_dir().join("config") }); - let file_conf = try_read_config_file(&file_conf_path)?.unwrap_or_default(); + let file_conf = try_read_config_file(&file_conf_path) + .chain_err(|| "Failed to load config file")? + .unwrap_or_default(); Ok(Config::from_env_and_file_configs(env_conf, file_conf)) } @@ -515,7 +568,8 @@ impl CachedConfig { Self::save_file_config(&Default::default()) .chain_err(|| format!("Unable to create cached config file at {}", file_conf_path.display()))? } - try_read_config_file(&file_conf_path)? + try_read_config_file(&file_conf_path) + .chain_err(|| "Failed to load cached config file")? .ok_or_else(|| format!("Failed to load from {}", file_conf_path.display()).into()) } fn save_file_config(c: &CachedFileConfig) -> Result<()> { @@ -527,6 +581,7 @@ impl CachedConfig { #[cfg(feature = "dist-server")] pub mod scheduler { + use std::net::SocketAddr; use std::path::Path; use errors::*; @@ -565,19 +620,21 @@ pub mod scheduler { #[derive(Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct Config { + pub public_addr: SocketAddr, pub client_auth: ClientAuth, pub server_auth: ServerAuth, } pub fn from_path(conf_path: &Path) -> Result> { - super::try_read_config_file(&conf_path) + super::try_read_config_file(&conf_path).chain_err(|| "Failed to load scheduler config file") } } #[cfg(feature = "dist-server")] pub mod server { - use std::net::{IpAddr, SocketAddr}; + use std::net::SocketAddr; use std::path::{Path, PathBuf}; + use super::HTTPUrl; use errors::*; @@ -615,14 +672,14 @@ pub mod server { pub builder: BuilderType, pub cache_dir: PathBuf, pub public_addr: SocketAddr, - pub scheduler_addr: IpAddr, + pub scheduler_url: HTTPUrl, pub scheduler_auth: SchedulerAuth, #[serde(default = "default_toolchain_cache_size")] pub toolchain_cache_size: u64, } pub fn from_path(conf_path: &Path) -> Result> { - super::try_read_config_file(&conf_path) + super::try_read_config_file(&conf_path).chain_err(|| "Failed to load server config file") } } diff --git a/src/dist/cache.rs b/src/dist/cache.rs index 83b681ad..c856f2be 100644 --- a/src/dist/cache.rs +++ b/src/dist/cache.rs @@ -36,13 +36,13 @@ mod client { } // TODO: possibly shouldn't be public - #[cfg(feature = "dist-client")] pub struct ClientToolchains { cache_dir: PathBuf, cache: Mutex, - // Lookup from dist toolchain -> toolchain details - custom_toolchains: Mutex>, + // Lookup from dist toolchain -> path to custom toolchain archive + custom_toolchain_archives: Mutex>, // Lookup from local path -> toolchain details + // The Option could be populated on startup, but it's lazy for efficiency custom_toolchain_paths: Mutex)>>, // Toolchains configured to not be distributed disabled_toolchains: HashSet, @@ -57,26 +57,32 @@ mod client { weak_map: Mutex>, } - #[cfg(feature = "dist-client")] impl ClientToolchains { - pub fn new(cache_dir: &Path, cache_size: u64, toolchain_configs: &[config::DistToolchainConfig]) -> Self { + pub fn new(cache_dir: &Path, cache_size: u64, toolchain_configs: &[config::DistToolchainConfig]) -> Result { let cache_dir = cache_dir.to_owned(); - fs::create_dir_all(&cache_dir).unwrap(); + fs::create_dir_all(&cache_dir).chain_err(|| "failed to create top level toolchain cache dir")?; let toolchain_creation_dir = cache_dir.join("toolchain_tmp"); if toolchain_creation_dir.exists() { - fs::remove_dir_all(&toolchain_creation_dir).unwrap() + fs::remove_dir_all(&toolchain_creation_dir).chain_err(|| "failed to clean up temporary toolchain creation directory")? } - fs::create_dir(&toolchain_creation_dir).unwrap(); + fs::create_dir(&toolchain_creation_dir).chain_err(|| "failed to create temporary toolchain creation directory")?; let weak_map_path = cache_dir.join("weak_map.json"); if !weak_map_path.exists() { - fs::File::create(&weak_map_path).unwrap().write_all(b"{}").unwrap() + fs::File::create(&weak_map_path) + .and_then(|mut f| f.write_all(b"{}")) + .chain_err(|| "failed to create new toolchain weak map file")? } - let weak_map = serde_json::from_reader(fs::File::open(weak_map_path).unwrap()).unwrap(); + let weak_map = fs::File::open(weak_map_path) + .map_err(Error::from) + .and_then(|f| serde_json::from_reader(f).map_err(Error::from)) + .chain_err(|| "failed to load toolchain weak map")?; let tc_cache_dir = cache_dir.join("tc"); - let cache = Mutex::new(TcCache::new(&tc_cache_dir, cache_size).unwrap()); + let cache = TcCache::new(&tc_cache_dir, cache_size) + .map(Mutex::new) + .chain_err(|| "failed to initialise a toolchain cache")?; // Load in toolchain configuration let mut custom_toolchain_paths = HashMap::new(); @@ -90,51 +96,52 @@ mod client { compiler_executable: archive_compiler_executable.clone(), }; if custom_toolchain_paths.insert(compiler_executable.clone(), (custom_tc, None)).is_some() { - panic!("Multiple toolchains for {}", compiler_executable.display()) + bail!("Multiple toolchains for {}", compiler_executable.display()) } if disabled_toolchains.contains(compiler_executable) { - panic!("Override for toolchain {} conflicts with it being disabled") + bail!("Override for toolchain {} conflicts with it being disabled", compiler_executable.display()) } }, config::DistToolchainConfig::NoDist { compiler_executable } => { debug!("Disabling toolchain {}", compiler_executable.display()); if !disabled_toolchains.insert(compiler_executable.clone()) { - panic!("Disabled toolchain {} multiple times", compiler_executable.display()) + bail!("Disabled toolchain {} multiple times", compiler_executable.display()) } if custom_toolchain_paths.contains_key(compiler_executable) { - panic!("Override for toolchain {} conflicts with it being disabled") + bail!("Override for toolchain {} conflicts with it being disabled", compiler_executable.display()) } }, } } let custom_toolchain_paths = Mutex::new(custom_toolchain_paths); - Self { + Ok(Self { cache_dir, cache, - custom_toolchains: Mutex::new(HashMap::new()), + custom_toolchain_archives: Mutex::new(HashMap::new()), custom_toolchain_paths, disabled_toolchains, // TODO: shouldn't clear on restart, but also should have some // form of pruning weak_map: Mutex::new(weak_map), - } + }) } // Get the bytes of a toolchain tar // TODO: by this point the toolchain should be known to exist - pub fn get_toolchain(&self, tc: &Toolchain) -> Option { + pub fn get_toolchain(&self, tc: &Toolchain) -> Result> { // TODO: be more relaxed about path casing and slashes on Windows - let file = if let Some(custom_tc) = self.custom_toolchains.lock().unwrap().get(tc) { - fs::File::open(&custom_tc.archive).unwrap() + let file = if let Some(custom_tc_archive) = self.custom_toolchain_archives.lock().unwrap().get(tc) { + fs::File::open(custom_tc_archive) + .chain_err(|| format!("could not open file for toolchain {}", custom_tc_archive.display()))? } else { match self.cache.lock().unwrap().get_file(tc) { Ok(file) => file, - Err(LruError::FileNotInCache) => return None, - Err(e) => panic!("{}", e), + Err(LruError::FileNotInCache) => return Ok(None), + Err(e) => return Err(Error::from(e).chain_err(|| "error while retrieving toolchain from cache")), } }; - Some(file) + Ok(Some(file)) } // If the toolchain doesn't already exist, create it and insert into the cache pub fn put_toolchain(&self, compiler_path: &Path, weak_key: &str, toolchain_packager: Box) -> Result<(Toolchain, Option)> { @@ -143,7 +150,7 @@ mod client { } if let Some(tc_and_compiler_path) = self.get_custom_toolchain(compiler_path) { debug!("Using custom toolchain for {:?}", compiler_path); - let (tc, compiler_path) = tc_and_compiler_path.unwrap(); + let (tc, compiler_path) = tc_and_compiler_path?; return Ok((tc, Some(compiler_path))) } if let Some(archive_id) = self.weak_to_strong(weak_key) { @@ -157,7 +164,7 @@ mod client { let tmpfile = tempfile::NamedTempFile::new_in(self.cache_dir.join("toolchain_tmp"))?; toolchain_packager.write_pkg(tmpfile.reopen()?).chain_err(|| "Could not package toolchain")?; let tc = cache.insert_file(tmpfile.path())?; - self.record_weak(weak_key.to_owned(), tc.archive_id.clone()); + self.record_weak(weak_key.to_owned(), tc.archive_id.clone())?; Ok((tc, None)) } @@ -171,7 +178,14 @@ mod client { }; let tc = Toolchain { archive_id }; *maybe_tc = Some(tc.clone()); - assert!(self.custom_toolchains.lock().unwrap().insert(tc.clone(), custom_tc.clone()).is_none()); + // If this entry already exists, someone has two custom toolchains with the same strong hash + if let Some(old_path) = self.custom_toolchain_archives.lock().unwrap().insert(tc.clone(), custom_tc.archive.clone()) { + // Log a warning if the user has identical toolchains at two different locations - it's + // not strictly wrong, but it is a bit odd + if old_path != custom_tc.archive { + warn!("Detected interchangable toolchain archives at {} and {}", old_path.display(), custom_tc.archive.display()) + } + } Some(Ok((tc, custom_tc.compiler_executable.clone()))) }, None => None, @@ -181,11 +195,114 @@ mod client { fn weak_to_strong(&self, weak_key: &str) -> Option { self.weak_map.lock().unwrap().get(weak_key).map(String::to_owned) } - fn record_weak(&self, weak_key: String, key: String) { + fn record_weak(&self, weak_key: String, key: String) -> Result<()> { let mut weak_map = self.weak_map.lock().unwrap(); weak_map.insert(weak_key, key); let weak_map_path = self.cache_dir.join("weak_map.json"); - serde_json::to_writer(fs::File::create(weak_map_path).unwrap(), &*weak_map).unwrap() + fs::File::create(weak_map_path).map_err(Error::from) + .and_then(|f| serde_json::to_writer(f, &*weak_map).map_err(Error::from)) + .chain_err(|| "failed to enter toolchain in weak map") + } + } + + #[cfg(test)] + mod test { + use config; + use std::io::Write; + use tempdir::TempDir; + use test::utils::create_file; + + use super::ClientToolchains; + + struct PanicToolchainPackager; + impl PanicToolchainPackager { + fn new() -> Box { Box::new(PanicToolchainPackager) } + } + #[cfg(all(target_os = "linux", target_arch = "x86_64"))] + impl ::dist::pkg::ToolchainPackager for PanicToolchainPackager { + fn write_pkg(self: Box, _f: ::std::fs::File) -> ::errors::Result<()> { + panic!("should not have called packager") + } + } + + #[test] + fn test_client_toolchains_custom() { + let td = TempDir::new("sccache").unwrap(); + + let ct1 = create_file(td.path(), "ct1", |mut f| f.write_all(b"toolchain_contents")).unwrap(); + + let client_toolchains = ClientToolchains::new(&td.path().join("cache"), 1024, &[ + config::DistToolchainConfig::PathOverride { + compiler_executable: "/my/compiler".into(), + archive: ct1, + archive_compiler_executable: "/my/compiler/in_archive".into(), + }, + ]).unwrap(); + + let (_tc, newpath) = client_toolchains.put_toolchain("/my/compiler".as_ref(), "weak_key", PanicToolchainPackager::new()).unwrap(); + assert!(newpath.unwrap() == "/my/compiler/in_archive"); + } + + #[test] + fn test_client_toolchains_custom_multiuse_archive() { + let td = TempDir::new("sccache").unwrap(); + + let ct1 = create_file(td.path(), "ct1", |mut f| f.write_all(b"toolchain_contents")).unwrap(); + + let client_toolchains = ClientToolchains::new(&td.path().join("cache"), 1024, &[ + config::DistToolchainConfig::PathOverride { + compiler_executable: "/my/compiler".into(), + archive: ct1.clone(), + archive_compiler_executable: "/my/compiler/in_archive".into(), + }, + // Uses the same archive, but a maps a different external compiler to a different achive compiler + config::DistToolchainConfig::PathOverride { + compiler_executable: "/my/compiler2".into(), + archive: ct1.clone(), + archive_compiler_executable: "/my/compiler2/in_archive".into(), + }, + // Uses the same archive, but a maps a different external compiler to the same achive compiler as the first + config::DistToolchainConfig::PathOverride { + compiler_executable: "/my/compiler3".into(), + archive: ct1, + archive_compiler_executable: "/my/compiler/in_archive".into(), + }, + ]).unwrap(); + + let (_tc, newpath) = client_toolchains.put_toolchain("/my/compiler".as_ref(), "weak_key", PanicToolchainPackager::new()).unwrap(); + assert!(newpath.unwrap() == "/my/compiler/in_archive"); + let (_tc, newpath) = client_toolchains.put_toolchain("/my/compiler2".as_ref(), "weak_key2", PanicToolchainPackager::new()).unwrap(); + assert!(newpath.unwrap() == "/my/compiler2/in_archive"); + let (_tc, newpath) = client_toolchains.put_toolchain("/my/compiler3".as_ref(), "weak_key2", PanicToolchainPackager::new()).unwrap(); + assert!(newpath.unwrap() == "/my/compiler/in_archive"); + } + + #[test] + fn test_client_toolchains_nodist() { + let td = TempDir::new("sccache").unwrap(); + + let client_toolchains = ClientToolchains::new(&td.path().join("cache"), 1024, &[ + config::DistToolchainConfig::NoDist { compiler_executable: "/my/compiler".into() }, + ]).unwrap(); + + assert!(client_toolchains.put_toolchain("/my/compiler".as_ref(), "weak_key", PanicToolchainPackager::new()).is_err()); + } + + #[test] + fn test_client_toolchains_custom_nodist_conflict() { + let td = TempDir::new("sccache").unwrap(); + + let ct1 = create_file(td.path(), "ct1", |mut f| f.write_all(b"toolchain_contents")).unwrap(); + + let client_toolchains = ClientToolchains::new(&td.path().join("cache"), 1024, &[ + config::DistToolchainConfig::PathOverride { + compiler_executable: "/my/compiler".into(), + archive: ct1, + archive_compiler_executable: "/my/compiler".into(), + }, + config::DistToolchainConfig::NoDist { compiler_executable: "/my/compiler".into() }, + ]); + assert!(client_toolchains.is_err()) } } } diff --git a/src/dist/client_auth.rs b/src/dist/client_auth.rs index edaa128e..da43d555 100644 --- a/src/dist/client_auth.rs +++ b/src/dist/client_auth.rs @@ -1,11 +1,13 @@ +use error_chain::ChainedError; use futures::sync::oneshot; use futures::Future; use hyper; -use hyper::{Body, Request, Response, Server}; +use hyper::{Body, Request, Response, Server, StatusCode}; use hyper::header::{ContentLength, ContentType}; use hyper::server::{Http, NewService, const_service, service_fn}; use serde::Serialize; use serde_json; +use std::collections::HashMap; use std::io; use std::net::{ToSocketAddrs, TcpStream}; use std::sync::mpsc; @@ -15,12 +17,38 @@ use uuid::Uuid; use errors::*; -type BoxFut = Box, Error = hyper::Error> + Send>; - // These (arbitrary) ports need to be registered as valid redirect urls in the oauth provider you're using pub const VALID_PORTS: &[u16] = &[12731, 32492, 56909]; -// Warn if the token will expire in under this amount of time -const ONE_DAY: Duration = Duration::from_secs(24 * 60 * 60); +// If token is valid for under this amount of time, print a warning +const MIN_TOKEN_VALIDITY: Duration = Duration::from_secs(2 * 24 * 60 * 60); +const MIN_TOKEN_VALIDITY_WARNING: &str = "two days"; + +trait ServeFn: Fn(Request) -> Box, Error = hyper::Error>> + Copy + 'static {} +impl ServeFn for T where T: Fn(Request) -> Box, Error = hyper::Error>> + Copy + 'static {} + +fn serve_sfuture(serve: fn(Request) -> SFuture>) -> impl ServeFn { + move |req: Request| { + let uri = req.uri().to_owned(); + Box::new(serve(req) + .or_else(move |e| { + let body = e.display_chain().to_string(); + eprintln!("Error during a request to {} on the client auth web server\n{}", uri, body); + let len = body.len(); + Ok(Response::new() + .with_status(StatusCode::InternalServerError) + .with_body(body) + .with_header(ContentType::text()) + .with_header(ContentLength(len as u64))) + })) as Box> + } +} + +fn query_pairs(url: &str) -> Result> { + // Url::parse operates on absolute URLs, so ensure there's a prefix + let url = Url::parse("http://unused_base").expect("Failed to parse fake url prefix") + .join(url).chain_err(|| "Failed to parse url while extracting query params")?; + Ok(url.query_pairs().map(|(k, v)| (k.into_owned(), v.into_owned())).collect()) +} fn html_response(body: &'static str) -> Response { Response::new() @@ -29,13 +57,13 @@ fn html_response(body: &'static str) -> Response { .with_header(ContentLength(body.len() as u64)) } -fn json_response(data: &T) -> Response { - let body = serde_json::to_vec(data).unwrap(); +fn json_response(data: &T) -> Result { + let body = serde_json::to_vec(data).chain_err(|| "Failed to serialize to JSON")?; let len = body.len(); - Response::new() + Ok(Response::new() .with_body(body) .with_header(ContentType::json()) - .with_header(ContentLength(len as u64)) + .with_header(ContentLength(len as u64))) } const REDIRECT_WITH_AUTH_JSON: &str = r##" @@ -78,7 +106,7 @@ mod code_grant_pkce { use std::sync::Mutex; use std::sync::mpsc; use std::time::{Duration, Instant}; - use super::{ONE_DAY, REDIRECT_WITH_AUTH_JSON, BoxFut, html_response, json_response}; + use super::{MIN_TOKEN_VALIDITY, MIN_TOKEN_VALIDITY_WARNING, REDIRECT_WITH_AUTH_JSON, query_pairs, html_response, json_response}; use url::Url; use errors::*; @@ -128,9 +156,9 @@ mod code_grant_pkce { pub static ref STATE: Mutex> = Mutex::new(None); } - pub fn generate_verifier_and_challenge() -> (String, String) { + pub fn generate_verifier_and_challenge() -> Result<(String, String)> { let mut code_verifier_bytes = vec![0; NUM_CODE_VERIFIER_BYTES]; - let mut rng = rand::OsRng::new().unwrap(); + let mut rng = rand::OsRng::new().chain_err(|| "Failed to initialise a random number generator")?; rng.fill_bytes(&mut code_verifier_bytes); let code_verifier = base64::encode_config(&code_verifier_bytes, base64::URL_SAFE_NO_PAD); let mut hasher = HASHER::new(); @@ -138,7 +166,7 @@ mod code_grant_pkce { let mut code_challenge_bytes = vec![0; hasher.output_bytes()]; hasher.result(&mut code_challenge_bytes); let code_challenge = base64::encode_config(&code_challenge_bytes, base64::URL_SAFE_NO_PAD); - (code_verifier, code_challenge) + Ok((code_verifier, code_challenge)) } pub fn finish_url(client_id: &str, url: &mut Url, redirect_uri: &str, state: &str, code_challenge: &str) { @@ -174,7 +202,7 @@ mod code_grant_pkce { "##; - pub fn serve(req: Request) -> BoxFut { + pub fn serve(req: Request) -> SFuture { let mut state = STATE.lock().unwrap(); let state = state.as_mut().unwrap(); debug!("Handling {} {}", req.method(), req.uri()); @@ -183,14 +211,14 @@ mod code_grant_pkce { html_response(REDIRECT_WITH_AUTH_JSON) }, (&Method::Get, "/auth_detail.json") => { - json_response(&state.auth_url) + ftry!(json_response(&state.auth_url)) }, (&Method::Get, "/redirect") => { - let url = Url::parse("http://unused_base").unwrap().join(req.uri().as_ref()).unwrap(); - let query_pairs = url.query_pairs().map(|(k, v)| (k.into_owned(), v.into_owned())).collect(); - let (code, auth_state) = handle_code_response(query_pairs).unwrap(); + let query_pairs = ftry!(query_pairs(req.uri().as_ref())); + let (code, auth_state) = ftry!(handle_code_response(query_pairs) + .chain_err(|| "Failed to handle response from redirect")); if auth_state != state.auth_state_value { - panic!("Mismatched auth states") + return f_err("Mismatched auth states after redirect") } // Deliberately in reverse order for a 'happens-before' relationship state.code_tx.send(code).unwrap(); @@ -214,9 +242,10 @@ mod code_grant_pkce { bail!("Sending code to {} failed, HTTP error: {}", token_url, res.status()) } - let (token, expires_at) = handle_token_response(res.json().unwrap())?; - if expires_at - Instant::now() < ONE_DAY * 2 { - warn!("Token retrieved expires in under two days") + let (token, expires_at) = handle_token_response(res.json().chain_err(|| "Failed to parse token response as JSON")?)?; + if expires_at - Instant::now() < MIN_TOKEN_VALIDITY { + warn!("Token retrieved expires in under {}", MIN_TOKEN_VALIDITY_WARNING); + eprintln!("Token retrieved expires in under {}", MIN_TOKEN_VALIDITY_WARNING); } Ok(token) } @@ -230,7 +259,7 @@ mod implicit { use std::sync::Mutex; use std::sync::mpsc; use std::time::{Duration, Instant}; - use super::{ONE_DAY, REDIRECT_WITH_AUTH_JSON, BoxFut, html_response, json_response}; + use super::{MIN_TOKEN_VALIDITY, MIN_TOKEN_VALIDITY_WARNING, REDIRECT_WITH_AUTH_JSON, query_pairs, html_response, json_response}; use url::Url; use errors::*; @@ -309,7 +338,7 @@ mod implicit { "##; - pub fn serve(req: Request) -> BoxFut { + pub fn serve(req: Request) -> SFuture { let mut state = STATE.lock().unwrap(); let state = state.as_mut().unwrap(); debug!("Handling {} {}", req.method(), req.uri()); @@ -318,25 +347,26 @@ mod implicit { html_response(REDIRECT_WITH_AUTH_JSON) }, (&Method::Get, "/auth_detail.json") => { - json_response(&state.auth_url) + ftry!(json_response(&state.auth_url)) }, (&Method::Get, "/redirect") => { html_response(SAVE_AUTH_AFTER_REDIRECT) }, (&Method::Post, "/save_auth") => { - let url = Url::parse("http://unused_base").unwrap().join(req.uri().as_ref()).unwrap(); - let query_pairs = url.query_pairs().map(|(k, v)| (k.into_owned(), v.into_owned())).collect(); - let (token, expires_at, auth_state) = handle_response(query_pairs).unwrap(); + let query_pairs = ftry!(query_pairs(req.uri().as_ref())); + let (token, expires_at, auth_state) = ftry!(handle_response(query_pairs) + .chain_err(|| "Failed to save auth after redirect")); if auth_state != state.auth_state_value { - panic!("Mismatched auth states") + return f_err("Mismatched auth states after redirect") } - if expires_at - Instant::now() < ONE_DAY * 2 { - warn!("Token retrieved expires in under two days") + if expires_at - Instant::now() < MIN_TOKEN_VALIDITY { + warn!("Token retrieved expires in under {}", MIN_TOKEN_VALIDITY_WARNING); + eprintln!("Token retrieved expires in under {}", MIN_TOKEN_VALIDITY_WARNING); } // Deliberately in reverse order for a 'happens-before' relationship state.token_tx.send(token).unwrap(); state.shutdown_tx.take().unwrap().send(()).unwrap(); - json_response(&"") + ftry!(json_response(&"")) }, _ => { warn!("Route not found"); @@ -348,11 +378,11 @@ mod implicit { } } -fn try_serve(serve: fn(Request) -> BoxFut) -> Result, Error=hyper::error::Error> + 'static, Body>> { +fn try_serve(serve: impl ServeFn) -> Result, Error=hyper::Error>, Body>> { // Try all the valid ports for &port in VALID_PORTS { - let mut addrs = ("localhost", port).to_socket_addrs().unwrap(); - let addr = addrs.next().unwrap(); + let mut addrs = ("localhost", port).to_socket_addrs().expect("Failed to interpret localhost address to listen on"); + let addr = addrs.next().expect("Expected at least one address in parsed socket address"); // Hyper binds with reuseaddr and reuseport so binding won't fail as you'd expect on Linux match TcpStream::connect(addr) { @@ -361,7 +391,7 @@ fn try_serve(serve: fn(Request) -> BoxFut) -> Result (), Err(e) => { - return Err(Error::with_chain(e, format!("Failed to bind to {}", addr))) + return Err(Error::from(e).chain_err(|| format!("Failed to check {} is available for binding", addr))) }, } @@ -374,7 +404,7 @@ fn try_serve(serve: fn(Request) -> BoxFut) -> Result { - return Err(Error::with_chain(e, format!("Failed to bind to {}", addr))) + return Err(Error::from(e).chain_err(|| format!("Failed to bind to {}", addr))) }, } } @@ -383,12 +413,12 @@ fn try_serve(serve: fn(Request) -> BoxFut) -> Result Result { - let server = try_serve(code_grant_pkce::serve)?; - let port = server.local_addr().unwrap().port(); + let server = try_serve(serve_sfuture(code_grant_pkce::serve))?; + let port = server.local_addr().chain_err(|| "Failed to retrieve local address of server")?.port(); let redirect_uri = format!("http://localhost:{}/redirect", port); let auth_state_value = Uuid::new_v4().simple().to_string(); - let (verifier, challenge) = code_grant_pkce::generate_verifier_and_challenge(); + let (verifier, challenge) = code_grant_pkce::generate_verifier_and_challenge()?; code_grant_pkce::finish_url(client_id, &mut auth_url, &redirect_uri, &auth_state_value, &challenge); info!("Listening on http://localhost:{} with 1 thread.", port); @@ -413,8 +443,8 @@ pub fn get_token_oauth2_code_grant_pkce(client_id: &str, mut auth_url: Url, toke // https://auth0.com/docs/api-auth/tutorials/implicit-grant pub fn get_token_oauth2_implicit(client_id: &str, mut auth_url: Url) -> Result { - let server = try_serve(implicit::serve)?; - let port = server.local_addr().unwrap().port(); + let server = try_serve(serve_sfuture(implicit::serve))?; + let port = server.local_addr().chain_err(|| "Failed to retrieve local address of server")?.port(); let redirect_uri = format!("http://localhost:{}/redirect", port); let auth_state_value = Uuid::new_v4().simple().to_string(); diff --git a/src/dist/http.rs b/src/dist/http.rs index 9725d558..0da88160 100644 --- a/src/dist/http.rs +++ b/src/dist/http.rs @@ -11,47 +11,27 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#![allow(unused)] - #[cfg(feature = "dist-client")] pub use self::client::Client; #[cfg(feature = "dist-server")] -pub use self::server::{Scheduler, ClientAuthCheck, ServerAuthCheck}; +pub use self::server::{Scheduler, ClientAuthCheck, ClientVisibleMsg, ServerAuthCheck}; #[cfg(feature = "dist-server")] pub use self::server::Server; -//#[allow(unused)] mod common { use bincode; + #[cfg(feature = "dist-client")] use futures::{Future, Stream}; use reqwest; use serde; - use std::net::{IpAddr, SocketAddr}; - use dist::{JobId, CompileCommand}; + #[cfg(feature = "dist-server")] + use std::collections::HashMap; + use std::fmt; + + use dist; use errors::*; - const SCHEDULER_PORT: u16 = 10500; - const SERVER_PORT: u16 = 10501; - - // TODO: move this into the config module - pub struct Cfg; - - impl Cfg { - pub fn scheduler_listen_addr() -> SocketAddr { - let ip_addr = "0.0.0.0".parse().unwrap(); - SocketAddr::new(ip_addr, SCHEDULER_PORT) - } - pub fn scheduler_connect_addr(scheduler_addr: IpAddr) -> SocketAddr { - SocketAddr::new(scheduler_addr, SCHEDULER_PORT) - } - - pub fn server_listen_addr() -> SocketAddr { - let ip_addr = "0.0.0.0".parse().unwrap(); - SocketAddr::new(ip_addr, SERVER_PORT) - } - } - // Note that content-length is necessary due to https://github.com/tiny-http/tiny-http/issues/147 pub trait ReqwestRequestBuilderExt { fn bincode(&mut self, bincode: &T) -> Result<&mut Self>; @@ -60,7 +40,7 @@ mod common { } impl ReqwestRequestBuilderExt for reqwest::RequestBuilder { fn bincode(&mut self, bincode: &T) -> Result<&mut Self> { - let bytes = bincode::serialize(bincode)?; + let bytes = bincode::serialize(bincode).chain_err(|| "Failed to serialize body to bincode")?; Ok(self.bytes(bytes)) } fn bytes(&mut self, bytes: Vec) -> &mut Self { @@ -74,7 +54,7 @@ mod common { } impl ReqwestRequestBuilderExt for reqwest::unstable::async::RequestBuilder { fn bincode(&mut self, bincode: &T) -> Result<&mut Self> { - let bytes = bincode::serialize(bincode)?; + let bytes = bincode::serialize(bincode).chain_err(|| "Failed to serialize body to bincode")?; Ok(self.bytes(bytes)) } fn bytes(&mut self, bytes: Vec) -> &mut Self { @@ -91,13 +71,14 @@ mod common { let mut res = req.send()?; let status = res.status(); let mut body = vec![]; - res.copy_to(&mut body).unwrap(); + res.copy_to(&mut body).chain_err(|| "error reading response body")?; if !status.is_success() { Err(format!("Error {} (Headers={:?}): {}", status.as_u16(), res.headers(), String::from_utf8_lossy(&body)).into()) } else { bincode::deserialize(&body).map_err(Into::into) } } + #[cfg(feature = "dist-client")] pub fn bincode_req_fut(req: &mut reqwest::unstable::async::RequestBuilder) -> SFuture { Box::new(req.send().map_err(Into::into) .and_then(|res| { @@ -120,23 +101,96 @@ mod common { #[derive(Eq, PartialEq)] #[serde(deny_unknown_fields)] pub struct JobJwt { - pub job_id: JobId, + pub job_id: dist::JobId, } #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] + pub enum AllocJobHttpResponse { + Success { job_alloc: dist::JobAlloc, need_toolchain: bool, cert_digest: Vec }, + Fail { msg: String }, + } + impl AllocJobHttpResponse { + #[cfg(feature = "dist-server")] + pub fn from_alloc_job_result(res: dist::AllocJobResult, certs: &HashMap, Vec)>) -> Self { + match res { + dist::AllocJobResult::Success { job_alloc, need_toolchain } => { + if let Some((digest, _)) = certs.get(&job_alloc.server_id) { + AllocJobHttpResponse::Success { job_alloc, need_toolchain, cert_digest: digest.to_owned() } + } else { + AllocJobHttpResponse::Fail { msg: format!("missing certificates for server {}", job_alloc.server_id.addr()) } + } + }, + dist::AllocJobResult::Fail { msg } => AllocJobHttpResponse::Fail { msg }, + } + } + } + + #[derive(Clone, Debug, Serialize, Deserialize)] + #[serde(deny_unknown_fields)] + pub struct ServerCertificateHttpResponse { + pub cert_digest: Vec, + pub cert_pem: Vec, + } + + #[derive(Clone, Serialize, Deserialize)] + #[serde(deny_unknown_fields)] pub struct HeartbeatServerHttpRequest { pub jwt_key: Vec, pub num_cpus: usize, + pub server_nonce: dist::ServerNonce, + pub cert_digest: Vec, + pub cert_pem: Vec, + } + // cert_pem is quite long so elide it (you can retrieve it by hitting the server url anyway) + impl fmt::Debug for HeartbeatServerHttpRequest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let HeartbeatServerHttpRequest { jwt_key, num_cpus, server_nonce, cert_digest, cert_pem } = self; + write!(f, "HeartbeatServerHttpRequest {{ jwt_key: {:?}, num_cpus: {:?}, server_nonce: {:?}, cert_digest: {:?}, cert_pem: [...{} bytes...] }}", jwt_key, num_cpus, server_nonce, cert_digest, cert_pem.len()) + } } #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct RunJobHttpRequest { - pub command: CompileCommand, + pub command: dist::CompileCommand, pub outputs: Vec, } } +pub mod urls { + use dist::{JobId, ServerId}; + use reqwest; + + pub fn scheduler_alloc_job(scheduler_url: &reqwest::Url) -> reqwest::Url { + scheduler_url.join("/api/v1/scheduler/alloc_job").expect("failed to create alloc job url") + } + pub fn scheduler_server_certificate(scheduler_url: &reqwest::Url, server_id: ServerId) -> reqwest::Url { + scheduler_url.join(&format!("/api/v1/scheduler/server_certificate/{}", server_id.addr())).expect("failed to create server certificate url") + } + pub fn scheduler_heartbeat_server(scheduler_url: &reqwest::Url) -> reqwest::Url { + scheduler_url.join("/api/v1/scheduler/heartbeat_server").expect("failed to create heartbeat url") + } + pub fn scheduler_job_state(scheduler_url: &reqwest::Url, job_id: JobId) -> reqwest::Url { + scheduler_url.join(&format!("/api/v1/scheduler/job_state/{}", job_id)).expect("failed to create job state url") + } + pub fn scheduler_status(scheduler_url: &reqwest::Url) -> reqwest::Url { + scheduler_url.join("/api/v1/scheduler/status").expect("failed to create alloc job url") + } + + pub fn server_assign_job(server_id: ServerId, job_id: JobId) -> reqwest::Url { + let url = format!("https://{}/api/v1/distserver/assign_job/{}", server_id.addr(), job_id); + reqwest::Url::parse(&url).expect("failed to create assign job url") + } + pub fn server_submit_toolchain(server_id: ServerId, job_id: JobId) -> reqwest::Url { + let url = format!("https://{}/api/v1/distserver/submit_toolchain/{}", server_id.addr(), job_id); + reqwest::Url::parse(&url).expect("failed to create submit toolchain url") + } + pub fn server_run_job(server_id: ServerId, job_id: JobId) -> reqwest::Url { + let url = format!("https://{}/api/v1/distserver/run_job/{}", server_id.addr(), job_id); + reqwest::Url::parse(&url).expect("failed to create run job url") + } +} + #[cfg(feature = "dist-server")] mod server { use bincode; @@ -144,44 +198,110 @@ mod server { use flate2::read::ZlibDecoder as ZlibReadDecoder; use jwt; use num_cpus; + use openssl; use rand::{self, RngCore}; use reqwest; use rouille; use serde; use serde_json; use std; + use std::collections::HashMap; use std::io::Read; - use std::net::{IpAddr, SocketAddr}; + use std::net::SocketAddr; + use std::result::Result as StdResult; use std::sync::atomic; + use std::sync::Mutex; use std::thread; use std::time::Duration; + use void::Void; use dist::{ self, - ServerId, JobId, Toolchain, + ServerId, ServerNonce, JobId, Toolchain, ToolchainReader, InputsReader, + JobAuthorizer, AllocJobResult, AssignJobResult, HeartbeatServerResult, RunJobResult, - StatusResult, + SchedulerStatusResult, SubmitToolchainResult, UpdateJobStateResult, JobState, }; use super::common::{ - Cfg, ReqwestRequestBuilderExt, bincode_req, JobJwt, + AllocJobHttpResponse, + ServerCertificateHttpResponse, HeartbeatServerHttpRequest, RunJobHttpRequest, }; + use super::urls; use errors::*; - pub type ClientAuthCheck = Box bool + Send + Sync>; + fn create_https_cert_and_privkey(addr: SocketAddr) -> Result<(Vec, Vec, Vec)> { + let rsa_key = openssl::rsa::Rsa::::generate(2048) + .chain_err(|| "failed to generate rsa privkey")?; + let privkey_pem = rsa_key.private_key_to_pem() + .chain_err(|| "failed to create pem from rsa privkey")?; + let privkey: openssl::pkey::PKey = openssl::pkey::PKey::from_rsa(rsa_key) + .chain_err(|| "failed to create openssl pkey from rsa privkey")?; + let mut builder = openssl::x509::X509::builder() + .chain_err(|| "failed to create x509 builder")?; + + // Populate the certificate with the necessary parts, mostly from mkcert in openssl + builder.set_version(2) + .chain_err(|| "failed to set x509 version")?; + let serial_number = openssl::bn::BigNum::from_u32(0).and_then(|bn| bn.to_asn1_integer()) + .chain_err(|| "failed to create openssl asn1 0")?; + builder.set_serial_number(serial_number.as_ref()) + .chain_err(|| "failed to set x509 serial number")?; + let not_before = openssl::asn1::Asn1Time::days_from_now(0) + .chain_err(|| "failed to create openssl not before asn1")?; + builder.set_not_before(not_before.as_ref()) + .chain_err(|| "failed to set not before on x509")?; + let not_after = openssl::asn1::Asn1Time::days_from_now(365) + .chain_err(|| "failed to create openssl not after asn1")?; + builder.set_not_after(not_after.as_ref()) + .chain_err(|| "failed to set not after on x509")?; + builder.set_pubkey(privkey.as_ref()) + .chain_err(|| "failed to set pubkey for x509")?; + + // Add the SubjectAlternativeName + let extension = openssl::x509::extension::SubjectAlternativeName::new() + .ip(&addr.ip().to_string()) + .build(&builder.x509v3_context(None, None)) + .chain_err(|| "failed to build SAN extension for x509")?; + builder.append_extension(extension) + .chain_err(|| "failed to append SAN extension for x509")?; + + // Finish the certificate + builder.sign(&privkey, openssl::hash::MessageDigest::sha1()) + .chain_err(|| "failed to sign x509 with sha1")?; + let cert: openssl::x509::X509 = builder.build(); + let cert_pem = cert.to_pem() + .chain_err(|| "failed to create pem from x509")?; + let cert_digest = cert.digest(openssl::hash::MessageDigest::sha1()) + .chain_err(|| "failed to create digest of x509 certificate")? + .as_ref().to_owned(); + + Ok((cert_digest, cert_pem, privkey_pem)) + } + + // Messages that are non-sensitive and can be sent to the client + #[derive(Debug)] + pub struct ClientVisibleMsg(String); + impl ClientVisibleMsg { + pub fn from_nonsensitive(s: String) -> Self { ClientVisibleMsg(s) } + } + + pub trait ClientAuthCheck: Send + Sync { + fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg>; + } pub type ServerAuthCheck = Box Option + Send + Sync>; const JWT_KEY_LENGTH: usize = 256 / 8; @@ -253,18 +373,6 @@ mod server { } } - // Based on rouille::Response::json - pub fn bincode_response(content: &T) -> rouille::Response where T: serde::Serialize { - let data = bincode::serialize(content).unwrap(); - - rouille::Response { - status_code: 200, - headers: vec![("Content-Type".into(), "application/octet-stream".into())], - data: rouille::ResponseBody::from_data(data), - upgrade: None, - } - } - // Based on try_or_400 in rouille, but with logging #[derive(Serialize)] pub struct ErrJson<'a> { @@ -278,7 +386,7 @@ mod server { ErrJson { description: err.description(), cause } } fn into_data(self) -> String { - serde_json::to_string(&self).unwrap() + serde_json::to_string(&self).expect("infallible serialization for ErrJson failed") } } macro_rules! try_or_err_and_log { @@ -309,14 +417,17 @@ mod server { macro_rules! try_or_500_log { ($reqid:expr, $result:expr) => { try_or_err_and_log!($reqid, 500, $result) }; } - fn make_401(short_err: &str) -> rouille::Response { + fn make_401_with_body(short_err: &str, body: ClientVisibleMsg) -> rouille::Response { rouille::Response { status_code: 401, headers: vec![("WWW-Authenticate".into(), format!("Bearer error=\"{}\"", short_err).into())], - data: rouille::ResponseBody::empty(), + data: rouille::ResponseBody::from_data(body.0), upgrade: None, } } + fn make_401(short_err: &str) -> rouille::Response { + make_401_with_body(short_err, ClientVisibleMsg(String::new())) + } fn bearer_http_auth(request: &rouille::Request) -> Option<&str> { let header = request.header("Authorization")?; @@ -329,50 +440,105 @@ mod server { split.next() } - macro_rules! try_jwt_or_401 { - ($request:ident, $key:expr, $valid_claims:expr) => {{ - let claims: Result<_> = match bearer_http_auth($request) { - Some(token) => { - jwt::decode(&token, $key, &JWT_VALIDATION) - .map_err(Into::into) - .and_then(|res| { - fn identical_t(_: &T, _: &T) {} - let valid_claims = $valid_claims; - identical_t(&res.claims, &valid_claims); - if res.claims == valid_claims { Ok(()) } else { Err("invalid claims".into()) } - }) - }, - None => Err("no Authorization header".into()), + + // Based on rouille::Response::json + pub fn bincode_response(content: &T) -> rouille::Response where T: serde::Serialize { + let data = bincode::serialize(content).chain_err(|| "Failed to serialize response body"); + let data = try_or_500_log!("bincode body serialization", data); + + rouille::Response { + status_code: 200, + headers: vec![("Content-Type".into(), "application/octet-stream".into())], + data: rouille::ResponseBody::from_data(data), + upgrade: None, + } + } + + // Verification of job auth in a request + macro_rules! job_auth_or_401 { + ($request:ident, $job_authorizer:expr, $job_id:expr) => {{ + let verify_result = match bearer_http_auth($request) { + Some(token) => $job_authorizer.verify_token($job_id, token), + None => Err("no Authorization header".to_owned()), }; - match claims { + match verify_result { Ok(()) => (), Err(err) => { + let err = Error::from(err); let json = ErrJson::from_err(&err); - let mut res = make_401("invalid_jwt"); - res.data = rouille::ResponseBody::from_data(json.into_data()); - return res + return make_401_with_body("invalid_jwt", ClientVisibleMsg(json.into_data())) }, } }}; } + // Generation and verification of job auth + struct JWTJobAuthorizer { + server_key: Vec, + } + impl JWTJobAuthorizer { + fn new(server_key: Vec) -> Box { + Box::new(Self { server_key }) + } + } + impl dist::JobAuthorizer for JWTJobAuthorizer { + fn generate_token(&self, job_id: JobId) -> StdResult { + let claims = JobJwt { job_id }; + jwt::encode(&JWT_HEADER, &claims, &self.server_key) + .map_err(|e| format!("Failed to create JWT for job: {}", e)) + } + fn verify_token(&self, job_id: JobId, token: &str) -> StdResult<(), String> { + let valid_claims = JobJwt { job_id }; + jwt::decode(&token, &self.server_key, &JWT_VALIDATION) + .map_err(|e| format!("JWT decode failed: {}", e)) + .and_then(|res| { + fn identical_t(_: &T, _: &T) {} + identical_t(&res.claims, &valid_claims); + if res.claims == valid_claims { Ok(()) } else { Err("mismatched claims".to_owned()) } + }) + } + } + + #[test] + fn test_job_token_verification() { + let ja = JWTJobAuthorizer::new(vec![1,2,2]); + + let job_id = JobId(55); + let token = ja.generate_token(job_id).unwrap(); + + let job_id2 = JobId(56); + let token2 = ja.generate_token(job_id2).unwrap(); + + let ja2 = JWTJobAuthorizer::new(vec![1,2,3]); + + // Check tokens are deterministic + assert_eq!(token, ja.generate_token(job_id).unwrap()); + // Check token verification works + assert!(ja.verify_token(job_id, &token).is_ok()); + assert!(ja.verify_token(job_id, &token2).is_err()); + assert!(ja.verify_token(job_id2, &token).is_err()); + assert!(ja.verify_token(job_id2, &token2).is_ok()); + // Check token verification with a different key fails + assert!(ja2.verify_token(job_id, &token).is_err()); + assert!(ja2.verify_token(job_id2, &token2).is_err()); + } pub struct Scheduler { + public_addr: SocketAddr, handler: S, // Is this client permitted to use the scheduler? - check_client_auth: ClientAuthCheck, + check_client_auth: Box, // Do we believe the server is who they appear to be? check_server_auth: ServerAuthCheck, } impl Scheduler { - pub fn new(handler: S, check_client_auth: ClientAuthCheck, check_server_auth: ServerAuthCheck) -> Self { - Self { handler, check_client_auth, check_server_auth } + pub fn new(public_addr: SocketAddr, handler: S, check_client_auth: Box, check_server_auth: ServerAuthCheck) -> Self { + Self { public_addr, handler, check_client_auth, check_server_auth } } - pub fn start(self) -> ! { - let Self { handler, check_client_auth, check_server_auth } = self; - let requester = SchedulerRequester { client: reqwest::Client::new() }; - let addr = Cfg::scheduler_listen_addr(); + pub fn start(self) -> Result { + let Self { public_addr, handler, check_client_auth, check_server_auth } = self; + let requester = SchedulerRequester { client: Mutex::new(reqwest::Client::new()) }; macro_rules! check_server_auth_or_401 { ($request:ident) => {{ @@ -384,20 +550,65 @@ mod server { }; } - info!("Scheduler listening for clients on {}", addr); + fn maybe_update_certs(client: &mut reqwest::Client, certs: &mut HashMap, Vec)>, server_id: ServerId, cert_digest: Vec, cert_pem: Vec) -> Result<()> { + if let Some((saved_cert_digest, _)) = certs.get(&server_id) { + if saved_cert_digest == &cert_digest { + return Ok(()) + } + } + info!("Adding new certificate for {} to scheduler", server_id.addr()); + let mut client_builder = reqwest::ClientBuilder::new(); + // Add all the certificates we know about + client_builder.add_root_certificate(reqwest::Certificate::from_pem(&cert_pem) + .chain_err(|| "failed to interpret pem as certificate")?); + for (_, cert_pem) in certs.values() { + client_builder.add_root_certificate(reqwest::Certificate::from_pem(cert_pem).expect("previously valid cert")); + } + // Finish the clients + let new_client = client_builder.build().chain_err(|| "failed to create a HTTP client")?; + // Use the updated certificates + *client = new_client; + certs.insert(server_id, (cert_digest, cert_pem)); + Ok(()) + } + + info!("Scheduler listening for clients on {}", public_addr); let request_count = atomic::AtomicUsize::new(0); - let server = rouille::Server::new(addr, move |request| { + // From server_id -> cert_digest, cert_pem + let server_certificates: Mutex, Vec)>> = Default::default(); + + let server = rouille::Server::new(public_addr, move |request| { let req_id = request_count.fetch_add(1, atomic::Ordering::SeqCst); trace!("Req {} ({}): {:?}", req_id, request.remote_addr(), request); let response = (|| router!(request, (POST) (/api/v1/scheduler/alloc_job) => { - if !bearer_http_auth(request).map_or(false, &*check_client_auth) { - return make_401("invalid_bearer_token") + let bearer_auth = match bearer_http_auth(request) { + Some(s) => s, + None => return make_401("no_bearer_auth"), + }; + match check_client_auth.check(bearer_auth) { + Ok(()) => (), + Err(client_msg) => { + warn!("Bearer auth failed: {:?}", client_msg); + return make_401_with_body("bearer_auth_failed", client_msg) + }, } let toolchain = try_or_400_log!(req_id, bincode_input(request)); trace!("Req {}: alloc_job: {:?}", req_id, toolchain); - let res: AllocJobResult = try_or_500_log!(req_id, handler.handle_alloc_job(&requester, toolchain)); + let alloc_job_res: AllocJobResult = try_or_500_log!(req_id, handler.handle_alloc_job(&requester, toolchain)); + let certs = server_certificates.lock().unwrap(); + let res = AllocJobHttpResponse::from_alloc_job_result(alloc_job_res, &certs); + bincode_response(&res) + }, + (GET) (/api/v1/scheduler/server_certificate/{server_id: ServerId}) => { + let certs = server_certificates.lock().unwrap(); + let (cert_digest, cert_pem) = try_or_500_log!(req_id, certs.get(&server_id) + .ok_or_else(|| Error::from("server cert not available"))); + let res = ServerCertificateHttpResponse { + cert_digest: cert_digest.clone(), + cert_pem: cert_pem.clone(), + }; bincode_response(&res) }, (POST) (/api/v1/scheduler/heartbeat_server) => { @@ -405,12 +616,18 @@ mod server { let heartbeat_server = try_or_400_log!(req_id, bincode_input(request)); trace!("Req {}: heartbeat_server: {:?}", req_id, heartbeat_server); - let HeartbeatServerHttpRequest { num_cpus, jwt_key } = heartbeat_server; - let generate_job_auth = Box::new(move |job_id| { - let claims = JobJwt { job_id }; - jwt::encode(&JWT_HEADER, &claims, &jwt_key).unwrap() - }); - let res: HeartbeatServerResult = handler.handle_heartbeat_server(server_id, num_cpus, generate_job_auth).unwrap(); + let HeartbeatServerHttpRequest { num_cpus, jwt_key, server_nonce, cert_digest, cert_pem } = heartbeat_server; + try_or_500_log!(req_id, maybe_update_certs( + &mut requester.client.lock().unwrap(), + &mut server_certificates.lock().unwrap(), + server_id, cert_digest, cert_pem + )); + let job_authorizer = JWTJobAuthorizer::new(jwt_key); + let res: HeartbeatServerResult = try_or_500_log!(req_id, handler.handle_heartbeat_server( + server_id, server_nonce, + num_cpus, + job_authorizer + )); bincode_response(&res) }, (POST) (/api/v1/scheduler/job_state/{job_id: JobId}) => { @@ -418,71 +635,98 @@ mod server { let job_state = try_or_400_log!(req_id, bincode_input(request)); trace!("Req {}: job state: {:?}", req_id, job_state); - let res: UpdateJobStateResult = handler.handle_update_job_state(job_id, server_id, job_state).unwrap(); + let res: UpdateJobStateResult = try_or_500_log!(req_id, handler.handle_update_job_state( + job_id, server_id, job_state + )); bincode_response(&res) }, (GET) (/api/v1/scheduler/status) => { - let res: StatusResult = handler.handle_status().unwrap(); + let res: SchedulerStatusResult = try_or_500_log!(req_id, handler.handle_status()); bincode_response(&res) }, _ => { warn!("Unknown request {:?}", request); rouille::Response::empty_404() }, - )) (); + ))(); trace!("Res {}: {:?}", req_id, response); response - }).unwrap(); + }).map_err(|e| Error::with_boxed_chain(e, ErrorKind::Msg("Failed to start http server for sccache scheduler".to_owned())))?; server.run(); - unreachable!() + panic!("Rouille server terminated") } } struct SchedulerRequester { - client: reqwest::Client, + client: Mutex, } impl dist::SchedulerOutgoing for SchedulerRequester { fn do_assign_job(&self, server_id: ServerId, job_id: JobId, tc: Toolchain, auth: String) -> Result { - let url = format!("http://{}/api/v1/distserver/assign_job/{}", server_id.addr(), job_id); - bincode_req(self.client.post(&url).bearer_auth(auth).bincode(&tc)?) + let url = urls::server_assign_job(server_id, job_id); + let mut req = self.client.lock().unwrap().post(url); + bincode_req(req.bearer_auth(auth).bincode(&tc)?) + .chain_err(|| "POST to scheduler assign_job failed") } } pub struct Server { - scheduler_addr: SocketAddr, + public_addr: SocketAddr, + scheduler_url: reqwest::Url, scheduler_auth: String, - handler: S, + // HTTPS pieces all the builders will use for connection encryption + cert_digest: Vec, + cert_pem: Vec, + privkey_pem: Vec, + // Key used to sign any requests relating to jobs jwt_key: Vec, + // Randomly generated nonce to allow the scheduler to detect server restarts + server_nonce: ServerNonce, + handler: S, } impl Server { - pub fn new(scheduler_addr: IpAddr, scheduler_auth: String, handler: S) -> Self { + pub fn new(public_addr: SocketAddr, scheduler_url: reqwest::Url, scheduler_auth: String, handler: S) -> Result { + let (cert_digest, cert_pem, privkey_pem) = create_https_cert_and_privkey(public_addr) + .chain_err(|| "failed to create HTTPS certificate for server")?; let mut jwt_key = vec![0; JWT_KEY_LENGTH]; - let mut rng = rand::OsRng::new().unwrap(); + let mut rng = rand::OsRng::new().chain_err(|| "Failed to initialise a random number generator")?; rng.fill_bytes(&mut jwt_key); - Self { - scheduler_addr: Cfg::scheduler_connect_addr(scheduler_addr), + let server_nonce = ServerNonce::from_rng(&mut rng); + + Ok(Self { + public_addr, + scheduler_url, scheduler_auth, + cert_digest, + cert_pem, + privkey_pem, jwt_key, + server_nonce, handler, - } + }) } - pub fn start(self) -> ! { - let Self { scheduler_addr, scheduler_auth, jwt_key, handler } = self; - let requester = ServerRequester { client: reqwest::Client::new(), scheduler_addr, scheduler_auth: scheduler_auth.clone() }; - let addr = Cfg::server_listen_addr(); + pub fn start(self) -> Result { + let Self { public_addr, scheduler_url, scheduler_auth, cert_digest, cert_pem, privkey_pem, jwt_key, server_nonce, handler } = self; + let heartbeat_req = HeartbeatServerHttpRequest { + num_cpus: num_cpus::get(), + jwt_key: jwt_key.clone(), + server_nonce, + cert_digest, + cert_pem: cert_pem.clone(), + }; + let job_authorizer = JWTJobAuthorizer::new(jwt_key); + let heartbeat_url = urls::scheduler_heartbeat_server(&scheduler_url); + let requester = ServerRequester { client: reqwest::Client::new(), scheduler_url, scheduler_auth: scheduler_auth.clone() }; // TODO: detect if this panics - let heartbeat_req = HeartbeatServerHttpRequest { num_cpus: num_cpus::get(), jwt_key: jwt_key.clone() }; thread::spawn(move || { - let url = format!("http://{}:{}/api/v1/scheduler/heartbeat_server", scheduler_addr.ip(), scheduler_addr.port()); let client = reqwest::Client::new(); loop { trace!("Performing heartbeat"); - match bincode_req(client.post(&url).bearer_auth(scheduler_auth.clone()).bincode(&heartbeat_req).unwrap()) { + match bincode_req(client.post(heartbeat_url.clone()).bearer_auth(scheduler_auth.clone()).bincode(&heartbeat_req).expect("failed to serialize heartbeat")) { Ok(HeartbeatServerResult { is_new }) => { trace!("Heartbeat success is_new={}", is_new); // TODO: if is_new, terminate all running jobs @@ -496,37 +740,40 @@ mod server { } }); - info!("Server listening for clients on {}", addr); + info!("Server listening for clients on {}", public_addr); let request_count = atomic::AtomicUsize::new(0); - let server = rouille::Server::new(addr, move |request| { + + let server = rouille::Server::new_ssl(public_addr, move |request| { let req_id = request_count.fetch_add(1, atomic::Ordering::SeqCst); trace!("Req {} ({}): {:?}", req_id, request.remote_addr(), request); let response = (|| router!(request, (POST) (/api/v1/distserver/assign_job/{job_id: JobId}) => { - try_jwt_or_401!(request, &jwt_key, JobJwt { job_id }); + job_auth_or_401!(request, &job_authorizer, job_id); let toolchain = try_or_400_log!(req_id, bincode_input(request)); trace!("Req {}: assign_job({}): {:?}", req_id, job_id, toolchain); - let res: AssignJobResult = try_or_500_log!(req_id, handler.handle_assign_job(job_id, toolchain)); + let res: AssignJobResult = try_or_500_log!(req_id, handler.handle_assign_job(&requester, job_id, toolchain)); bincode_response(&res) }, (POST) (/api/v1/distserver/submit_toolchain/{job_id: JobId}) => { - try_jwt_or_401!(request, &jwt_key, JobJwt { job_id }); + job_auth_or_401!(request, &job_authorizer, job_id); trace!("Req {}: submit_toolchain({})", req_id, job_id); - let mut body = request.data().unwrap(); + let mut body = request.data().expect("body was already read in submit_toolchain"); let toolchain_rdr = ToolchainReader(Box::new(body)); let res: SubmitToolchainResult = try_or_500_log!(req_id, handler.handle_submit_toolchain(&requester, job_id, toolchain_rdr)); bincode_response(&res) }, (POST) (/api/v1/distserver/run_job/{job_id: JobId}) => { - try_jwt_or_401!(request, &jwt_key, JobJwt { job_id }); + job_auth_or_401!(request, &job_authorizer, job_id); - let mut body = request.data().unwrap(); - let bincode_length = body.read_u32::().unwrap() as u64; + let mut body = request.data().expect("body was already read in run_job"); + let bincode_length = try_or_500_log!(req_id, body.read_u32::() + .chain_err(|| "failed to read run job input length")) as u64; let mut bincode_reader = body.take(bincode_length); - let runjob = bincode::deserialize_from(&mut bincode_reader).unwrap(); + let runjob = try_or_500_log!(req_id, bincode::deserialize_from(&mut bincode_reader) + .chain_err(|| "failed to deserialize run job request")); trace!("Req {}: run_job({}): {:?}", req_id, job_id, runjob); let RunJobHttpRequest { command, outputs } = runjob; let body = bincode_reader.into_inner(); @@ -543,23 +790,24 @@ mod server { ))(); trace!("Res {}: {:?}", req_id, response); response - }).unwrap(); + }, cert_pem, privkey_pem).map_err(|e| Error::with_boxed_chain(e, ErrorKind::Msg("Failed to start http server for sccache server".to_owned())))?; server.run(); - unreachable!() + panic!("Rouille server terminated") } } struct ServerRequester { client: reqwest::Client, - scheduler_addr: SocketAddr, + scheduler_url: reqwest::Url, scheduler_auth: String, } impl dist::ServerOutgoing for ServerRequester { fn do_update_job_state(&self, job_id: JobId, state: JobState) -> Result { - let url = format!("http://{}/api/v1/scheduler/job_state/{}", self.scheduler_addr, job_id); - bincode_req(self.client.post(&url).bearer_auth(self.scheduler_auth.clone()).bincode(&state)?) + let url = urls::scheduler_job_state(&self.scheduler_url, job_id); + bincode_req(self.client.post(url).bearer_auth(self.scheduler_auth.clone()).bincode(&state)?) + .chain_err(|| "POST to scheduler job_state failed") } } } @@ -573,13 +821,13 @@ mod client { use dist::pkg::{InputsPackager, ToolchainPackager}; use flate2::Compression; use flate2::write::ZlibEncoder as ZlibWriteEncoder; - use futures::{Future, Stream}; + use futures::Future; use futures_cpupool::CpuPool; use reqwest; - use std::fs; + use std::collections::HashMap; use std::io::Write; - use std::net::{IpAddr, SocketAddr}; use std::path::Path; + use std::sync::{Arc, Mutex}; use std::time::Duration; use super::super::cache; use tokio_core; @@ -590,80 +838,153 @@ mod client { AllocJobResult, JobAlloc, RunJobResult, SubmitToolchainResult, }; use super::common::{ - Cfg, ReqwestRequestBuilderExt, bincode_req, bincode_req_fut, + AllocJobHttpResponse, + ServerCertificateHttpResponse, RunJobHttpRequest, }; + use super::urls; use errors::*; const REQUEST_TIMEOUT_SECS: u64 = 600; pub struct Client { auth_token: String, - scheduler_addr: SocketAddr, + scheduler_url: reqwest::Url, + // cert_digest -> cert_pem + server_certs: Arc, Vec>>>, // TODO: this should really only use the async client, but reqwest async bodies are extremely limited // and only support owned bytes, which means the whole toolchain would end up in memory - client: reqwest::Client, - client_async: reqwest::unstable::async::Client, + client: Arc>, + client_async: Arc>, + handle: tokio_core::reactor::Handle, pool: CpuPool, - tc_cache: cache::ClientToolchains, + tc_cache: Arc, } impl Client { - pub fn new(handle: &tokio_core::reactor::Handle, pool: &CpuPool, scheduler_addr: IpAddr, cache_dir: &Path, cache_size: u64, toolchain_configs: &[config::DistToolchainConfig], auth_token: String) -> Self { + pub fn new(handle: tokio_core::reactor::Handle, pool: &CpuPool, scheduler_url: reqwest::Url, cache_dir: &Path, cache_size: u64, toolchain_configs: &[config::DistToolchainConfig], auth_token: String) -> Result { let timeout = Duration::new(REQUEST_TIMEOUT_SECS, 0); - let client = reqwest::ClientBuilder::new().timeout(timeout).build().unwrap(); - let client_async = reqwest::unstable::async::ClientBuilder::new().timeout(timeout).build(handle).unwrap(); - Self { + let client = reqwest::ClientBuilder::new().timeout(timeout).build() + .chain_err(|| "failed to create a HTTP client")?; + let client_async = reqwest::unstable::async::ClientBuilder::new().timeout(timeout).build(&handle) + .chain_err(|| "failed to create an async HTTP client")?; + let client_toolchains = cache::ClientToolchains::new(cache_dir, cache_size, toolchain_configs) + .chain_err(|| "failed to initialise client toolchains")?; + Ok(Self { auth_token, - scheduler_addr: Cfg::scheduler_connect_addr(scheduler_addr), - client, - client_async, + scheduler_url, + server_certs: Default::default(), + client: Arc::new(Mutex::new(client)), + client_async: Arc::new(Mutex::new(client_async)), + handle, pool: pool.clone(), - tc_cache: cache::ClientToolchains::new(cache_dir, cache_size, toolchain_configs), + tc_cache: Arc::new(client_toolchains), + }) + } + + fn update_certs(client: &mut reqwest::Client, client_async: &mut reqwest::unstable::async::Client, handle: tokio_core::reactor::Handle, certs: &mut HashMap, Vec>, cert_digest: Vec, cert_pem: Vec) -> Result<()> { + let mut client_builder = reqwest::ClientBuilder::new(); + let mut client_async_builder = reqwest::unstable::async::ClientBuilder::new(); + // Add all the certificates we know about + client_builder.add_root_certificate(reqwest::Certificate::from_pem(&cert_pem) + .chain_err(|| "failed to interpret pem as certificate")?); + client_async_builder.add_root_certificate(reqwest::Certificate::from_pem(&cert_pem) + .chain_err(|| "failed to interpret pem as certificate")?); + for cert_pem in certs.values() { + client_builder.add_root_certificate(reqwest::Certificate::from_pem(cert_pem).expect("previously valid cert")); + client_async_builder.add_root_certificate(reqwest::Certificate::from_pem(cert_pem).expect("previously valid cert")); } + // Finish the clients + let timeout = Duration::new(REQUEST_TIMEOUT_SECS, 0); + let new_client = client_builder.timeout(timeout).build() + .chain_err(|| "failed to create a HTTP client")?; + let new_client_async = client_async_builder.timeout(timeout).build(&handle) + .chain_err(|| "failed to create an async HTTP client")?; + // Use the updated certificates + *client = new_client; + *client_async = new_client_async; + certs.insert(cert_digest, cert_pem); + Ok(()) } } impl dist::Client for Client { fn do_alloc_job(&self, tc: Toolchain) -> SFuture { - let url = format!("http://{}/api/v1/scheduler/alloc_job", self.scheduler_addr); - Box::new(f_res(self.client_async.post(&url).bearer_auth(self.auth_token.clone()).bincode(&tc).map(bincode_req_fut)).and_then(|r| r)) + let scheduler_url = self.scheduler_url.clone(); + let url = urls::scheduler_alloc_job(&scheduler_url); + let mut req = self.client_async.lock().unwrap().post(url); + ftry!(req.bearer_auth(self.auth_token.clone()).bincode(&tc)); + + let client = self.client.clone(); + let client_async = self.client_async.clone(); + let handle = self.handle.clone(); + let server_certs = self.server_certs.clone(); + Box::new(bincode_req_fut(&mut req).map_err(|e| e.chain_err(|| "POST to scheduler alloc_job failed")).and_then(move |res| { + match res { + AllocJobHttpResponse::Success { job_alloc, need_toolchain, cert_digest } => { + let server_id = job_alloc.server_id; + let alloc_job_res = f_ok(AllocJobResult::Success { job_alloc, need_toolchain }); + if server_certs.lock().unwrap().contains_key(&cert_digest) { + return alloc_job_res + } + info!("Need to request new certificate for server {}", server_id.addr()); + let url = urls::scheduler_server_certificate(&scheduler_url, server_id); + let mut req = client_async.lock().unwrap().get(url); + Box::new(bincode_req_fut(&mut req).map_err(|e| e.chain_err(|| "GET to scheduler server_certificate failed")) + .and_then(move |res: ServerCertificateHttpResponse| { + ftry!(Self::update_certs( + &mut client.lock().unwrap(), &mut client_async.lock().unwrap(), + handle, + &mut server_certs.lock().unwrap(), + res.cert_digest, res.cert_pem, + )); + alloc_job_res + })) + }, + AllocJobHttpResponse::Fail { msg } => { + f_ok(AllocJobResult::Fail { msg }) + }, + } + })) } fn do_submit_toolchain(&self, job_alloc: JobAlloc, tc: Toolchain) -> SFuture { - if let Some(toolchain_file) = self.tc_cache.get_toolchain(&tc) { - let url = format!("http://{}/api/v1/distserver/submit_toolchain/{}", job_alloc.server_id.addr(), job_alloc.job_id); - let mut req = self.client.post(&url); + match self.tc_cache.get_toolchain(&tc) { + Ok(Some(toolchain_file)) => { + let url = urls::server_submit_toolchain(job_alloc.server_id, job_alloc.job_id); + let mut req = self.client.lock().unwrap().post(url); - Box::new(self.pool.spawn_fn(move || { - req.bearer_auth(job_alloc.auth.clone()).body(toolchain_file); - bincode_req(&mut req) - })) - } else { - f_err("couldn't find toolchain locally") + Box::new(self.pool.spawn_fn(move || { + req.bearer_auth(job_alloc.auth.clone()).body(toolchain_file); + bincode_req(&mut req) + })) + }, + Ok(None) => f_err("couldn't find toolchain locally"), + Err(e) => f_err(e), } } fn do_run_job(&self, job_alloc: JobAlloc, command: CompileCommand, outputs: Vec, inputs_packager: Box) -> SFuture<(RunJobResult, PathTransformer)> { - let url = format!("http://{}/api/v1/distserver/run_job/{}", job_alloc.server_id.addr(), job_alloc.job_id); - let mut req = self.client.post(&url); + let url = urls::server_run_job(job_alloc.server_id, job_alloc.job_id); + let mut req = self.client.lock().unwrap().post(url); Box::new(self.pool.spawn_fn(move || { - let bincode = bincode::serialize(&RunJobHttpRequest { command, outputs }).unwrap(); + let bincode = bincode::serialize(&RunJobHttpRequest { command, outputs }) + .chain_err(|| "failed to serialize run job request")?; let bincode_length = bincode.len(); let mut body = vec![]; - body.write_u32::(bincode_length as u32).unwrap(); - body.write(&bincode).unwrap(); + body.write_u32::(bincode_length as u32).expect("Infallible write of bincode length to vec failed"); + body.write(&bincode).expect("Infallible write of bincode body to vec failed"); let path_transformer; { let mut compressor = ZlibWriteEncoder::new(&mut body, Compression::fast()); path_transformer = inputs_packager.write_inputs(&mut compressor).chain_err(|| "Could not write inputs for compilation")?; - compressor.flush().unwrap(); + compressor.flush().chain_err(|| "failed to flush compressor")?; trace!("Compressed inputs from {} -> {}", compressor.total_in(), compressor.total_out()); - compressor.finish().unwrap(); + compressor.finish().chain_err(|| "failed to finish compressor")?; } req.bearer_auth(job_alloc.auth.clone()).bytes(body); @@ -671,11 +992,11 @@ mod client { })) } - fn put_toolchain(&self, compiler_path: &Path, weak_key: &str, toolchain_packager: Box) -> Result<(Toolchain, Option)> { - self.tc_cache.put_toolchain(compiler_path, weak_key, toolchain_packager) - } - fn may_dist(&self) -> bool { - true + fn put_toolchain(&self, compiler_path: &Path, weak_key: &str, toolchain_packager: Box) -> SFuture<(Toolchain, Option)> { + let compiler_path = compiler_path.to_owned(); + let weak_key = weak_key.to_owned(); + let tc_cache = self.tc_cache.clone(); + Box::new(self.pool.spawn_fn(move || tc_cache.put_toolchain(&compiler_path, &weak_key, toolchain_packager))) } } } diff --git a/src/dist/mod.rs b/src/dist/mod.rs index af2c5f6d..caab7f98 100644 --- a/src/dist/mod.rs +++ b/src/dist/mod.rs @@ -13,6 +13,7 @@ // limitations under the License. use compiler; +use rand::{self, RngCore}; use std::fmt; use std::io::{self, Read}; use std::net::SocketAddr; @@ -56,17 +57,17 @@ mod path_transform { use std::path::{Component, Components, Path, PathBuf, Prefix, PrefixComponent}; use std::str; - fn take_prefix<'a>(components: &'a mut Components) -> PrefixComponent<'a> { - let prefix = components.next().unwrap(); + fn take_prefix<'a>(components: &'a mut Components) -> Option> { + let prefix = components.next()?; let pc = match prefix { Component::Prefix(pc) => pc, - _ => panic!("unrecognised start to path: {:?}", prefix), + _ => return None, }; - let root = components.next().unwrap(); + let root = components.next()?; if root != Component::RootDir { - panic!("unexpected non-root component in path starting {:?}", prefix) + return None } - pc + Some(pc) } fn transform_prefix_component(pc: PrefixComponent) -> Option { @@ -78,7 +79,7 @@ mod path_transform { Prefix::VerbatimDisk(diskchar) => { assert!(diskchar.is_ascii_alphabetic()); let diskchar = diskchar.to_ascii_uppercase(); - Some(format!("/prefix/disk-{}", str::from_utf8(&[diskchar]).unwrap())) + Some(format!("/prefix/disk-{}", str::from_utf8(&[diskchar]).expect("invalid disk char"))) }, Prefix::Verbatim(_) | Prefix::VerbatimUNC(_, _) | @@ -98,20 +99,23 @@ mod path_transform { dist_to_local_path: HashMap::new(), } } - pub fn to_dist_assert_abs(&mut self, p: &Path) -> Option { - if !p.is_absolute() { panic!("non absolute path {:?}", p) } + pub fn to_dist_abs(&mut self, p: &Path) -> Option { + if !p.is_absolute() { return None } self.to_dist(p) } pub fn to_dist(&mut self, p: &Path) -> Option { let mut components = p.components(); + // Extract the prefix (e.g. "C:/") if present let maybe_dist_prefix = if p.is_absolute() { - let pc = take_prefix(&mut components); + let pc = take_prefix(&mut components) + .expect("could not take prefix from absolute path"); Some(transform_prefix_component(pc)?) } else { None }; + // Reconstruct the path (minus the prefix) as a Linux path let mut dist_suffix = String::new(); for component in components { let part = match component { @@ -152,7 +156,8 @@ mod path_transform { continue } let mut components = local_path.components(); - let mut local_prefix = take_prefix(&mut components); + let mut local_prefix = take_prefix(&mut components) + .expect("could not take prefix from absolute path"); let local_prefix_component = Component::Prefix(local_prefix); let local_prefix_path: &Path = local_prefix_component.as_ref(); let mappings = if let Prefix::VerbatimDisk(_) = local_prefix.kind() { @@ -163,17 +168,73 @@ mod path_transform { if mappings.contains_key(local_prefix_path) { continue } - let dist_prefix = transform_prefix_component(local_prefix).unwrap(); + let dist_prefix = transform_prefix_component(local_prefix) + .expect("prefix already in tracking map could not be transformed"); mappings.insert(local_prefix_path.to_owned(), dist_prefix); } // Prioritise normal mappings for the same disk, as verbatim mappings can // look odd to users normal_mappings.into_iter().chain(verbatim_mappings) } - pub fn to_local(&self, p: &str) -> PathBuf { - self.dist_to_local_path.get(p).unwrap().clone() + pub fn to_local(&self, p: &str) -> Option { + self.dist_to_local_path.get(p).cloned() } } + + #[test] + fn test_basic() { + let mut pt = PathTransformer::new(); + assert_eq!(pt.to_dist(Path::new("C:/a")).unwrap(), "/prefix/disk-C/a"); + assert_eq!(pt.to_dist(Path::new(r#"C:\a\b.c"#)).unwrap(), "/prefix/disk-C/a/b.c"); + assert_eq!(pt.to_dist(Path::new("X:/other.c")).unwrap(), "/prefix/disk-X/other.c"); + let mut disk_mappings: Vec<_> = pt.disk_mappings().collect(); + disk_mappings.sort(); + assert_eq!( + disk_mappings, + &[ + (Path::new("C:").into(), "/prefix/disk-C".into()), + (Path::new("X:").into(), "/prefix/disk-X".into()), + ] + ); + assert_eq!(pt.to_local("/prefix/disk-C/a").unwrap(), Path::new("C:/a")); + assert_eq!(pt.to_local("/prefix/disk-C/a/b.c").unwrap(), Path::new("C:/a/b.c")); + assert_eq!(pt.to_local("/prefix/disk-X/other.c").unwrap(), Path::new("X:/other.c")); + } + + #[test] + fn test_relative_paths() { + let mut pt = PathTransformer::new(); + assert_eq!(pt.to_dist(Path::new("a/b")).unwrap(), "a/b"); + assert_eq!(pt.to_dist(Path::new(r#"a\b"#)).unwrap(), "a/b"); + assert_eq!(pt.to_local("a/b").unwrap(), Path::new("a/b")); + } + + #[test] + fn test_verbatim_disks() { + let mut pt = PathTransformer::new(); + assert_eq!(pt.to_dist(Path::new("X:/other.c")).unwrap(), "/prefix/disk-X/other.c"); + pt.to_dist(Path::new(r#"\\?\X:\out\other.o"#)); + assert_eq!(pt.to_local("/prefix/disk-X/other.c").unwrap(), Path::new("X:/other.c")); + assert_eq!(pt.to_local("/prefix/disk-X/out/other.o").unwrap(), Path::new(r#"\\?\X:\out\other.o"#)); + let disk_mappings: Vec<_> = pt.disk_mappings().collect(); + // Verbatim disks should come last + assert_eq!( + disk_mappings, + &[ + (Path::new("X:").into(), "/prefix/disk-X".into()), + (Path::new(r#"\\?\X:"#).into(), "/prefix/disk-X".into()), + ] + ); + } + + #[test] + fn test_slash_directions() { + let mut pt = PathTransformer::new(); + assert_eq!(pt.to_dist(Path::new("C:/a")).unwrap(), "/prefix/disk-C/a"); + assert_eq!(pt.to_dist(Path::new("C:\\a")).unwrap(), "/prefix/disk-C/a"); + assert_eq!(pt.to_local("/prefix/disk-C/a").unwrap(), Path::new("C:/a")); + assert_eq!(pt.disk_mappings().count(), 1); + } } #[cfg(unix)] @@ -186,8 +247,8 @@ mod path_transform { impl PathTransformer { pub fn new() -> Self { PathTransformer } - pub fn to_dist_assert_abs(&mut self, p: &Path) -> Option { - if !p.is_absolute() { panic!("non absolute path {:?}", p) } + pub fn to_dist_abs(&mut self, p: &Path) -> Option { + if !p.is_absolute() { return None } self.to_dist(p) } pub fn to_dist(&mut self, p: &Path) -> Option { @@ -196,8 +257,8 @@ mod path_transform { pub fn disk_mappings(&self) -> impl Iterator { iter::empty() } - pub fn to_local(&self, p: &str) -> PathBuf { - PathBuf::from(p) + pub fn to_local(&self, p: &str) -> Option { + Some(PathBuf::from(p)) } } } @@ -257,12 +318,31 @@ impl FromStr for JobId { #[derive(Hash, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] -pub struct ServerId(pub SocketAddr); +pub struct ServerId(SocketAddr); impl ServerId { + pub fn new(addr: SocketAddr) -> Self { + ServerId(addr) + } pub fn addr(&self) -> SocketAddr { self.0 } } +impl FromStr for ServerId { + type Err = ::Err; + fn from_str(s: &str) -> ::std::result::Result { + SocketAddr::from_str(s).map(ServerId) + } +} +#[derive(Eq, PartialEq)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ServerNonce(u64); +impl ServerNonce { + pub fn from_rng(rng: &mut rand::OsRng) -> Self { + ServerNonce(rng.next_u64()) + } +} + #[derive(Hash, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -294,17 +374,30 @@ pub struct CompileCommand { pub cwd: String, } -// process::Output is not serialize +// process::Output is not serialize so we have a custom Output type. However, +// we cannot encode all information in here, such as Unix signals, as the other +// end may not understand them (e.g. if it's Windows) #[derive(Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct ProcessOutput { - code: Option, // TODO: extract the extra info from the UnixCommandExt + code: i32, stdout: Vec, stderr: Vec, } -impl From for ProcessOutput { - fn from(o: process::Output) -> Self { - ProcessOutput { code: o.status.code(), stdout: o.stdout, stderr: o.stderr } +impl ProcessOutput { + #[cfg(unix)] + pub fn try_from(o: process::Output) -> Result { + let process::Output { status, stdout, stderr } = o; + let code = match (status.code(), status.signal()) { + (Some(c), _) => c, + (None, Some(s)) => bail!("Process status {} terminated with signal {}", status, s), + (None, None) => bail!("Process status {} has no exit code or signal", status), + }; + Ok(ProcessOutput { code, stdout, stderr }) + } + #[cfg(test)] + pub fn fake_output(code: i32, stdout: Vec, stderr: Vec) -> Self { + Self { code, stdout, stderr } } } #[cfg(unix)] @@ -317,13 +410,14 @@ fn exit_status(code: i32) -> process::ExitStatus { } #[cfg(windows)] fn exit_status(code: i32) -> process::ExitStatus { - // TODO: this is probably a subideal conversion + // TODO: this is probably a subideal conversion - it's not clear how Unix exit codes map to + // Windows exit codes (other than 0 being a success) process::ExitStatus::from_raw(code as u32) } impl From for process::Output { fn from(o: ProcessOutput) -> Self { // TODO: handle signals, i.e. None code - process::Output { status: exit_status(o.code.unwrap()), stdout: o.stdout, stderr: o.stderr } + process::Output { status: exit_status(o.code), stdout: o.stdout, stderr: o.stderr } } } @@ -331,14 +425,14 @@ impl From for process::Output { #[serde(deny_unknown_fields)] pub struct OutputData(Vec, u64); impl OutputData { - #[cfg(feature = "dist-server")] - pub fn from_reader(r: R) -> Self { + #[cfg(any(feature = "dist-server", all(feature = "dist-client", test)))] + pub fn try_from_reader(r: R) -> io::Result { use flate2::Compression; use flate2::read::ZlibEncoder as ZlibReadEncoder; let mut compressor = ZlibReadEncoder::new(r, Compression::fast()); let mut res = vec![]; - io::copy(&mut compressor, &mut res).unwrap(); - OutputData(res, compressor.total_in()) + io::copy(&mut compressor, &mut res)?; + Ok(OutputData(res, compressor.total_in())) } pub fn lens(&self) -> OutputDataLens { OutputDataLens { actual: self.1, compressed: self.0.len() as u64 } @@ -366,7 +460,7 @@ impl fmt::Display for OutputDataLens { // AllocJob -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct JobAlloc { pub auth: String, @@ -422,9 +516,9 @@ pub struct JobComplete { // Status -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] -pub struct StatusResult { +pub struct SchedulerStatusResult { pub num_servers: usize, } @@ -479,24 +573,31 @@ pub trait ServerOutgoing { fn do_update_job_state(&self, job_id: JobId, state: JobState) -> Result; } +// Trait to handle the creation and verification of job authorization tokens +#[cfg(feature = "dist-server")] +pub trait JobAuthorizer: Send { + fn generate_token(&self, job_id: JobId) -> ExtResult; + fn verify_token(&self, job_id: JobId, token: &str) -> ExtResult<(), String>; +} + #[cfg(feature = "dist-server")] pub trait SchedulerIncoming: Send + Sync { type Error: ::std::error::Error; // From Client fn handle_alloc_job(&self, requester: &SchedulerOutgoing, tc: Toolchain) -> ExtResult; // From Server - fn handle_heartbeat_server(&self, server_id: ServerId, num_cpus: usize, generate_job_auth: Box String + Send>) -> ExtResult; + fn handle_heartbeat_server(&self, server_id: ServerId, server_nonce: ServerNonce, num_cpus: usize, job_authorizer: Box) -> ExtResult; // From Server fn handle_update_job_state(&self, job_id: JobId, server_id: ServerId, job_state: JobState) -> ExtResult; // From anyone - fn handle_status(&self) -> ExtResult; + fn handle_status(&self) -> ExtResult; } #[cfg(feature = "dist-server")] pub trait ServerIncoming: Send + Sync { type Error: ::std::error::Error; // From Scheduler - fn handle_assign_job(&self, job_id: JobId, tc: Toolchain) -> ExtResult; + fn handle_assign_job(&self, requester: &ServerOutgoing, job_id: JobId, tc: Toolchain) -> ExtResult; // From Client fn handle_submit_toolchain(&self, requester: &ServerOutgoing, job_id: JobId, tc_rdr: ToolchainReader) -> ExtResult; // From Client @@ -519,29 +620,5 @@ pub trait Client { fn do_submit_toolchain(&self, job_alloc: JobAlloc, tc: Toolchain) -> SFuture; // To Server fn do_run_job(&self, job_alloc: JobAlloc, command: CompileCommand, outputs: Vec, inputs_packager: Box) -> SFuture<(RunJobResult, PathTransformer)>; - fn put_toolchain(&self, compiler_path: &Path, weak_key: &str, toolchain_packager: Box) -> Result<(Toolchain, Option)>; - fn may_dist(&self) -> bool; -} - -///////// - -pub struct NoopClient; - -impl Client for NoopClient { - fn do_alloc_job(&self, _tc: Toolchain) -> SFuture { - f_ok(AllocJobResult::Fail { msg: "Using NoopClient".to_string() }) - } - fn do_submit_toolchain(&self, _job_alloc: JobAlloc, _tc: Toolchain) -> SFuture { - panic!("NoopClient"); - } - fn do_run_job(&self, _job_alloc: JobAlloc, _command: CompileCommand, _outputs: Vec, _inputs_packager: Box) -> SFuture<(RunJobResult, PathTransformer)> { - panic!("NoopClient"); - } - - fn put_toolchain(&self, _compiler_path: &Path, _weak_key: &str, _toolchain_packager: Box) -> Result<(Toolchain, Option)> { - bail!("NoopClient"); - } - fn may_dist(&self) -> bool { - false - } + fn put_toolchain(&self, compiler_path: &Path, weak_key: &str, toolchain_packager: Box) -> SFuture<(Toolchain, Option)>; } diff --git a/src/dist/pkg.rs b/src/dist/pkg.rs index a2a2ce0a..5aa8dded 100644 --- a/src/dist/pkg.rs +++ b/src/dist/pkg.rs @@ -23,7 +23,7 @@ use errors::*; pub use self::toolchain_imp::*; -pub trait ToolchainPackager { +pub trait ToolchainPackager: Send { fn write_pkg(self: Box, f: fs::File) -> Result<()>; } @@ -44,7 +44,7 @@ mod toolchain_imp { // Distributed client, but an unsupported platform for toolchain packaging so // create a failing implementation that will conflict with any others. - impl ToolchainPackager for T { + impl ToolchainPackager for T { fn write_pkg(self: Box, _f: fs::File) -> Result<()> { bail!("Automatic packaging not supported on this platform") } @@ -54,7 +54,7 @@ mod toolchain_imp { #[cfg(all(target_os = "linux", target_arch = "x86_64"))] mod toolchain_imp { use std::collections::BTreeMap; - use std::io::Write; + use std::io::{Read, Write}; use std::fs; use std::path::{Component, Path, PathBuf}; use std::process; @@ -88,7 +88,9 @@ mod toolchain_imp { if self.file_set.contains_key(&tar_path) { continue } - remaining.extend(find_ldd_libraries(&obj_path)?); + let ldd_libraries = find_ldd_libraries(&obj_path) + .chain_err(|| format!("Failed to analyse {} with ldd", obj_path.display()))?; + remaining.extend(ldd_libraries); self.file_set.insert(tar_path, obj_path); } Ok(()) @@ -146,13 +148,43 @@ mod toolchain_imp { // libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f6877711000) // /lib64/ld-linux-x86-64.so.2 (0x00007f6878171000) // libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f68774f4000) + // + // Elf executables can be statically or dynamically linked, and position independant (PIE) or not: + // - dynamic + PIE = ET_DYN, ldd stdouts something like the list above and exits with code 0 + // - dynamic + non-PIE = ET_EXEC, ldd stdouts something like the list above and exits with code 0 + // - static + PIE = ET_DYN, ldd stdouts something like "\tstatically linked" or + // "\tldd (0x7f79ef662000)" and exits with code 0 + // - static + non-PIE = ET_EXEC, ldd stderrs something like "\tnot a dynamic executable" or + // "ldd: a.out: Not a valid dynamic program" and exits with code 1 + // fn find_ldd_libraries(executable: &Path) -> Result> { let process::Output { status, stdout, stderr } = process::Command::new("ldd").arg(executable).output()?; - // Not a file ldd understands + // Not a file ldd can handle. This can be a non-executable, or a static non-PIE if !status.success() { - bail!(format!("ldd failed to run on {}", executable.to_string_lossy())) + // Best-effort detection of static non-PIE + let mut elf = fs::File::open(executable)?; + let mut elf_bytes = [0; 0x12]; + elf.read_exact(&mut elf_bytes)?; + if elf_bytes[..0x4] != [0x7f, 0x45, 0x4c, 0x46] { + bail!("Elf magic not found") + } + let little_endian = match elf_bytes[0x5] { + 1 => true, + 2 => false, + _ => bail!("Invalid endianness in elf header"), + }; + let e_type = if little_endian { + (elf_bytes[0x11] as u16) << 8 | elf_bytes[0x10] as u16 + } else { + (elf_bytes[0x10] as u16) << 8 | elf_bytes[0x11] as u16 + }; + if e_type != 0x02 { + bail!("ldd failed on a non-ET_EXEC elf") + } + // It appears to be an ET_EXEC, good enough for us + return Ok(vec![]) } if !stderr.is_empty() { @@ -160,9 +192,12 @@ mod toolchain_imp { } let stdout = str::from_utf8(&stdout).map_err(|_| "ldd output not utf8")?; + Ok(parse_ldd_output(stdout)) + } - // If it's static the output will be a line like "not a dynamic executable", so be forgiving - // in the parsing here and treat parsing oddities as an empty list. + // If it's a static PIE the output will be a line like "\tstatically linked", so be forgiving + // in the parsing here and treat parsing oddities as an empty list. + fn parse_ldd_output(stdout: &str) -> Vec { let mut libs = vec![]; for line in stdout.lines() { let line = line.trim(); @@ -195,7 +230,38 @@ mod toolchain_imp { libs.push(libpath) } - Ok(libs) + libs + } + + #[test] + fn test_ldd_parse() { + let ubuntu_ls_output = "\tlinux-vdso.so.1 => (0x00007fffcfffe000) +\tlibselinux.so.1 => /lib/x86_64-linux-gnu/libselinux.so.1 (0x00007f69caa6b000) +\tlibc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f69ca6a1000) +\tlibpcre.so.3 => /lib/x86_64-linux-gnu/libpcre.so.3 (0x00007f69ca431000) +\tlibdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f69ca22d000) +\t/lib64/ld-linux-x86-64.so.2 (0x00007f69cac8d000) +\tlibpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f69ca010000) +"; + assert_eq!(parse_ldd_output(ubuntu_ls_output).iter().map(|p| p.to_str().unwrap()).collect::>(), &[ + "/lib/x86_64-linux-gnu/libselinux.so.1", + "/lib/x86_64-linux-gnu/libc.so.6", + "/lib/x86_64-linux-gnu/libpcre.so.3", + "/lib/x86_64-linux-gnu/libdl.so.2", + "/lib64/ld-linux-x86-64.so.2", + "/lib/x86_64-linux-gnu/libpthread.so.0", + ]) + } + + #[test] + fn test_ldd_parse_static() { + let static_outputs = &[ + "\tstatically linked", // glibc ldd output + "\tldd (0x7f79ef662000)", // musl ldd output + ]; + for static_output in static_outputs { + assert_eq!(parse_ldd_output(static_output).len(), 0) + } } } diff --git a/src/lib.rs b/src/lib.rs index 85b217b1..556b55b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,6 +94,8 @@ extern crate toml; #[cfg(any(feature = "azure", feature = "gcs", feature = "dist-client"))] extern crate url; extern crate uuid; +#[cfg(feature = "void")] +extern crate void; extern crate walkdir; extern crate which; #[cfg(windows)] @@ -145,6 +147,9 @@ pub fn main() { }, Err(e) => { println!("sccache: {}", e); + for e in e.iter().skip(1) { + println!("caused by: {}", e); + } cmdline::get_app().print_help().unwrap(); println!(""); 1 diff --git a/src/server.rs b/src/server.rs index f6a064a1..58dbb38e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -22,9 +22,12 @@ use compiler::{ CompilerArguments, CompilerHasher, CompileResult, + DistType, MissType, get_compiler_info, }; +#[cfg(feature = "dist-client")] +use config; use config::Config; use dist; use filetime::FileTime; @@ -46,11 +49,17 @@ use std::env; use std::ffi::{OsStr, OsString}; use std::fs::metadata; use std::io::{self, Write}; +#[cfg(feature = "dist-client")] +use std::mem; use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr}; use std::path::PathBuf; use std::process::{Output, ExitStatus}; use std::rc::Rc; use std::sync::Arc; +#[cfg(feature = "dist-client")] +use std::sync::Mutex; +#[cfg(feature = "dist-client")] +use std::time::Instant; use std::time::Duration; use std::u64; use tokio_core::net::TcpListener; @@ -69,6 +78,11 @@ use errors::*; /// If the server is idle for this many seconds, shut down. const DEFAULT_IDLE_TIMEOUT: u64 = 600; +/// If the dist client couldn't be created, retry creation at this number +/// of seconds from now (or later) +#[cfg(feature = "dist-client")] +const DIST_CLIENT_RECREATE_TIMEOUT: Duration = Duration::from_secs(30); + /// Result of background server startup. #[derive(Debug, Serialize, Deserialize)] pub enum ServerStartup { @@ -127,6 +141,159 @@ fn get_signal(_status: ExitStatus) -> i32 { panic!("no signals on windows") } +pub struct DistClientContainer { + // The actual dist client state + #[cfg(feature = "dist-client")] + state: Mutex, +} + +#[cfg(feature = "dist-client")] +struct DistClientConfig { + // Reusable items tied to an SccacheServer instance + handle: Handle, + pool: CpuPool, + + // From the static dist configuration + scheduler_url: Option, + auth: config::DistAuth, + cache_dir: PathBuf, + toolchain_cache_size: u64, + toolchains: Vec, +} + +#[cfg(feature = "dist-client")] +enum DistClientState { + #[cfg(feature = "dist-client")] + Some(Arc), + #[cfg(feature = "dist-client")] + RetryCreateAt(DistClientConfig, Instant), + Disabled, +} + +#[cfg(not(feature = "dist-client"))] +impl DistClientContainer { + #[cfg(not(feature = "dist-client"))] + fn new(config: &Config, _: &CpuPool, _: Handle) -> Self { + if let Some(_) = config.dist.scheduler_url { + warn!("Scheduler address configured but dist feature disabled, disabling distributed sccache") + } + Self {} + } + + pub fn new_disabled() -> Self { + Self {} + } + + + fn get_client(&self) -> Option> { + None + } +} + +#[cfg(feature = "dist-client")] +impl DistClientContainer { + fn new(config: &Config, pool: &CpuPool, handle: Handle) -> Self { + let config = DistClientConfig { + handle, + pool: pool.clone(), + + scheduler_url: config.dist.scheduler_url.clone(), + auth: config.dist.auth.clone(), + cache_dir: config.dist.cache_dir.clone(), + toolchain_cache_size: config.dist.toolchain_cache_size, + toolchains: config.dist.toolchains.clone(), + }; + let state = Self::create_state(config); + Self { state: Mutex::new(state) } + } + + pub fn new_disabled() -> Self { + Self { state: Mutex::new(DistClientState::Disabled) } + } + + + fn get_client(&self) -> Option> { + let mut guard = self.state.lock(); + let state = guard.as_mut().unwrap(); + let state: &mut DistClientState = &mut **state; + Self::maybe_recreate_state(state); + match state { + DistClientState::Some(dc) => Some(dc.clone()), + DistClientState::Disabled | + DistClientState::RetryCreateAt(_, _) => None, + } + } + + fn maybe_recreate_state(state: &mut DistClientState) { + if let DistClientState::RetryCreateAt(_, instant) = *state { + if instant > Instant::now() { + return + } + let config = match mem::replace(state, DistClientState::Disabled) { + DistClientState::RetryCreateAt(config, _) => config, + _ => unreachable!(), + }; + info!("Attempting to recreate the dist client"); + *state = Self::create_state(config) + } + } + + // Attempt to recreate the dist client + fn create_state(config: DistClientConfig) -> DistClientState { + macro_rules! try_or_retry_later { + ($v:expr) => {{ + match $v { + Ok(v) => v, + Err(e) => { + use error_chain::ChainedError; + error!("{}", e.display_chain()); + return DistClientState::RetryCreateAt(config, Instant::now() + DIST_CLIENT_RECREATE_TIMEOUT) + }, + } + }}; + } + // TODO: NLL would avoid this clone + match config.scheduler_url.clone() { + Some(addr) => { + let url = addr.to_url(); + info!("Enabling distributed sccache to {}", url); + let auth_token = match &config.auth { + config::DistAuth::Token { token } => Ok(token.to_owned()), + config::DistAuth::Oauth2CodeGrantPKCE { client_id: _, auth_url, token_url: _ } | + config::DistAuth::Oauth2Implicit { client_id: _, auth_url } => + Self::get_cached_config_auth_token(auth_url), + }; + // TODO: NLL would let us move this inside the previous match + let auth_token = try_or_retry_later!(auth_token.chain_err(|| "could not load client auth token")); + let dist_client = dist::http::Client::new( + config.handle.clone(), + &config.pool, + url, + &config.cache_dir.join("client"), + config.toolchain_cache_size, + &config.toolchains, + auth_token, + ); + let dist_client = try_or_retry_later!(dist_client.chain_err(|| "failure during dist client creation")); + info!("Successfully created dist client"); + DistClientState::Some(Arc::new(dist_client)) + }, + None => { + info!("No scheduler address configured, disabling distributed sccache"); + DistClientState::Disabled + }, + } + } + + fn get_cached_config_auth_token(auth_url: &str) -> Result { + let cached_config = config::CachedConfig::load()?; + cached_config.with(|c| { + c.dist.auth_tokens.get(auth_url).map(String::to_owned) + }).ok_or_else(|| Error::from(format!("token for url {} not present in cached config", auth_url))) + } +} + + /// Start an sccache server, listening on `port`. /// /// Spins an event loop handling client connections until a client @@ -136,46 +303,7 @@ pub fn start_server(config: &Config, port: u16) -> Result<()> { let client = unsafe { Client::new() }; let core = Core::new()?; let pool = CpuPool::new(20); - let dist_client: Arc = match config.dist.scheduler_addr { - #[cfg(feature = "dist-client")] - Some(addr) => { - use config; - info!("Enabling distributed sccache to {}", addr); - let auth_token = match &config.dist.auth { - config::DistAuth::Token { token } => token.to_owned(), - config::DistAuth::Oauth2CodeGrantPKCE { client_id: _, auth_url, token_url: _ } => { - let cached_config = config::CachedConfig::load().unwrap(); - cached_config.with(|c| { - c.dist.auth_tokens.get(auth_url).unwrap().to_owned() - }) - }, - config::DistAuth::Oauth2Implicit { client_id: _, auth_url } => { - let cached_config = config::CachedConfig::load().unwrap(); - cached_config.with(|c| { - c.dist.auth_tokens.get(auth_url).unwrap().to_owned() - }) - }, - }; - Arc::new(dist::http::Client::new( - &core.handle(), - &pool, - addr, - &config.dist.cache_dir.join("client"), - config.dist.toolchain_cache_size, - &config.dist.toolchains, - auth_token, - )) - }, - #[cfg(not(feature = "dist-client"))] - Some(_) => { - warn!("Scheduler address configured but dist feature disabled, disabling distributed sccache"); - Arc::new(dist::NoopClient) - }, - None => { - info!("No scheduler address configured, disabling distributed sccache"); - Arc::new(dist::NoopClient) - }, - }; + let dist_client = DistClientContainer::new(config, &pool, core.handle()); let storage = storage_from_config(config, &pool, &core.handle()); let res = SccacheServer::::new(port, pool, core, client, dist_client, storage); let notify = env::var_os("SCCACHE_STARTUP_NOTIFY"); @@ -210,7 +338,7 @@ impl SccacheServer { pool: CpuPool, core: Core, client: Client, - dist_client: Arc, + dist_client: DistClientContainer, storage: Arc) -> Result> { let handle = core.handle(); let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port); @@ -363,7 +491,7 @@ struct SccacheService { stats: Rc>, /// Distributed sccache client - dist_client: Arc, + dist_client: Rc, /// Cache storage. storage: Arc, @@ -459,7 +587,7 @@ impl Service for SccacheService impl SccacheService where C: CommandCreatorSync, { - pub fn new(dist_client: Arc, + pub fn new(dist_client: DistClientContainer, storage: Arc, handle: Handle, client: &Client, @@ -468,7 +596,7 @@ impl SccacheService info: ActiveInfo) -> SccacheService { SccacheService { stats: Rc::new(RefCell::new(ServerStats::default())), - dist_client, + dist_client: Rc::new(dist_client), storage: storage, compilers: Rc::new(RefCell::new(HashMap::new())), pool: pool, @@ -627,7 +755,7 @@ impl SccacheService }; let out_pretty = hasher.output_pretty().into_owned(); let color_mode = hasher.color_mode(); - let result = hasher.get_cached_or_compile(self.dist_client.clone(), + let result = hasher.get_cached_or_compile(self.dist_client.get_client(), self.creator.clone(), self.storage.clone(), arguments, @@ -652,7 +780,12 @@ impl SccacheService stats.cache_hits += 1; stats.cache_read_hit_duration += duration; }, - CompileResult::CacheMiss(miss_type, duration, future) => { + CompileResult::CacheMiss(miss_type, dist_type, duration, future) => { + match dist_type { + DistType::NoDist => {}, + DistType::Ok => stats.dist_compiles += 1, + DistType::Error => stats.dist_errors += 1, + } match miss_type { MissType::Normal => {} MissType::ForcedRecache => { @@ -784,6 +917,10 @@ pub struct ServerStats { pub compile_fails: u64, /// Counts of reasons why compiles were not cached. pub not_cached: HashMap, + /// The count of compilations that were successfully distributed + pub dist_compiles: u64, + /// The count of compilations that were distributed but failed and had to be re-run locally + pub dist_errors: u64, } /// Info and stats about the server. @@ -817,6 +954,8 @@ impl Default for ServerStats { cache_read_miss_duration: Duration::new(0, 0), compile_fails: u64::default(), not_cached: HashMap::new(), + dist_compiles: u64::default(), + dist_errors: u64::default(), } } } @@ -861,6 +1000,8 @@ impl ServerStats { set_stat!(stats_vec, self.requests_not_cacheable, "Non-cacheable calls"); set_stat!(stats_vec, self.requests_not_compile, "Non-compilation calls"); set_stat!(stats_vec, self.requests_unsupported_compiler, "Unsupported compiler calls"); + set_stat!(stats_vec, self.dist_compiles, "Successful distributed compilations"); + set_stat!(stats_vec, self.dist_errors, "Failed distributed compilations"); set_duration_stat!(stats_vec, self.cache_write_duration, self.cache_writes, "Average cache write"); set_duration_stat!(stats_vec, self.cache_read_miss_duration, self.cache_misses, "Average cache read miss"); set_duration_stat!(stats_vec, self.cache_read_hit_duration, self.cache_hits, "Average cache read hit"); diff --git a/src/test/tests.rs b/src/test/tests.rs index 000ec98c..97ff0e0e 100644 --- a/src/test/tests.rs +++ b/src/test/tests.rs @@ -21,13 +21,13 @@ use ::commands::{ request_shutdown, request_stats, }; -use dist::NoopClient; use env_logger; use futures::sync::oneshot::{self, Sender}; use futures_cpupool::CpuPool; use jobserver::Client; use ::mock_command::*; use ::server::{ + DistClientContainer, ServerMessage, SccacheServer, }; @@ -76,14 +76,14 @@ fn run_server_thread(cache_dir: &Path, options: T) .and_then(|o| o.cache_size.as_ref()) .map(|s| *s) .unwrap_or(u64::MAX); - let pool = CpuPool::new(1); - let dist_client = Arc::new(NoopClient); - let storage = Arc::new(DiskCache::new(&cache_dir, cache_size, &pool)); - // Create a server on a background thread, get some useful bits from it. let (tx, rx) = mpsc::channel(); let (shutdown_tx, shutdown_rx) = oneshot::channel(); let handle = thread::spawn(move || { + let pool = CpuPool::new(1); + let dist_client = DistClientContainer::new_disabled(); + let storage = Arc::new(DiskCache::new(&cache_dir, cache_size, &pool)); + let core = Core::new().unwrap(); let client = unsafe { Client::new() }; let srv = SccacheServer::new(0, pool, core, client, dist_client, storage).unwrap(); diff --git a/tests/dist.rs b/tests/dist.rs new file mode 100644 index 00000000..b87d9082 --- /dev/null +++ b/tests/dist.rs @@ -0,0 +1,202 @@ +#![cfg(all(feature = "dist-client", feature = "dist-server"))] + +extern crate assert_cmd; +#[macro_use] +extern crate error_chain; +#[macro_use] +extern crate log; +extern crate sccache; +extern crate serde_json; +extern crate tempdir; + +use assert_cmd::prelude::*; +use sccache::config::HTTPUrl; +use harness::{ + sccache_command, + start_local_daemon, stop_local_daemon, + get_stats, + write_json_cfg, write_source, +}; +use sccache::dist::{ + AssignJobResult, + CompileCommand, + InputsReader, + JobId, + JobState, + RunJobResult, + ServerIncoming, + ServerOutgoing, + SubmitToolchainResult, + Toolchain, + ToolchainReader, +}; +use std::ffi::OsStr; +use std::path::Path; +use tempdir::TempDir; + +use sccache::errors::*; + +mod harness; + +fn basic_compile(tmpdir: &Path, sccache_cfg_path: &Path, sccache_cached_cfg_path: &Path) { + let envs: Vec<(_, &OsStr)> = vec![ + ("RUST_BACKTRACE", "1".as_ref()), + ("RUST_LOG", "sccache=trace".as_ref()), + ("SCCACHE_CONF", sccache_cfg_path.as_ref()), + ("SCCACHE_CACHED_CONF", sccache_cached_cfg_path.as_ref()), + ]; + let source_file = "x.c"; + let obj_file = "x.o"; + write_source(tmpdir, source_file, "int x() { return 5; }"); + sccache_command() + .args(&["gcc", "-c"]).arg(tmpdir.join(source_file)).arg("-o").arg(tmpdir.join(obj_file)) + .envs(envs) + .assert() + .success(); +} + +pub fn dist_test_sccache_client_cfg(tmpdir: &Path, scheduler_url: HTTPUrl) -> sccache::config::FileConfig { + let mut sccache_cfg = harness::sccache_client_cfg(tmpdir); + sccache_cfg.cache.disk.as_mut().unwrap().size = 0; + sccache_cfg.dist.scheduler_url = Some(scheduler_url); + sccache_cfg +} + +#[test] +#[cfg_attr(not(feature = "dist-tests"), ignore)] +fn test_dist_basic() { + let tmpdir = TempDir::new("sccache_dist_test").unwrap(); + let tmpdir = tmpdir.path(); + let sccache_dist = harness::sccache_dist_path(); + + let mut system = harness::DistSystem::new(&sccache_dist, tmpdir); + system.add_scheduler(); + system.add_server(); + + let sccache_cfg = dist_test_sccache_client_cfg(tmpdir, system.scheduler_url()); + let sccache_cfg_path = tmpdir.join("sccache-cfg.json"); + write_json_cfg(tmpdir, "sccache-cfg.json", &sccache_cfg); + let sccache_cached_cfg_path = tmpdir.join("sccache-cached-cfg"); + + stop_local_daemon(); + start_local_daemon(&sccache_cfg_path, &sccache_cached_cfg_path); + basic_compile(tmpdir, &sccache_cfg_path, &sccache_cached_cfg_path); + + get_stats(|info| { + assert_eq!(1, info.stats.dist_compiles); + assert_eq!(0, info.stats.dist_errors); + assert_eq!(1, info.stats.compile_requests); + assert_eq!(1, info.stats.requests_executed); + assert_eq!(0, info.stats.cache_hits); + assert_eq!(1, info.stats.cache_misses); + }); +} + +#[test] +#[cfg_attr(not(feature = "dist-tests"), ignore)] +fn test_dist_restartedserver() { + let tmpdir = TempDir::new("sccache_dist_test").unwrap(); + let tmpdir = tmpdir.path(); + let sccache_dist = harness::sccache_dist_path(); + + let mut system = harness::DistSystem::new(&sccache_dist, tmpdir); + system.add_scheduler(); + let server_handle = system.add_server(); + + let sccache_cfg = dist_test_sccache_client_cfg(tmpdir, system.scheduler_url()); + let sccache_cfg_path = tmpdir.join("sccache-cfg.json"); + write_json_cfg(tmpdir, "sccache-cfg.json", &sccache_cfg); + let sccache_cached_cfg_path = tmpdir.join("sccache-cached-cfg"); + + stop_local_daemon(); + start_local_daemon(&sccache_cfg_path, &sccache_cached_cfg_path); + basic_compile(tmpdir, &sccache_cfg_path, &sccache_cached_cfg_path); + + system.restart_server(&server_handle); + basic_compile(tmpdir, &sccache_cfg_path, &sccache_cached_cfg_path); + + get_stats(|info| { + assert_eq!(2, info.stats.dist_compiles); + assert_eq!(0, info.stats.dist_errors); + assert_eq!(2, info.stats.compile_requests); + assert_eq!(2, info.stats.requests_executed); + assert_eq!(0, info.stats.cache_hits); + assert_eq!(2, info.stats.cache_misses); + }); +} + +#[test] +#[cfg_attr(not(feature = "dist-tests"), ignore)] +fn test_dist_nobuilder() { + let tmpdir = TempDir::new("sccache_dist_test").unwrap(); + let tmpdir = tmpdir.path(); + let sccache_dist = harness::sccache_dist_path(); + + let mut system = harness::DistSystem::new(&sccache_dist, tmpdir); + system.add_scheduler(); + + let sccache_cfg = dist_test_sccache_client_cfg(tmpdir, system.scheduler_url()); + let sccache_cfg_path = tmpdir.join("sccache-cfg.json"); + write_json_cfg(tmpdir, "sccache-cfg.json", &sccache_cfg); + let sccache_cached_cfg_path = tmpdir.join("sccache-cached-cfg"); + + stop_local_daemon(); + start_local_daemon(&sccache_cfg_path, &sccache_cached_cfg_path); + basic_compile(tmpdir, &sccache_cfg_path, &sccache_cached_cfg_path); + + get_stats(|info| { + assert_eq!(0, info.stats.dist_compiles); + assert_eq!(1, info.stats.dist_errors); + assert_eq!(1, info.stats.compile_requests); + assert_eq!(1, info.stats.requests_executed); + assert_eq!(0, info.stats.cache_hits); + assert_eq!(1, info.stats.cache_misses); + }); +} + +struct FailingServer; +impl ServerIncoming for FailingServer { + type Error = Error; + fn handle_assign_job(&self, requester: &ServerOutgoing, job_id: JobId, _tc: Toolchain) -> Result { + let need_toolchain = false; + requester.do_update_job_state(job_id, JobState::Ready).chain_err(|| "Updating job state failed")?; + Ok(AssignJobResult { need_toolchain }) + } + fn handle_submit_toolchain(&self, _requester: &ServerOutgoing, _job_id: JobId, _tc_rdr: ToolchainReader) -> Result { + panic!("should not have submitted toolchain") + } + fn handle_run_job(&self, requester: &ServerOutgoing, job_id: JobId, _command: CompileCommand, _outputs: Vec, _inputs_rdr: InputsReader) -> Result { + requester.do_update_job_state(job_id, JobState::Started).chain_err(|| "Updating job state failed")?; + bail!("internal build failure") + } +} + +#[test] +#[cfg_attr(not(feature = "dist-tests"), ignore)] +fn test_dist_failingserver() { + let tmpdir = TempDir::new("sccache_dist_test").unwrap(); + let tmpdir = tmpdir.path(); + let sccache_dist = harness::sccache_dist_path(); + + let mut system = harness::DistSystem::new(&sccache_dist, tmpdir); + system.add_scheduler(); + system.add_custom_server(FailingServer); + + let sccache_cfg = dist_test_sccache_client_cfg(tmpdir, system.scheduler_url()); + let sccache_cfg_path = tmpdir.join("sccache-cfg.json"); + write_json_cfg(tmpdir, "sccache-cfg.json", &sccache_cfg); + let sccache_cached_cfg_path = tmpdir.join("sccache-cached-cfg"); + + stop_local_daemon(); + start_local_daemon(&sccache_cfg_path, &sccache_cached_cfg_path); + basic_compile(tmpdir, &sccache_cfg_path, &sccache_cached_cfg_path); + + get_stats(|info| { + assert_eq!(0, info.stats.dist_compiles); + assert_eq!(1, info.stats.dist_errors); + assert_eq!(1, info.stats.compile_requests); + assert_eq!(1, info.stats.requests_executed); + assert_eq!(0, info.stats.cache_hits); + assert_eq!(1, info.stats.cache_misses); + }); +} diff --git a/tests/harness/Dockerfile.sccache-dist b/tests/harness/Dockerfile.sccache-dist new file mode 100644 index 00000000..2eaf87eb --- /dev/null +++ b/tests/harness/Dockerfile.sccache-dist @@ -0,0 +1,15 @@ +FROM ubuntu:18.04 as bwrap-build +RUN apt-get update && \ + apt-get install -y wget xz-utils gcc libcap-dev make && \ + apt-get clean +RUN wget -q -O - https://github.com/projectatomic/bubblewrap/releases/download/v0.3.1/bubblewrap-0.3.1.tar.xz | \ + tar -xJ +RUN cd /bubblewrap-0.3.1 && \ + ./configure --disable-man && \ + make + +FROM aidanhs/ubuntu-docker:18.04-17.03.2-ce +RUN apt-get update && \ + apt-get install libcap2 libssl1.0.0 && \ + apt-get clean +COPY --from=bwrap-build /bubblewrap-0.3.1/bwrap /bwrap diff --git a/tests/harness/mod.rs b/tests/harness/mod.rs new file mode 100644 index 00000000..d13eac7a --- /dev/null +++ b/tests/harness/mod.rs @@ -0,0 +1,538 @@ +extern crate assert_cmd; +extern crate bincode; +extern crate env_logger; +extern crate escargot; +#[cfg(feature = "dist-server")] +extern crate nix; +extern crate predicates; +#[cfg(feature = "dist-server")] +extern crate reqwest; +extern crate sccache; +extern crate serde; +extern crate serde_json; +extern crate uuid; +#[cfg(feature = "dist-server")] +extern crate void; + +#[cfg(feature = "dist-server")] +use std::env; +use std::fs; +use std::io::Write; +use std::net::{self, IpAddr, SocketAddr}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Output, Stdio}; +use std::str; +use std::thread; +use std::time::{Duration, Instant}; +#[cfg(any(feature = "dist-client", feature = "dist-server"))] +use sccache::config::HTTPUrl; +use sccache::dist::{self, SchedulerStatusResult, ServerId}; +use sccache::server::ServerInfo; + +use self::assert_cmd::prelude::*; +use self::escargot::CargoBuild; +#[cfg(feature = "dist-server")] +use self::nix::{ + sys::{ + signal::Signal, + wait::{WaitPidFlag, WaitStatus}, + }, + unistd::{ForkResult, Pid}, +}; +use self::predicates::prelude::*; +use self::serde::Serialize; +use self::uuid::Uuid; + +#[cfg(feature = "dist-server")] +macro_rules! matches { + ($expression:expr, $($pattern:tt)+) => { + match $expression { + $($pattern)+ => true, + _ => false + } + } +} + +const CONTAINER_NAME_PREFIX: &str = "sccache_dist_test"; +const DIST_IMAGE: &str = "sccache_dist_test_image"; +const DIST_DOCKERFILE: &str = include_str!("Dockerfile.sccache-dist"); +const DIST_IMAGE_BWRAP_PATH: &str = "/bwrap"; +const MAX_STARTUP_WAIT: Duration = Duration::from_secs(5); + +const DIST_SERVER_TOKEN: &str = "THIS IS THE TEST TOKEN"; + +const CONFIGS_CONTAINER_PATH: &str = "/sccache-bits"; +const BUILD_DIR_CONTAINER_PATH: &str = "/sccache-bits/build-dir"; +const SCHEDULER_PORT: u16 = 10500; +const SERVER_PORT: u16 = 12345; // arbitrary + +const TC_CACHE_SIZE: u64 = 1 * 1024 * 1024 * 1024; // 1 gig + +pub fn start_local_daemon(cfg_path: &Path, cached_cfg_path: &Path) { + // Don't run this with run() because on Windows `wait_with_output` + // will hang because the internal server process is not detached. + sccache_command() + .arg("--start-server") + .env("SCCACHE_CONF", cfg_path) + .env("SCCACHE_CACHED_CONF", cached_cfg_path) + .status() + .unwrap() + .success(); +} +pub fn stop_local_daemon() { + trace!("sccache --stop-server"); + drop(sccache_command() + .arg("--stop-server") + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status()); +} + +pub fn get_stats(f: F) { + sccache_command() + .args(&["--show-stats", "--stats-format=json"]) + .assert() + .success() + .stdout(predicate::function(move |output: &[u8]| { + let s = str::from_utf8(output).expect("Output not UTF-8"); + f(serde_json::from_str(s).expect("Failed to parse JSON stats")); + true + })); +} + +#[allow(unused)] +pub fn zero_stats() { + trace!("sccache --zero-stats"); + drop(sccache_command() + .arg("--zero-stats") + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status()); +} + +pub fn write_json_cfg(path: &Path, filename: &str, contents: &T) { + let p = path.join(filename); + let mut f = fs::File::create(&p).unwrap(); + f.write_all(&serde_json::to_vec(contents).unwrap()).unwrap(); +} + +pub fn write_source(path: &Path, filename: &str, contents: &str) { + let p = path.join(filename); + let mut f = fs::File::create(&p).unwrap(); + f.write_all(contents.as_bytes()).unwrap(); +} + +// Alter an sccache command to override any environment variables that could adversely +// affect test execution +fn blankslate_sccache(mut cmd: Command) -> Command { + cmd + .env("SCCACHE_CONF", "nonexistent_conf_path") + .env("SCCACHE_CACHED_CONF", "nonexistent_cached_conf_path"); + cmd +} + +#[cfg(not(feature = "dist-client"))] +pub fn sccache_command() -> Command { + blankslate_sccache(CargoBuild::new() + .bin("sccache") + .current_release() + .current_target() + .run() + .unwrap() + .command()) +} + +#[cfg(feature = "dist-client")] +pub fn sccache_command() -> Command { + blankslate_sccache(CargoBuild::new() + .bin("sccache") + // This should just inherit from the feature list we're compiling with to avoid recompilation + // https://github.com/assert-rs/assert_cmd/issues/44#issuecomment-418485128 + .arg("--features").arg("dist-client dist-server") + .current_release() + .current_target() + .run() + .unwrap() + .command()) +} + +#[cfg(feature = "dist-server")] +pub fn sccache_dist_path() -> PathBuf { + CargoBuild::new() + .bin("sccache-dist") + // This should just inherit from the feature list we're compiling with to avoid recompilation + // https://github.com/assert-rs/assert_cmd/issues/44#issuecomment-418485128 + .arg("--features").arg("dist-client dist-server") + .current_release() + .current_target() + .run() + .unwrap() + .path() + .to_owned() +} + +pub fn sccache_client_cfg(tmpdir: &Path) -> sccache::config::FileConfig { + let cache_relpath = "client-cache"; + let dist_cache_relpath = "client-dist-cache"; + fs::create_dir(tmpdir.join(cache_relpath)).unwrap(); + fs::create_dir(tmpdir.join(dist_cache_relpath)).unwrap(); + + let mut disk_cache: sccache::config::DiskCacheConfig = Default::default(); + disk_cache.dir = tmpdir.join(cache_relpath); + sccache::config::FileConfig { + cache: sccache::config::CacheConfigs { + azure: None, + disk: Some(disk_cache), + gcs: None, + memcached: None, + redis: None, + s3: None, + }, + dist: sccache::config::DistConfig { + auth: Default::default(), // dangerously_insecure + scheduler_url: None, + cache_dir: tmpdir.join(dist_cache_relpath), + toolchains: vec![], + toolchain_cache_size: TC_CACHE_SIZE, + }, + } +} +#[cfg(feature = "dist-server")] +fn sccache_scheduler_cfg() -> sccache::config::scheduler::Config { + sccache::config::scheduler::Config { + public_addr: SocketAddr::from(([0, 0, 0, 0], SCHEDULER_PORT)), + client_auth: sccache::config::scheduler::ClientAuth::Insecure, + server_auth: sccache::config::scheduler::ServerAuth::Token { token: DIST_SERVER_TOKEN.to_owned() }, + } +} +#[cfg(feature = "dist-server")] +fn sccache_server_cfg(tmpdir: &Path, scheduler_url: HTTPUrl, server_ip: IpAddr) -> sccache::config::server::Config { + let relpath = "server-cache"; + fs::create_dir(tmpdir.join(relpath)).unwrap(); + + sccache::config::server::Config { + builder: sccache::config::server::BuilderType::Overlay { + build_dir: BUILD_DIR_CONTAINER_PATH.into(), + bwrap_path: DIST_IMAGE_BWRAP_PATH.into(), + }, + cache_dir: Path::new(CONFIGS_CONTAINER_PATH).join(relpath), + public_addr: SocketAddr::new(server_ip, SERVER_PORT), + scheduler_url, + scheduler_auth: sccache::config::server::SchedulerAuth::Token { token: DIST_SERVER_TOKEN.to_owned() }, + toolchain_cache_size: TC_CACHE_SIZE, + } +} + +// TODO: this is copied from the sccache-dist binary - it's not clear where would be a better place to put the +// code so that it can be included here +#[cfg(feature = "dist-server")] +fn create_server_token(server_id: ServerId, auth_token: &str) -> String { + format!("{} {}", server_id.addr(), auth_token) +} + +#[cfg(feature = "dist-server")] +pub enum ServerHandle { + Container { cid: String, url: HTTPUrl }, + Process { pid: Pid, url: HTTPUrl }, +} + +#[cfg(feature = "dist-server")] +pub struct DistSystem { + sccache_dist: PathBuf, + tmpdir: PathBuf, + + scheduler_name: Option, + server_names: Vec, + server_pids: Vec, +} + +#[cfg(feature = "dist-server")] +impl DistSystem { + pub fn new(sccache_dist: &Path, tmpdir: &Path) -> Self { + // Make sure the docker image is available, building it if necessary + let mut child = Command::new("docker") + .args(&["build", "-q", "-t", DIST_IMAGE, "-"]) + .stdin(Stdio::piped()) + .spawn().unwrap(); + child.stdin.as_mut().unwrap().write_all(DIST_DOCKERFILE.as_bytes()).unwrap(); + let output = child.wait_with_output().unwrap(); + check_output(&output); + + let tmpdir = tmpdir.join("distsystem"); + fs::create_dir(&tmpdir).unwrap(); + + Self { + sccache_dist: sccache_dist.to_owned(), + tmpdir, + + scheduler_name: None, + server_names: vec![], + server_pids: vec![], + } + } + + pub fn add_scheduler(&mut self) { + let scheduler_cfg_relpath = "scheduler-cfg.json"; + let scheduler_cfg_path = self.tmpdir.join(scheduler_cfg_relpath); + let scheduler_cfg_container_path = Path::new(CONFIGS_CONTAINER_PATH).join(scheduler_cfg_relpath); + let scheduler_cfg = sccache_scheduler_cfg(); + fs::File::create(&scheduler_cfg_path).unwrap().write_all(&serde_json::to_vec(&scheduler_cfg).unwrap()).unwrap(); + + // Create the scheduler + let scheduler_name = make_container_name("scheduler"); + let output = Command::new("docker") + .args(&[ + "run", + "--name", &scheduler_name, + "-e", "RUST_LOG=sccache=trace", + "-e", "RUST_BACKTRACE=1", + "-v", &format!("{}:/sccache-dist", self.sccache_dist.to_str().unwrap()), + "-v", &format!("{}:{}", self.tmpdir.to_str().unwrap(), CONFIGS_CONTAINER_PATH), + "-d", + DIST_IMAGE, + "bash", "-c", &format!(r#" + set -o errexit && + exec /sccache-dist scheduler --config {cfg} + "#, cfg=scheduler_cfg_container_path.to_str().unwrap()), + ]).output().unwrap(); + self.scheduler_name = Some(scheduler_name); + + check_output(&output); + + let scheduler_url = self.scheduler_url(); + wait_for_http(scheduler_url, Duration::from_millis(100), MAX_STARTUP_WAIT); + wait_for(|| { + let status = self.scheduler_status(); + if matches!(self.scheduler_status(), SchedulerStatusResult { num_servers: 0 }) { Ok(()) } else { Err(format!("{:?}", status)) } + }, Duration::from_millis(100), MAX_STARTUP_WAIT); + } + + pub fn add_server(&mut self) -> ServerHandle { + let server_cfg_relpath = format!("server-cfg-{}.json", self.server_names.len()); + let server_cfg_path = self.tmpdir.join(&server_cfg_relpath); + let server_cfg_container_path = Path::new(CONFIGS_CONTAINER_PATH).join(server_cfg_relpath); + + let server_name = make_container_name("server"); + let output = Command::new("docker") + .args(&[ + "run", + // Important for the bubblewrap builder + "--privileged", + "--name", &server_name, + "-e", "RUST_LOG=sccache=debug", + "-e", "RUST_BACKTRACE=1", + "-v", &format!("{}:/sccache-dist", self.sccache_dist.to_str().unwrap()), + "-v", &format!("{}:{}", self.tmpdir.to_str().unwrap(), CONFIGS_CONTAINER_PATH), + "-d", + DIST_IMAGE, + "bash", "-c", &format!(r#" + set -o errexit && + while [ ! -f {cfg}.ready ]; do sleep 0.1; done && + exec /sccache-dist server --config {cfg} + "#, cfg=server_cfg_container_path.to_str().unwrap()), + ]).output().unwrap(); + self.server_names.push(server_name.clone()); + + check_output(&output); + + let server_ip = self.container_ip(&server_name); + let server_cfg = sccache_server_cfg(&self.tmpdir, self.scheduler_url(), server_ip); + fs::File::create(&server_cfg_path).unwrap().write_all(&serde_json::to_vec(&server_cfg).unwrap()).unwrap(); + fs::File::create(format!("{}.ready", server_cfg_path.to_str().unwrap())).unwrap(); + + let url = HTTPUrl::from_url(reqwest::Url::parse(&format!("https://{}:{}", server_ip, SERVER_PORT)).unwrap()); + let handle = ServerHandle::Container { cid: server_name, url }; + self.wait_server_ready(&handle); + handle + } + + pub fn add_custom_server(&mut self, handler: S) -> ServerHandle { + let server_addr = { + let ip = self.host_interface_ip(); + let listener = net::TcpListener::bind(SocketAddr::from((ip, 0))).unwrap(); + listener.local_addr().unwrap() + }; + let token = create_server_token(ServerId::new(server_addr), DIST_SERVER_TOKEN); + let server = dist::http::Server::new(server_addr, self.scheduler_url().to_url(), token, handler).unwrap(); + let pid = match nix::unistd::fork().unwrap() { + ForkResult::Parent { child } => { + self.server_pids.push(child); + child + }, + ForkResult::Child => { + env::set_var("RUST_LOG", "sccache=trace"); + env_logger::try_init().unwrap(); + void::unreachable(server.start().unwrap()) + }, + }; + + let url = HTTPUrl::from_url(reqwest::Url::parse(&format!("https://{}", server_addr)).unwrap()); + let handle = ServerHandle::Process { pid, url }; + self.wait_server_ready(&handle); + handle + } + + pub fn restart_server(&mut self, handle: &ServerHandle) { + match handle { + ServerHandle::Container { cid, url: _ } => { + let output = Command::new("docker").args(&["restart", cid]).output().unwrap(); + check_output(&output); + }, + ServerHandle::Process { pid: _, url: _ } => { + // TODO: pretty easy, just no need yet + panic!("restart not yet implemented for pids") + }, + } + self.wait_server_ready(handle) + } + + pub fn wait_server_ready(&mut self, handle: &ServerHandle) { + let url = match handle { + ServerHandle::Container { cid: _, url } | + ServerHandle::Process { pid: _, url } => url.clone(), + }; + wait_for_http(url, Duration::from_millis(100), MAX_STARTUP_WAIT); + wait_for(|| { + let status = self.scheduler_status(); + if matches!(self.scheduler_status(), SchedulerStatusResult { num_servers: 1 }) { Ok(()) } else { Err(format!("{:?}", status)) } + }, Duration::from_millis(100), MAX_STARTUP_WAIT); + } + + pub fn scheduler_url(&self) -> HTTPUrl { + let ip = self.container_ip(self.scheduler_name.as_ref().unwrap()); + let url = format!("http://{}:{}", ip, SCHEDULER_PORT); + HTTPUrl::from_url(reqwest::Url::parse(&url).unwrap()) + } + + fn scheduler_status(&self) -> SchedulerStatusResult { + let res = reqwest::get(dist::http::urls::scheduler_status(&self.scheduler_url().to_url())).unwrap(); + assert!(res.status().is_success()); + bincode::deserialize_from(res).unwrap() + } + + fn container_ip(&self, name: &str) -> IpAddr { + let output = Command::new("docker") + .args(&["inspect", "--format", "{{ .NetworkSettings.IPAddress }}", name]) + .output().unwrap(); + check_output(&output); + let stdout = String::from_utf8(output.stdout).unwrap(); + stdout.trim().to_owned().parse().unwrap() + } + + // The interface that the host sees on the docker network (typically 'docker0') + fn host_interface_ip(&self) -> IpAddr { + let output = Command::new("docker") + .args(&["inspect", "--format", "{{ .NetworkSettings.Gateway }}", self.scheduler_name.as_ref().unwrap()]) + .output().unwrap(); + check_output(&output); + let stdout = String::from_utf8(output.stdout).unwrap(); + stdout.trim().to_owned().parse().unwrap() + } +} + +// If you want containers to hang around (e.g. for debugging), commend out the "rm -f" lines +#[cfg(feature = "dist-server")] +impl Drop for DistSystem { + fn drop(&mut self) { + let mut did_err = false; + + // Panicking halfway through drop would either abort (if it's a double panic) or leave us with + // resources that aren't yet cleaned up. Instead, do as much as possible then decide what to do + // at the end - panic (if not already doing so) or let the panic continue + macro_rules! droperr { + ($e:expr) => { + match $e { + Ok(()) => (), + Err(e) => { + did_err = true; + eprintln!("Error with {}: {}", stringify!($e), e) + }, + } + } + } + + let mut logs = vec![]; + let mut outputs = vec![]; + let mut exits = vec![]; + + if let Some(scheduler_name) = self.scheduler_name.as_ref() { + droperr!(Command::new("docker").args(&["logs", scheduler_name]).output().map(|o| logs.push((scheduler_name, o)))); + droperr!(Command::new("docker").args(&["kill", scheduler_name]).output().map(|o| outputs.push((scheduler_name, o)))); + droperr!(Command::new("docker").args(&["rm", "-f", scheduler_name]).output().map(|o| outputs.push((scheduler_name, o)))); + } + for server_name in self.server_names.iter() { + droperr!(Command::new("docker").args(&["logs", server_name]).output().map(|o| logs.push((server_name, o)))); + droperr!(Command::new("docker").args(&["kill", server_name]).output().map(|o| outputs.push((server_name, o)))); + droperr!(Command::new("docker").args(&["rm", "-f", server_name]).output().map(|o| outputs.push((server_name, o)))); + } + for &pid in self.server_pids.iter() { + droperr!(nix::sys::signal::kill(pid, Signal::SIGINT)); + thread::sleep(Duration::from_millis(100)); + let mut killagain = true; // Default to trying to kill again, e.g. if there was an error waiting on the pid + droperr!(nix::sys::wait::waitpid(pid, Some(WaitPidFlag::WNOHANG)).map(|ws| if ws != WaitStatus::StillAlive { killagain = false; exits.push(ws) })); + if killagain { + eprintln!("SIGINT didn't kill process, trying SIGKILL"); + droperr!(nix::sys::signal::kill(pid, Signal::SIGKILL)); + droperr!(nix::sys::wait::waitpid(pid, Some(WaitPidFlag::WNOHANG)) + .map_err(|e| e.to_string()) + .and_then(|ws| if ws == WaitStatus::StillAlive { Err("process alive after sigkill".to_owned()) } else { exits.push(ws); Ok(()) })); + } + } + + for (container, Output { status, stdout, stderr }) in logs { + println!("LOGS == ({}) ==\n> {} <:\n## STDOUT\n{}\n\n## STDERR\n{}\n====", + status, container, String::from_utf8_lossy(&stdout), String::from_utf8_lossy(&stderr)); + } + for (container, Output { status, stdout, stderr }) in outputs { + println!("OUTPUTS == ({}) ==\n> {} <:\n## STDOUT\n{}\n\n## STDERR\n{}\n====", + status, container, String::from_utf8_lossy(&stdout), String::from_utf8_lossy(&stderr)); + } + for exit in exits { + println!("EXIT: {:?}", exit) + } + + if did_err && !thread::panicking() { + panic!("Encountered failures during dist system teardown") + } + } +} + +fn make_container_name(tag: &str) -> String { + format!("{}_{}_{}", CONTAINER_NAME_PREFIX, tag, Uuid::new_v4().hyphenated()) +} + +fn check_output(output: &Output) { + if !output.status.success() { + println!("[BEGIN OUTPUT]\n===========\n{}\n==========\n\n\n\n=========\n{}\n===============\n[FIN OUTPUT]\n\n", + String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stderr)); + panic!() + } +} + +#[cfg(feature = "dist-server")] +fn wait_for_http(url: HTTPUrl, interval: Duration, max_wait: Duration) { + // TODO: after upgrading to reqwest >= 0.9, use 'danger_accept_invalid_certs' and stick with that rather than tcp + wait_for(|| { + //match reqwest::get(url.to_url()) { + match net::TcpStream::connect(url.to_url()) { + Ok(_) => Ok(()), + Err(e) => Err(e.to_string()), + } + }, interval, max_wait) +} + +fn wait_for Result<(), String>>(f: F, interval: Duration, max_wait: Duration) { + let start = Instant::now(); + let mut lasterr; + loop { + match f() { + Ok(()) => return, + Err(e) => lasterr = e, + } + if start.elapsed() > max_wait { + break + } + thread::sleep(interval) + } + panic!("wait timed out, last error result: {}", lasterr) +} diff --git a/tests/oauth.rs b/tests/oauth.rs index 11fe4319..fe9b5b20 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -56,7 +56,7 @@ fn config_with_dist_auth(tmpdir: &Path, auth_config: sccache::config::DistAuth) cache: Default::default(), dist: sccache::config::DistConfig { auth: auth_config, - scheduler_addr: None, + scheduler_url: None, cache_dir: tmpdir.join("unused-cache"), toolchains: vec![], toolchain_cache_size: 0, diff --git a/tests/system.rs b/tests/system.rs index d165a4c9..e591579e 100644 --- a/tests/system.rs +++ b/tests/system.rs @@ -30,9 +30,15 @@ extern crate which; use assert_cmd::prelude::*; use escargot::CargoBuild; +use harness::{ + sccache_command, + sccache_client_cfg, + start_local_daemon, stop_local_daemon, + write_json_cfg, write_source, + get_stats, zero_stats, +}; use log::Level::Trace; use predicates::prelude::*; -use sccache::server::ServerInfo; use std::collections::HashMap; use std::env; use std::ffi::{OsStr,OsString}; @@ -49,6 +55,8 @@ use std::str; use tempdir::TempDir; use which::which_in; +mod harness; + #[derive(Clone)] struct Compiler { pub name: &'static str, @@ -66,34 +74,6 @@ const COMPILERS: &'static [&'static str] = &["clang"]; //TODO: could test gcc when targeting mingw. -fn sccache_command() -> Command { - CargoBuild::new() - .bin("sccache") - .current_release() - .current_target() - .run() - .unwrap() - .command() -} - -fn stop() { - trace!("sccache --stop-server"); - drop(sccache_command() - .arg("--stop-server") - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status()); -} - -fn zero_stats() { - trace!("sccache --zero-stats"); - drop(sccache_command() - .arg("--zero-stats") - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status()); -} - macro_rules! vec_from { ( $t:ty, $( $x:expr ),* ) => { vec!($( Into::<$t>::into(&$x), )*) @@ -108,24 +88,6 @@ fn compile_cmdline>(compiler: &str, exe: T, input: &str, output: } } -fn get_stats(f: F) { - sccache_command() - .args(&["--show-stats", "--stats-format=json"]) - .assert() - .success() - .stdout(predicate::function(move |output: &[u8]| { - let s = str::from_utf8(output).expect("Output not UTF-8"); - f(serde_json::from_str(s).expect("Failed to parse JSON stats")); - true - })); -} - -fn write_source(path: &Path, filename: &str, contents: &str) { - let p = path.join(filename); - let mut f = File::create(&p).unwrap(); - f.write_all(contents.as_bytes()).unwrap(); -} - const INPUT: &'static str = "test.c"; const INPUT_ERR: &'static str = "test_err.c"; const OUTPUT: &'static str = "test.o"; @@ -369,24 +331,18 @@ fn test_sccache_command() { warn!("No compilers found, skipping test"); } else { // Ensure there's no existing sccache server running. - stop(); - // Create a subdir for the cache. - let cache = tempdir.path().join("cache"); - fs::create_dir_all(&cache).unwrap(); + stop_local_daemon(); + // Create the configurations + let sccache_cfg = sccache_client_cfg(tempdir.path()); + write_json_cfg(tempdir.path(), "sccache-cfg.json", &sccache_cfg); + let sccache_cached_cfg_path = tempdir.path().join("sccache-cached-cfg"); // Start a server. trace!("start server"); - // Don't run this with run() because on Windows `wait_with_output` - // will hang because the internal server process is not detached. - sccache_command() - .arg("--start-server") - .env("SCCACHE_DIR", &cache) - .status() - .unwrap() - .success(); + start_local_daemon(&tempdir.path().join("sccache-cfg.json"), &sccache_cached_cfg_path); for compiler in compilers { run_sccache_command_tests(compiler, tempdir.path()); zero_stats(); } - stop(); + stop_local_daemon(); } }