cli: add streams to rpc, generic 'spawn' command (#179732)
* cli: apply improvements from integrated wsl branch * cli: add streams to rpc, generic 'spawn' command For the "exec server" concept, fyi @aeschli. * update clippy and apply fixes * fix unused imports :(
This commit is contained in:
Родитель
bb7570f4f8
Коммит
2d8ff25c85
|
@ -146,9 +146,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
|
|||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.2.1"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db"
|
||||
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
|
||||
|
||||
[[package]]
|
||||
name = "cache-padded"
|
||||
|
@ -230,6 +230,7 @@ dependencies = [
|
|||
"async-trait",
|
||||
"atty",
|
||||
"base64",
|
||||
"bytes",
|
||||
"cfg-if",
|
||||
"chrono",
|
||||
"clap",
|
||||
|
|
|
@ -17,7 +17,7 @@ clap = { version = "3.0", features = ["derive", "env"] }
|
|||
open = { version = "2.1.0" }
|
||||
reqwest = { version = "0.11.9", default-features = false, features = ["json", "stream", "native-tls"] }
|
||||
tokio = { version = "1.24.2", features = ["full"] }
|
||||
tokio-util = { version = "0.7", features = ["compat"] }
|
||||
tokio-util = { version = "0.7", features = ["compat", "codec"] }
|
||||
flate2 = { version = "1.0.22" }
|
||||
zip = { version = "0.5.13", default-features = false, features = ["time", "deflate"] }
|
||||
regex = { version = "1.5.5" }
|
||||
|
@ -54,6 +54,7 @@ thiserror = "1.0"
|
|||
cfg-if = "1.0.0"
|
||||
pin-project = "1.0"
|
||||
console = "0.15"
|
||||
bytes = "1.4"
|
||||
|
||||
[build-dependencies]
|
||||
serde = { version = "1.0" }
|
||||
|
|
|
@ -190,7 +190,7 @@ pub async fn rename(ctx: CommandContext, rename_args: TunnelRenameArgs) -> Resul
|
|||
let auth = Auth::new(&ctx.paths, ctx.log.clone());
|
||||
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
|
||||
dt.rename_tunnel(&rename_args.name).await?;
|
||||
ctx.log.result(&format!(
|
||||
ctx.log.result(format!(
|
||||
"Successfully renamed this gateway to {}",
|
||||
&rename_args.name
|
||||
));
|
||||
|
@ -287,7 +287,7 @@ pub async fn prune(ctx: CommandContext) -> Result<i32, AnyError> {
|
|||
.filter(|s| s.get_running_pid().is_none())
|
||||
.try_for_each(|s| {
|
||||
ctx.log
|
||||
.result(&format!("Deleted {}", s.server_dir.display()));
|
||||
.result(format!("Deleted {}", s.server_dir.display()));
|
||||
s.delete()
|
||||
})
|
||||
.map_err(AnyError::from)?;
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use indicatif::ProgressBar;
|
||||
|
||||
use crate::{
|
||||
|
@ -17,7 +19,7 @@ use super::{args::StandaloneUpdateArgs, CommandContext};
|
|||
pub async fn update(ctx: CommandContext, args: StandaloneUpdateArgs) -> Result<i32, AnyError> {
|
||||
let update_service = UpdateService::new(
|
||||
ctx.log.clone(),
|
||||
ReqwestSimpleHttp::with_client(ctx.http.clone()),
|
||||
Arc::new(ReqwestSimpleHttp::with_client(ctx.http.clone())),
|
||||
);
|
||||
let update_service = SelfUpdate::new(&update_service)?;
|
||||
|
||||
|
|
|
@ -58,5 +58,5 @@ pub async fn show(ctx: CommandContext) -> Result<i32, AnyError> {
|
|||
}
|
||||
|
||||
fn print_now_using(log: &log::Logger, version: &RequestedVersion, path: &Path) {
|
||||
log.result(&format!("Now using {} from {}", version, path.display()));
|
||||
log.result(format!("Now using {} from {}", version, path.display()));
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ pub async fn start_json_rpc<C: Send + Sync + 'static, S: Clone>(
|
|||
mut msg_rx: impl Receivable<Vec<u8>>,
|
||||
mut shutdown_rx: Barrier<S>,
|
||||
) -> io::Result<Option<S>> {
|
||||
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
|
||||
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
|
||||
let mut read = BufReader::new(read);
|
||||
|
||||
let mut read_buf = String::new();
|
||||
|
@ -84,7 +84,18 @@ pub async fn start_json_rpc<C: Send + Sync + 'static, S: Clone>(
|
|||
let write_tx = write_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(v) = fut.await {
|
||||
write_tx.send(v).ok();
|
||||
let _ = write_tx.send(v).await;
|
||||
}
|
||||
});
|
||||
},
|
||||
MaybeSync::Stream((dto, fut)) => {
|
||||
if let Some(dto) = dto {
|
||||
dispatcher.register_stream(write_tx.clone(), dto).await;
|
||||
}
|
||||
let write_tx = write_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(v) = fut.await {
|
||||
let _ = write_tx.send(v).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
@ -27,21 +27,19 @@ pub fn next_counter() -> u32 {
|
|||
|
||||
// Log level
|
||||
#[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
#[derive(Default)]
|
||||
pub enum Level {
|
||||
Trace = 0,
|
||||
Debug,
|
||||
Info,
|
||||
#[default]
|
||||
Info,
|
||||
Warn,
|
||||
Error,
|
||||
Critical,
|
||||
Off,
|
||||
}
|
||||
|
||||
impl Default for Level {
|
||||
fn default() -> Self {
|
||||
Level::Info
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl fmt::Display for Level {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
|
|
|
@ -8,6 +8,7 @@ use tokio::{
|
|||
pin,
|
||||
sync::mpsc,
|
||||
};
|
||||
use tokio_util::codec::Decoder;
|
||||
|
||||
use crate::{
|
||||
rpc::{self, MaybeSync, Serialization},
|
||||
|
@ -38,7 +39,6 @@ pub fn new_msgpack_rpc() -> rpc::RpcBuilder<MsgPackSerializer> {
|
|||
rpc::RpcBuilder::new(MsgPackSerializer {})
|
||||
}
|
||||
|
||||
#[allow(clippy::read_zero_byte_vec)] // false positive
|
||||
pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
|
||||
dispatcher: rpc::RpcDispatcher<MsgPackSerializer, C>,
|
||||
read: impl AsyncRead + Unpin,
|
||||
|
@ -46,34 +46,45 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
|
|||
mut msg_rx: impl Receivable<Vec<u8>>,
|
||||
mut shutdown_rx: Barrier<S>,
|
||||
) -> io::Result<Option<S>> {
|
||||
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
|
||||
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
|
||||
let mut read = BufReader::new(read);
|
||||
let mut decode_buf = vec![];
|
||||
let mut decoder = U32PrefixedCodec {};
|
||||
let mut decoder_buf = bytes::BytesMut::new();
|
||||
|
||||
let shutdown_fut = shutdown_rx.wait();
|
||||
pin!(shutdown_fut);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
u = read.read_u32() => {
|
||||
let msg_length = u? as usize;
|
||||
decode_buf.resize(msg_length, 0);
|
||||
tokio::select! {
|
||||
r = read.read_exact(&mut decode_buf) => match dispatcher.dispatch(&decode_buf[..r?]) {
|
||||
r = read.read_buf(&mut decoder_buf) => {
|
||||
r?;
|
||||
|
||||
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
|
||||
match dispatcher.dispatch(&frame) {
|
||||
MaybeSync::Sync(Some(v)) => {
|
||||
write_tx.send(v).ok();
|
||||
let _ = write_tx.send(v).await;
|
||||
},
|
||||
MaybeSync::Sync(None) => continue,
|
||||
MaybeSync::Future(fut) => {
|
||||
let write_tx = write_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(v) = fut.await {
|
||||
write_tx.send(v).ok();
|
||||
let _ = write_tx.send(v).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
r = &mut shutdown_fut => return Ok(r.ok()),
|
||||
MaybeSync::Stream((stream, fut)) => {
|
||||
if let Some(stream) = stream {
|
||||
dispatcher.register_stream(write_tx.clone(), stream).await;
|
||||
}
|
||||
let write_tx = write_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(v) = fut.await {
|
||||
let _ = write_tx.send(v).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
},
|
||||
Some(m) = write_rx.recv() => {
|
||||
|
@ -88,3 +99,33 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
|
|||
write.flush().await?;
|
||||
}
|
||||
}
|
||||
|
||||
/// Reader that reads length-prefixed msgpack messages in a cancellation-safe
|
||||
/// way using Tokio's codecs.
|
||||
pub struct U32PrefixedCodec {}
|
||||
|
||||
const U32_SIZE: usize = 4;
|
||||
|
||||
impl tokio_util::codec::Decoder for U32PrefixedCodec {
|
||||
type Item = Vec<u8>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if src.len() < 4 {
|
||||
src.reserve(U32_SIZE - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut be_bytes = [0; U32_SIZE];
|
||||
be_bytes.copy_from_slice(&src[..U32_SIZE]);
|
||||
let required_len = U32_SIZE + (u32::from_be_bytes(be_bytes) as usize);
|
||||
if src.len() < required_len {
|
||||
src.reserve(required_len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let msg = src[U32_SIZE..].to_vec();
|
||||
src.resize(0, 0);
|
||||
Ok(Some(msg))
|
||||
}
|
||||
}
|
||||
|
|
218
cli/src/rpc.rs
218
cli/src/rpc.rs
|
@ -15,17 +15,26 @@ use std::{
|
|||
use crate::log;
|
||||
use futures::{future::BoxFuture, Future, FutureExt};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf},
|
||||
sync::{mpsc, oneshot},
|
||||
};
|
||||
|
||||
use crate::util::errors::AnyError;
|
||||
|
||||
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> Option<Vec<u8>>>;
|
||||
pub type AsyncMethod =
|
||||
Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> BoxFuture<'static, Option<Vec<u8>>>>;
|
||||
pub type Duplex = Arc<
|
||||
dyn Send
|
||||
+ Sync
|
||||
+ Fn(Option<u32>, &[u8]) -> (Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>),
|
||||
>;
|
||||
|
||||
pub enum Method {
|
||||
Sync(SyncMethod),
|
||||
Async(AsyncMethod),
|
||||
Duplex(Duplex),
|
||||
}
|
||||
|
||||
/// Serialization is given to the RpcBuilder and defines how data gets serialized
|
||||
|
@ -81,6 +90,12 @@ pub struct RpcMethodBuilder<S, C> {
|
|||
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DuplexStreamStarted {
|
||||
pub for_request_id: u32,
|
||||
pub stream_id: u32,
|
||||
}
|
||||
|
||||
impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
/// Registers a synchronous rpc call that returns its result directly.
|
||||
pub fn register_sync<P, R, F>(&mut self, method_name: &'static str, callback: F)
|
||||
|
@ -179,14 +194,105 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
|||
);
|
||||
}
|
||||
|
||||
/// Registers an async rpc call that returns a Future containing a duplex
|
||||
/// stream that should be handled by the client.
|
||||
pub fn register_duplex<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F)
|
||||
where
|
||||
P: DeserializeOwned + Send + 'static,
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
Fut: Future<Output = Result<R, AnyError>> + Send,
|
||||
F: (Fn(DuplexStream, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
|
||||
{
|
||||
let serial = self.serializer.clone();
|
||||
let context = self.context.clone();
|
||||
self.methods.insert(
|
||||
method_name,
|
||||
Method::Duplex(Arc::new(move |id, body| {
|
||||
let param = match serial.deserialize::<RequestParams<P>>(body) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
return (
|
||||
None,
|
||||
future::ready(id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: 0,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
}))
|
||||
.boxed(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let callback = callback.clone();
|
||||
let serial = serial.clone();
|
||||
let context = context.clone();
|
||||
let stream_id = next_message_id();
|
||||
let (client, server) = tokio::io::duplex(8192);
|
||||
|
||||
let fut = async move {
|
||||
match callback(server, param.params, context).await {
|
||||
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
|
||||
Err(err) => id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
}),
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
Some(StreamDto {
|
||||
req_id: id.unwrap_or(0),
|
||||
stream_id,
|
||||
duplex: client,
|
||||
}),
|
||||
fut.boxed(),
|
||||
)
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
/// Builds into a usable, sync rpc dispatcher.
|
||||
pub fn build(self, log: log::Logger) -> RpcDispatcher<S, C> {
|
||||
pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
|
||||
let streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>> =
|
||||
Arc::new(tokio::sync::Mutex::new(HashMap::new()));
|
||||
|
||||
let s1 = streams.clone();
|
||||
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
|
||||
let s1 = s1.clone();
|
||||
async move {
|
||||
s1.lock().await.remove(&m.stream);
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
let s2 = streams.clone();
|
||||
self.register_async(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
|
||||
let s2 = s2.clone();
|
||||
async move {
|
||||
let mut lock = s2.lock().await;
|
||||
if let Some(stream) = lock.get_mut(&m.stream) {
|
||||
let _ = stream.write_all(&m.segment).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
RpcDispatcher {
|
||||
log,
|
||||
context: self.context,
|
||||
calls: self.calls,
|
||||
serializer: self.serializer,
|
||||
methods: Arc::new(self.methods),
|
||||
streams,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -281,6 +387,7 @@ pub struct RpcDispatcher<S, C> {
|
|||
serializer: Arc<S>,
|
||||
methods: Arc<HashMap<&'static str, Method>>,
|
||||
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
|
||||
streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>>,
|
||||
}
|
||||
|
||||
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
|
||||
|
@ -310,6 +417,7 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
|
|||
match method {
|
||||
Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)),
|
||||
Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
|
||||
Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
|
||||
None => MaybeSync::Sync(id.map(|id| {
|
||||
self.serializer.serialize(&ErrorResponse {
|
||||
id,
|
||||
|
@ -333,11 +441,91 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Registers a stream call returned from dispatch().
|
||||
pub async fn register_stream(
|
||||
&self,
|
||||
write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>,
|
||||
dto: StreamDto,
|
||||
) {
|
||||
let stream_id = dto.stream_id;
|
||||
let for_request_id = dto.req_id;
|
||||
let (mut read, write) = tokio::io::split(dto.duplex);
|
||||
let serial = self.serializer.clone();
|
||||
|
||||
self.streams.lock().await.insert(dto.stream_id, write);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let r = write_tx
|
||||
.send(
|
||||
serial
|
||||
.serialize(&FullRequest {
|
||||
id: None,
|
||||
method: METHOD_STREAM_STARTED,
|
||||
params: DuplexStreamStarted {
|
||||
stream_id,
|
||||
for_request_id,
|
||||
},
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if r.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut buf = Vec::with_capacity(4096);
|
||||
loop {
|
||||
match read.read_buf(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => {
|
||||
let r = write_tx
|
||||
.send(
|
||||
serial
|
||||
.serialize(&FullRequest {
|
||||
id: None,
|
||||
method: METHOD_STREAM_DATA,
|
||||
params: StreamDataParams {
|
||||
segment: &buf[..n],
|
||||
stream: stream_id,
|
||||
},
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if r.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
buf.truncate(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = write_tx
|
||||
.send(
|
||||
serial
|
||||
.serialize(&FullRequest {
|
||||
id: None,
|
||||
method: METHOD_STREAM_ENDED,
|
||||
params: StreamEndedParams { stream: stream_id },
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
pub fn context(&self) -> Arc<C> {
|
||||
self.context.clone()
|
||||
}
|
||||
}
|
||||
|
||||
const METHOD_STREAM_STARTED: &str = "stream_started";
|
||||
const METHOD_STREAM_DATA: &str = "stream_data";
|
||||
const METHOD_STREAM_ENDED: &str = "stream_ended";
|
||||
|
||||
trait AssertIsSync: Sync {}
|
||||
impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {}
|
||||
|
||||
|
@ -349,6 +537,25 @@ struct PartialIncoming {
|
|||
pub error: Option<ResponseError>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamDataIncomingParams {
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub segment: Vec<u8>,
|
||||
pub stream: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct StreamDataParams<'a> {
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub segment: &'a [u8],
|
||||
pub stream: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct StreamEndedParams {
|
||||
pub stream: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct FullRequest<M: AsRef<str>, P> {
|
||||
pub id: Option<u32>,
|
||||
|
@ -384,7 +591,14 @@ enum Outcome {
|
|||
Error(ResponseError),
|
||||
}
|
||||
|
||||
pub struct StreamDto {
|
||||
stream_id: u32,
|
||||
req_id: u32,
|
||||
duplex: DuplexStream,
|
||||
}
|
||||
|
||||
pub enum MaybeSync {
|
||||
Stream((Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>)),
|
||||
Future(BoxFuture<'static, Option<Vec<u8>>>),
|
||||
Sync(Option<Vec<u8>>),
|
||||
}
|
||||
|
|
|
@ -86,8 +86,8 @@ impl<'a> SelfUpdate<'a> {
|
|||
// Try to rename the old CLI to the tempdir, where it can get cleaned up by the
|
||||
// OS later. However, this can fail if the tempdir is on a different drive
|
||||
// than the installation dir. In this case just rename it to ".old".
|
||||
if fs::rename(&target_path, &tempdir.path().join("old-code-cli")).is_err() {
|
||||
fs::rename(&target_path, &target_path.with_extension(".old"))
|
||||
if fs::rename(&target_path, tempdir.path().join("old-code-cli")).is_err() {
|
||||
fs::rename(&target_path, target_path.with_extension(".old"))
|
||||
.map_err(|e| wrap(e, "failed to rename old CLI"))?;
|
||||
}
|
||||
|
||||
|
@ -132,7 +132,7 @@ fn copy_updated_cli_to_path(unzipped_content: &Path, staging_path: &Path) -> Res
|
|||
let archive_file = unzipped_files[0]
|
||||
.as_ref()
|
||||
.map_err(|e| wrap(e, "error listing update files"))?;
|
||||
fs::copy(&archive_file.path(), staging_path)
|
||||
fs::copy(archive_file.path(), staging_path)
|
||||
.map_err(|e| wrap(e, "error copying to staging file"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -140,7 +140,7 @@ fn copy_updated_cli_to_path(unzipped_content: &Path, staging_path: &Path) -> Res
|
|||
#[cfg(target_os = "windows")]
|
||||
fn copy_file_metadata(from: &Path, to: &Path) -> Result<(), std::io::Error> {
|
||||
let permissions = from.metadata()?.permissions();
|
||||
fs::set_permissions(&to, permissions)?;
|
||||
fs::set_permissions(to, permissions)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ use crate::util::command::{capture_command, kill_tree};
|
|||
use crate::util::errors::{
|
||||
wrap, AnyError, ExtensionInstallFailed, MissingEntrypointError, WrappedError,
|
||||
};
|
||||
use crate::util::http::{self, SimpleHttp};
|
||||
use crate::util::http::{self, BoxedHttp};
|
||||
use crate::util::io::SilentCopyProgress;
|
||||
use crate::util::machine::process_exists;
|
||||
use crate::{debug, info, log, span, spanf, trace, warning};
|
||||
|
@ -176,7 +176,7 @@ impl ServerParamsRaw {
|
|||
pub async fn resolve(
|
||||
self,
|
||||
log: &log::Logger,
|
||||
http: impl SimpleHttp + Send + Sync + 'static,
|
||||
http: BoxedHttp,
|
||||
) -> Result<ResolvedServerParams, AnyError> {
|
||||
Ok(ResolvedServerParams {
|
||||
release: self.get_or_fetch_commit_id(log, http).await?,
|
||||
|
@ -187,7 +187,7 @@ impl ServerParamsRaw {
|
|||
async fn get_or_fetch_commit_id(
|
||||
&self,
|
||||
log: &log::Logger,
|
||||
http: impl SimpleHttp + Send + Sync + 'static,
|
||||
http: BoxedHttp,
|
||||
) -> Result<Release, AnyError> {
|
||||
let target = match self.headless {
|
||||
true => TargetKind::Server,
|
||||
|
@ -287,7 +287,7 @@ async fn install_server_if_needed(
|
|||
log: &log::Logger,
|
||||
paths: &ServerPaths,
|
||||
release: &Release,
|
||||
http: impl SimpleHttp + Send + Sync + 'static,
|
||||
http: BoxedHttp,
|
||||
existing_archive_path: Option<PathBuf>,
|
||||
) -> Result<(), AnyError> {
|
||||
if paths.executable.exists() {
|
||||
|
@ -321,7 +321,7 @@ async fn download_server(
|
|||
path: &Path,
|
||||
release: &Release,
|
||||
log: &log::Logger,
|
||||
http: impl SimpleHttp + Send + Sync + 'static,
|
||||
http: BoxedHttp,
|
||||
) -> Result<PathBuf, AnyError> {
|
||||
let response = UpdateService::new(log.clone(), http)
|
||||
.get_download_stream(release)
|
||||
|
@ -403,20 +403,20 @@ async fn do_extension_install_on_running_server(
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ServerBuilder<'a, Http: SimpleHttp + Send + Sync + Clone> {
|
||||
pub struct ServerBuilder<'a> {
|
||||
logger: &'a log::Logger,
|
||||
server_params: &'a ResolvedServerParams,
|
||||
last_used: LastUsedServers<'a>,
|
||||
server_paths: ServerPaths,
|
||||
http: Http,
|
||||
http: BoxedHttp,
|
||||
}
|
||||
|
||||
impl<'a, Http: SimpleHttp + Send + Sync + Clone + 'static> ServerBuilder<'a, Http> {
|
||||
impl<'a> ServerBuilder<'a> {
|
||||
pub fn new(
|
||||
logger: &'a log::Logger,
|
||||
server_params: &'a ResolvedServerParams,
|
||||
launcher_paths: &'a LauncherPaths,
|
||||
http: Http,
|
||||
http: BoxedHttp,
|
||||
) -> Self {
|
||||
Self {
|
||||
logger,
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
use crate::async_pipe::get_socket_rw_stream;
|
||||
use crate::constants::CONTROL_PORT;
|
||||
use crate::log;
|
||||
use crate::msgpack_rpc::U32PrefixedCodec;
|
||||
use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization};
|
||||
use crate::self_update::SelfUpdate;
|
||||
use crate::state::LauncherPaths;
|
||||
|
@ -12,7 +13,8 @@ use crate::tunnels::protocol::HttpRequestParams;
|
|||
use crate::tunnels::socket_signal::CloseReason;
|
||||
use crate::update_service::{Platform, UpdateService};
|
||||
use crate::util::errors::{
|
||||
wrap, AnyError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError,
|
||||
wrap, AnyError, CodeError, InvalidRpcDataError, MismatchedLaunchModeError,
|
||||
NoAttachedServerError,
|
||||
};
|
||||
use crate::util::http::{
|
||||
DelegatedHttpRequest, DelegatedSimpleHttp, FallbackSimpleHttp, ReqwestSimpleHttp,
|
||||
|
@ -24,11 +26,14 @@ use crate::util::sync::{new_barrier, Barrier};
|
|||
use opentelemetry::trace::SpanKind;
|
||||
use opentelemetry::KeyValue;
|
||||
use std::collections::HashMap;
|
||||
use std::process::Stdio;
|
||||
use tokio::pin;
|
||||
use tokio_util::codec::Decoder;
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, DuplexStream};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
use super::code_server::{
|
||||
|
@ -40,8 +45,8 @@ use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
|
|||
use super::protocol::{
|
||||
CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject, ForwardParams,
|
||||
ForwardResult, GetHostnameResponse, HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog,
|
||||
ServerMessageParams, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult,
|
||||
VersionParams,
|
||||
ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, UnforwardParams, UpdateParams,
|
||||
UpdateResult, VersionParams,
|
||||
};
|
||||
use super::server_bridge::ServerBridge;
|
||||
use super::server_multiplexer::ServerMultiplexer;
|
||||
|
@ -73,7 +78,7 @@ struct HandlerContext {
|
|||
/// install platform for the VS Code server
|
||||
platform: Platform,
|
||||
/// http client to make download/update requests
|
||||
http: FallbackSimpleHttp,
|
||||
http: Arc<FallbackSimpleHttp>,
|
||||
/// requests being served by the client
|
||||
http_requests: HttpRequestsMap,
|
||||
}
|
||||
|
@ -196,7 +201,7 @@ pub async fn serve(
|
|||
],
|
||||
);
|
||||
cx.span().end();
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -247,7 +252,10 @@ async fn process_socket(
|
|||
server_bridges: server_bridges.clone(),
|
||||
port_forwarding,
|
||||
platform,
|
||||
http: FallbackSimpleHttp::new(ReqwestSimpleHttp::new(), http_delegated),
|
||||
http: Arc::new(FallbackSimpleHttp::new(
|
||||
ReqwestSimpleHttp::new(),
|
||||
http_delegated,
|
||||
)),
|
||||
http_requests: http_requests.clone(),
|
||||
});
|
||||
|
||||
|
@ -276,6 +284,9 @@ async fn process_socket(
|
|||
rpc.register_async("unforward", |p: UnforwardParams, c| async move {
|
||||
handle_unforward(&c.log, &c.port_forwarding, p).await
|
||||
});
|
||||
rpc.register_duplex("spawn", |stream, p: SpawnParams, c| async move {
|
||||
handle_spawn(&c.log, stream, p).await
|
||||
});
|
||||
rpc.register_sync("httpheaders", |p: HttpHeadersParams, c| {
|
||||
if let Some(req) = c.http_requests.lock().unwrap().get(&p.req_id) {
|
||||
req.initial_response(p.status_code, p.headers);
|
||||
|
@ -393,20 +404,20 @@ async fn handle_socket_read(
|
|||
rx_counter: Arc<AtomicUsize>,
|
||||
rpc: &RpcDispatcher<MsgPackSerializer, HandlerContext>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
let mut socket_reader = BufReader::new(readhalf);
|
||||
let mut decode_buf = vec![];
|
||||
let mut readhalf = BufReader::new(readhalf);
|
||||
let mut decoder = U32PrefixedCodec {};
|
||||
let mut decoder_buf = bytes::BytesMut::new();
|
||||
|
||||
loop {
|
||||
let read = read_next(
|
||||
&mut socket_reader,
|
||||
&rx_counter,
|
||||
&mut closer,
|
||||
&mut decode_buf,
|
||||
)
|
||||
.await;
|
||||
let read_len = tokio::select! {
|
||||
r = readhalf.read_buf(&mut decoder_buf) => r,
|
||||
_ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
|
||||
}?;
|
||||
|
||||
match read {
|
||||
Ok(len) => match rpc.dispatch(&decode_buf[..len]) {
|
||||
rx_counter.fetch_add(read_len, Ordering::Relaxed);
|
||||
|
||||
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
|
||||
match rpc.dispatch(&frame) {
|
||||
MaybeSync::Sync(Some(v)) => {
|
||||
if socket_tx.send(SocketSignal::Send(v)).await.is_err() {
|
||||
return Ok(());
|
||||
|
@ -421,34 +432,22 @@ async fn handle_socket_read(
|
|||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
|
||||
Err(e) => return Err(e),
|
||||
MaybeSync::Stream((stream, fut)) => {
|
||||
if let Some(stream) = stream {
|
||||
rpc.register_stream(socket_tx.clone(), stream).await;
|
||||
}
|
||||
let socket_tx = socket_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(v) = fut.await {
|
||||
socket_tx.send(SocketSignal::Send(v)).await.ok();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads and handles the next data packet. Returns the next packet to dispatch,
|
||||
/// or an error (including EOF).
|
||||
async fn read_next(
|
||||
socket_reader: &mut BufReader<impl AsyncRead + Unpin>,
|
||||
rx_counter: &Arc<AtomicUsize>,
|
||||
closer: &mut Barrier<()>,
|
||||
decode_buf: &mut Vec<u8>,
|
||||
) -> Result<usize, std::io::Error> {
|
||||
let msg_length = tokio::select! {
|
||||
u = socket_reader.read_u32() => u? as usize,
|
||||
_ = closer.wait() => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
|
||||
};
|
||||
decode_buf.resize(msg_length, 0);
|
||||
rx_counter.fetch_add(msg_length + 4 /* u32 */, Ordering::Relaxed);
|
||||
|
||||
tokio::select! {
|
||||
r = socket_reader.read_exact(decode_buf) => r,
|
||||
_ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ServerOutputSink {
|
||||
tx: mpsc::Sender<SocketSignal>,
|
||||
|
@ -487,7 +486,9 @@ async fn handle_serve(
|
|||
};
|
||||
|
||||
let resolved = if params.use_local_download {
|
||||
params_raw.resolve(&c.log, c.http.delegated()).await
|
||||
params_raw
|
||||
.resolve(&c.log, Arc::new(c.http.delegated()))
|
||||
.await
|
||||
} else {
|
||||
params_raw.resolve(&c.log, c.http.clone()).await
|
||||
}?;
|
||||
|
@ -518,7 +519,7 @@ async fn handle_serve(
|
|||
&install_log,
|
||||
&resolved,
|
||||
&c.launcher_paths,
|
||||
c.http.delegated(),
|
||||
Arc::new(c.http.delegated()),
|
||||
);
|
||||
do_setup!(sb)
|
||||
} else {
|
||||
|
@ -606,7 +607,7 @@ fn handle_prune(paths: &LauncherPaths) -> Result<Vec<String>, AnyError> {
|
|||
}
|
||||
|
||||
async fn handle_update(
|
||||
http: &FallbackSimpleHttp,
|
||||
http: &Arc<FallbackSimpleHttp>,
|
||||
log: &log::Logger,
|
||||
did_update: &AtomicBool,
|
||||
params: &UpdateParams,
|
||||
|
@ -732,3 +733,83 @@ async fn handle_call_server_http(
|
|||
.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_spawn(
|
||||
log: &log::Logger,
|
||||
mut duplex: DuplexStream,
|
||||
params: SpawnParams,
|
||||
) -> Result<SpawnResult, AnyError> {
|
||||
debug!(
|
||||
log,
|
||||
"requested to spawn {} with args {:?}", params.command, params.args
|
||||
);
|
||||
|
||||
let mut p = tokio::process::Command::new(¶ms.command)
|
||||
.args(¶ms.args)
|
||||
.envs(¶ms.env)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(CodeError::ProcessSpawnFailed)?;
|
||||
|
||||
let mut stdout = p.stdout.take().unwrap();
|
||||
let mut stderr = p.stderr.take().unwrap();
|
||||
let mut stdin = p.stdin.take().unwrap();
|
||||
let (tx, mut rx) = mpsc::channel(4);
|
||||
|
||||
macro_rules! copy_stream_to {
|
||||
($target:expr) => {
|
||||
let tx = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0; 4096];
|
||||
loop {
|
||||
let n = match $target.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => return,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if !tx.send(buf[..n].to_vec()).await.is_ok() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
copy_stream_to!(stdout);
|
||||
copy_stream_to!(stderr);
|
||||
|
||||
let mut stdin_buf = vec![0; 4096];
|
||||
let closed = p.wait();
|
||||
pin!(closed);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(n) = duplex.read(&mut stdin_buf) => {
|
||||
let _ = stdin.write_all(&stdin_buf[..n]).await;
|
||||
},
|
||||
Some(m) = rx.recv() => {
|
||||
let _ = duplex.write_all(&m).await;
|
||||
},
|
||||
r = &mut closed => {
|
||||
let r = match r {
|
||||
Ok(e) => SpawnResult {
|
||||
message: e.to_string(),
|
||||
exit_code: e.code().unwrap_or(-1),
|
||||
},
|
||||
Err(e) => SpawnResult {
|
||||
message: e.to_string(),
|
||||
exit_code: -1,
|
||||
},
|
||||
};
|
||||
|
||||
debug!(
|
||||
log,
|
||||
"spawned command {} exited with code {}", params.command, r.exit_code
|
||||
);
|
||||
|
||||
return Ok(r)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ impl ServerPaths {
|
|||
|
||||
// VS Code Server pid
|
||||
pub fn write_pid(&self, pid: u32) -> Result<(), WrappedError> {
|
||||
write(&self.pidfile, &format!("{}", pid)).map_err(|e| {
|
||||
write(&self.pidfile, format!("{}", pid)).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!("error writing process id into {}", self.pidfile.display()),
|
||||
|
|
|
@ -158,6 +158,20 @@ impl Default for VersionParams {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SpawnParams {
|
||||
pub command: String,
|
||||
pub args: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SpawnResult {
|
||||
pub message: String,
|
||||
pub exit_code: i32,
|
||||
}
|
||||
|
||||
pub mod singleton {
|
||||
use crate::log;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
|
@ -59,7 +59,7 @@ impl CliServiceManager for WindowsService {
|
|||
};
|
||||
|
||||
for arg in args {
|
||||
add_arg(*arg);
|
||||
add_arg(arg);
|
||||
}
|
||||
|
||||
add_arg("--log-to-file");
|
||||
|
|
|
@ -22,6 +22,12 @@ pub enum SocketSignal {
|
|||
CloseWith(CloseReason),
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for SocketSignal {
|
||||
fn from(v: Vec<u8>) -> Self {
|
||||
SocketSignal::Send(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl SocketSignal {
|
||||
pub fn from_message<T>(msg: &T) -> Self
|
||||
where
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{
|
||||
|
@ -139,7 +141,12 @@ async fn handle_serve(
|
|||
},
|
||||
};
|
||||
|
||||
let sb = ServerBuilder::new(&c.log, &resolved, &c.launcher_paths, c.http.clone());
|
||||
let sb = ServerBuilder::new(
|
||||
&c.log,
|
||||
&resolved,
|
||||
&c.launcher_paths,
|
||||
Arc::new(c.http.clone()),
|
||||
);
|
||||
let code_server = match sb.get_running().await? {
|
||||
Some(AnyCodeServer::Socket(s)) => s,
|
||||
Some(_) => return Err(MismatchedLaunchModeError().into()),
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::path::Path;
|
||||
use std::{fmt, path::Path};
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
|
@ -11,19 +11,20 @@ use crate::{
|
|||
constants::VSCODE_CLI_UPDATE_ENDPOINT,
|
||||
debug, log, options, spanf,
|
||||
util::{
|
||||
errors::{AnyError, UnsupportedPlatformError, UpdatesNotConfigured, WrappedError},
|
||||
http::{SimpleHttp, SimpleResponse},
|
||||
errors::{AnyError, CodeError, UpdatesNotConfigured, WrappedError},
|
||||
http::{BoxedHttp, SimpleResponse},
|
||||
io::ReportCopyProgress,
|
||||
},
|
||||
};
|
||||
|
||||
/// Implementation of the VS Code Update service for use in the CLI.
|
||||
pub struct UpdateService {
|
||||
client: Box<dyn SimpleHttp + Send + Sync + 'static>,
|
||||
client: BoxedHttp,
|
||||
log: log::Logger,
|
||||
}
|
||||
|
||||
/// Describes a specific release, can be created manually or returned from the update service.
|
||||
#[derive(Clone, Eq, PartialEq)]
|
||||
pub struct Release {
|
||||
pub name: String,
|
||||
pub platform: Platform,
|
||||
|
@ -53,11 +54,8 @@ fn quality_download_segment(quality: options::Quality) -> &'static str {
|
|||
}
|
||||
|
||||
impl UpdateService {
|
||||
pub fn new(log: log::Logger, http: impl SimpleHttp + Send + Sync + 'static) -> Self {
|
||||
UpdateService {
|
||||
client: Box::new(http),
|
||||
log,
|
||||
}
|
||||
pub fn new(log: log::Logger, http: BoxedHttp) -> Self {
|
||||
UpdateService { client: http, log }
|
||||
}
|
||||
|
||||
pub async fn get_release_by_semver_version(
|
||||
|
@ -71,7 +69,7 @@ impl UpdateService {
|
|||
VSCODE_CLI_UPDATE_ENDPOINT.ok_or_else(UpdatesNotConfigured::no_url)?;
|
||||
let download_segment = target
|
||||
.download_segment(platform)
|
||||
.ok_or(UnsupportedPlatformError())?;
|
||||
.ok_or_else(|| CodeError::UnsupportedPlatform(platform.to_string()))?;
|
||||
let download_url = format!(
|
||||
"{}/api/versions/{}/{}/{}",
|
||||
update_endpoint,
|
||||
|
@ -113,7 +111,7 @@ impl UpdateService {
|
|||
VSCODE_CLI_UPDATE_ENDPOINT.ok_or_else(UpdatesNotConfigured::no_url)?;
|
||||
let download_segment = target
|
||||
.download_segment(platform)
|
||||
.ok_or(UnsupportedPlatformError())?;
|
||||
.ok_or_else(|| CodeError::UnsupportedPlatform(platform.to_string()))?;
|
||||
let download_url = format!(
|
||||
"{}/api/latest/{}/{}",
|
||||
update_endpoint,
|
||||
|
@ -150,7 +148,7 @@ impl UpdateService {
|
|||
let download_segment = release
|
||||
.target
|
||||
.download_segment(release.platform)
|
||||
.ok_or(UnsupportedPlatformError())?;
|
||||
.ok_or_else(|| CodeError::UnsupportedPlatform(release.platform.to_string()))?;
|
||||
|
||||
let download_url = format!(
|
||||
"{}/commit:{}/{}/{}",
|
||||
|
@ -208,7 +206,7 @@ impl TargetKind {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub enum Platform {
|
||||
LinuxAlpineX64,
|
||||
LinuxAlpineARM64,
|
||||
|
@ -306,3 +304,20 @@ impl Platform {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Platform {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
Platform::LinuxAlpineARM64 => "LinuxAlpineARM64",
|
||||
Platform::LinuxAlpineX64 => "LinuxAlpineX64",
|
||||
Platform::LinuxX64 => "LinuxX64",
|
||||
Platform::LinuxARM64 => "LinuxARM64",
|
||||
Platform::LinuxARM32 => "LinuxARM32",
|
||||
Platform::DarwinX64 => "DarwinX64",
|
||||
Platform::DarwinARM64 => "DarwinARM64",
|
||||
Platform::WindowsX64 => "WindowsX64",
|
||||
Platform::WindowsX86 => "WindowsX86",
|
||||
Platform::WindowsARM64 => "WindowsARM64",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,29 +2,47 @@
|
|||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use super::errors::{wrap, AnyError, CommandFailed, WrappedError};
|
||||
use std::{borrow::Cow, ffi::OsStr, process::Stdio};
|
||||
use super::errors::CodeError;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
ffi::OsStr,
|
||||
process::{Output, Stdio},
|
||||
};
|
||||
use tokio::process::Command;
|
||||
|
||||
pub async fn capture_command_and_check_status(
|
||||
command_str: impl AsRef<OsStr>,
|
||||
args: &[impl AsRef<OsStr>],
|
||||
) -> Result<std::process::Output, AnyError> {
|
||||
) -> Result<std::process::Output, CodeError> {
|
||||
let output = capture_command(&command_str, args).await?;
|
||||
|
||||
check_output_status(output, || {
|
||||
format!(
|
||||
"{} {}",
|
||||
command_str.as_ref().to_string_lossy(),
|
||||
args.iter()
|
||||
.map(|a| a.as_ref().to_string_lossy())
|
||||
.collect::<Vec<Cow<'_, str>>>()
|
||||
.join(" ")
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn check_output_status(
|
||||
output: Output,
|
||||
cmd_str: impl FnOnce() -> String,
|
||||
) -> Result<std::process::Output, CodeError> {
|
||||
if !output.status.success() {
|
||||
return Err(CommandFailed {
|
||||
command: format!(
|
||||
"{} {}",
|
||||
command_str.as_ref().to_string_lossy(),
|
||||
args.iter()
|
||||
.map(|a| a.as_ref().to_string_lossy())
|
||||
.collect::<Vec<Cow<'_, str>>>()
|
||||
.join(" ")
|
||||
),
|
||||
output,
|
||||
}
|
||||
.into());
|
||||
return Err(CodeError::CommandFailed {
|
||||
command: cmd_str(),
|
||||
code: output.status.code().unwrap_or(-1),
|
||||
output: String::from_utf8_lossy(if output.stderr.is_empty() {
|
||||
&output.stdout
|
||||
} else {
|
||||
&output.stderr
|
||||
})
|
||||
.into(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
|
@ -33,7 +51,7 @@ pub async fn capture_command_and_check_status(
|
|||
pub async fn capture_command<A, I, S>(
|
||||
command_str: A,
|
||||
args: I,
|
||||
) -> Result<std::process::Output, WrappedError>
|
||||
) -> Result<std::process::Output, CodeError>
|
||||
where
|
||||
A: AsRef<OsStr>,
|
||||
I: IntoIterator<Item = S>,
|
||||
|
@ -45,27 +63,23 @@ where
|
|||
.stdout(Stdio::piped())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!(
|
||||
"failed to execute command '{}'",
|
||||
command_str.as_ref().to_string_lossy()
|
||||
),
|
||||
)
|
||||
.map_err(|e| CodeError::CommandFailed {
|
||||
command: command_str.as_ref().to_string_lossy().to_string(),
|
||||
code: -1,
|
||||
output: e.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Kills and processes and all of its children.
|
||||
#[cfg(target_os = "windows")]
|
||||
pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> {
|
||||
pub async fn kill_tree(process_id: u32) -> Result<(), CodeError> {
|
||||
capture_command("taskkill", &["/t", "/pid", &process_id.to_string()]).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Kills and processes and all of its children.
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> {
|
||||
pub async fn kill_tree(process_id: u32) -> Result<(), CodeError> {
|
||||
use futures::future::join_all;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
|
||||
|
@ -82,7 +96,11 @@ pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> {
|
|||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| wrap(e, "error enumerating process tree"))?;
|
||||
.map_err(|e| CodeError::CommandFailed {
|
||||
command: format!("pgrep -P {}", parent_id),
|
||||
code: -1,
|
||||
output: e.to_string(),
|
||||
})?;
|
||||
|
||||
let mut kill_futures = vec![tokio::spawn(
|
||||
async move { kill_single_pid(parent_id).await },
|
||||
|
|
|
@ -258,18 +258,6 @@ impl std::fmt::Display for RefreshTokenNotAvailableError {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UnsupportedPlatformError();
|
||||
|
||||
impl std::fmt::Display for UnsupportedPlatformError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"This operation is not supported on your current platform"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NoInstallInUserProvidedPath(pub String);
|
||||
|
||||
|
@ -419,28 +407,6 @@ impl std::fmt::Display for OAuthError {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CommandFailed {
|
||||
pub output: std::process::Output,
|
||||
pub command: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CommandFailed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Failed to run command \"{}\" (code {}): {}",
|
||||
self.command,
|
||||
self.output.status,
|
||||
String::from_utf8_lossy(if self.output.stderr.is_empty() {
|
||||
&self.output.stdout
|
||||
} else {
|
||||
&self.output.stderr
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Makes an "AnyError" enum that contains any of the given errors, in the form
|
||||
// `enum AnyError { FooError(FooError) }` (when given `makeAnyError!(FooError)`).
|
||||
// Useful to easily deal with application error types without making tons of "From"
|
||||
|
@ -500,6 +466,20 @@ pub enum CodeError {
|
|||
#[cfg(windows)]
|
||||
#[error("could not get windows app lock: {0:?}")]
|
||||
AppLockFailed(std::io::Error),
|
||||
#[error("failed to run command \"{command}\" (code {code}): {output}")]
|
||||
CommandFailed {
|
||||
command: String,
|
||||
code: i32,
|
||||
output: String,
|
||||
},
|
||||
|
||||
#[error("platform not currently supported: {0}")]
|
||||
UnsupportedPlatform(String),
|
||||
#[error("This machine not meet {name}'s prerequisites, expected either...: {bullets}")]
|
||||
PrerequisitesFailed { name: &'static str, bullets: String },
|
||||
|
||||
#[error("failed to spawn process: {0:?}")]
|
||||
ProcessSpawnFailed(std::io::Error)
|
||||
}
|
||||
|
||||
makeAnyError!(
|
||||
|
@ -518,7 +498,6 @@ makeAnyError!(
|
|||
ExtensionInstallFailed,
|
||||
MismatchedLaunchModeError,
|
||||
NoAttachedServerError,
|
||||
UnsupportedPlatformError,
|
||||
RefreshTokenNotAvailableError,
|
||||
NoInstallInUserProvidedPath,
|
||||
UserCancelledInstallation,
|
||||
|
@ -530,7 +509,6 @@ makeAnyError!(
|
|||
UpdatesNotConfigured,
|
||||
CorruptDownload,
|
||||
MissingHomeDirectory,
|
||||
CommandFailed,
|
||||
OAuthError,
|
||||
InvalidRpcDataError,
|
||||
CodeError
|
||||
|
|
|
@ -16,7 +16,7 @@ use hyper::{
|
|||
HeaderMap, StatusCode,
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{io, pin::Pin, str::FromStr, task::Poll};
|
||||
use std::{io, pin::Pin, str::FromStr, sync::Arc, task::Poll};
|
||||
use tokio::{
|
||||
fs,
|
||||
io::{AsyncRead, AsyncReadExt},
|
||||
|
@ -116,6 +116,8 @@ pub trait SimpleHttp {
|
|||
) -> Result<SimpleResponse, AnyError>;
|
||||
}
|
||||
|
||||
pub type BoxedHttp = Arc<dyn SimpleHttp + Send + Sync + 'static>;
|
||||
|
||||
// Implementation of SimpleHttp that uses a reqwest client.
|
||||
#[derive(Clone)]
|
||||
pub struct ReqwestSimpleHttp {
|
||||
|
@ -324,7 +326,6 @@ impl AsyncRead for DelegatedReader {
|
|||
|
||||
/// Simple http implementation that falls back to delegated http if
|
||||
/// making a direct reqwest fails.
|
||||
#[derive(Clone)]
|
||||
pub struct FallbackSimpleHttp {
|
||||
native: ReqwestSimpleHttp,
|
||||
delegated: DelegatedSimpleHttp,
|
||||
|
|
|
@ -7,13 +7,12 @@ use std::cmp::Ordering;
|
|||
use super::command::capture_command;
|
||||
use crate::constants::QUALITYLESS_SERVER_NAME;
|
||||
use crate::update_service::Platform;
|
||||
use crate::util::errors::SetupError;
|
||||
use lazy_static::lazy_static;
|
||||
use regex::bytes::Regex as BinRegex;
|
||||
use regex::Regex;
|
||||
use tokio::fs;
|
||||
|
||||
use super::errors::AnyError;
|
||||
use super::errors::CodeError;
|
||||
|
||||
lazy_static! {
|
||||
static ref LDCONFIG_STDC_RE: Regex = Regex::new(r"libstdc\+\+.* => (.+)").unwrap();
|
||||
|
@ -41,19 +40,18 @@ impl PreReqChecker {
|
|||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
pub async fn verify(&self) -> Result<Platform, AnyError> {
|
||||
use crate::constants::QUALITYLESS_PRODUCT_NAME;
|
||||
pub async fn verify(&self) -> Result<Platform, CodeError> {
|
||||
Platform::env_default().ok_or_else(|| {
|
||||
SetupError(format!(
|
||||
"{} is not supported on this platform",
|
||||
QUALITYLESS_PRODUCT_NAME
|
||||
CodeError::UnsupportedPlatform(format!(
|
||||
"{} {}",
|
||||
std::env::consts::OS,
|
||||
std::env::consts::ARCH
|
||||
))
|
||||
.into()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub async fn verify(&self) -> Result<Platform, AnyError> {
|
||||
pub async fn verify(&self) -> Result<Platform, CodeError> {
|
||||
let (is_nixos, gnu_a, gnu_b, or_musl) = tokio::join!(
|
||||
check_is_nixos(),
|
||||
check_glibc_version(),
|
||||
|
@ -96,10 +94,10 @@ impl PreReqChecker {
|
|||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
Err(AnyError::from(SetupError(format!(
|
||||
"This machine not meet {}'s prerequisites, expected either...\n{}",
|
||||
QUALITYLESS_SERVER_NAME, bullets,
|
||||
))))
|
||||
Err(CodeError::PrerequisitesFailed {
|
||||
bullets,
|
||||
name: QUALITYLESS_SERVER_NAME,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,9 +4,11 @@
|
|||
*--------------------------------------------------------------------------------------------*/
|
||||
use async_trait::async_trait;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
use tokio::sync::{
|
||||
broadcast, mpsc,
|
||||
watch::{self, error::RecvError},
|
||||
use tokio::{
|
||||
sync::{
|
||||
broadcast, mpsc,
|
||||
watch::{self, error::RecvError},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
Загрузка…
Ссылка в новой задаче