/* * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ use dashmap::DashMap; use flexi_logger::{Logger, LoggerHandle}; use futures::future::select; use futures::{pin_mut, FutureExt}; use futures::{SinkExt, StreamExt}; use http_auth_basic::Credentials; use notify_push::config::{Bind, Config}; use notify_push::message::DEBOUNCE_ENABLE; use notify_push::{listen_loop, serve, App}; use once_cell::sync::Lazy; use redis::AsyncCommands; use smallvec::alloc::sync::Arc; use sqlx::AnyPool; use std::net::SocketAddr; use std::sync::atomic::{AtomicU16, Ordering}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; use tokio::task::spawn; use tokio::time::timeout; use tokio::time::{sleep, Duration}; use tokio_stream::wrappers::TcpListenerStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use warp::http::StatusCode; use warp::{Filter, Reply}; static LAST_PORT: AtomicU16 = AtomicU16::new(1024); async fn listen_available_port() -> Option { for _ in LAST_PORT.load(Ordering::SeqCst)..65535 { let port = LAST_PORT.fetch_add(1, Ordering::SeqCst); if let Ok(tcp) = TcpListener::bind(("127.0.0.1", port)).await { return Some(tcp); } } None } struct Services { redis: SocketAddr, nextcloud: SocketAddr, _redis_shutdown: oneshot::Sender<()>, _nextcloud_shutdown: oneshot::Sender<()>, users: Arc>, db: AnyPool, } static LOG_HANDLE: Lazy = Lazy::new(|| Logger::try_with_str("").unwrap().start().unwrap()); impl Services { pub async fn new() -> Self { sqlx::any::install_default_drivers(); DEBOUNCE_ENABLE.store(false, Ordering::SeqCst); let redis_tcp = listen_available_port() .await .expect("Can't find open port for redis"); let nextcloud_tcp = listen_available_port() .await .expect("Can't find open port for nextcloud mock"); let redis_addr = redis_tcp .local_addr() .expect("Failed to get redis socket address"); let nextcloud_addr = nextcloud_tcp .local_addr() .expect("Failed to get nextcloud mock socket address"); // use the port in the db name to prevent collisions let db = AnyPool::connect(&format!( "sqlite:file:memory{}?mode=memory&cache=shared", nextcloud_addr.port() )) .await .expect("Failed to connect sqlite database"); sqlx::query("CREATE TABLE oc_filecache(fileid BIGINT, path TEXT)") .execute(&db) .await .unwrap(); sqlx::query("CREATE INDEX fc_id ON oc_filecache (fileid)") .execute(&db) .await .unwrap(); sqlx::query("CREATE TABLE oc_mounts(storage_id BIGINT, root_id BIGINT, user_id TEXT)") .execute(&db) .await .unwrap(); sqlx::query("CREATE INDEX mount_storage ON oc_mounts (storage_id)") .execute(&db) .await .unwrap(); sqlx::query("CREATE INDEX mount_root ON oc_mounts (root_id)") .execute(&db) .await .unwrap(); let users: Arc> = Arc::default(); let users_filter = users.clone(); let users_filter = warp::any().map(move || users_filter.clone()); let uid = warp::any() .and(warp::header::("authorization")) .and(users_filter) .map(|auth, users: Arc>| { let credentials = match Credentials::from_header(auth) { Ok(credentials) => credentials, Err(_) => return Box::new(StatusCode::BAD_REQUEST) as Box, }; match users.get(&credentials.user_id) { Some(pass) if pass.value() == &credentials.password => { Box::new(credentials.user_id) } _ => Box::new(StatusCode::UNAUTHORIZED), } }); let (redis_shutdown, redis_shutdown_rx) = oneshot::channel(); let (nextcloud_shutdown, nextcloud_shutdown_rx) = oneshot::channel(); spawn(async move { warp::serve(uid) .serve_incoming_with_graceful_shutdown( TcpListenerStream::new(nextcloud_tcp), nextcloud_shutdown_rx.map(|_| ()), ) .await; }); spawn(async move { mini_redis::server::run(redis_tcp, redis_shutdown_rx) .await .ok(); }); Self { redis: redis_addr, nextcloud: nextcloud_addr, _redis_shutdown: redis_shutdown, _nextcloud_shutdown: nextcloud_shutdown, users, db, } } fn config(&self) -> Config { Config { database: "sqlite::memory:?cache=shared".parse().unwrap(), database_prefix: "oc_".to_string(), redis: vec![format!("redis://{}", self.redis).parse().unwrap()], nextcloud_url: format!("http://{}/", self.nextcloud), metrics_bind: None, log_level: "".to_string(), bind: Bind::Tcp(self.nextcloud), allow_self_signed: false, no_ansi: false, tls: None, max_debounce_time: 15, max_connection_time: 0, } } async fn app(&self) -> App { let config = self.config(); App::with_connection(self.db.clone(), config, LOG_HANDLE.clone(), false) .await .unwrap() } async fn spawn_server(&self) -> ServerHandle { let app = Arc::new(self.app().await); let addr = async { let tcp = listen_available_port().await.unwrap(); tcp.local_addr() } .await .unwrap(); let (serve_tx, serve_rx) = oneshot::channel(); let (listen_tx, listen_rx) = oneshot::channel(); let bind = Bind::Tcp(addr); spawn(async move { let serve = serve(app.clone(), bind, serve_rx, None, 15, 0).unwrap(); let listen = listen_loop(app.clone(), listen_rx); pin_mut!(serve); pin_mut!(listen); select(serve, listen).await; }); sleep(Duration::from_millis(10)).await; ServerHandle { _serve_handle: serve_tx, _listen_handle: listen_tx, port: addr.port(), } } async fn redis_client(&self) -> redis::aio::MultiplexedConnection { let client = redis::Client::open(self.config().redis.first().unwrap().clone()).unwrap(); client.get_multiplexed_async_connection().await.unwrap() } fn add_user(&self, username: &str, password: &str) { self.users.insert(username.into(), password.into()); } async fn add_storage_mapping(&self, username: &str, storage: u32, root: u32) { sqlx::query("INSERT INTO oc_mounts(storage_id, root_id, user_id) VALUES(?, ?, ?)") .bind(storage as i64) .bind(root as i64) .bind(username) .execute(&self.db) .await .unwrap(); } async fn add_filecache_item(&self, fileid: u32, path: &str) { sqlx::query("INSERT INTO oc_filecache(fileid, path) VALUES(?, ?)") .bind(fileid as i64) .bind(path) .execute(&self.db) .await .unwrap(); } } struct ServerHandle { _serve_handle: oneshot::Sender<()>, _listen_handle: oneshot::Sender<()>, port: u16, } impl ServerHandle { async fn connect(&self) -> WebSocketStream> { tokio_tungstenite::connect_async(format!("ws://127.0.0.1:{}/ws", self.port)) .await .unwrap() .0 } async fn connect_auth( &self, username: &str, password: &str, ) -> WebSocketStream> { let mut client = tokio_tungstenite::connect_async(format!("ws://127.0.0.1:{}/ws", self.port)) .await .unwrap() .0; client.send(Message::Text(username.into())).await.unwrap(); client.send(Message::Text(password.into())).await.unwrap(); assert_next_message(&mut client, "authenticated").await; client } } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_auth() { let services = Services::new().await; services.add_user("foo", "bar"); let server_handle = services.spawn_server().await; let mut client = server_handle.connect().await; client.send(Message::Text("foo".into())).await.unwrap(); client.send(Message::Text("bar".into())).await.unwrap(); assert_next_message(&mut client, "authenticated").await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_auth_failure() { let services = Services::new().await; services.add_user("foo", "bar"); let server_handle = services.spawn_server().await; let mut client = server_handle.connect().await; client.send(Message::Text("foo".into())).await.unwrap(); client.send(Message::Text("not_bar".into())).await.unwrap(); assert_next_message(&mut client, "err: Invalid credentials").await; } async fn assert_next_message( client: &mut WebSocketStream>, expected: &str, ) { sleep(Duration::from_millis(100)).await; assert_eq!( timeout(Duration::from_millis(200), client.next()) .await .unwrap() .unwrap() .unwrap(), Message::Text(expected.to_string()) ); } async fn assert_no_message(client: &mut WebSocketStream>) { sleep(Duration::from_millis(5)).await; assert!(timeout(Duration::from_millis(10), client.next()) .await .is_err()); } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_activity() { let services = Services::new().await; services.add_user("foo", "bar"); let server_handle = services.spawn_server().await; let mut client = server_handle.connect_auth("foo", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>("notify_activity", r#"{"user":"foo"}"#) .await .unwrap(); assert_next_message(&mut client, "notify_activity").await; std::mem::forget(services); } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_activity_other_user() { let services = Services::new().await; services.add_user("foo", "bar"); let server_handle = services.spawn_server().await; let mut client = server_handle.connect_auth("foo", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>("notify_activity", r#"{"user":"someone_else"}"#) .await .unwrap(); assert_no_message(&mut client).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_file() { let services = Services::new().await; services.add_user("foo", "bar"); services.add_filecache_item(10, "foo").await; services.add_filecache_item(11, "foo/bar").await; services.add_storage_mapping("foo", 10, 11).await; let server_handle = services.spawn_server().await; let mut client = server_handle.connect_auth("foo", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>( "notify_storage_update", r#"{"storage":10, "path":"foo/bar", "file_id":5}"#, ) .await .unwrap(); assert_next_message(&mut client, "notify_file").await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_file_different_storage() { let services = Services::new().await; services.add_user("foo", "bar"); services.add_filecache_item(10, "foo").await; services.add_filecache_item(11, "foo/bar").await; services.add_storage_mapping("foo", 10, 11).await; let server_handle = services.spawn_server().await; let mut client = server_handle.connect_auth("foo", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>( "notify_storage_update", r#"{"storage":11, "path":"foo/bar", "file_id":5}"#, ) .await .unwrap(); assert_no_message(&mut client).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_file_multiple() { let services = Services::new().await; services.add_user("foo", "bar"); services.add_user("foo2", "bar"); services.add_user("foo3", "bar"); services.add_filecache_item(10, "foo").await; services.add_filecache_item(11, "foo/bar").await; services.add_filecache_item(12, "foo/outside").await; services.add_storage_mapping("foo", 10, 10).await; services.add_storage_mapping("foo2", 10, 11).await; services.add_storage_mapping("foo2", 10, 12).await; let server_handle = services.spawn_server().await; let mut client1 = server_handle.connect_auth("foo", "bar").await; let mut client2 = server_handle.connect_auth("foo2", "bar").await; let mut client3 = server_handle.connect_auth("foo3", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>( "notify_storage_update", r#"{"storage":10, "path":"foo/bar", "file_id":5}"#, ) .await .unwrap(); assert_next_message(&mut client1, "notify_file").await; assert_next_message(&mut client2, "notify_file").await; assert_no_message(&mut client3).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_pre_auth() { let services = Services::new().await; let server_handle = services.spawn_server().await; sleep(Duration::from_millis(500)).await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>("notify_pre_auth", r#"{"user":"foo", "token": "token"}"#) .await .unwrap(); sleep(Duration::from_millis(100)).await; let mut client = server_handle.connect_auth("", "token").await; // verify that we are the correct user redis .publish::<_, _, ()>("notify_activity", r#"{"user":"foo"}"#) .await .unwrap(); assert_next_message(&mut client, "notify_activity").await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_notification() { let services = Services::new().await; services.add_user("foo", "bar"); services.add_user("foo2", "bar"); let server_handle = services.spawn_server().await; let mut client1 = server_handle.connect_auth("foo", "bar").await; let mut client2 = server_handle.connect_auth("foo2", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>("notify_notification", r#"{"user":"foo"}"#) .await .unwrap(); assert_next_message(&mut client1, "notify_notification").await; assert_no_message(&mut client2).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_share() { let services = Services::new().await; services.add_user("foo", "bar"); services.add_user("foo2", "bar"); let server_handle = services.spawn_server().await; let mut client1 = server_handle.connect_auth("foo", "bar").await; let mut client2 = server_handle.connect_auth("foo2", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>("notify_user_share_created", r#"{"user":"foo"}"#) .await .unwrap(); assert_next_message(&mut client1, "notify_file").await; assert_no_message(&mut client2).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_group() { let services = Services::new().await; services.add_user("foo", "bar"); services.add_user("foo2", "bar"); let server_handle = services.spawn_server().await; let mut client1 = server_handle.connect_auth("foo", "bar").await; let mut client2 = server_handle.connect_auth("foo2", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>( "notify_group_membership_update", r#"{"user":"foo", "group":"asd"}"#, ) .await .unwrap(); assert_next_message(&mut client1, "notify_file").await; assert_no_message(&mut client2).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_custom() { let services = Services::new().await; services.add_user("foo", "bar"); services.add_user("foo2", "bar"); let server_handle = services.spawn_server().await; let mut client1 = server_handle.connect_auth("foo", "bar").await; let mut client2 = server_handle.connect_auth("foo2", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>( "notify_custom", r#"{"user":"foo", "message":"my_custom_message"}"#, ) .await .unwrap(); assert_next_message(&mut client1, "my_custom_message").await; assert_no_message(&mut client2).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_notify_custom_body() { let services = Services::new().await; services.add_user("foo", "bar"); services.add_user("foo2", "bar"); let server_handle = services.spawn_server().await; let mut client1 = server_handle.connect_auth("foo", "bar").await; let mut client2 = server_handle.connect_auth("foo2", "bar").await; let mut redis = services.redis_client().await; redis .publish::<_, _, ()>( "notify_custom", r#"{"user":"foo", "message":"my_custom_message", "body": [1,2,3]}"#, ) .await .unwrap(); assert_next_message(&mut client1, "my_custom_message [1,2,3]").await; assert_no_message(&mut client2).await; }