Search code examples
rustssh-tunnel

How to create an SSH tunnel with russh that supports multiple connections?


I am working with Rust and the russh crate to implement an SSH tunnel for accessing a remote server via a local listening port.

Here’s my current implementation code:

let mut ssh_client= russh::client::connect(
    Arc::new(Config::default()),
    format!("{}:{}", self.host, self.port),
    IHandler {},
).await?;

ssh_client.authenticate_password(self.username.clone(), self.password.clone()).await?;

let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?;
let addr = listener.local_addr()?;

let channel = ssh_client.channel_open_direct_tcpip(
    self.forwarding_host.clone(),
    self.forwarding_port as u32,
    Ipv4Addr::LOCALHOST.to_string(),
    addr.port() as u32,
).await?;

let mut remote_stream = channel.into_stream();
tokio::spawn(async move {
    loop {
        if let Ok((mut local_stream, _)) = listener.accept().await {
            tokio::spawn(async move {
                select! {
                    result = tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024) => {
                        // close 
                }
            });
        }
        if rx_clone.changed().await.is_ok() {
            break;
        }
    }
    drop(listener);
    Ok::<(), Error>(())
});

Issue:

The code above successfully establishes an SSH tunnel through the local TcpListener port, but it has a limitation: the client can only create one connection to this port, and subsequent connection attempts will be blocked. Ideally, I want to support multiple connections(to the TcpListener).

My environment:

  • russh = "0.45.0"
  • tokio = "1"

Is there a way to make russh support multiple connections? Or how can I modify the code to achieve this requirement?

A single SSH-tunnel supports multiple connections.


Solution

  • According to @maxy 's suggestion of

    calling channel_open_direct_tcpip() for every connection.

    I moved the behavior of opening the channel from the SSH client into the TCP listener's accept code block. Additionally, I modified the exit signal listening method from rx_clone.changed().await.is_ok() to rx_clone.has_changed()?.

    The final code is as follows:

    use anyhow::{Error, Result};
    use async_trait::async_trait;
    use log::{error, info, warn};
    use russh::client::{Config, Handler, Msg, Session};
    use russh::keys::key;
    use russh::{Channel, ChannelId, Disconnect};
    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
    use std::sync::Arc;
    use tokio::io::AsyncWriteExt;
    use tokio::net::TcpListener;
    use tokio::select;
    
    #[derive(Clone, Debug)]
    pub struct SshTunnel {
        pub host: String,
        pub port: u16,
        pub username: String,
        pub password: String,
        pub forwarding_host: String,
        pub forwarding_port: u16,
        tx: tokio::sync::watch::Sender<u8>,
        rx: tokio::sync::watch::Receiver<u8>,
        is_connected: bool,
    }
    
    impl SshTunnel {
        pub fn new(host: String, port: u16, username: String, password: String, forwarding_host: String, forwarding_port: u16) -> Self {
            let (tx, rx) = tokio::sync::watch::channel::<u8>(1);
            Self {
                host,
                port,
                username,
                password,
                forwarding_host,
                forwarding_port,
                tx,
                rx,
                is_connected: false,
            }
        }
    
        pub async fn open(&mut self) -> Result<SocketAddr> {
            let mut ssh_client = russh::client::connect(
                Arc::new(Config::default()),
                format!("{}:{}", self.host, self.port),
                IHandler {},
            ).await?;
            ssh_client.authenticate_password(self.username.clone(), self.password.clone()).await?;
            let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?;
            let addr = listener.local_addr()?;
            let forwarding_host = self.forwarding_host.clone();
            let forwarding_port = self.forwarding_port as u32;
    
            let mut rx_clone = self.rx.clone();
            tokio::spawn(async move {
                loop {
                    let mut rx_clone_clone = rx_clone.clone();
                    if let Ok((mut local_stream, _)) = listener.accept().await {
                        let channel = ssh_client.channel_open_direct_tcpip(
                            forwarding_host.clone(),
                            forwarding_port,
                            Ipv4Addr::LOCALHOST.to_string(),
                            addr.port() as u32,
                        ).await?;
                        let mut remote_stream = channel.into_stream();
                        tokio::spawn(async move {
                            select! {
                                result = tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024) => {
                                    if let Err(e) = result {
                                        error!("Error during bidirectional copy: {}", e);
                                    }
                                    warn!("Bidirectional copy stopped");
                                }
                                _ = rx_clone_clone.changed() => {
                                    info!("Received close signal");
                                }
                            }
                            let _ = remote_stream.shutdown().await;
                        });
                    }
                    if rx_clone.has_changed()? {
                        ssh_client.disconnect(Disconnect::ByApplication, "exit", "none").await?;
                        break;
                    }
                }
                drop(listener);
                info!("Stream closed");
                Ok::<(), Error>(())
            });
    
            self.is_connected = true;
            Ok(addr)
        }
    
        pub async fn close(&mut self) -> Result<()> {
            self.tx.send(0)?;
            self.is_connected = false;
            Ok(())
        }
    
        pub fn is_connected(&self) -> bool {
            self.is_connected
        }
    }
    
    struct IHandler;
    
    #[async_trait]
    impl Handler for IHandler {
        type Error = Error;
        async fn check_server_key(&mut self, _: &key::PublicKey) -> Result<bool, Self::Error> {
            Ok(true)
        }
    }