diff options
author | David Sherret <dsherret@users.noreply.github.com> | 2023-03-28 17:49:00 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-28 21:49:00 +0000 |
commit | 6fb6b0c1f302e8637c96131c9ffc4c4b9f3f5f0f (patch) | |
tree | dff55c1b345f317ebd3ec5a3b62c26ed27d5830c /test_util/src | |
parent | c65149c0a072fa710098b14776c6cd3cc8a204d6 (diff) |
chore: restore pty tests and make them run on the Linux CI (#18424)
1. Rewrites the tests to be more back and forth rather than getting the
output all at once (which I believe was causing the hangs on linux and
maybe mac)
2. Runs the pty tests on the linux ci.
3. Fixes a bunch of tests that were just wrong.
4. Adds timeouts on the pty tests.
Diffstat (limited to 'test_util/src')
-rw-r--r-- | test_util/src/builders.rs | 151 | ||||
-rw-r--r-- | test_util/src/lib.rs | 101 | ||||
-rw-r--r-- | test_util/src/pty.rs | 399 |
3 files changed, 414 insertions, 237 deletions
diff --git a/test_util/src/builders.rs b/test_util/src/builders.rs index 9b300b911..84befb57a 100644 --- a/test_util/src/builders.rs +++ b/test_util/src/builders.rs @@ -10,7 +10,6 @@ use std::process::Command; use std::process::Stdio; use std::rc::Rc; -use backtrace::Backtrace; use os_pipe::pipe; use pretty_assertions::assert_eq; @@ -20,6 +19,7 @@ use crate::env_vars_for_npm_tests_no_sync_download; use crate::http_server; use crate::lsp::LspClientBuilder; use crate::new_deno_dir; +use crate::pty::Pty; use crate::strip_ansi_codes; use crate::testdata_path; use crate::wildcard_match; @@ -268,34 +268,29 @@ impl TestCommandBuilder { self } - pub fn run(&self) -> TestCommandOutput { - fn read_pipe_to_string(mut pipe: os_pipe::PipeReader) -> String { - let mut output = String::new(); - pipe.read_to_string(&mut output).unwrap(); - output - } - - fn sanitize_output(text: String, args: &[String]) -> String { - let mut text = strip_ansi_codes(&text).to_string(); - // deno test's output capturing flushes with a zero-width space in order to - // synchronize the output pipes. Occassionally this zero width space - // might end up in the output so strip it from the output comparison here. - if args.first().map(|s| s.as_str()) == Some("test") { - text = text.replace('\u{200B}', ""); - } - text - } - + fn build_cwd(&self) -> PathBuf { let cwd = self.cwd.as_ref().or(self.context.cwd.as_ref()); - let cwd = if self.context.use_temp_cwd { + if self.context.use_temp_cwd { assert!(cwd.is_none()); self.context.temp_dir.path().to_owned() } else if let Some(cwd_) = cwd { self.context.testdata_dir.join(cwd_) } else { self.context.testdata_dir.clone() - }; - let args = if self.args_vec.is_empty() { + } + } + + fn build_command_path(&self) -> PathBuf { + let command_name = &self.command_name; + if command_name == "deno" { + deno_exe_path() + } else { + PathBuf::from(command_name) + } + } + + fn build_args(&self) -> Vec<String> { + if self.args_vec.is_empty() { std::borrow::Cow::Owned( self .args @@ -314,21 +309,58 @@ impl TestCommandBuilder { .map(|arg| { arg.replace("$TESTDATA", &self.context.testdata_dir.to_string_lossy()) }) - .collect::<Vec<_>>(); - let command_name = &self.command_name; - let mut command = if command_name == "deno" { - Command::new(deno_exe_path()) - } else { - Command::new(command_name) - }; - command.env("DENO_DIR", self.context.deno_dir.path()); + .collect::<Vec<_>>() + } + + pub fn with_pty(&self, mut action: impl FnMut(Pty)) { + if !Pty::is_supported() { + return; + } + + let args = self.build_args(); + let args = args.iter().map(|s| s.as_str()).collect::<Vec<_>>(); + let mut envs = self.envs.clone(); + if !envs.contains_key("NO_COLOR") { + // set this by default for pty tests + envs.insert("NO_COLOR".to_string(), "1".to_string()); + } + action(Pty::new( + &self.build_command_path(), + &args, + &self.build_cwd(), + Some(envs), + )) + } - println!("command {} {}", command_name, args.join(" ")); + pub fn run(&self) -> TestCommandOutput { + fn read_pipe_to_string(mut pipe: os_pipe::PipeReader) -> String { + let mut output = String::new(); + pipe.read_to_string(&mut output).unwrap(); + output + } + + fn sanitize_output(text: String, args: &[String]) -> String { + let mut text = strip_ansi_codes(&text).to_string(); + // deno test's output capturing flushes with a zero-width space in order to + // synchronize the output pipes. Occassionally this zero width space + // might end up in the output so strip it from the output comparison here. + if args.first().map(|s| s.as_str()) == Some("test") { + text = text.replace('\u{200B}', ""); + } + text + } + + let cwd = self.build_cwd(); + let args = self.build_args(); + let mut command = Command::new(self.build_command_path()); + + println!("command {} {}", self.command_name, args.join(" ")); println!("command cwd {:?}", &cwd); command.args(args.iter()); if self.env_clear { command.env_clear(); } + command.env("DENO_DIR", self.context.deno_dir.path()); command.envs({ let mut envs = self.context.envs.clone(); for (key, value) in &self.envs { @@ -423,13 +455,10 @@ impl Drop for TestCommandOutput { fn drop(&mut self) { fn panic_unasserted_output(text: &str) { println!("OUTPUT\n{text}\nOUTPUT"); - panic!( - concat!( - "The non-empty text of the command was not asserted at {}. ", - "Call `output.skip_output_check()` to skip if necessary.", - ), - failed_position() - ); + panic!(concat!( + "The non-empty text of the command was not asserted. ", + "Call `output.skip_output_check()` to skip if necessary.", + ),); } if std::thread::panicking() { @@ -438,9 +467,8 @@ impl Drop for TestCommandOutput { // force the caller to assert these if !*self.asserted_exit_code.borrow() && self.exit_code != Some(0) { panic!( - "The non-zero exit code of the command was not asserted: {:?} at {}.", + "The non-zero exit code of the command was not asserted: {:?}", self.exit_code, - failed_position(), ) } @@ -511,6 +539,7 @@ impl TestCommandOutput { .expect("call .split_output() on the builder") } + #[track_caller] pub fn assert_exit_code(&self, expected_exit_code: i32) -> &Self { let actual_exit_code = self.exit_code(); @@ -518,26 +547,22 @@ impl TestCommandOutput { if *exit_code != expected_exit_code { self.print_output(); panic!( - "bad exit code, expected: {:?}, actual: {:?} at {}", - expected_exit_code, - exit_code, - failed_position(), + "bad exit code, expected: {:?}, actual: {:?}", + expected_exit_code, exit_code, ); } } else { self.print_output(); if let Some(signal) = self.signal() { panic!( - "process terminated by signal, expected exit code: {:?}, actual signal: {:?} at {}", + "process terminated by signal, expected exit code: {:?}, actual signal: {:?}", actual_exit_code, signal, - failed_position(), ); } else { panic!( - "process terminated without status code on non unix platform, expected exit code: {:?} at {}", + "process terminated without status code on non unix platform, expected exit code: {:?}", actual_exit_code, - failed_position(), ); } } @@ -554,14 +579,17 @@ impl TestCommandOutput { } } + #[track_caller] pub fn assert_matches_text(&self, expected_text: impl AsRef<str>) -> &Self { self.inner_assert_matches_text(self.combined_output(), expected_text) } + #[track_caller] pub fn assert_matches_file(&self, file_path: impl AsRef<Path>) -> &Self { self.inner_assert_matches_file(self.combined_output(), file_path) } + #[track_caller] pub fn assert_stdout_matches_text( &self, expected_text: impl AsRef<str>, @@ -569,6 +597,7 @@ impl TestCommandOutput { self.inner_assert_matches_text(self.stdout(), expected_text) } + #[track_caller] pub fn assert_stdout_matches_file( &self, file_path: impl AsRef<Path>, @@ -576,6 +605,7 @@ impl TestCommandOutput { self.inner_assert_matches_file(self.stdout(), file_path) } + #[track_caller] pub fn assert_stderr_matches_text( &self, expected_text: impl AsRef<str>, @@ -583,6 +613,7 @@ impl TestCommandOutput { self.inner_assert_matches_text(self.stderr(), expected_text) } + #[track_caller] pub fn assert_stderrr_matches_file( &self, file_path: impl AsRef<Path>, @@ -590,6 +621,7 @@ impl TestCommandOutput { self.inner_assert_matches_file(self.stderr(), file_path) } + #[track_caller] fn inner_assert_matches_text( &self, actual: &str, @@ -597,15 +629,16 @@ impl TestCommandOutput { ) -> &Self { let expected = expected.as_ref(); if !expected.contains("[WILDCARD]") { - assert_eq!(actual, expected, "at {}", failed_position()); + assert_eq!(actual, expected); } else if !wildcard_match(expected, actual) { println!("OUTPUT START\n{actual}\nOUTPUT END"); println!("EXPECTED START\n{expected}\nEXPECTED END"); - panic!("pattern match failed at {}", failed_position()); + panic!("pattern match failed"); } self } + #[track_caller] fn inner_assert_matches_file( &self, actual: &str, @@ -620,21 +653,3 @@ impl TestCommandOutput { self.inner_assert_matches_text(actual, expected_text) } } - -fn failed_position() -> String { - let backtrace = Backtrace::new(); - - for frame in backtrace.frames() { - for symbol in frame.symbols() { - if let Some(filename) = symbol.filename() { - if !filename.to_string_lossy().ends_with("builders.rs") { - let line_num = symbol.lineno().unwrap_or(0); - let line_col = symbol.colno().unwrap_or(0); - return format!("{}:{}:{}", filename.display(), line_num, line_col); - } - } - } - } - - "<unknown>".to_string() -} diff --git a/test_util/src/lib.rs b/test_util/src/lib.rs index d4effd88b..b38d72cd9 100644 --- a/test_util/src/lib.rs +++ b/test_util/src/lib.rs @@ -16,6 +16,7 @@ use hyper::StatusCode; use lazy_static::lazy_static; use npm::CUSTOM_NPM_PACKAGE_CACHE; use pretty_assertions::assert_eq; +use pty::Pty; use regex::Regex; use rustls::Certificate; use rustls::PrivateKey; @@ -24,7 +25,6 @@ use std::collections::HashMap; use std::convert::Infallible; use std::env; use std::io; -use std::io::Read; use std::io::Write; use std::mem::replace; use std::net::SocketAddr; @@ -92,13 +92,8 @@ pub const PERMISSION_VARIANTS: [&str; 5] = pub const PERMISSION_DENIED_PATTERN: &str = "PermissionDenied"; lazy_static! { - // STRIP_ANSI_RE and strip_ansi_codes are lifted from the "console" crate. - // Copyright 2017 Armin Ronacher <armin.ronacher@active-4.com>. MIT License. - static ref STRIP_ANSI_RE: Regex = Regex::new( - r"[\x1b\x9b][\[()#;?]*(?:[0-9]{1,4}(?:;[0-9]{0,4})*)?[0-9A-PRZcf-nqry=><]" - ).unwrap(); - - static ref GUARD: Mutex<HttpServerCount> = Mutex::new(HttpServerCount::default()); + static ref GUARD: Mutex<HttpServerCount> = + Mutex::new(HttpServerCount::default()); } pub fn env_vars_for_npm_tests_no_sync_download() -> Vec<(String, String)> { @@ -1758,7 +1753,7 @@ pub fn http_server() -> HttpServerGuard { /// Helper function to strip ansi codes. pub fn strip_ansi_codes(s: &str) -> std::borrow::Cow<str> { - STRIP_ANSI_RE.replace_all(s, "") + console_static_text::ansi::strip_ansi_codes(s) } pub fn run( @@ -2171,82 +2166,8 @@ pub fn pattern_match(pattern: &str, s: &str, wildcard: &str) -> bool { t.1.is_empty() } -pub enum PtyData { - Input(&'static str), - Output(&'static str), -} - -pub fn test_pty2(args: &str, data: Vec<PtyData>) { - use std::io::BufRead; - - with_pty(&args.split_whitespace().collect::<Vec<_>>(), |console| { - let mut buf_reader = std::io::BufReader::new(console); - for d in data.iter() { - match d { - PtyData::Input(s) => { - println!("INPUT {}", s.escape_debug()); - buf_reader.get_mut().write_text(s); - - // Because of tty echo, we should be able to read the same string back. - assert!(s.ends_with('\n')); - let mut echo = String::new(); - buf_reader.read_line(&mut echo).unwrap(); - println!("ECHO: {}", echo.escape_debug()); - - // Windows may also echo the previous line, so only check the end - assert_ends_with!(normalize_text(&echo), normalize_text(s)); - } - PtyData::Output(s) => { - let mut line = String::new(); - if s.ends_with('\n') { - buf_reader.read_line(&mut line).unwrap(); - } else { - // assumes the buffer won't have overlapping virtual terminal sequences - while normalize_text(&line).len() < normalize_text(s).len() { - let mut buf = [0; 64 * 1024]; - let bytes_read = buf_reader.read(&mut buf).unwrap(); - assert!(bytes_read > 0); - let buf_str = std::str::from_utf8(&buf) - .unwrap() - .trim_end_matches(char::from(0)); - line += buf_str; - } - } - println!("OUTPUT {}", line.escape_debug()); - assert_eq!(normalize_text(&line), normalize_text(s)); - } - } - } - }); - - // This normalization function is not comprehensive - // and may need to updated as new scenarios emerge. - fn normalize_text(text: &str) -> String { - lazy_static! { - static ref MOVE_CURSOR_RIGHT_ONE_RE: Regex = - Regex::new(r"\x1b\[1C").unwrap(); - static ref FOUND_SEQUENCES_RE: Regex = - Regex::new(r"(\x1b\]0;[^\x07]*\x07)*(\x08)*(\x1b\[\d+X)*").unwrap(); - static ref CARRIAGE_RETURN_RE: Regex = - Regex::new(r"[^\n]*\r([^\n])").unwrap(); - } - - // any "move cursor right" sequences should just be a space - let text = MOVE_CURSOR_RIGHT_ONE_RE.replace_all(text, " "); - // replace additional virtual terminal sequences that strip ansi codes doesn't catch - let text = FOUND_SEQUENCES_RE.replace_all(&text, ""); - // strip any ansi codes, which also strips more terminal sequences - let text = strip_ansi_codes(&text); - // get rid of any text that is overwritten with only a carriage return - let text = CARRIAGE_RETURN_RE.replace_all(&text, "$1"); - // finally, trim surrounding whitespace - text.trim().to_string() - } -} - -pub fn with_pty(deno_args: &[&str], mut action: impl FnMut(Box<dyn pty::Pty>)) { - if !atty::is(atty::Stream::Stdin) || !atty::is(atty::Stream::Stderr) { - eprintln!("Ignoring non-tty environment."); +pub fn with_pty(deno_args: &[&str], mut action: impl FnMut(Pty)) { + if !Pty::is_supported() { return; } @@ -2257,14 +2178,12 @@ pub fn with_pty(deno_args: &[&str], mut action: impl FnMut(Box<dyn pty::Pty>)) { "DENO_DIR".to_string(), deno_dir.path().to_string_lossy().to_string(), ); - let pty = pty::create_pty( - &deno_exe_path().to_string_lossy().to_string(), + action(Pty::new( + &deno_exe_path(), deno_args, - testdata_path(), + &testdata_path(), Some(env_vars), - ); - - action(pty); + )) } pub struct WrkOutput { diff --git a/test_util/src/pty.rs b/test_util/src/pty.rs index f3bb2829f..80d06881e 100644 --- a/test_util/src/pty.rs +++ b/test_util/src/pty.rs @@ -1,36 +1,253 @@ // Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. use std::collections::HashMap; +use std::collections::HashSet; use std::io::Read; +use std::io::Write; use std::path::Path; +use std::time::Duration; +use std::time::Instant; + +use crate::strip_ansi_codes; + +/// Points to know about when writing pty tests: +/// +/// - Consecutive writes cause issues where you might write while a prompt +/// is not showing. So when you write, always `.expect(...)` on the output. +/// - Similar to the last point, using `.expect(...)` can help make the test +/// more deterministic. If the test is flaky, try adding more `.expect(...)`s +pub struct Pty { + pty: Box<dyn SystemPty>, + read_bytes: Vec<u8>, + last_index: usize, +} + +impl Pty { + pub fn new( + program: &Path, + args: &[&str], + cwd: &Path, + env_vars: Option<HashMap<String, String>>, + ) -> Self { + let pty = create_pty(program, args, cwd, env_vars); + let mut pty = Self { + pty, + read_bytes: Vec::new(), + last_index: 0, + }; + if args[0] == "repl" && !args.contains(&"--quiet") { + // wait for the repl to start up before writing to it + pty.expect("exit using ctrl+d, ctrl+c, or close()"); + } + pty + } + + pub fn is_supported() -> bool { + let is_mac_or_windows = cfg!(target_os = "macos") || cfg!(windows); + if is_mac_or_windows && std::env::var("CI").is_ok() { + // the pty tests give a ENOTTY error for Mac and don't really start up + // on the windows CI for some reason so ignore them for now + eprintln!("Ignoring windows CI."); + false + } else { + true + } + } + + #[track_caller] + pub fn write_raw(&mut self, line: impl AsRef<str>) { + let line = if cfg!(windows) { + line.as_ref().replace('\n', "\r\n") + } else { + line.as_ref().to_string() + }; + if let Err(err) = self.pty.write(line.as_bytes()) { + panic!("{:#}", err) + } + self.pty.flush().unwrap(); + } + + #[track_caller] + pub fn write_line(&mut self, line: impl AsRef<str>) { + self.write_line_raw(&line); + + // expect what was written to show up in the output + // due to "pty echo" + for line in line.as_ref().lines() { + self.expect(line); + } + } + + /// Writes a line without checking if it's in the output. + #[track_caller] + pub fn write_line_raw(&mut self, line: impl AsRef<str>) { + self.write_raw(format!("{}\n", line.as_ref())); + } + + #[track_caller] + pub fn read_until(&mut self, end_text: impl AsRef<str>) -> String { + self.read_until_with_advancing(|text| { + text + .find(end_text.as_ref()) + .map(|index| index + end_text.as_ref().len()) + }) + } + + #[track_caller] + pub fn expect(&mut self, text: impl AsRef<str>) { + self.read_until(text.as_ref()); + } + + #[track_caller] + pub fn expect_any(&mut self, texts: &[&str]) { + self.read_until_with_advancing(|text| { + for find_text in texts { + if let Some(index) = text.find(find_text) { + return Some(index); + } + } + None + }); + } + + /// Consumes and expects to find all the text until a timeout is hit. + #[track_caller] + pub fn expect_all(&mut self, texts: &[&str]) { + let mut pending_texts: HashSet<&&str> = HashSet::from_iter(texts); + let mut max_index: Option<usize> = None; + self.read_until_with_advancing(|text| { + for pending_text in pending_texts.clone() { + if let Some(index) = text.find(pending_text) { + let index = index + pending_text.len(); + match &max_index { + Some(current) => { + if *current < index { + max_index = Some(index); + } + } + None => { + max_index = Some(index); + } + } + pending_texts.remove(pending_text); + } + } + if pending_texts.is_empty() { + max_index + } else { + None + } + }); + } + + /// Expects the raw text to be found, which may include ANSI codes. + /// Note: this expects the raw bytes in any output that has already + /// occurred or may occur within the next few seconds. + #[track_caller] + pub fn expect_raw_in_current_output(&mut self, text: impl AsRef<str>) { + self.read_until_condition(|pty| { + let data = String::from_utf8_lossy(&pty.read_bytes); + data.contains(text.as_ref()) + }); + } + + #[track_caller] + fn read_until_with_advancing( + &mut self, + mut condition: impl FnMut(&str) -> Option<usize>, + ) -> String { + let mut final_text = String::new(); + self.read_until_condition(|pty| { + let text = pty.next_text(); + if let Some(end_index) = condition(&text) { + pty.last_index += end_index; + final_text = text[..end_index].to_string(); + true + } else { + false + } + }); + final_text + } + + #[track_caller] + fn read_until_condition( + &mut self, + mut condition: impl FnMut(&mut Self) -> bool, + ) { + let timeout_time = + Instant::now().checked_add(Duration::from_secs(5)).unwrap(); + while Instant::now() < timeout_time { + self.fill_more_bytes(); + if condition(self) { + return; + } + } -pub trait Pty: Read { - fn write_text(&mut self, text: &str); + let text = self.next_text(); + eprintln!( + "------ Start Full Text ------\n{:?}\n------- End Full Text -------", + String::from_utf8_lossy(&self.read_bytes) + ); + eprintln!("Next text: {:?}", text); + panic!("Timed out.") + } - fn write_line(&mut self, text: &str) { - self.write_text(&format!("{text}\n")); + fn next_text(&self) -> String { + let text = String::from_utf8_lossy(&self.read_bytes).to_string(); + let text = strip_ansi_codes(&text); + text[self.last_index..].to_string() } - /// Reads the output to the EOF. - fn read_all_output(&mut self) -> String { - let mut text = String::new(); - self.read_to_string(&mut text).unwrap(); - text + fn fill_more_bytes(&mut self) { + let mut buf = [0; 256]; + if let Ok(count) = self.pty.read(&mut buf) { + self.read_bytes.extend(&buf[..count]); + } else { + std::thread::sleep(Duration::from_millis(10)); + } } } +trait SystemPty: Read + Write {} + #[cfg(unix)] -pub fn create_pty( - program: impl AsRef<Path>, +fn setup_pty(master: &pty2::fork::Master) { + use nix::fcntl::fcntl; + use nix::fcntl::FcntlArg; + use nix::fcntl::OFlag; + use nix::sys::termios; + use nix::sys::termios::tcgetattr; + use nix::sys::termios::tcsetattr; + use nix::sys::termios::SetArg; + use std::os::fd::AsRawFd; + + let fd = master.as_raw_fd(); + let mut term = tcgetattr(fd).unwrap(); + // disable cooked mode + term.local_flags.remove(termios::LocalFlags::ICANON); + tcsetattr(fd, SetArg::TCSANOW, &term).unwrap(); + + // turn on non-blocking mode so we get timeouts + let flags = fcntl(fd, FcntlArg::F_GETFL).unwrap(); + let new_flags = OFlag::from_bits_truncate(flags) | OFlag::O_NONBLOCK; + fcntl(fd, FcntlArg::F_SETFL(new_flags)).unwrap(); +} + +#[cfg(unix)] +fn create_pty( + program: &Path, args: &[&str], - cwd: impl AsRef<Path>, + cwd: &Path, env_vars: Option<HashMap<String, String>>, -) -> Box<dyn Pty> { +) -> Box<dyn SystemPty> { let fork = pty2::fork::Fork::from_ptmx().unwrap(); if fork.is_parent().is_ok() { + let master = fork.is_parent().unwrap(); + setup_pty(&master); Box::new(unix::UnixPty { fork }) } else { - std::process::Command::new(program.as_ref()) + std::process::Command::new(program) .current_dir(cwd) .args(args) .envs(env_vars.unwrap_or_default()) @@ -47,7 +264,7 @@ mod unix { use std::io::Read; use std::io::Write; - use super::Pty; + use super::SystemPty; pub struct UnixPty { pub fork: pty2::fork::Fork, @@ -55,46 +272,55 @@ mod unix { impl Drop for UnixPty { fn drop(&mut self) { - self.fork.wait().unwrap(); - } - } + use nix::sys::signal::kill; + use nix::sys::signal::Signal; + use nix::unistd::Pid; - impl Pty for UnixPty { - fn write_text(&mut self, text: &str) { - let mut master = self.fork.is_parent().unwrap(); - master.write_all(text.as_bytes()).unwrap(); + if let pty2::fork::Fork::Parent(child_pid, _) = self.fork { + let pid = Pid::from_raw(child_pid); + kill(pid, Signal::SIGTERM).unwrap() + } } } + impl SystemPty for UnixPty {} + impl Read for UnixPty { fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { let mut master = self.fork.is_parent().unwrap(); master.read(buf) } } + + impl Write for UnixPty { + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { + let mut master = self.fork.is_parent().unwrap(); + master.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + let mut master = self.fork.is_parent().unwrap(); + master.flush() + } + } } #[cfg(target_os = "windows")] -pub fn create_pty( - program: impl AsRef<Path>, +fn create_pty( + program: &Path, args: &[&str], - cwd: impl AsRef<Path>, + cwd: &Path, env_vars: Option<HashMap<String, String>>, -) -> Box<dyn Pty> { - let pty = windows::WinPseudoConsole::new( - program, - args, - &cwd.as_ref().to_string_lossy(), - env_vars, - ); +) -> Box<dyn SystemPty> { + let pty = windows::WinPseudoConsole::new(program, args, cwd, env_vars); Box::new(pty) } #[cfg(target_os = "windows")] mod windows { use std::collections::HashMap; + use std::io::ErrorKind; use std::io::Read; - use std::io::Write; use std::path::Path; use std::ptr; use std::time::Duration; @@ -105,11 +331,13 @@ mod windows { use winapi::shared::winerror::S_OK; use winapi::um::consoleapi::ClosePseudoConsole; use winapi::um::consoleapi::CreatePseudoConsole; + use winapi::um::fileapi::FlushFileBuffers; use winapi::um::fileapi::ReadFile; use winapi::um::fileapi::WriteFile; use winapi::um::handleapi::DuplicateHandle; use winapi::um::handleapi::INVALID_HANDLE_VALUE; use winapi::um::namedpipeapi::CreatePipe; + use winapi::um::namedpipeapi::PeekNamedPipe; use winapi::um::processthreadsapi::CreateProcessW; use winapi::um::processthreadsapi::DeleteProcThreadAttributeList; use winapi::um::processthreadsapi::GetCurrentProcess; @@ -127,7 +355,7 @@ mod windows { use winapi::um::winnt::DUPLICATE_SAME_ACCESS; use winapi::um::winnt::HANDLE; - use super::Pty; + use super::SystemPty; macro_rules! assert_win_success { ($expression:expr) => { @@ -138,6 +366,15 @@ mod windows { }; } + macro_rules! handle_err { + ($expression:expr) => { + let success = $expression; + if success != TRUE { + return Err(std::io::Error::last_os_error()); + } + }; + } + pub struct WinPseudoConsole { stdin_write_handle: WinHandle, stdout_read_handle: WinHandle, @@ -149,9 +386,9 @@ mod windows { impl WinPseudoConsole { pub fn new( - program: impl AsRef<Path>, + program: &Path, args: &[&str], - cwd: &str, + cwd: &Path, maybe_env_vars: Option<HashMap<String, String>>, ) -> Self { // https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session @@ -184,15 +421,19 @@ mod windows { let mut proc_info: PROCESS_INFORMATION = std::mem::zeroed(); let command = format!( "\"{}\" {}", - program.as_ref().to_string_lossy(), - args.join(" ") + program.to_string_lossy(), + args + .iter() + .map(|a| format!("\"{}\"", a)) + .collect::<Vec<_>>() + .join(" ") ) .trim() .to_string(); - let mut application_str = - to_windows_str(&program.as_ref().to_string_lossy()); + let mut application_str = to_windows_str(&program.to_string_lossy()); let mut command_str = to_windows_str(&command); - let mut cwd = to_windows_str(cwd); + let cwd = cwd.to_string_lossy().replace('/', "\\"); + let mut cwd = to_windows_str(&cwd); assert_win_success!(CreateProcessW( application_str.as_mut_ptr(), @@ -242,45 +483,47 @@ mod windows { impl Read for WinPseudoConsole { fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { - loop { - let mut bytes_read = 0; - // SAFETY: - // winapi call - let success = unsafe { - ReadFile( - self.stdout_read_handle.as_raw_handle(), - buf.as_mut_ptr() as _, - buf.len() as u32, - &mut bytes_read, - ptr::null_mut(), - ) - }; - - // ignore zero-byte writes - let is_zero_byte_write = bytes_read == 0 && success == TRUE; - if !is_zero_byte_write { - return Ok(bytes_read as usize); - } + // don't do a blocking read in order to support timing out + let mut bytes_available = 0; + // SAFETY: winapi call + handle_err!(unsafe { + PeekNamedPipe( + self.stdout_read_handle.as_raw_handle(), + ptr::null_mut(), + 0, + ptr::null_mut(), + &mut bytes_available, + ptr::null_mut(), + ) + }); + if bytes_available == 0 { + return Err(std::io::Error::new(ErrorKind::WouldBlock, "Would block.")); } - } - } - impl Pty for WinPseudoConsole { - fn write_text(&mut self, text: &str) { - // windows pseudo console requires a \r\n to do a newline - let newline_re = regex::Regex::new("\r?\n").unwrap(); - self - .write_all(newline_re.replace_all(text, "\r\n").as_bytes()) - .unwrap(); + let mut bytes_read = 0; + // SAFETY: winapi call + handle_err!(unsafe { + ReadFile( + self.stdout_read_handle.as_raw_handle(), + buf.as_mut_ptr() as _, + buf.len() as u32, + &mut bytes_read, + ptr::null_mut(), + ) + }); + + Ok(bytes_read as usize) } } + impl SystemPty for WinPseudoConsole {} + impl std::io::Write for WinPseudoConsole { fn write(&mut self, buffer: &[u8]) -> std::io::Result<usize> { let mut bytes_written = 0; // SAFETY: // winapi call - assert_win_success!(unsafe { + handle_err!(unsafe { WriteFile( self.stdin_write_handle.as_raw_handle(), buffer.as_ptr() as *const _, @@ -293,6 +536,10 @@ mod windows { } fn flush(&mut self) -> std::io::Result<()> { + // SAFETY: winapi call + handle_err!(unsafe { + FlushFileBuffers(self.stdin_write_handle.as_raw_handle()) + }); Ok(()) } } @@ -307,12 +554,10 @@ mod windows { } pub fn duplicate(&self) -> WinHandle { - // SAFETY: - // winapi call + // SAFETY: winapi call let process_handle = unsafe { GetCurrentProcess() }; let mut duplicate_handle = ptr::null_mut(); - // SAFETY: - // winapi call + // SAFETY: winapi call assert_win_success!(unsafe { DuplicateHandle( process_handle, @@ -410,8 +655,7 @@ mod windows { impl Drop for ProcThreadAttributeList { fn drop(&mut self) { - // SAFETY: - // winapi call + // SAFETY: winapi call unsafe { DeleteProcThreadAttributeList(self.as_mut_ptr()) }; } } @@ -420,8 +664,7 @@ mod windows { let mut read_handle = std::ptr::null_mut(); let mut write_handle = std::ptr::null_mut(); - // SAFETY: - // Creating an anonymous pipe with winapi. + // SAFETY: Creating an anonymous pipe with winapi. assert_win_success!(unsafe { CreatePipe(&mut read_handle, &mut write_handle, ptr::null_mut(), 0) }); |