Search code examples
sslrustfile-descriptorstdio

Redirect stdio over TLS in Rust


I am trying to replicate the "-e" option in ncat to redirect stdio in Rust to a remote ncat listener.

I can do it over TcpStream by using dup2 and then executing the "/bin/sh" command in Rust. However, I do not know how to do it over TLS as redirection seems to require file descriptors, which TlsStream does not seem to provide.

Can anyone advise on this?

EDIT 2 Nov 2020

Someone in the Rust forum has kindly shared a solution with me (https://users.rust-lang.org/t/redirect-stdio-pipes-and-file-descriptors/50751/8) and now I am trying to work on how to redirect the stdio over the TLS connection.

let mut command_output = std::process::Command::new("/bin/sh")
    .stdin(Stdio::piped())
    .stdout(Stdio::piped())
    .stderr(Stdio::piped())
    .spawn()
    .expect("cannot execute command");

let mut command_stdin = command_output.stdin.unwrap();
println!("command_stdin {}", command_stdin.as_raw_fd());

let copy_stdin_thread = std::thread::spawn(move || {
    io::copy(&mut io::stdin(), &mut command_stdin)
});
        
let mut command_stdout = command_output.stdout.unwrap();
println!("command_stdout {}", command_stdout.as_raw_fd());

let copy_stdout_thread = std::thread::spawn(move || {
   io::copy(&mut command_stdout, &mut io::stdout())
});

let command_stderr = command_output.stderr.unwrap();
println!("command_stderr {}", command_stderr.as_raw_fd());

let copy_stderr_thread = std::thread::spawn(move || {
    io::copy(&mut command_stderr, &mut io::stderr())
});

copy_stdin_thread.join().unwrap()?;
copy_stdout_thread.join().unwrap()?;
copy_stderr_thread.join().unwrap()?;

Solution

  • This question and this answer are not specific to Rust.

    You noticed the important fact that the I/O of the redirected process must be file descriptors. One possible solution in your application is

    • use socketpair(PF_LOCAL, SOCK_STREAM, 0, fd)
      • this provides two connected bidirectional file descriptors
    • use dup2() on one end of this socketpair for the I/O of the redirected process (as you would do with an unencrypted TCP stream)
    • watch both the other end and the TLS stream (in a select()-like manner for example) in order to
      • receive what becomes available from the socketpair and send it to the TLS stream,
      • receive what becomes available from the TLS stream and send it to the socketpair.

    Note that select() on a TLS stream (its underlying file descriptor, actually) is a bit tricky because some bytes may already have been received (on its underlying file descriptor) and decrypted in the internal buffer while not yet consumed by the application. You have to ask the TSL stream if its reception buffer is empty before trying a new select() on it. Using an asynchronous or threaded solution for this watch/recv/send loop is probably easier than relying on a select()-like solution.


    edit, after the edition in the question

    Since you have now a solution relying on three distinct pipes you can forget everything about socketpair().

    The invocation of std::io::copy() in each thread of your example is a simple loop that receives some bytes from its first parameter and sends them to the second. Your TlsStream is probably a single structure performing all the encrypted I/O operations (sending as well as receiving) thus you will not be able to provide a &mut reference on it to your multiple threads.

    The best is probably to write your own loop trying to detect new incoming bytes and then dispatch them to the appropriate destination. As explained ebove, I would use select() for that. Unfortunately in Rust, as far as I know, we have to rely on low-level features as libc for that (there may be other high level solutions I am not aware of in the async world...).

    I produced a (not so) minimal example below in order to show the main idea; it is certainly far from being perfect, so « handle with care » ;^) (it relies on native-tls and libc)

    Accessing it from openssl gives this

    $ openssl s_client -connect localhost:9876
    CONNECTED(00000003)
    Can't use SSL_get_servername
    ...
        Extended master secret: yes
    ---
    hello
    /bin/sh: line 1: hello: command not found
    df
    Filesystem     1K-blocks      Used Available Use% Mounted on
    dev              4028936         0   4028936   0% /dev
    run              4038472      1168   4037304   1% /run
    /dev/sda5       30832548  22074768   7168532  76% /
    tmpfs            4038472    234916   3803556   6% /dev/shm
    tmpfs               4096         0      4096   0% /sys/fs/cgroup
    tmpfs            4038472         4   4038468   1% /tmp
    /dev/sda6      338368556 219588980 101568392  69% /home
    tmpfs             807692        56    807636   1% /run/user/9223
    exit
    read:errno=0
    
    fn main() {
        let args: Vec<_> = std::env::args().collect();
        let use_simple = args.len() == 2 && args[1] == "s";
    
        let mut file = std::fs::File::open("server.pfx").unwrap();
        let mut identity = vec![];
        use std::io::Read;
        file.read_to_end(&mut identity).unwrap();
        let identity =
            native_tls::Identity::from_pkcs12(&identity, "dummy").unwrap();
    
        let listener = std::net::TcpListener::bind("0.0.0.0:9876").unwrap();
        let acceptor = native_tls::TlsAcceptor::new(identity).unwrap();
        let acceptor = std::sync::Arc::new(acceptor);
    
        for stream in listener.incoming() {
            match stream {
                Ok(stream) => {
                    let acceptor = acceptor.clone();
                    std::thread::spawn(move || {
                        let stream = acceptor.accept(stream).unwrap();
                        if use_simple {
                            simple_client(stream);
                        } else {
                            redirect_shell(stream);
                        }
                    });
                }
                Err(_) => {
                    println!("accept failure");
                    break;
                }
            }
        }
    }
    
    fn simple_client(mut stream: native_tls::TlsStream<std::net::TcpStream>) {
        let mut buffer = [0_u8; 100];
        let mut count = 0;
        loop {
            use std::io::Read;
            if let Ok(sz_r) = stream.read(&mut buffer) {
                if sz_r == 0 {
                    println!("EOF");
                    break;
                }
                println!(
                    "received <{}>",
                    std::str::from_utf8(&buffer[0..sz_r]).unwrap_or("???")
                );
                let reply = format!("message {} is {} bytes long\n", count, sz_r);
                count += 1;
                use std::io::Write;
                if stream.write_all(reply.as_bytes()).is_err() {
                    println!("write failure");
                    break;
                }
            } else {
                println!("read failure");
                break;
            }
        }
    }
    
    fn redirect_shell(mut stream: native_tls::TlsStream<std::net::TcpStream>) {
        // start child process
        let mut child = std::process::Command::new("/bin/sh")
            .stdin(std::process::Stdio::piped())
            .stdout(std::process::Stdio::piped())
            .stderr(std::process::Stdio::piped())
            .spawn()
            .expect("cannot execute command");
        // access useful I/O and file descriptors
        let stdin = child.stdin.as_mut().unwrap();
        let stdout = child.stdout.as_mut().unwrap();
        let stderr = child.stderr.as_mut().unwrap();
        use std::os::unix::io::AsRawFd;
        let stream_fd = stream.get_ref().as_raw_fd();
        let stdout_fd = stdout.as_raw_fd();
        let stderr_fd = stderr.as_raw_fd();
        // main send/recv loop
        use std::io::{Read, Write};
        let mut buffer = [0_u8; 100];
        loop {
            // no need to wait for new incoming bytes on tcp-stream
            // if some are already decoded in the tls-stream
            let already_buffered = match stream.buffered_read_size() {
                Ok(sz) if sz > 0 => true,
                _ => false,
            };
            // prepare file descriptors to be watched for by select()
            let mut fdset =
                unsafe { std::mem::MaybeUninit::uninit().assume_init() };
            let mut max_fd = -1;
            unsafe { libc::FD_ZERO(&mut fdset) };
            unsafe { libc::FD_SET(stdout_fd, &mut fdset) };
            max_fd = std::cmp::max(max_fd, stdout_fd);
            unsafe { libc::FD_SET(stderr_fd, &mut fdset) };
            max_fd = std::cmp::max(max_fd, stderr_fd);
            if !already_buffered {
                // see above
                unsafe { libc::FD_SET(stream_fd, &mut fdset) };
                max_fd = std::cmp::max(max_fd, stream_fd);
            }
            // block this thread until something new happens
            // on these file-descriptors (don't wait if some bytes
            // are already decoded in the tls-stream)
            let mut zero_timeout =
                unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
            unsafe {
                libc::select(
                    max_fd + 1,
                    &mut fdset,
                    std::ptr::null_mut(),
                    std::ptr::null_mut(),
                    if already_buffered {
                        &mut zero_timeout
                    } else {
                        std::ptr::null_mut()
                    },
                )
            };
            // this thread is not blocked any more,
            // try to handle what happened on the file descriptors
            if unsafe { libc::FD_ISSET(stdout_fd, &mut fdset) } {
                // something new happened on stdout,
                // try to receive some bytes an send them through the tls-stream
                if let Ok(sz_r) = stdout.read(&mut buffer) {
                    if sz_r == 0 {
                        println!("EOF detected on stdout");
                        break;
                    }
                    if stream.write_all(&buffer[0..sz_r]).is_err() {
                        println!("write failure on tls-stream");
                        break;
                    }
                } else {
                    println!("read failure on process stdout");
                    break;
                }
            }
            if unsafe { libc::FD_ISSET(stderr_fd, &mut fdset) } {
                // something new happened on stderr,
                // try to receive some bytes an send them through the tls-stream
                if let Ok(sz_r) = stderr.read(&mut buffer) {
                    if sz_r == 0 {
                        println!("EOF detected on stderr");
                        break;
                    }
                    if stream.write_all(&buffer[0..sz_r]).is_err() {
                        println!("write failure on tls-stream");
                        break;
                    }
                } else {
                    println!("read failure on process stderr");
                    break;
                }
            }
            if already_buffered
                || unsafe { libc::FD_ISSET(stream_fd, &mut fdset) }
            {
                // something new happened on the tls-stream
                // (or some bytes were already buffered),
                // try to receive some bytes an send them on stdin
                if let Ok(sz_r) = stream.read(&mut buffer) {
                    if sz_r == 0 {
                        println!("EOF detected on tls-stream");
                        break;
                    }
                    if stdin.write_all(&buffer[0..sz_r]).is_err() {
                        println!("write failure on stdin");
                        break;
                    }
                } else {
                    println!("read failure on tls-stream");
                    break;
                }
            }
        }
        let _ = child.wait();
    }