Search code examples
pythonwindowssshparamikoopenssh

Gracefully abort remote Windows command executed over SSH from Windows Python Paramiko script when Ctrl+C is pressed


I have a follow up question that builds off the question I asked here: Run multiple commands in different SSH servers in parallel using Python Paramiko, which was already answered.

Thanks to the answer on the link above, my python script is as follows:

# SSH.py

import paramiko
import argparse
import os

path = "path"
python_script = "worker.py"

# definitions for ssh connection and cluster
ip_list = ['XXX.XXX.XXX.XXX', 'XXX.XXX.XXX.XXX', 'XXX.XXX.XXX.XXX']
port_list = [':XXXX', ':XXXX', ':XXXX']
user_list = ['user', 'user', 'user']
password_list = ['pass', 'pass', 'pass']
node_list = list(map(lambda x: f'-node{x + 1} ', list(range(len(ip_list)))))
cluster = ' '.join([node + ip + port for node, ip, port in zip(node_list, ip_list, port_list)])

# run script on command line of local machine
os.system(f"cd {path} && python {python_script} {cluster} -type worker -index 0 -batch 64 > {path}/logs/'command output'/{ip_list[0]}.log 2>&1")

# loop for IP and password
stdouts = []

clients = []

for i, (ip, user, password) in enumerate(zip(ip_list[1:], user_list[1:], password_list[1:]), 1):
    try:
        print("Open session in: " + ip + "...")
        client = paramiko.SSHClient()
        client.connect(ip, user, password)
    except paramiko.SSHException:
        print("Connection Failed")
        quit()

    try:
        path = f"C:/Users/{user}/Desktop/temp-ines"

        stdin, stdout, stderr = ssh.exec_command(
    
              f"cd {path} && python {python_script} {cluster} -type worker -index {i} -batch          64>"

              f"C:/Users/{user}/Desktop/{ip}.log 2>&1 &"
)
        
clients.append(ssh)

        stdouts.append(stdout)
    except paramiko.SSHException:
        print("Cannot run file. Continue with other IPs in list...")
        client.close()
        continue

# Wait for commands to complete

for i in range(len(stdouts)):
 
        print("hello")
        stdouts[i].read()
 
        print("hello1")   
        clients[i].close()
        print('hello2")

print("\n\n***********************End execution***********************\n\n")

This script, which is run locally, is able to SSH into the servers and run the command (i.e., run a python script called worker.py and log the command output to a log file). I.e., it is able to go through the first for loop with no issues.

My issue is related to the second for loop. Please see the print statements I added in the second for loop to be clear. When I run SSH.py locally, this is what I observe:

As you can see, I ssh into each of the servers and then stay at reading the command output of the first server I ssh over to. The worker.py script can take 30 mins or so to complete and the command output is the same on each server -- so it will take 30 mins to read the command output of the first server, then close the SSH connection of the first server, take a couple seconds to read the command output of the second server (as it is the same as the first one and would already be entirely printed), close its SSH connection, and so on. Please see below some of the command line output, if this helps.

Now, my question is, what if I don't want to wait until the worker.py script finishes, i.e., those entire 30 mins? I cannot/do not know how to raise a KeyboardInterrupt exception. What I have tried is quitting the local SSH.py script. However, as you can see from the print statements, this will not close the SSH connections although the training, and thus the log files, will stop logging info. In addition, after I quit the local SSH.py script, if I try to delete any of the log files, I get an error saying "cannot delete file because it is being used in cmd.exe" -- this only happens sometimes and I believe it is because of not closing the SSH connections?


First run in python console:

enter image description here

It hangs: Local python and log file running and saving but no print statements and no python and log file being run/saved in servers.

I run it again so second process starts:

enter image description here

Now, the first process doesn't hang anymore (python running and log files being saved in server). And can close this second run/process. It is like the second run/process helps with the hang of the first run/process.

If I were to run python SSH.py in the terminal it would just hang.

This was not happening before.


Solution

  • If you know that SSHClient.close cleanly close the connection and abort the remote command, call it on response to KeyboardInterrupt.

    For this you cannot use the simple solution with stdout.read, as it blocks and prevents handling of the Ctrl+C on Windows.

    try:
        while any(x is not None for x in stdouts):
            for i in range(len(stdouts)):
                stdout = stdouts[i]
                if stdout is not None:
                    channel = stdout.channel
                    # To prevent losing output at the end, first test for exit,
                    # then for output
                    exited = channel.exit_status_ready()
                    while channel.recv_ready():
                        s = channel.recv(1024).decode('utf8')
                        print(f"#{i} stdout: {s}")
                    while channel.recv_stderr_ready():
                        s = channel.recv_stderr(1024).decode('utf8')
                        print(f"#{i} stderr: {s}")
                    if exited:
                        print(f"#{i} done")
                        clients[i].close()
                        stdouts[i] = None
            time.sleep(0.1)
    except (KeyboardInterrupt):
        print("Aborting")
        for i in range(len(clients)):
            print(f"#{i} closing")
            clients[i].close()
    

    If you do not need to separate the stdout and stderr, you can greatly simplify the code by using Channel.set_combine_stderr. See Paramiko ssh die/hang with big output.