diff --git a/Cargo.lock b/Cargo.lock index 5744f5ce4d36..6389e132b6e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1036,6 +1036,21 @@ dependencies = [ "libdbus-sys", ] +[[package]] +name = "defaultagent-static" +version = "0.1.0" +dependencies = [ + "log", + "serde", + "serde_derive", + "serde_json", + "url", + "viaduct", + "winapi 0.3.7", + "wineventlog", + "wio", +] + [[package]] name = "deflate" version = "0.7.19" diff --git a/Cargo.toml b/Cargo.toml index 1e02e148c799..e74b9b32999a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "toolkit/crashreporter/rust", "toolkit/library/gtest/rust", "toolkit/library/rust/", + "toolkit/mozapps/defaultagent/rust", ] # Excluded crates may be built as dependencies, but won't be considered members diff --git a/toolkit/mozapps/defaultagent/moz.build b/toolkit/mozapps/defaultagent/moz.build index 343918543e12..09e392589735 100644 --- a/toolkit/mozapps/defaultagent/moz.build +++ b/toolkit/mozapps/defaultagent/moz.build @@ -8,6 +8,8 @@ Program("default-browser-agent") SPHINX_TREES['default-browser-agent'] = "docs" +DIRS += ['rust'] + UNIFIED_SOURCES += [ "/mfbt/Poison.cpp", "/mfbt/Unused.cpp", @@ -30,6 +32,7 @@ SOURCES += [ ] USE_LIBS += [ + "defaultagent-static", "jsoncpp", ] @@ -51,6 +54,9 @@ OS_LIBS += [ "shell32", "shlwapi", "taskschd", + "userenv", + "wininet", + "ws2_32", ] DEFINES["NS_NO_XPCOM"] = True diff --git a/toolkit/mozapps/defaultagent/rust/Cargo.toml b/toolkit/mozapps/defaultagent/rust/Cargo.toml new file mode 100644 index 000000000000..992cf7908289 --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "defaultagent-static" +version = "0.1.0" +authors = ["The Mozilla Install/Update Team "] +edition = "2018" +description = "FFI to Rust for use in Firefox's default browser agent." +repository = "https://github.com/mozilla/defaultagent-static" +license = "MPL-2.0" + +[dependencies] +log = { version = "0.4", features = ["std"] } +serde = "1.0" +serde_derive = "1.0" +serde_json = "1.0" +url = "2.1" +viaduct = { git = "https://github.com/mozilla/application-services", rev = "61dcc364ac0d6d0816ab88a494bbf20d824b009b"} +wineventlog = { path = "../../../components/updateagent/wineventlog"} +wio = "0.2" +winapi = { version = "0.3", features = ["errhandlingapi", "handleapi", "minwindef", "winerror", "wininet", "winuser"] } + +[lib] +crate-type = ["staticlib"] diff --git a/toolkit/mozapps/defaultagent/rust/moz.build b/toolkit/mozapps/defaultagent/rust/moz.build new file mode 100644 index 000000000000..8b290d067756 --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/moz.build @@ -0,0 +1,7 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +RustLibrary('defaultagent-static') diff --git a/toolkit/mozapps/defaultagent/rust/src/lib.rs b/toolkit/mozapps/defaultagent/rust/src/lib.rs new file mode 100644 index 000000000000..54c4f348f57c --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/src/lib.rs @@ -0,0 +1,167 @@ +/* -*- Mode: rust; rust-indent-offset: 4 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#![allow(non_snake_case)] + +use std::ffi::{CStr, OsString}; +use std::os::raw::c_char; + +#[macro_use] +extern crate serde_derive; +#[macro_use] +extern crate log; +use winapi::shared::ntdef::HRESULT; +use winapi::shared::winerror::{HRESULT_FROM_WIN32, S_OK}; +use wio::wide::FromWide; + +mod viaduct_wininet; +use viaduct_wininet::WinInetBackend; + +// HRESULT with 0x80000000 is an error, 0x20000000 set is a customer error code. +#[allow(overflowing_literals)] +const HR_NETWORK_ERROR: HRESULT = 0x80000000 | 0x20000000 | 0x1; +#[allow(overflowing_literals)] +const HR_SETTINGS_ERROR: HRESULT = 0x80000000 | 0x20000000 | 0x2; + +#[derive(Debug, Deserialize)] +pub struct EnabledRecord { + // Unknown fields are ignored by serde: see the docs for `#[serde(deny_unknown_fields)]`. + pub(crate) enabled: bool, +} + +pub enum Error { + /// A backend error with an attached Windows error code from `GetLastError()`. + WindowsError(u32), + + /// A network or otherwise transient error. + NetworkError, + + /// A configuration or settings data error that probably requires code, configuration, or + /// server-side changes to address. + SettingsError, +} + +impl From for Error { + fn from(_err: viaduct::UnexpectedStatus) -> Self { + Error::NetworkError + } +} + +impl From for Error { + fn from(err: viaduct::Error) -> Self { + match err { + viaduct::Error::NetworkError(_) => Error::NetworkError, + viaduct::Error::BackendError(raw) => { + // If we have a string that's a hex error code like + // "0xabcde", that's a Windows error. + if raw.starts_with("0x") { + let without_prefix = raw.trim_start_matches("0x"); + let parse_result = u32::from_str_radix(without_prefix, 16); + if let Ok(parsed) = parse_result { + return Error::WindowsError(parsed); + } + } + Error::SettingsError + } + _ => Error::SettingsError, + } + } +} + +impl From for Error { + fn from(_err: serde_json::Error) -> Self { + Error::SettingsError + } +} + +impl From for Error { + fn from(_err: url::ParseError) -> Self { + Error::SettingsError + } +} + +fn is_agent_remote_disabled>(url: S) -> Result { + // Be careful setting the viaduct backend twice. If the backend + // has been set already, assume that it's our backend: we may as + // well at least try to continue. + match viaduct::set_backend(&WinInetBackend) { + Ok(_) => {} + Err(viaduct::Error::SetBackendError) => {} + e => e?, + } + + let url = url::Url::parse(url.as_ref())?; + let req = viaduct::Request::new(viaduct::Method::Get, url); + let resp = req.send()?; + + let resp = resp.require_success()?; + + let body: serde_json::Value = resp.json()?; + let data = body.get("data").ok_or(Error::SettingsError)?; + let record: EnabledRecord = serde_json::from_value(data.clone())?; + + let disabled = !record.enabled; + Ok(disabled) +} + +// This is an easy way to consume `MOZ_APP_DISPLAYNAME` from Rust code. +extern "C" { + #[no_mangle] + static gWinEventLogSourceName: *const u16; +} + +#[allow(dead_code)] +#[no_mangle] +extern "C" fn IsAgentRemoteDisabledRust(szUrl: *const c_char, lpdwDisabled: *mut u32) -> HRESULT { + let wineventlog_name = unsafe { OsString::from_wide_ptr_null(gWinEventLogSourceName) }; + let logger = wineventlog::EventLogger::new(&wineventlog_name); + // It's fine to initialize logging twice. + let _ = log::set_boxed_logger(Box::new(logger)); + log::set_max_level(log::LevelFilter::Info); + + // Use an IIFE for `?`. + let disabled_result = (|| { + if lpdwDisabled.is_null() { + return Err(Error::SettingsError); + } + + let url = unsafe { CStr::from_ptr(szUrl).to_str().map(|x| x.to_string()) } + .map_err(|_| Error::SettingsError)?; + + info!("Using remote settings URL: {}", url); + + is_agent_remote_disabled(url) + })(); + + match disabled_result { + Err(e) => { + return match e { + Error::WindowsError(errno) => { + let hr = HRESULT_FROM_WIN32(errno); + error!("Error::WindowsError({}) (HRESULT: 0x{:x})", errno, hr); + hr + } + Error::NetworkError => { + let hr = HR_NETWORK_ERROR; + error!("Error::NetworkError (HRESULT: 0x{:x})", hr); + hr + } + Error::SettingsError => { + let hr = HR_SETTINGS_ERROR; + error!("Error::SettingsError (HRESULT: 0x{:x})", hr); + hr + } + }; + } + + Ok(remote_disabled) => { + // We null-checked `lpdwDisabled` earlier, but just to be safe. + if !lpdwDisabled.is_null() { + unsafe { *lpdwDisabled = if remote_disabled { 1 } else { 0 } }; + } + return S_OK; + } + } +} diff --git a/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/internet_handle.rs b/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/internet_handle.rs new file mode 100644 index 000000000000..85f4254c8839 --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/internet_handle.rs @@ -0,0 +1,53 @@ +// Licensed under the Apache License, Version 2.0 +// or the MIT license +// , at your option. +// All files in the project carrying such notice may not be copied, modified, or distributed +// except according to those terms. + +//! Wrapping and automatically closing Internet handles. Copy-pasted from +//! [comedy-rs](https://github.com/agashlin/comedy-rs/blob/c244b91e9237c887f6a7bc6cd03db98b51966494/src/handle.rs). + +use winapi::shared::minwindef::DWORD; +use winapi::shared::ntdef::NULL; +use winapi::um::errhandlingapi::GetLastError; +use winapi::um::wininet::{InternetCloseHandle, HINTERNET}; + +/// Check and automatically close a Windows `HINTERNET`. +#[repr(transparent)] +#[derive(Debug)] +pub struct InternetHandle(HINTERNET); + +impl InternetHandle { + /// Take ownership of a `HINTERNET`, which will be closed with `InternetCloseHandle` upon drop. + /// Returns an error in case of `NULL`. + /// + /// # Safety + /// + /// `h` should be the only copy of the handle. `GetLastError()` is called to + /// return an error, so the last Windows API called on this thread should have been + /// what produced the invalid handle. + pub unsafe fn new(h: HINTERNET) -> Result { + if h == NULL { + Err(GetLastError()) + } else { + Ok(InternetHandle(h)) + } + } + + /// Obtains the raw `HINTERNET` without transferring ownership. + /// + /// Do __not__ close this handle because it is still owned by the `InternetHandle`. + /// + /// Do __not__ use this handle beyond the lifetime of the `InternetHandle`. + pub fn as_raw(&self) -> HINTERNET { + self.0 + } +} + +impl Drop for InternetHandle { + fn drop(&mut self) { + unsafe { + InternetCloseHandle(self.0); + } + } +} diff --git a/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/mod.rs b/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/mod.rs new file mode 100644 index 000000000000..1abd170f8f06 --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/mod.rs @@ -0,0 +1,257 @@ +// Licensed under the Apache License, Version 2.0 +// or the MIT license +// , at your option. +// All files in the project carrying such notice may not be copied, modified, or distributed +// except according to those terms. + +use winapi::shared::winerror::ERROR_INSUFFICIENT_BUFFER; +use winapi::um::errhandlingapi::GetLastError; +use winapi::um::wininet; +use wio::wide::ToWide; + +use viaduct::Backend; + +mod internet_handle; +use internet_handle::InternetHandle; + +pub struct WinInetBackend; + +/// Errors +fn to_viaduct_error(e: u32) -> viaduct::Error { + // Like "0xabcde". + viaduct::Error::BackendError(format!("{:#x}", e)) +} + +fn get_status(req: wininet::HINTERNET) -> Result { + let mut status: u32 = 0; + let mut size: u32 = std::mem::size_of::() as u32; + let result = unsafe { + wininet::HttpQueryInfoW( + req, + wininet::HTTP_QUERY_STATUS_CODE | wininet::HTTP_QUERY_FLAG_NUMBER, + &mut status as *mut _ as *mut _, + &mut size, + std::ptr::null_mut(), + ) + }; + if 0 == result { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + + Ok(status as u16) +} + +fn get_headers(req: wininet::HINTERNET) -> Result { + // We follow https://docs.microsoft.com/en-us/windows/win32/wininet/retrieving-http-headers. + // + // Per + // https://docs.microsoft.com/en-us/windows/win32/api/wininet/nf-wininet-httpqueryinfoa: + // The `HttpQueryInfoA` function represents headers as ISO-8859-1 characters + // not ANSI characters. + let mut size: u32 = 0; + + let result = unsafe { + wininet::HttpQueryInfoA( + req, + wininet::HTTP_QUERY_RAW_HEADERS, + std::ptr::null_mut(), + &mut size, + std::ptr::null_mut(), + ) + }; + if 0 == result { + let error = unsafe { GetLastError() }; + if error == wininet::ERROR_HTTP_HEADER_NOT_FOUND { + return Ok(viaduct::Headers::new()); + } else if error != ERROR_INSUFFICIENT_BUFFER { + return Err(to_viaduct_error(error)); + } + } + + let mut buffer = vec![0 as u8; size as usize]; + let result = unsafe { + wininet::HttpQueryInfoA( + req, + wininet::HTTP_QUERY_RAW_HEADERS, + buffer.as_mut_ptr() as *mut _, + &mut size, + std::ptr::null_mut(), + ) + }; + if 0 == result { + let error = unsafe { GetLastError() }; + if error == wininet::ERROR_HTTP_HEADER_NOT_FOUND { + return Ok(viaduct::Headers::new()); + } else { + return Err(to_viaduct_error(error)); + } + } + + // The API returns all of the headers as a single char buffer in + // ISO-8859-1 encoding. Each header is terminated by '\0' and + // there's a trailing '\0' terminator as well. + // + // We want UTF-8. It's not worth include a non-trivial encoding + // library like `encoding_rs` just for these headers, so let's use + // the fact that ISO-8859-1 and UTF-8 intersect on the lower 7 bits + // and decode lossily. It will at least be reasonably clear when + // there is an encoding issue. + let allheaders = String::from_utf8_lossy(&buffer); + + let mut headers = viaduct::Headers::new(); + for header in allheaders.split(0 as char) { + let mut it = header.splitn(2, ":"); + if let (Some(name), Some(value)) = (it.next(), it.next()) { + headers.insert(name.trim().to_string(), value.trim().to_string())?; + } + } + + return Ok(headers); +} + +fn get_body(req: wininet::HINTERNET) -> Result, viaduct::Error> { + let mut body = Vec::new(); + + const BUFFER_SIZE: usize = 65535; + let mut buffer: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE]; + + loop { + let mut bytes_downloaded: u32 = 0; + let result = unsafe { + wininet::InternetReadFile( + req, + buffer.as_mut_ptr() as *mut _, + BUFFER_SIZE as u32, + &mut bytes_downloaded, + ) + }; + if 0 == result { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + if bytes_downloaded == 0 { + break; + } + + body.extend_from_slice(&buffer[0..bytes_downloaded as usize]); + } + Ok(body) +} + +impl Backend for WinInetBackend { + fn send(&self, request: viaduct::Request) -> Result { + viaduct::note_backend("wininet.dll"); + + let request_method = request.method; + let url = request.url; + + let session = unsafe { + InternetHandle::new(wininet::InternetOpenW( + "DefaultAgent/1.0".to_wide_null().as_ptr(), + wininet::INTERNET_OPEN_TYPE_PRECONFIG, + std::ptr::null_mut(), + std::ptr::null_mut(), + 0, + )) + } + .map_err(to_viaduct_error)?; + + // Consider asserting the scheme here too, for documentation purposes. + // Viaduct itself only allows HTTPS at this time, but that might change. + let host = url + .host_str() + .ok_or(viaduct::Error::BackendError("no host".to_string()))?; + + let conn = unsafe { + InternetHandle::new(wininet::InternetConnectW( + session.as_raw(), + host.to_wide_null().as_ptr(), + wininet::INTERNET_DEFAULT_HTTPS_PORT as u16, + std::ptr::null_mut(), + std::ptr::null_mut(), + wininet::INTERNET_SERVICE_HTTP, + 0, + 0, + )) + } + .map_err(to_viaduct_error)?; + + let path = url[url::Position::BeforePath..].to_string(); + let req = unsafe { + wininet::HttpOpenRequestW( + conn.as_raw(), + request_method.as_str().to_wide_null().as_ptr(), + path.to_wide_null().as_ptr(), + std::ptr::null_mut(), /* lpszVersion */ + std::ptr::null_mut(), /* lpszReferrer */ + std::ptr::null_mut(), /* lplpszAcceptTypes */ + // Avoid the cache as best we can. + wininet::INTERNET_FLAG_NO_AUTH + | wininet::INTERNET_FLAG_NO_CACHE_WRITE + | wininet::INTERNET_FLAG_NO_COOKIES + | wininet::INTERNET_FLAG_NO_UI + | wininet::INTERNET_FLAG_PRAGMA_NOCACHE + | wininet::INTERNET_FLAG_RELOAD + | wininet::INTERNET_FLAG_SECURE, + 0, + ) + }; + if req.is_null() { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + + for header in request.headers { + // Per + // https://docs.microsoft.com/en-us/windows/win32/api/wininet/nf-wininet-httpaddrequestheadersw, + // "Each header must be terminated by a CR/LF (carriage return/line + // feed) pair." + let h = format!("{}: {}\r\n", header.name(), header.value()); + let result = unsafe { + wininet::HttpAddRequestHeadersW( + req, + h.to_wide_null().as_ptr(), /* lpszHeaders */ + -1i32 as u32, /* dwHeadersLength */ + wininet::HTTP_ADDREQ_FLAG_ADD | wininet::HTTP_ADDREQ_FLAG_REPLACE, /* dwModifiers */ + ) + }; + if 0 == result { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + } + + // Future work: support sending a body. + if request.body.is_some() { + return Err(viaduct::Error::BackendError( + "non-empty body is not yet supported".to_string(), + )); + } + + let result = unsafe { + wininet::HttpSendRequestW( + req, + std::ptr::null_mut(), /* lpszHeaders */ + 0, /* dwHeadersLength */ + std::ptr::null_mut(), /* lpOptional */ + 0, /* dwOptionalLength */ + ) + }; + if 0 == result { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + + let status = get_status(req)?; + let headers = get_headers(req)?; + + // Not all responses have a body. + let has_body = headers.get_header("content-type").is_some() + || headers.get_header("content-length").is_some(); + let body = if has_body { get_body(req)? } else { Vec::new() }; + + Ok(viaduct::Response { + request_method, + body, + url, + status, + headers, + }) + } +}