/* * */ #include #include /* close() */ #include #include /* memset() */ #include #include #include #include #include #include #include #include #include "UDPTCPNetwork.h" static int ssl_init = 0; SSLSocket::SSLSocket() { readcnt = 0; writecnt = 0; certfile = ""; keyfile = ""; sslerror = SSL_ERROR_NONE; timeout = 0; ctx = NULL; ssl = NULL; if (ssl_init == 0) { ssl_init = 1; SSL_library_init(); OpenSSL_add_all_algorithms(); SSL_load_error_strings(); } }; SSLSocket::~SSLSocket() { if (ctx) SSL_CTX_free(ctx); ctx = NULL; if (ssl) SSL_free(ssl); ssl = NULL; }; int SSLSocket::NewServerCTX() { struct stat st; if (stat (certfile.c_str(), &st)) return 0; if (stat (keyfile.c_str(), &st)) return 0; if (ctx) SSL_CTX_free(ctx); ctx = NULL; #ifdef SSLv23_method ctx = SSL_CTX_new(TLS_server_method()); #else ctx = SSL_CTX_new(TLSv1_2_server_method()); #endif if (SSL_CTX_use_certificate_file(ctx, certfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) { ERR_print_errors_fp(stderr); if (ctx) SSL_CTX_free(ctx); ctx = NULL; errno = EPROTO; return 0; } if ( SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) { ERR_print_errors_fp(stderr); if (ctx) SSL_CTX_free(ctx); ctx = NULL; errno = EPROTO; return 0; } if (!SSL_CTX_check_private_key(ctx)) { if (ctx) SSL_CTX_free(ctx); ctx = NULL; errno = EPROTO; return 0; } return 1; }; int SSLSocket::NewClientCTX() { if (ctx) SSL_CTX_free(ctx); ctx = NULL; #ifdef SSLv23_method ctx = SSL_CTX_new(TLS_client_method()); #else ctx = SSL_CTX_new(TLSv1_2_client_method()); #endif return 1; }; int SSLSocket::SetCertificat(string certf, string keyf) { certfile = certf; keyfile = keyf; if (!file_is_readable(certf.c_str())) return 0; if (!file_is_readable(keyf.c_str())) return 0; return 1; }; int SSLSocket::Connect (int sockfd, int block_timeout) { int flags, res; TimeoutReset(); sslerror = SSL_ERROR_NONE; if (NewServerCTX() == 0) { debug ("error on NewServerCTX()\n"); return 0; } timeout = block_timeout; if (sockfd > 0 && block_timeout > 0) { #if defined(_WIN32) || defined(_WIN64) || defined(__CYGWIN__) u_long mode = 1; ioctlsocket(sockfd, FIONBIO, &mode); #else flags = fcntl(sockfd, F_GETFL, 0); fcntl(sockfd, F_SETFL, flags | O_NONBLOCK); #endif } ssl = SSL_new(ctx); if (ssl == NULL) { debug ("SSL_new failed\n"); } SSL_set_fd (ssl, sockfd); do { res = SSL_connect(ssl); if (res == -1) sslerror = SSL_get_error(ssl, -1); } while (res == -1 && TimeoutTime() < timeout && (sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE)); if (res == -1) return 0; return 1; }; /// @brief /// @param sockfd /// @param block_timeout timeout in ms /// @return 1 on success, 0 on error int SSLSocket::Accept (int sockfd, int block_timeout) { int flags, res; TimeoutReset(); sslerror = SSL_ERROR_NONE; timeout = block_timeout; if (NewServerCTX() == 0) { debug ("error on NewServerCTX()\n"); return 0; } if (sockfd > 0 && block_timeout > 0) { #if defined(_WIN32) || defined(_WIN64) || defined(__CYGWIN__) u_long mode = 1; ioctlsocket(sockfd, FIONBIO, &mode); #else flags = fcntl(sockfd, F_GETFL, 0); fcntl(sockfd, F_SETFL, flags | O_NONBLOCK); #endif } ssl = SSL_new(ctx); if (ssl == NULL) { debug ("SSL_new failed\n"); } SSL_set_fd (ssl, sockfd); do { res = SSL_accept(ssl); if (res == -1) sslerror = SSL_get_error(ssl, -1); } while (res == -1 && TimeoutTime() < timeout && (sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE)); return 1; }; const string SSLSocket::GetSSLErrorText(int err) { string s; switch (err) { case SSL_ERROR_NONE: s = "SSL_ERROR_NONE"; break; case SSL_ERROR_SSL: s = "SSL_ERROR_SSL"; break; case SSL_ERROR_WANT_READ: s = "SSL_ERROR_WANT_READ"; break; case SSL_ERROR_WANT_WRITE: s = "SSL_ERROR_WANT_WRITE"; break; case SSL_ERROR_SYSCALL: s = "SSL_ERROR_SYSCALL"; break; case SSL_ERROR_WANT_CONNECT: s = "SSL_ERROR_WANT_CONNECT"; break; case SSL_ERROR_ZERO_RETURN: s = "SSL_ERROR_ZERO_RETURN"; break; case SSL_ERROR_WANT_ACCEPT: s = "SSL_ERROR_WANT_ACCEPT"; break; default: s = "SSL_ERROR unknown " + to_string(err); break; } return s; } // // close ssl and return socket. int SSLSocket::Close () { int sock = 0; int flags; if (ssl) { sock = SSL_get_fd(ssl); SSL_free(ssl); ssl = NULL; } if (ctx) { SSL_CTX_free(ctx); ctx = NULL; } if (sock > 0 && timeout > 0) { #if defined(_WIN32) || defined(_WIN64) || defined(__CYGWIN__) u_long mode = 0; ioctlsocket(sock, FIONBIO, &mode); #else flags = fcntl(sock, F_GETFL, 0); fcntl(sock, F_SETFL, flags & ~(O_NONBLOCK)); #endif } return sock; }; // // long int SSLSocket::Read (char *buffer, long int len) { int ret; sslerror = SSL_ERROR_NONE; TimeoutReset(); if (!ssl) { errno = EPROTO; return -1; } do { ret = SSL_read(ssl, buffer, len); if (ret == -1) sslerror = SSL_get_error(ssl, -1); } while (ret == -1 && TimeoutTime() < timeout && (sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE)); if (ret == 0 && sslerror == 0) return -1; if (ret == -1 && (sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE)) ret = 0; return ret; }; // // long int SSLSocket::Write (char *buffer, long int len) { int ret; sslerror = SSL_ERROR_NONE; TimeoutReset(); if (!ssl) { errno = EPROTO; return -1; } do { ret = SSL_write(ssl, buffer, len); if (ret == -1) sslerror = SSL_get_error(ssl, -1); } while (ret == -1 && TimeoutTime() < timeout && (sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE)); return ret; }; // // Reset Timeout Timer void SSLSocket::TimeoutReset() { gettimeofday (&timeout_start, NULL); }; // // Return time which has past since reset in ms. int SSLSocket::TimeoutTime() { struct timeval tv; gettimeofday (&tv, NULL); return ((tv.tv_sec-timeout_start.tv_sec) * 1000) + ((tv.tv_usec-timeout_start.tv_usec) / 1000); }; /// @brief generate an sha256 hash from the giving string /// @param in string to get an hash sum from /// @return hashsum std::string getsha256sum(std::string *in) { std::string out = ""; EVP_MD_CTX *mdctx; unsigned char *chksum = NULL; unsigned int chksumlen; char hex[] = "0123456789abcdef"; int i; if (in == NULL) return ""; if ((mdctx = EVP_MD_CTX_create()) == NULL) return ""; if (1 != EVP_DigestInit_ex(mdctx, EVP_sha256(), NULL)) { EVP_MD_CTX_destroy(mdctx); return ""; } if (1 != EVP_DigestUpdate(mdctx, in->c_str(), strlen(in->c_str()))) { EVP_MD_CTX_destroy(mdctx); return ""; } if ((chksum = (unsigned char *)OPENSSL_malloc(EVP_MD_size(EVP_sha256()))) == NULL) { EVP_MD_CTX_destroy(mdctx); return ""; } if (1 != EVP_DigestFinal_ex(mdctx, chksum, &chksumlen)) out = ""; else { // // convert data to hex; for (out = "", i = 0; i < chksumlen; i++) { unsigned char c = chksum[i]; out += hex[(c >> 4) & 0x0f]; out += hex[c & 0x0f]; } } OPENSSL_free (chksum); EVP_MD_CTX_destroy(mdctx); return out; }; std::string getrandomtext(int numbytes) { std::string text = ""; unsigned int seed; seed = arc4random(); srand(seed); for (int i = 0; i < numbytes; i++) { char c = ((rand() % (90-48)) + 48); text += c; } return text; }; // 1 2 3 4 5 6 // 0123456789012345678901234567890123456789012345678901234567890123 static char base64table1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; static char base64table2[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; int base64encode(char *ibuffer, int ilen, std::string *output, int urlsafe) { int icnt; // input pos char c1, c2; // two bytes int idx; *output = ""; for (c1 = 0, c2 = 0, icnt = 0; icnt < ilen; icnt++) { // byte 0 c1 = ibuffer[icnt]; idx = c1 >> 2; *output += (urlsafe ? base64table2[idx] : base64table1[idx]); // byte 1 if (++icnt >= ilen) break; c2 = ibuffer[icnt]; idx = (c1 & 0x03) << 4 | (c2 >> 4); *output += (urlsafe ? base64table2[idx] : base64table1[idx]); // byte 2 if (++icnt >= ilen) break; c1 = ibuffer[icnt]; idx = (c2 & 0x07) << 2 | (c1 & 0xC0) >> 6; *output += (urlsafe ? base64table2[idx] : base64table1[idx]); idx = c1 & 0x3F; *output += (urlsafe ? base64table2[idx] : base64table1[idx]); } switch (icnt % 3) { case 0: break; case 1: idx = (c1 & 0x03) << 4; *output += (urlsafe ? base64table2[idx] : base64table1[idx]); *output += '='; *output += '='; break; case 2: idx = (c2 & 0x07) << 2; *output += (urlsafe ? base64table2[idx] : base64table1[idx]); *output += '='; break; } return 1; }; int _base64getidx(char c) { char *p; if (c == '=') return 0; p = strchr (base64table1, c); if (p != NULL) return p-base64table1; p = strchr (base64table2, c); if (p != NULL) return p-base64table2; return -1; }; int base64decode(std::string input, char **obuffer, int *olen) { int ocnt; int icnt; int v1, v2; int ilen = input.length(); char c; // calculate and realloc size of buffer if (*obuffer == NULL || (1 + ilen*3/4) < *olen) { *olen = 2 + ilen*3/4; // increase buffer by two ... null terminated string? *obuffer = (char*) realloc(*obuffer, *olen); } for (v1 = 0, v2 = 0, ocnt = 0, icnt = 0; icnt < ilen && ocnt < *olen; icnt++) { // read byte 1 v1 = _base64getidx(input[icnt]); if (++icnt >= ilen || v1 == -1) return 0; v2 = _base64getidx(input[icnt]); if (v2 == -1) return 0; c = ((v1 << 2) | ((v2 & 0x30) >> 4)) & 0xFF; (*obuffer)[ocnt++] = c; // read byte 2 if (++icnt >= ilen) return 0; if (input[icnt] == '=') break; v1 = _base64getidx(input[icnt]); if (v1 == -1) return 0; c = ((v2 & 0x07) << 4 | (v1 & 0x3C) >> 2) & 0xFF; (*obuffer)[ocnt++] = c; // read byte 3 if (++icnt >= ilen) return 0; if (input[icnt] == '=') break; v2 = _base64getidx(input[icnt]); if (v2 == -1) return 0; c = ((v1 & 0x03) << 6| (v2 & 0x3F)) & 0xFF; (*obuffer)[ocnt++] = c; } *olen = ocnt; (*obuffer)[ocnt] = 0; // nullterminate - just to prevent unseen stuff return 1; };