diff --git a/servo/components/net_traits/hosts.rs b/servo/components/net_traits/hosts.rs index 0a42be555725..c3c48ce99389 100644 --- a/servo/components/net_traits/hosts.rs +++ b/servo/components/net_traits/hosts.rs @@ -8,9 +8,10 @@ use std::env; use std::fs::File; use std::io::{BufReader, Read}; use std::net::IpAddr; +use std::sync::Mutex; lazy_static! { - static ref HOST_TABLE: Option> = create_host_table(); + static ref HOST_TABLE: Mutex>> = Mutex::new(create_host_table()); } fn create_host_table() -> Option> { @@ -34,13 +35,17 @@ fn create_host_table() -> Option> { return Some(parse_hostsfile(&lines)); } +pub fn replace_host_table(table: HashMap) { + *HOST_TABLE.lock().unwrap() = Some(table); +} + pub fn parse_hostsfile(hostsfile_content: &str) -> HashMap { let mut host_table = HashMap::new(); for line in hostsfile_content.split('\n') { - let ip_host: Vec<&str> = line.trim().split(|c: char| c == ' ' || c == '\t').collect(); - if ip_host.len() > 1 { - if let Ok(address) = ip_host[0].parse::() { - for token in ip_host.iter().skip(1) { + let mut ip_host = line.trim().split(|c: char| c == ' ' || c == '\t'); + if let Some(ip) = ip_host.next() { + if let Ok(address) = ip.parse::() { + for token in ip_host { if token.as_bytes()[0] == b'#' { break; } @@ -53,7 +58,7 @@ pub fn parse_hostsfile(hostsfile_content: &str) -> HashMap { } pub fn replace_hosts(url: &ServoUrl) -> ServoUrl { - HOST_TABLE.as_ref().map_or_else(|| url.clone(), |host_table| { + HOST_TABLE.lock().unwrap().as_ref().map_or_else(|| url.clone(), |host_table| { host_replacement(host_table, url) }) } diff --git a/servo/tests/unit/net/http_loader.rs b/servo/tests/unit/net/http_loader.rs index f3af14eeb6ac..e7b2055fa3c7 100644 --- a/servo/tests/unit/net/http_loader.rs +++ b/servo/tests/unit/net/http_loader.rs @@ -20,6 +20,7 @@ use hyper::method::Method; use hyper::mime::{Mime, SubLevel, TopLevel}; use hyper::server::{Request as HyperRequest, Response as HyperResponse}; use hyper::status::StatusCode; +use hyper::uri::RequestUri; use make_server; use msg::constellation_msg::{PipelineId, TEST_PIPELINE_ID}; use net::cookie::Cookie; @@ -31,11 +32,13 @@ use net::test::{HttpRequest, HttpRequestFactory, HttpState, LoadError, UIProvide use net::test::{HttpResponse, LoadErrorType}; use net_traits::{CookieSource, IncludeSubdomains, LoadContext, LoadData}; use net_traits::{CustomResponse, LoadOrigin, Metadata, NetworkError, ReferrerPolicy}; +use net_traits::hosts::replace_host_table; use net_traits::request::{Request, RequestInit, CredentialsMode, Destination}; use net_traits::response::ResponseBody; use new_fetch_context; use servo_url::ServoUrl; use std::borrow::Cow; +use std::collections::HashMap; use std::io::{self, Cursor, Read, Write}; use std::rc::Rc; use std::sync::{Arc, Mutex, RwLock, mpsc}; @@ -1386,75 +1389,79 @@ fn test_load_errors_when_cancelled() { #[test] fn test_redirect_from_x_to_y_provides_y_cookies_from_y() { - let url_x = ServoUrl::parse("http://mozilla.com").unwrap(); - let url_y = ServoUrl::parse("http://mozilla.org").unwrap(); - - struct Factory; - - impl HttpRequestFactory for Factory { - type R = MockRequest; - - fn create(&self, url: ServoUrl, _: Method, headers: Headers) -> Result { - if url.domain().unwrap() == "mozilla.com" { - let mut expected_headers_x = Headers::new(); - expected_headers_x.set_raw("Cookie".to_owned(), - vec![<[_]>::to_vec("mozillaIsNot=dotCom".as_bytes())]); - assert_headers_included(&expected_headers_x, &headers); - - Ok(MockRequest::new( - ResponseType::Redirect("http://mozilla.org".to_owned()))) - } else if url.domain().unwrap() == "mozilla.org" { - let mut expected_headers_y = Headers::new(); - expected_headers_y.set_raw( - "Cookie".to_owned(), - vec![<[_]>::to_vec("mozillaIs=theBest".as_bytes())]); - assert_headers_included(&expected_headers_y, &headers); - - Ok(MockRequest::new( - ResponseType::Text(<[_]>::to_vec("Yay!".as_bytes())))) - } else { - panic!("unexpected host {:?}", url) - } + let shared_url_y = Arc::new(Mutex::new(None::)); + let shared_url_y_clone = shared_url_y.clone(); + let handler = move |request: HyperRequest, mut response: HyperResponse| { + let path = match request.uri { + RequestUri::AbsolutePath(path) => path, + uri => panic!("Unexpected uri: {:?}", uri), + }; + if path == "/com/" { + assert_eq!(request.headers.get(), + Some(&CookieHeader(vec![CookiePair::new("mozillaIsNot".to_owned(), "dotOrg".to_owned())]))); + let location = shared_url_y.lock().unwrap().as_ref().unwrap().to_string(); + response.headers_mut().set(Location(location)); + *response.status_mut() = StatusCode::MovedPermanently; + response.send(b"").unwrap(); + } else if path == "/org/" { + assert_eq!(request.headers.get(), + Some(&CookieHeader(vec![CookiePair::new("mozillaIs".to_owned(), "theBest".to_owned())]))); + response.send(b"Yay!").unwrap(); + } else { + panic!("unexpected path {:?}", path) } - } + }; + let (mut server, url) = make_server(handler); + let port = url.port().unwrap(); - let load_data = LoadData::new(LoadContext::Browsing, url_x.clone(), &HttpTest); + assert_eq!(url.host_str(), Some("localhost")); + let ip = "127.0.0.1".parse().unwrap(); + let mut host_table = HashMap::new(); + host_table.insert("mozilla.com".to_owned(), ip); + host_table.insert("mozilla.org".to_owned(), ip); - let http_state = HttpState::new(); - let ui_provider = TestProvider::new(); + replace_host_table(host_table); + let url_x = ServoUrl::parse(&format!("http://mozilla.com:{}/com/", port)).unwrap(); + let url_y = ServoUrl::parse(&format!("http://mozilla.org:{}/org/", port)).unwrap(); + *shared_url_y_clone.lock().unwrap() = Some(url_y.clone()); + + let context = new_fetch_context(None); { - let mut cookie_jar = http_state.cookie_jar.write().unwrap(); - let cookie_x_url = url_x.clone(); + let mut cookie_jar = context.state.cookie_jar.write().unwrap(); let cookie_x = Cookie::new_wrapped( - CookiePair::new("mozillaIsNot".to_owned(), "dotCom".to_owned()), - &cookie_x_url, + CookiePair::new("mozillaIsNot".to_owned(), "dotOrg".to_owned()), + &url_x, CookieSource::HTTP ).unwrap(); cookie_jar.push(cookie_x, CookieSource::HTTP); - let cookie_y_url = url_y.clone(); let cookie_y = Cookie::new_wrapped( CookiePair::new("mozillaIs".to_owned(), "theBest".to_owned()), - &cookie_y_url, + &url_y, CookieSource::HTTP ).unwrap(); cookie_jar.push(cookie_y, CookieSource::HTTP); } - match load(&load_data, - &ui_provider, &http_state, - None, - &Factory, - DEFAULT_USER_AGENT.into(), - &CancellationListener::new(None), None) { - Err(e) => panic!("expected to follow a redirect {:?}", e), - Ok(mut lr) => { - let response = read_response(&mut lr); - assert_eq!(response, "Yay!".to_owned()); - } - } + let request = Request::from_init(RequestInit { + url: url_x.clone(), + method: Method::Get, + destination: Destination::Document, + origin: url_x.clone(), + pipeline_id: Some(TEST_PIPELINE_ID), + credentials_mode: CredentialsMode::Include, + .. RequestInit::default() + }); + let response = fetch(Rc::new(request), &mut None, &context); + + let _ = server.close(); + + let response = response.to_actual(); + assert!(response.status.unwrap().is_success()); + assert_eq!(*response.body.lock().unwrap(), + ResponseBody::Done(b"Yay!".to_vec())); } #[test]