/* * */ #include #include /* close() */ #include #include /* memset() */ #include #include #include #include #include #include #include "UDPTCPNetwork.h" static int ssl_init = 0; SSLSocket::SSLSocket() { readcnt = 0; writecnt = 0; certfile = ""; keyfile = ""; ctx = NULL; ssl = NULL; if (ssl_init == 0) { ssl_init = 1; SSL_library_init(); OpenSSL_add_all_algorithms(); SSL_load_error_strings(); } }; SSLSocket::~SSLSocket() { if (ctx) SSL_CTX_free(ctx); ctx = NULL; if (ssl) SSL_free(ssl); ssl = NULL; }; int SSLSocket::NewServerCTX() { struct stat st; if (stat (certfile.c_str(), &st)) return 0; if (stat (keyfile.c_str(), &st)) return 0; if (ctx) SSL_CTX_free(ctx); ctx = NULL; ctx = SSL_CTX_new(TLSv1_2_server_method()); if (SSL_CTX_use_certificate_file(ctx, certfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) { ERR_print_errors_fp(stderr); if (ctx) SSL_CTX_free(ctx); ctx = NULL; errno = EPROTO; return 0; } if ( SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) { ERR_print_errors_fp(stderr); if (ctx) SSL_CTX_free(ctx); ctx = NULL; errno = EPROTO; return 0; } if (!SSL_CTX_check_private_key(ctx)) { if (ctx) SSL_CTX_free(ctx); ctx = NULL; errno = EPROTO; return 0; } return 1; }; int SSLSocket::NewClientCTX() { if (ctx) SSL_CTX_free(ctx); ctx = NULL; ctx = SSL_CTX_new(TLSv1_2_client_method()); return 1; }; int SSLSocket::SetCertificat(string certf, string keyf) { certfile = certf; keyfile = keyf; return 1; }; int SSLSocket::Connect (int sockfd) { NewClientCTX(); ssl = SSL_new(ctx); SSL_set_fd (ssl, sockfd); if (SSL_connect(ssl) == -1 ) { printf ("ssl connect: error\n"); ERR_print_errors_fp(stderr); exit (1); } return 1; }; int SSLSocket::Accept (int sockfd) { int err; NewServerCTX(); ssl = SSL_new(ctx); SSL_set_fd (ssl, sockfd); if (SSL_accept(ssl) == -1 ) { err = SSL_get_error(ssl, -1); printf ("%s %s:%d error: %s\n", __FUNCTION__, __FILE__, __LINE__, GetSSLErrorText(err).c_str()); return 0; } return 1; }; const string SSLSocket::GetSSLErrorText(int err) { string s; switch (err) { case SSL_ERROR_NONE: s = "SSL_ERROR_NONE"; break; case SSL_ERROR_SSL: s = "SSL_ERROR_SSL"; break; case SSL_ERROR_WANT_READ: s = "SSL_ERROR_WANT_READ"; break; case SSL_ERROR_WANT_WRITE: s = "SSL_ERROR_WANT_WRITE"; break; case SSL_ERROR_SYSCALL: s = "SSL_ERROR_SYSCALL"; break; case SSL_ERROR_WANT_CONNECT: s = "SSL_ERROR_WANT_CONNECT"; break; case SSL_ERROR_ZERO_RETURN: s = "SSL_ERROR_ZERO_RETURN"; break; case SSL_ERROR_WANT_ACCEPT: s = "SSL_ERROR_WANT_ACCEPT"; break; default: s = "SSL_ERROR unknown " + to_string(err); break; } return s; } // // close ssl and return socket. int SSLSocket::Close () { int sock = 0; if (ssl) { sock = SSL_get_fd(ssl); SSL_free(ssl); ssl = NULL; } if (ctx) { SSL_CTX_free(ctx); ctx = NULL; } return sock; }; // // need to add timeout currently data is blocking long int SSLSocket::Read (char *buffer, long int len) { int ret; if (!ssl) { errno = EPROTO; return -1; } ret = SSL_read(ssl, buffer, len); return ret; }; // // need to add timeout currently data is blocking long int SSLSocket::Write (char *buffer, long int len) { int ret; if (!ssl) { errno = EPROTO; return -1; } ret = SSL_write(ssl, buffer, len); return ret; };