Search code examples
csocketssslopensslnonblocking

How to use SSL_read() and SSL_accept() on non blocking


I have a client/server communication that needs to be upgraded to SSL communication. Currently I have a network socket that sends and receives tcp data.

  1. Client issues a tcp connect() to the server.
  2. Server has the accept part already implemented , after accepting the connection the server enters a select loop, awaiting further op.

what I have tried:

After the tcp connect() is complete, the fd I get is what I use for openssl SSL_set_fd(ssl,fd). While on the server side after network accept I am upgrading the connection to non blocking and doing SSL_accept , calling SSL_connect() client side which is successful ( I have taken care of all the certificate and other things needed).

What I need to undertand:

  1. SSL_accept() returns -1 with SSL_ERROR_WANT_READ, few inputs I received suggested me to put it in a 'while' loop, waiting for the ssl accept to complete. which solves it by eventually looping over many times on the ssl_accept. The confusion is if I should loop here or go back to select loop. Going back to select loop I see that the network select pops instantly back probably due to some data present due to ssl_connect. Is this the right way?

  2. I have sent X bytes on client side with SSL_write(), which is successfully sent, when the select pops with read, I am doing an SSL_read() which does the read operation but it is less than X, so I iterate again over SSL_read() only to see 0 bytes returned in successive reads. Same question here how long should I loop aroung SSL_read() and do I have to, or got to select and wait out.

  3. what happens if I pass SSL_read(ssl,buf,bytes) bytes more than what was received, how to handle that

I have tried SSL_pending() after first read but it always returns 0. where as there is obviously missing data. Client code

    ssl = SSL_new(ctx);
    SSL_set_fd(ssl, fd);

    if (SSL_connect(ssl) < 0 )
    {
        /* Log failure */
        return(-1);
    } else {

        ssl_write_return = SSL_write(ssl, msg , req_len);

        switch(SSL_get_error(ssl, ssl_write_return)) 
        {
            
            case SSL_ERROR_NONE:
               ...
            default:
               ... 
            SSL_free(ssl);
        }
    }
    SSL_CTX_free(ctx);

Server code

 ssl = SSL_new(ctx);
    SSL_set_fd(ssl, session->fd);    
    while(TRUE){
            if ((ssl_accept_ret =SSL_accept(ssl)) != 1){
                
                        log ("ssl_accept failed with %d\n", ssl_accept_ret);
                switch(SSL_get_error(ssl,ssl_accept_ret )){
                    case SSL_ERROR_NONE:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));  
                        break;
                    case SSL_ERROR_SSL:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));  
                        break;
                    case SSL_ERROR_WANT_READ:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));
                        continue;            
                    case SSL_ERROR_WANT_WRITE:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));  
                        continue;        
                    case SSL_ERROR_SYSCALL:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));  
                        break;
                    case SSL_ERROR_ZERO_RETURN:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));
                        break;
                    case SSL_ERROR_WANT_CONNECT:
                }
                return(-1);
            } else {
                log ("ssl_accept was successful with %d\n", ssl_accept_ret);
                return 0;
            }
        }

server read code

while(TRUE){
            ret = SSL_read(session->ssl, buf, sizeof(buf));
            if (ret<=0){

                
                switch(SSL_get_error(session->ssl,ret)){

                    case SSL_ERROR_NONE:                       
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));  
                        break;
                    case SSL_ERROR_ZERO_RETURN:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));  
                        break;
                    case SSL_ERROR_WANT_READ:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));
                        continue;            
                    case SSL_ERROR_WANT_WRITE:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));
                        continue;        
                    case SSL_ERROR_SYSCALL:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));  
                            break;
                    case SSL_ERROR_SSL:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));  
                        break;
                    default:
                        ERR_error_string_n(ERR_get_error(), err_msg, sizeof(err_msg));  
                        break;
                }
                exit_select_loop()
            } else {
                
                log( "ssl_read was successful with %d and %s \n", ret,buf);
                do{
                     ret = SSL_read(session->ssl, buf, sizeof(buf));
                   
                      log("ssl_read  %d and %s \n", ret,buf);
                }while(SSL_pending(session->ssl)!=0 && (SSL_get_error(session->ssl, ret) == SSL_ERROR_WANT_WRITE || SSL_ERROR_WANT_READ));
            } 
        }

