diff options
author | David Sherret <dsherret@users.noreply.github.com> | 2021-09-20 22:15:44 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-20 22:15:44 -0400 |
commit | 0f23d926019d333572366a4de4f291b848fa6ded (patch) | |
tree | 7767a39d72cabca060740198dc7a3c23f97bf363 | |
parent | 60b68e63f1045a36496257912ef4f32e716a2440 (diff) |
chore(tests): windows pty tests (#12091)
-rw-r--r-- | Cargo.lock | 42 | ||||
-rw-r--r-- | cli/Cargo.toml | 1 | ||||
-rw-r--r-- | cli/tests/integration/repl_tests.rs | 164 | ||||
-rw-r--r-- | cli/tests/integration/run_tests.rs | 5 | ||||
-rw-r--r-- | test_util/Cargo.toml | 4 | ||||
-rw-r--r-- | test_util/src/lib.rs | 85 | ||||
-rw-r--r-- | test_util/src/pty.rs | 442 |
7 files changed, 566 insertions, 177 deletions
diff --git a/Cargo.lock b/Cargo.lock index 953bd2e8d..174da8c87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -597,7 +597,6 @@ dependencies = [ "dprint-plugin-typescript", "encoding_rs", "env_logger", - "exec", "fancy-regex", "flaky_test", "fwdansi", @@ -1184,27 +1183,6 @@ dependencies = [ ] [[package]] -name = "errno" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa68f2fb9cae9d37c9b2b3584aba698a2e97f72d7aef7b9f7aa71d8b54ce46fe" -dependencies = [ - "errno-dragonfly", - "libc", - "winapi 0.3.9", -] - -[[package]] -name = "errno-dragonfly" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14ca354e36190500e1e1fb267c647932382b54053c50b14970856c0b00a35067" -dependencies = [ - "gcc", - "libc", -] - -[[package]] name = "error-code" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1215,16 +1193,6 @@ dependencies = [ ] [[package]] -name = "exec" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "886b70328cba8871bfc025858e1de4be16b1d5088f2ba50b57816f4210672615" -dependencies = [ - "errno 0.2.7", - "libc", -] - -[[package]] name = "fallible-iterator" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1491,12 +1459,6 @@ dependencies = [ ] [[package]] -name = "gcc" -version = "0.3.55" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f5f3913fa0bfe7ee1fd8248b6b9f42a5af4b9d65ec2dd2c3c26132b950ecfc2" - -[[package]] name = "generic-array" version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2636,7 +2598,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f50f3d255966981eb4e4c5df3e983e6f7d163221f547406d83b6a460ff5c5ee8" dependencies = [ - "errno 0.1.8", + "errno", "libc", ] @@ -3881,6 +3843,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-stream", + "atty", "base64 0.13.0", "futures", "hyper", @@ -3894,6 +3857,7 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-tungstenite", + "winapi 0.3.9", ] [[package]] diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 977d01c53..40e95c8f5 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -100,7 +100,6 @@ trust-dns-client = "0.20.3" trust-dns-server = "0.20.3" [target.'cfg(unix)'.dev-dependencies] -exec = "0.3.1" # Used in test_raw_tty nix = "0.22.1" [package.metadata.winres] diff --git a/cli/tests/integration/repl_tests.rs b/cli/tests/integration/repl_tests.rs index 79c2cf0f5..a8f354598 100644 --- a/cli/tests/integration/repl_tests.rs +++ b/cli/tests/integration/repl_tests.rs @@ -2,27 +2,23 @@ use test_util as util; -#[cfg(unix)] #[test] fn pty_multiline() { - use std::io::{Read, Write}; - run_pty_test(|master| { - master.write_all(b"(\n1 + 2\n)\n").unwrap(); - master.write_all(b"{\nfoo: \"foo\"\n}\n").unwrap(); - master.write_all(b"`\nfoo\n`\n").unwrap(); - master.write_all(b"`\n\\`\n`\n").unwrap(); - master.write_all(b"'{'\n").unwrap(); - master.write_all(b"'('\n").unwrap(); - master.write_all(b"'['\n").unwrap(); - master.write_all(b"/{/\n").unwrap(); - master.write_all(b"/\\(/\n").unwrap(); - master.write_all(b"/\\[/\n").unwrap(); - master.write_all(b"console.log(\"{test1} abc {test2} def {{test3}}\".match(/{([^{].+?)}/));\n").unwrap(); - master.write_all(b"close();\n").unwrap(); - - let mut output = String::new(); - master.read_to_string(&mut output).unwrap(); - + util::with_pty(&["repl"], |mut console| { + console.write_line("(\n1 + 2\n)"); + console.write_line("{\nfoo: \"foo\"\n}"); + console.write_line("`\nfoo\n`"); + console.write_line("`\n\\`\n`"); + console.write_line("'{'"); + console.write_line("'('"); + console.write_line("'['"); + console.write_line("/{/"); + console.write_line("/\\(/"); + console.write_line("/\\[/"); + console.write_line("console.log(\"{test1} abc {test2} def {{test3}}\".match(/{([^{].+?)}/));"); + console.write_line("close();"); + + let output = console.read_all_output(); assert!(output.contains('3')); assert!(output.contains("{ foo: \"foo\" }")); assert!(output.contains("\"\\nfoo\\n\"")); @@ -37,109 +33,85 @@ fn pty_multiline() { }); } -#[cfg(unix)] #[test] fn pty_unpaired_braces() { - use std::io::{Read, Write}; - run_pty_test(|master| { - master.write_all(b")\n").unwrap(); - master.write_all(b"]\n").unwrap(); - master.write_all(b"}\n").unwrap(); - master.write_all(b"close();\n").unwrap(); - - let mut output = String::new(); - master.read_to_string(&mut output).unwrap(); + util::with_pty(&["repl"], |mut console| { + console.write_line(")"); + console.write_line("]"); + console.write_line("}"); + console.write_line("close();"); + let output = console.read_all_output(); assert!(output.contains("Unexpected token `)`")); assert!(output.contains("Unexpected token `]`")); assert!(output.contains("Unexpected token `}`")); }); } -#[cfg(unix)] #[test] fn pty_bad_input() { - use std::io::{Read, Write}; - run_pty_test(|master| { - master.write_all(b"'\\u{1f3b5}'[0]\n").unwrap(); - master.write_all(b"close();\n").unwrap(); - - let mut output = String::new(); - master.read_to_string(&mut output).unwrap(); + util::with_pty(&["repl"], |mut console| { + console.write_line("'\\u{1f3b5}'[0]"); + console.write_line("close();"); + let output = console.read_all_output(); assert!(output.contains("Unterminated string literal")); }); } -#[cfg(unix)] #[test] fn pty_syntax_error_input() { - use std::io::{Read, Write}; - run_pty_test(|master| { - master.write_all(b"('\\u')\n").unwrap(); - master.write_all(b"('\n").unwrap(); - master.write_all(b"close();\n").unwrap(); - - let mut output = String::new(); - master.read_to_string(&mut output).unwrap(); + util::with_pty(&["repl"], |mut console| { + console.write_line("('\\u')"); + console.write_line("('"); + console.write_line("close();"); + let output = console.read_all_output(); assert!(output.contains("Unterminated string constant")); assert!(output.contains("Unexpected eof")); }); } -#[cfg(unix)] #[test] fn pty_complete_symbol() { - use std::io::{Read, Write}; - run_pty_test(|master| { - master.write_all(b"Symbol.it\t\n").unwrap(); - master.write_all(b"close();\n").unwrap(); - - let mut output = String::new(); - master.read_to_string(&mut output).unwrap(); + util::with_pty(&["repl"], |mut console| { + console.write_line("Symbol.it\t"); + console.write_line("close();"); + let output = console.read_all_output(); assert!(output.contains("Symbol(Symbol.iterator)")); }); } -#[cfg(unix)] #[test] fn pty_complete_declarations() { - use std::io::{Read, Write}; - run_pty_test(|master| { - master.write_all(b"class MyClass {}\n").unwrap(); - master.write_all(b"My\t\n").unwrap(); - master.write_all(b"let myVar;\n").unwrap(); - master.write_all(b"myV\t\n").unwrap(); - master.write_all(b"close();\n").unwrap(); - - let mut output = String::new(); - master.read_to_string(&mut output).unwrap(); - + util::with_pty(&["repl"], |mut console| { + console.write_line("class MyClass {}"); + console.write_line("My\t"); + console.write_line("let myVar;"); + console.write_line("myV\t"); + console.write_line("close();"); + + let output = console.read_all_output(); assert!(output.contains("> MyClass")); assert!(output.contains("> myVar")); }); } -#[cfg(unix)] #[test] fn pty_complete_primitives() { - use std::io::{Read, Write}; - run_pty_test(|master| { - master.write_all(b"let func = function test(){}\n").unwrap(); - master.write_all(b"func.appl\t\n").unwrap(); - master.write_all(b"let str = ''\n").unwrap(); - master.write_all(b"str.leng\t\n").unwrap(); - master.write_all(b"false.valueO\t\n").unwrap(); - master.write_all(b"5n.valueO\t\n").unwrap(); - master.write_all(b"let num = 5\n").unwrap(); - master.write_all(b"num.toStrin\t\n").unwrap(); - master.write_all(b"close();\n").unwrap(); - - let mut output = String::new(); - master.read_to_string(&mut output).unwrap(); - + util::with_pty(&["repl"], |mut console| { + console.write_line("let func = function test(){}"); + console.write_line("func.appl\t"); + console.write_line("let str = ''"); + console.write_line("str.leng\t"); + console.write_line("false.valueO\t"); + console.write_line("5n.valueO\t"); + console.write_line("let num = 5"); + console.write_line("num.toStrin\t"); + console.write_line("close();"); + + let output = console.read_all_output(); assert!(output.contains("> func.apply")); assert!(output.contains("> str.length")); assert!(output.contains("> 5n.valueOf")); @@ -148,17 +120,13 @@ fn pty_complete_primitives() { }); } -#[cfg(unix)] #[test] fn pty_ignore_symbols() { - use std::io::{Read, Write}; - run_pty_test(|master| { - master.write_all(b"Array.Symbol\t\n").unwrap(); - master.write_all(b"close();\n").unwrap(); - - let mut output = String::new(); - master.read_to_string(&mut output).unwrap(); + util::with_pty(&["repl"], |mut console| { + console.write_line("Array.Symbol\t"); + console.write_line("close();"); + let output = console.read_all_output(); assert!(output.contains("undefined")); assert!( !output.contains("Uncaught TypeError: Array.Symbol is not a function") @@ -166,22 +134,6 @@ fn pty_ignore_symbols() { }); } -#[cfg(unix)] -fn run_pty_test(mut run: impl FnMut(&mut util::pty::fork::Master)) { - use util::pty::fork::*; - let deno_exe = util::deno_exe_path(); - let fork = Fork::from_ptmx().unwrap(); - if let Ok(mut master) = fork.is_parent() { - run(&mut master); - fork.wait().unwrap(); - } else { - std::env::set_var("NO_COLOR", "1"); - let err = exec::Command::new(deno_exe).arg("repl").exec(); - println!("err {}", err); - unreachable!() - } -} - #[test] fn console_log() { let (out, err) = util::run_and_collect_output( diff --git a/cli/tests/integration/run_tests.rs b/cli/tests/integration/run_tests.rs index 04ba10b7b..df92ad422 100644 --- a/cli/tests/integration/run_tests.rs +++ b/cli/tests/integration/run_tests.rs @@ -335,7 +335,6 @@ itest!(_089_run_allow_list { output: "089_run_allow_list.ts.out", }); -#[cfg(unix)] #[test] fn _090_run_permissions_request() { let args = "run --quiet 090_run_permissions_request.ts"; @@ -1726,7 +1725,6 @@ mod permissions { assert!(!err.contains(util::PERMISSION_DENIED_PATTERN)); } - #[cfg(unix)] #[test] fn _061_permissions_request() { let args = "run --quiet 061_permissions_request.ts"; @@ -1742,7 +1740,6 @@ mod permissions { ]); } - #[cfg(unix)] #[test] fn _062_permissions_request_global() { let args = "run --quiet 062_permissions_request_global.ts"; @@ -1766,7 +1763,6 @@ mod permissions { output: "064_permissions_revoke_global.ts.out", }); - #[cfg(unix)] #[test] fn _066_prompt() { let args = "run --quiet --unstable 066_prompt.ts"; @@ -1861,7 +1857,6 @@ itest!(byte_order_mark { output: "byte_order_mark.out", }); -#[cfg(unix)] #[test] fn issue9750() { use util::PtyData::*; diff --git a/test_util/Cargo.toml b/test_util/Cargo.toml index 497f2294a..92523ac81 100644 --- a/test_util/Cargo.toml +++ b/test_util/Cargo.toml @@ -14,6 +14,7 @@ path = "src/test_server.rs" [dependencies] anyhow = "1.0.43" async-stream = "0.3.2" +atty = "0.2.14" base64 = "0.13.0" futures = "0.3.16" hyper = { version = "0.14.12", features = ["server", "http1", "runtime"] } @@ -29,3 +30,6 @@ tokio-tungstenite = "0.14.0" [target.'cfg(unix)'.dependencies] pty = "0.2.2" + +[target.'cfg(windows)'.dependencies] +winapi = { version = "0.3.9", features = ["consoleapi", "handleapi", "namedpipeapi", "winbase", "winerror"] } diff --git a/test_util/src/lib.rs b/test_util/src/lib.rs index 5eaedbaa0..8bfe5caa0 100644 --- a/test_util/src/lib.rs +++ b/test_util/src/lib.rs @@ -44,10 +44,8 @@ use tokio_rustls::rustls::{self, Session}; use tokio_rustls::TlsAcceptor; use tokio_tungstenite::accept_async; -#[cfg(unix)] -pub use pty; - pub mod lsp; +pub mod pty; const PORT: u16 = 4545; const TEST_AUTH_TOKEN: &str = "abcdef123456789"; @@ -1589,62 +1587,97 @@ pub enum PtyData { Output(&'static str), } -#[cfg(unix)] pub fn test_pty2(args: &str, data: Vec<PtyData>) { - use pty::fork::Fork; use std::io::BufRead; - let tests_path = testdata_path(); - let fork = Fork::from_ptmx().unwrap(); - if let Ok(master) = fork.is_parent() { - let mut buf_reader = std::io::BufReader::new(master); - for d in data { + with_pty(&args.split_whitespace().collect::<Vec<_>>(), |console| { + let mut buf_reader = std::io::BufReader::new(console); + for d in data.iter() { match d { PtyData::Input(s) => { println!("INPUT {}", s.escape_debug()); - buf_reader.get_mut().write_all(s.as_bytes()).unwrap(); + buf_reader.get_mut().write_text(s); // Because of tty echo, we should be able to read the same string back. assert!(s.ends_with('\n')); let mut echo = String::new(); buf_reader.read_line(&mut echo).unwrap(); println!("ECHO: {}", echo.escape_debug()); - assert!(echo.starts_with(&s.trim())); + + // Windows may also echo the previous line, so only check the end + assert!(normalize_text(&echo).ends_with(&normalize_text(s))); } PtyData::Output(s) => { let mut line = String::new(); if s.ends_with('\n') { buf_reader.read_line(&mut line).unwrap(); } else { - while s != line { + // assumes the buffer won't have overlapping virtual terminal sequences + while normalize_text(&line).len() < normalize_text(s).len() { let mut buf = [0; 64 * 1024]; - let _n = buf_reader.read(&mut buf).unwrap(); + let bytes_read = buf_reader.read(&mut buf).unwrap(); + assert!(bytes_read > 0); let buf_str = std::str::from_utf8(&buf) .unwrap() .trim_end_matches(char::from(0)); line += buf_str; - assert!(s.starts_with(&line)); } } println!("OUTPUT {}", line.escape_debug()); - assert_eq!(line, s); + assert_eq!(normalize_text(&line), normalize_text(s)); } } } + }); - fork.wait().unwrap(); - } else { - deno_cmd() - .current_dir(tests_path) - .env("NO_COLOR", "1") - .args(args.split_whitespace()) - .spawn() - .unwrap() - .wait() - .unwrap(); + // This normalization function is not comprehensive + // and may need to updated as new scenarios emerge. + fn normalize_text(text: &str) -> String { + lazy_static! { + static ref MOVE_CURSOR_RIGHT_ONE_RE: Regex = + Regex::new(r"\x1b\[1C").unwrap(); + static ref FOUND_SEQUENCES_RE: Regex = + Regex::new(r"(\x1b\]0;[^\x07]*\x07)*(\x08)*(\x1b\[\d+X)*").unwrap(); + static ref CARRIAGE_RETURN_RE: Regex = + Regex::new(r"[^\n]*\r([^\n])").unwrap(); + } + + // any "move cursor right" sequences should just be a space + let text = MOVE_CURSOR_RIGHT_ONE_RE.replace_all(text, " "); + // replace additional virtual terminal sequences that strip ansi codes doesn't catch + let text = FOUND_SEQUENCES_RE.replace_all(&text, ""); + // strip any ansi codes, which also strips more terminal sequences + let text = strip_ansi_codes(&text); + // get rid of any text that is overwritten with only a carriage return + let text = CARRIAGE_RETURN_RE.replace_all(&text, "$1"); + // finally, trim surrounding whitespace + text.trim().to_string() } } +pub fn with_pty(deno_args: &[&str], mut action: impl FnMut(Box<dyn pty::Pty>)) { + if !atty::is(atty::Stream::Stdin) || !atty::is(atty::Stream::Stderr) { + eprintln!("Ignoring non-tty environment."); + return; + } + + let deno_dir = new_deno_dir(); + let mut env_vars = std::collections::HashMap::new(); + env_vars.insert("NO_COLOR".to_string(), "1".to_string()); + env_vars.insert( + "DENO_DIR".to_string(), + deno_dir.path().to_string_lossy().to_string(), + ); + let pty = pty::create_pty( + &deno_exe_path().to_string_lossy().to_string(), + deno_args, + testdata_path(), + Some(env_vars), + ); + + action(pty); +} + pub struct WrkOutput { pub latency: f64, pub requests: u64, diff --git a/test_util/src/pty.rs b/test_util/src/pty.rs new file mode 100644 index 000000000..2fa2ed4cd --- /dev/null +++ b/test_util/src/pty.rs @@ -0,0 +1,442 @@ +use std::collections::HashMap; +use std::io::Read; +use std::path::Path; + +pub trait Pty: Read { + fn write_text(&mut self, text: &str); + + fn write_line(&mut self, text: &str) { + self.write_text(&format!("{}\n", text)); + } + + /// Reads the output to the EOF. + fn read_all_output(&mut self) -> String { + let mut text = String::new(); + self.read_to_string(&mut text).unwrap(); + text + } +} + +#[cfg(unix)] +pub fn create_pty( + program: impl AsRef<Path>, + args: &[&str], + cwd: impl AsRef<Path>, + env_vars: Option<HashMap<String, String>>, +) -> Box<dyn Pty> { + let fork = pty::fork::Fork::from_ptmx().unwrap(); + if fork.is_parent().is_ok() { + Box::new(unix::UnixPty { fork }) + } else { + std::process::Command::new(program.as_ref()) + .current_dir(cwd) + .args(args) + .envs(env_vars.unwrap_or_default()) + .spawn() + .unwrap() + .wait() + .unwrap(); + unreachable!(); + } +} + +#[cfg(unix)] +mod unix { + use std::io::Read; + use std::io::Write; + + use super::Pty; + + pub struct UnixPty { + pub fork: pty::fork::Fork, + } + + impl Drop for UnixPty { + fn drop(&mut self) { + self.fork.wait().unwrap(); + } + } + + impl Pty for UnixPty { + fn write_text(&mut self, text: &str) { + let mut master = self.fork.is_parent().unwrap(); + master.write_all(text.as_bytes()).unwrap(); + } + } + + impl Read for UnixPty { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { + let mut master = self.fork.is_parent().unwrap(); + master.read(buf) + } + } +} + +#[cfg(target_os = "windows")] +pub fn create_pty( + program: impl AsRef<Path>, + args: &[&str], + cwd: impl AsRef<Path>, + env_vars: Option<HashMap<String, String>>, +) -> Box<dyn Pty> { + let pty = windows::WinPseudoConsole::new( + program, + args, + &cwd.as_ref().to_string_lossy().to_string(), + env_vars, + ); + Box::new(pty) +} + +#[cfg(target_os = "windows")] +mod windows { + use std::collections::HashMap; + use std::io::Read; + use std::io::Write; + use std::path::Path; + use std::ptr; + use std::time::Duration; + + use winapi::shared::minwindef::FALSE; + use winapi::shared::minwindef::LPVOID; + use winapi::shared::minwindef::TRUE; + use winapi::shared::winerror::S_OK; + use winapi::um::consoleapi::ClosePseudoConsole; + use winapi::um::consoleapi::CreatePseudoConsole; + use winapi::um::fileapi::ReadFile; + use winapi::um::fileapi::WriteFile; + use winapi::um::handleapi::DuplicateHandle; + use winapi::um::handleapi::INVALID_HANDLE_VALUE; + use winapi::um::namedpipeapi::CreatePipe; + use winapi::um::processthreadsapi::CreateProcessW; + use winapi::um::processthreadsapi::DeleteProcThreadAttributeList; + use winapi::um::processthreadsapi::GetCurrentProcess; + use winapi::um::processthreadsapi::InitializeProcThreadAttributeList; + use winapi::um::processthreadsapi::UpdateProcThreadAttribute; + use winapi::um::processthreadsapi::LPPROC_THREAD_ATTRIBUTE_LIST; + use winapi::um::processthreadsapi::PROCESS_INFORMATION; + use winapi::um::synchapi::WaitForSingleObject; + use winapi::um::winbase::CREATE_UNICODE_ENVIRONMENT; + use winapi::um::winbase::EXTENDED_STARTUPINFO_PRESENT; + use winapi::um::winbase::INFINITE; + use winapi::um::winbase::STARTUPINFOEXW; + use winapi::um::wincontypes::COORD; + use winapi::um::wincontypes::HPCON; + use winapi::um::winnt::DUPLICATE_SAME_ACCESS; + use winapi::um::winnt::HANDLE; + + use super::Pty; + + macro_rules! assert_win_success { + ($expression:expr) => { + let success = $expression; + if success != TRUE { + panic!("{}", std::io::Error::last_os_error().to_string()) + } + }; + } + + pub struct WinPseudoConsole { + stdin_write_handle: WinHandle, + stdout_read_handle: WinHandle, + // keep these alive for the duration of the pseudo console + _process_handle: WinHandle, + _thread_handle: WinHandle, + _attribute_list: ProcThreadAttributeList, + } + + impl WinPseudoConsole { + pub fn new( + program: impl AsRef<Path>, + args: &[&str], + cwd: &str, + maybe_env_vars: Option<HashMap<String, String>>, + ) -> Self { + // https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session + unsafe { + let mut size: COORD = std::mem::zeroed(); + size.X = 800; + size.Y = 500; + let mut console_handle = std::ptr::null_mut(); + let (stdin_read_handle, stdin_write_handle) = create_pipe(); + let (stdout_read_handle, stdout_write_handle) = create_pipe(); + + let result = CreatePseudoConsole( + size, + stdin_read_handle.as_raw_handle(), + stdout_write_handle.as_raw_handle(), + 0, + &mut console_handle, + ); + assert_eq!(result, S_OK); + + let mut environment_vars = maybe_env_vars.map(get_env_vars); + let mut attribute_list = ProcThreadAttributeList::new(console_handle); + let mut startup_info: STARTUPINFOEXW = std::mem::zeroed(); + startup_info.StartupInfo.cb = + std::mem::size_of::<STARTUPINFOEXW>() as u32; + startup_info.lpAttributeList = attribute_list.as_mut_ptr(); + + let mut proc_info: PROCESS_INFORMATION = std::mem::zeroed(); + let command = format!( + "\"{}\" {}", + program.as_ref().to_string_lossy(), + args.join(" ") + ) + .trim() + .to_string(); + let mut application_str = + to_windows_str(&program.as_ref().to_string_lossy()); + let mut command_str = to_windows_str(&command); + let mut cwd = to_windows_str(cwd); + + assert_win_success!(CreateProcessW( + application_str.as_mut_ptr(), + command_str.as_mut_ptr(), + ptr::null_mut(), + ptr::null_mut(), + FALSE, + EXTENDED_STARTUPINFO_PRESENT | CREATE_UNICODE_ENVIRONMENT, + environment_vars + .as_mut() + .map(|v| v.as_mut_ptr() as LPVOID) + .unwrap_or(ptr::null_mut()), + cwd.as_mut_ptr(), + &mut startup_info.StartupInfo, + &mut proc_info, + )); + + // close the handles that the pseudoconsole now has + drop(stdin_read_handle); + drop(stdout_write_handle); + + // start a thread that will close the pseudoconsole on process exit + let thread_handle = WinHandle::new(proc_info.hThread); + std::thread::spawn({ + let thread_handle = thread_handle.duplicate(); + let console_handle = WinHandle::new(console_handle); + move || { + WaitForSingleObject(thread_handle.as_raw_handle(), INFINITE); + // wait for the reading thread to catch up + std::thread::sleep(Duration::from_millis(200)); + // close the console handle which will close the + // stdout pipe for the reader + ClosePseudoConsole(console_handle.into_raw_handle()); + } + }); + + Self { + stdin_write_handle, + stdout_read_handle, + _process_handle: WinHandle::new(proc_info.hProcess), + _thread_handle: thread_handle, + _attribute_list: attribute_list, + } + } + } + } + + impl Read for WinPseudoConsole { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { + unsafe { + loop { + let mut bytes_read = 0; + let success = ReadFile( + self.stdout_read_handle.as_raw_handle(), + buf.as_mut_ptr() as _, + buf.len() as u32, + &mut bytes_read, + ptr::null_mut(), + ); + + // ignore zero-byte writes + let is_zero_byte_write = bytes_read == 0 && success == TRUE; + if !is_zero_byte_write { + return Ok(bytes_read as usize); + } + } + } + } + } + + impl Pty for WinPseudoConsole { + fn write_text(&mut self, text: &str) { + // windows psuedo console requires a \r\n to do a newline + let newline_re = regex::Regex::new("\r?\n").unwrap(); + self + .write_all(newline_re.replace_all(text, "\r\n").as_bytes()) + .unwrap(); + } + } + + impl std::io::Write for WinPseudoConsole { + fn write(&mut self, buffer: &[u8]) -> std::io::Result<usize> { + unsafe { + let mut bytes_written = 0; + assert_win_success!(WriteFile( + self.stdin_write_handle.as_raw_handle(), + buffer.as_ptr() as *const _, + buffer.len() as u32, + &mut bytes_written, + ptr::null_mut(), + )); + Ok(bytes_written as usize) + } + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + } + + struct WinHandle { + inner: HANDLE, + } + + impl WinHandle { + pub fn new(handle: HANDLE) -> Self { + WinHandle { inner: handle } + } + + pub fn duplicate(&self) -> WinHandle { + unsafe { + let process_handle = GetCurrentProcess(); + let mut duplicate_handle = ptr::null_mut(); + assert_win_success!(DuplicateHandle( + process_handle, + self.inner, + process_handle, + &mut duplicate_handle, + 0, + 0, + DUPLICATE_SAME_ACCESS, + )); + + WinHandle::new(duplicate_handle) + } + } + + pub fn as_raw_handle(&self) -> HANDLE { + self.inner + } + + pub fn into_raw_handle(self) -> HANDLE { + let handle = self.inner; + // skip the drop implementation in order to not close the handle + std::mem::forget(self); + handle + } + } + + unsafe impl Send for WinHandle {} + unsafe impl Sync for WinHandle {} + + impl Drop for WinHandle { + fn drop(&mut self) { + unsafe { + if !self.inner.is_null() && self.inner != INVALID_HANDLE_VALUE { + winapi::um::handleapi::CloseHandle(self.inner); + } + } + } + } + + struct ProcThreadAttributeList { + buffer: Vec<u8>, + } + + impl ProcThreadAttributeList { + pub fn new(console_handle: HPCON) -> Self { + unsafe { + // discover size required for the list + let mut size = 0; + let attribute_count = 1; + assert_eq!( + InitializeProcThreadAttributeList( + ptr::null_mut(), + attribute_count, + 0, + &mut size + ), + FALSE + ); + + let mut buffer = vec![0u8; size]; + let attribute_list_ptr = buffer.as_mut_ptr() as _; + + assert_win_success!(InitializeProcThreadAttributeList( + attribute_list_ptr, + attribute_count, + 0, + &mut size, + )); + + const PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE: usize = 0x00020016; + assert_win_success!(UpdateProcThreadAttribute( + attribute_list_ptr, + 0, + PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE, + console_handle, + std::mem::size_of::<HPCON>(), + ptr::null_mut(), + ptr::null_mut(), + )); + + ProcThreadAttributeList { buffer } + } + } + + pub fn as_mut_ptr(&mut self) -> LPPROC_THREAD_ATTRIBUTE_LIST { + self.buffer.as_mut_slice().as_mut_ptr() as *mut _ + } + } + + impl Drop for ProcThreadAttributeList { + fn drop(&mut self) { + unsafe { DeleteProcThreadAttributeList(self.as_mut_ptr()) }; + } + } + + fn create_pipe() -> (WinHandle, WinHandle) { + unsafe { + let mut read_handle = std::ptr::null_mut(); + let mut write_handle = std::ptr::null_mut(); + + assert_win_success!(CreatePipe( + &mut read_handle, + &mut write_handle, + ptr::null_mut(), + 0 + )); + + (WinHandle::new(read_handle), WinHandle::new(write_handle)) + } + } + + fn to_windows_str(str: &str) -> Vec<u16> { + use std::os::windows::prelude::OsStrExt; + std::ffi::OsStr::new(str) + .encode_wide() + .chain(Some(0)) + .collect() + } + + fn get_env_vars(env_vars: HashMap<String, String>) -> Vec<u16> { + // See lpEnvironment: https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-createprocessw + let mut parts = env_vars + .into_iter() + // each environment variable is in the form `name=value\0` + .map(|(key, value)| format!("{}={}\0", key, value)) + .collect::<Vec<_>>(); + + // all strings in an environment block must be case insensitively + // sorted alphabetically by name + // https://docs.microsoft.com/en-us/windows/win32/procthread/changing-environment-variables + parts.sort_by_key(|part| part.to_lowercase()); + + // the entire block is terminated by NULL (\0) + format!("{}\0", parts.join("")) + .encode_utf16() + .collect::<Vec<_>>() + } +} |