Solution

  • I find a working ssl code in website, you can try it:

    #include <stdio.h>
    #include <unistd.h>
    #include <string.h>
    #include <strings.h>
    #include <errno.h>
    #include <netinet/in.h>
    #include <netdb.h>
    #include <sys/socket.h>
    #include <sys/types.h>
    #include <sys/time.h>
    #include <sys/select.h>
    #include <netinet/tcp.h>
    #include <ctype.h>
    
    #include <openssl/ssl.h>
    #include <openssl/err.h>
    
    //#include "imapfilter.h"
    //#include "session.h"
    
    
    #define TCP_DEFAULTBACK_LOG 15
    #define HTTP_RESPON_TIMEOUT 15
    //#define FORMAT_HTTPCHECK_REQ(buf,serv,port) snprintf(buf,sizeof(buf)-1,"GET /api/v1/echo/ HTTP/1.1\r\nHost: %s:%d\r\n\r\n",serv,port)
    #define FORMAT_HTTPCHECK_REQ(buf,serv,port) snprintf(buf,sizeof(buf)-1,"CONNECT 192.168.1.1 HTTP/1.1\r\nHost: %s:%d\r\n\r\n",serv,port)
    
    #if OPENSSL_VERSION_NUMBER >= 0x1010000fL
    SSL_CTX *sslctx = NULL;
    #else
    SSL_CTX *ssl23ctx = NULL;
    #ifndef OPENSSL_NO_SSL3_METHOD
    SSL_CTX *ssl3ctx = NULL;
    #endif
    #ifndef OPENSSL_NO_TLS1_METHOD
    SSL_CTX *tls1ctx = NULL;
    #endif
    #ifndef OPENSSL_NO_TLS1_1_METHOD
    SSL_CTX *tls11ctx = NULL;
    #endif
    #ifndef OPENSSL_NO_TLS1_2_METHOD
    SSL_CTX *tls12ctx = NULL;
    #endif
    #endif
    
    
    
    /* IMAP session. */
    typedef struct session {
        int socket;     /* Socket. */
        SSL *sslconn;       /* SSL connection. */
    } session;
    #define error  printf
    
    static int
    open_connection(session *ssn,const char* serv,uint16_t port,const char* sslproto);
    static int
    close_connection(session *ssn);
    static ssize_t
    socket_read(session *ssn, char *buf, size_t len, long timeout, int timeoutfail, int *interrupt);
    static ssize_t
    socket_write(session *ssn, const char *buf, size_t len);
    
    static int
    open_secure_connection(session *ssn,const char* serv,const char* sslproto);
    static int
    close_secure_connection(session *ssn);
    static ssize_t
    socket_secure_read(session *ssn, char *buf, size_t len);
    static ssize_t
    socket_secure_write(session *ssn, const char *buf, size_t len);
    
    
    
    /*
     * Connect to mail server.
     */
    static int
    open_connection(session *ssn,const char* serv,uint16_t port,const char* sslproto)
    {
        struct addrinfo hints, *res, *ressave;
        int n, sockfd;
    
            char    portstr[32];
            sprintf(portstr,"%d",portstr);
        memset(&hints, 0, sizeof(struct addrinfo));
    
        hints.ai_family = AF_UNSPEC;
        hints.ai_socktype = SOCK_STREAM;
    
        n = getaddrinfo(serv, portstr, &hints, &res);
    
        if (n < 0) {
            error("gettaddrinfo; %s\n", gai_strerror(n));
            return -1;
        }
    
        ressave = res;
    
        sockfd = -1;
    
        while (res) {
            sockfd = socket(res->ai_family, res->ai_socktype,
                res->ai_protocol);
    
            if (sockfd >= 0) {
                if (connect(sockfd, res->ai_addr, res->ai_addrlen) == 0)
                    break;
    
                sockfd = -1;
            }
            res = res->ai_next;
        }
    
        if (ressave)
            freeaddrinfo(ressave);
    
        if (sockfd == -1) {
            error("error while initiating connection to %s at port %d\n",
                serv, port);
            return -1;
        }
    
        ssn->socket = sockfd;
    
        if (sslproto) {
            if (open_secure_connection(ssn,serv,sslproto) == -1) {
                close_connection(ssn);
                return -1;
            }
        }
    
        return ssn->socket;
    }
    
    
    /*
     * Initialize SSL/TLS connection.
     */
    static int
    open_secure_connection(session *ssn,const char* serv,const char* sslproto)
    {
        int r, e;
        SSL_CTX *ctx = NULL;
    
    #if OPENSSL_VERSION_NUMBER >= 0x1010000fL
        if (sslctx)
            ctx = sslctx;
    #else
        if (ssl23ctx)
            ctx = ssl23ctx;
    
        if (sslproto) {
    #ifndef OPENSSL_NO_SSL3_METHOD
            if (ssl3ctx && !strcasecmp(sslproto, "ssl3"))
                ctx = ssl3ctx;
    #endif
    #ifndef OPENSSL_NO_TLS1_METHOD
            if (tls1ctx && !strcasecmp(sslproto, "tls1"))
                ctx = tls1ctx;
    #endif
    #ifndef OPENSSL_NO_TLS1_1_METHOD
            if (tls11ctx && !strcasecmp(sslproto, "tls1.1"))
                ctx = tls11ctx;
    #endif
    #ifndef OPENSSL_NO_TLS1_2_METHOD
            if (tls12ctx && !strcasecmp(sslproto, "tls1.2"))
                ctx = tls12ctx;
    #endif
        }
    #endif
    
        if (ctx == NULL) {
            error("initiating SSL connection to %s; protocol version "
                  "not supported by current build", serv);
            goto fail;
        }
    
        if (!(ssn->sslconn = SSL_new(ctx)))
            goto fail;
    
    #if OPENSSL_VERSION_NUMBER >= 0x1000000fL
        r = SSL_set_tlsext_host_name(ssn->sslconn, serv);
        if (r == 0) {
            error("failed setting the Server Name Indication (SNI) to "
                "%s; %s\n", serv,
                ERR_error_string(ERR_get_error(), NULL));
            goto fail;
        }
    #endif
    
        SSL_set_fd(ssn->sslconn, ssn->socket);
    
        for (;;) {
            if ((r = SSL_connect(ssn->sslconn)) > 0)
                break;
    
            switch (SSL_get_error(ssn->sslconn, r)) {
            case SSL_ERROR_ZERO_RETURN:
                error("initiating SSL connection to %s; the "
                    "connection has been closed cleanly\n",
                    serv);
                goto fail;
            case SSL_ERROR_NONE:
            case SSL_ERROR_WANT_CONNECT:
            case SSL_ERROR_WANT_ACCEPT:
            case SSL_ERROR_WANT_X509_LOOKUP:
            case SSL_ERROR_WANT_READ:
            case SSL_ERROR_WANT_WRITE:
                break;
            case SSL_ERROR_SYSCALL:
                e = ERR_get_error();
                if (e == 0 && r == 0)
                    error("initiating SSL connection to %s; EOF in "
                        "violation of the protocol\n", serv);
                else if (e == 0 && r == -1)
                    error("initiating SSL connection to %s; %s\n",
                        serv, strerror(errno));
                else
                    error("initiating SSL connection to %s; %s\n",
                        serv, ERR_error_string(e, NULL));
                goto fail;
            case SSL_ERROR_SSL:
                error("initiating SSL connection to %s; %s\n",
                    serv, ERR_error_string(ERR_get_error(),
                    NULL));
                goto fail;
            default:
                break;
            }
        }
        // TODO: ignore cert if (get_option_boolean("certificates") && get_cert(ssn) == -1)
        //  goto fail;
    
        return 0;
    
    fail:
        ssn->sslconn = NULL;
    
        return -1;
    }