You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

499 lines
10 KiB

/*
*
*/
#include <stdio.h>
#include <unistd.h> /* close() */
#include <stdlib.h>
#include <string.h> /* memset() */
#include <errno.h>
#include <string.h>
#include <errno.h>
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include "UDPTCPNetwork.h"
static int ssl_init = 0;
SSLSocket::SSLSocket() {
readcnt = 0;
writecnt = 0;
certfile = "";
keyfile = "";
sslerror = SSL_ERROR_NONE;
timeout = 0;
ctx = NULL;
ssl = NULL;
if (ssl_init == 0) {
ssl_init = 1;
SSL_library_init();
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
}
};
SSLSocket::~SSLSocket() {
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
if (ssl) SSL_free(ssl);
ssl = NULL;
};
int SSLSocket::NewServerCTX() {
struct stat st;
if (stat (certfile.c_str(), &st)) return 0;
if (stat (keyfile.c_str(), &st)) return 0;
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
#ifdef SSLv23_method
ctx = SSL_CTX_new(TLS_server_method());
#else
ctx = SSL_CTX_new(TLSv1_2_server_method());
#endif
if (SSL_CTX_use_certificate_file(ctx, certfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) {
ERR_print_errors_fp(stderr);
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
errno = EPROTO;
return 0;
}
if ( SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) {
ERR_print_errors_fp(stderr);
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
errno = EPROTO;
return 0;
}
if (!SSL_CTX_check_private_key(ctx)) {
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
errno = EPROTO;
return 0;
}
return 1;
};
int SSLSocket::NewClientCTX() {
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
#ifdef SSLv23_method
ctx = SSL_CTX_new(TLS_client_method());
#else
ctx = SSL_CTX_new(TLSv1_2_client_method());
#endif
return 1;
};
int SSLSocket::SetCertificat(string certf, string keyf) {
certfile = certf;
keyfile = keyf;
if (!file_is_readable(certf.c_str())) return 0;
if (!file_is_readable(keyf.c_str())) return 0;
return 1;
};
int SSLSocket::Connect (int sockfd, int block_timeout) {
int flags, res;
TimeoutReset();
sslerror = SSL_ERROR_NONE;
if (NewServerCTX() == 0) {
debug ("error on NewServerCTX()\n");
return 0;
}
timeout = block_timeout;
if (sockfd > 0 && block_timeout > 0) {
#if defined(_WIN32) || defined(_WIN64) || defined(__CYGWIN__)
u_long mode = 1;
ioctlsocket(sockfd, FIONBIO, &mode);
#else
flags = fcntl(sockfd, F_GETFL, 0);
fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
#endif
}
ssl = SSL_new(ctx);
if (ssl == NULL) {
debug ("SSL_new failed\n");
}
SSL_set_fd (ssl, sockfd);
do {
res = SSL_connect(ssl);
if (res == -1) sslerror = SSL_get_error(ssl, -1);
} while (res == -1 && TimeoutTime() < timeout &&
(sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE));
if (res == -1) return 0;
return 1;
};
/// @brief
/// @param sockfd
/// @param block_timeout timeout in ms
/// @return 1 on success, 0 on error
int SSLSocket::Accept (int sockfd, int block_timeout) {
int flags, res;
TimeoutReset();
sslerror = SSL_ERROR_NONE;
timeout = block_timeout;
if (NewServerCTX() == 0) {
debug ("error on NewServerCTX()\n");
return 0;
}
if (sockfd > 0 && block_timeout > 0) {
#if defined(_WIN32) || defined(_WIN64) || defined(__CYGWIN__)
u_long mode = 1;
ioctlsocket(sockfd, FIONBIO, &mode);
#else
flags = fcntl(sockfd, F_GETFL, 0);
fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
#endif
}
ssl = SSL_new(ctx);
if (ssl == NULL) {
debug ("SSL_new failed\n");
}
SSL_set_fd (ssl, sockfd);
do {
res = SSL_accept(ssl);
if (res == -1) sslerror = SSL_get_error(ssl, -1);
} while (res == -1 && TimeoutTime() < timeout &&
(sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE));
return 1;
};
const string SSLSocket::GetSSLErrorText(int err) {
string s;
switch (err) {
case SSL_ERROR_NONE:
s = "SSL_ERROR_NONE";
break;
case SSL_ERROR_SSL:
s = "SSL_ERROR_SSL";
break;
case SSL_ERROR_WANT_READ:
s = "SSL_ERROR_WANT_READ";
break;
case SSL_ERROR_WANT_WRITE:
s = "SSL_ERROR_WANT_WRITE";
break;
case SSL_ERROR_SYSCALL:
s = "SSL_ERROR_SYSCALL";
break;
case SSL_ERROR_WANT_CONNECT:
s = "SSL_ERROR_WANT_CONNECT";
break;
case SSL_ERROR_ZERO_RETURN:
s = "SSL_ERROR_ZERO_RETURN";
break;
case SSL_ERROR_WANT_ACCEPT:
s = "SSL_ERROR_WANT_ACCEPT";
break;
default:
s = "SSL_ERROR unknown " + to_string(err);
break;
}
return s;
}
//
// close ssl and return socket.
int SSLSocket::Close () {
int sock = 0;
int flags;
if (ssl) {
sock = SSL_get_fd(ssl);
SSL_free(ssl);
ssl = NULL;
}
if (ctx) {
SSL_CTX_free(ctx);
ctx = NULL;
}
if (sock > 0 && timeout > 0) {
#if defined(_WIN32) || defined(_WIN64) || defined(__CYGWIN__)
u_long mode = 0;
ioctlsocket(sock, FIONBIO, &mode);
#else
flags = fcntl(sock, F_GETFL, 0);
fcntl(sock, F_SETFL, flags & ~(O_NONBLOCK));
#endif
}
return sock;
};
//
//
long int SSLSocket::Read (char *buffer, long int len) {
int ret;
sslerror = SSL_ERROR_NONE;
TimeoutReset();
if (!ssl) {
errno = EPROTO;
return -1;
}
do {
ret = SSL_read(ssl, buffer, len);
if (ret == -1) sslerror = SSL_get_error(ssl, -1);
} while (ret == -1 && TimeoutTime() < timeout &&
(sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE));
if (ret == 0 && sslerror == 0) return -1;
if (ret == -1 &&
(sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE))
ret = 0;
return ret;
};
//
//
long int SSLSocket::Write (char *buffer, long int len) {
int ret;
sslerror = SSL_ERROR_NONE;
TimeoutReset();
if (!ssl) {
errno = EPROTO;
return -1;
}
do {
ret = SSL_write(ssl, buffer, len);
if (ret == -1) sslerror = SSL_get_error(ssl, -1);
} while (ret == -1 && TimeoutTime() < timeout &&
(sslerror == SSL_ERROR_WANT_READ || sslerror == SSL_ERROR_WANT_WRITE));
return ret;
};
//
// Reset Timeout Timer
void SSLSocket::TimeoutReset() {
gettimeofday (&timeout_start, NULL);
};
//
// Return time which has past since reset in ms.
int SSLSocket::TimeoutTime() {
struct timeval tv;
gettimeofday (&tv, NULL);
return ((tv.tv_sec-timeout_start.tv_sec) * 1000) +
((tv.tv_usec-timeout_start.tv_usec) / 1000);
};
/// @brief generate an sha256 hash from the giving string
/// @param in string to get an hash sum from
/// @return hashsum
std::string getsha256sum(std::string *in) {
std::string out = "";
EVP_MD_CTX *mdctx;
unsigned char *chksum = NULL;
unsigned int chksumlen;
char hex[] = "0123456789abcdef";
int i;
if (in == NULL) return "";
if ((mdctx = EVP_MD_CTX_create()) == NULL) return "";
if (1 != EVP_DigestInit_ex(mdctx, EVP_sha256(), NULL)) {
EVP_MD_CTX_destroy(mdctx);
return "";
}
if (1 != EVP_DigestUpdate(mdctx, in->c_str(), strlen(in->c_str()))) {
EVP_MD_CTX_destroy(mdctx);
return "";
}
if ((chksum = (unsigned char *)OPENSSL_malloc(EVP_MD_size(EVP_sha256()))) == NULL) {
EVP_MD_CTX_destroy(mdctx);
return "";
}
if (1 != EVP_DigestFinal_ex(mdctx, chksum, &chksumlen)) out = "";
else {
//
// convert data to hex;
for (out = "", i = 0; i < chksumlen; i++) {
unsigned char c = chksum[i];
out += hex[(c >> 4) & 0x0f];
out += hex[c & 0x0f];
}
}
OPENSSL_free (chksum);
EVP_MD_CTX_destroy(mdctx);
return out;
};
std::string getrandomtext(int numbytes) {
std::string text = "";
unsigned int seed;
seed = arc4random();
srand(seed);
for (int i = 0; i < numbytes; i++) {
char c = ((rand() % (90-48)) + 48);
text += c;
}
return text;
};
// 1 2 3 4 5 6
// 0123456789012345678901234567890123456789012345678901234567890123
static char base64table1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static char base64table2[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
int base64encode(char *ibuffer, int ilen, std::string *output, int urlsafe) {
int icnt; // input pos
char c1, c2; // two bytes
int idx;
*output = "";
for (c1 = 0, c2 = 0, icnt = 0; icnt < ilen; icnt++) {
// byte 0
c1 = ibuffer[icnt];
idx = c1 >> 2;
*output += (urlsafe ? base64table2[idx] : base64table1[idx]);
// byte 1
if (++icnt >= ilen) break;
c2 = ibuffer[icnt];
idx = (c1 & 0x03) << 4 | (c2 >> 4);
*output += (urlsafe ? base64table2[idx] : base64table1[idx]);
// byte 2
if (++icnt >= ilen) break;
c1 = ibuffer[icnt];
idx = (c2 & 0x07) << 2 | (c1 & 0xC0) >> 6;
*output += (urlsafe ? base64table2[idx] : base64table1[idx]);
idx = c1 & 0x3F;
*output += (urlsafe ? base64table2[idx] : base64table1[idx]);
}
switch (icnt % 3) {
case 0:
break;
case 1:
idx = (c1 & 0x03) << 4;
*output += (urlsafe ? base64table2[idx] : base64table1[idx]);
*output += '=';
*output += '=';
break;
case 2:
idx = (c2 & 0x07) << 2;
*output += (urlsafe ? base64table2[idx] : base64table1[idx]);
*output += '=';
break;
}
return 1;
};
int _base64getidx(char c) {
char *p;
if (c == '=') return 0;
p = strchr (base64table1, c);
if (p != NULL) return p-base64table1;
p = strchr (base64table2, c);
if (p != NULL) return p-base64table2;
return -1;
};
int base64decode(std::string input, char **obuffer, int *olen) {
int ocnt;
int icnt;
int v1, v2;
int ilen = input.length();
char c;
// calculate and realloc size of buffer
if (*obuffer == NULL || (1 + ilen*3/4) < *olen) {
*olen = 1 + ilen*3/4;
*obuffer = (char*) realloc(*obuffer, *olen);
}
for (v1 = 0, v2 = 0, ocnt = 0, icnt = 0; icnt < ilen && ocnt < *olen; icnt++) {
// read byte 1
v1 = _base64getidx(input[icnt]);
if (++icnt >= ilen || v1 == -1) return 0;
v2 = _base64getidx(input[icnt]);
if (v2 == -1) return 0;
c = ((v1 << 2) | ((v2 & 0x30) >> 4)) & 0xFF;
(*obuffer)[ocnt++] = c;
// read byte 2
if (++icnt >= ilen) return 0;
if (input[icnt] == '=') break;
v1 = _base64getidx(input[icnt]);
if (v1 == -1) return 0;
c = ((v2 & 0x07) << 4 | (v1 & 0x3C) >> 2) & 0xFF;
(*obuffer)[ocnt++] = c;
// read byte 3
if (++icnt >= ilen) return 0;
if (input[icnt] == '=') break;
v2 = _base64getidx(input[icnt]);
if (v2 == -1) return 0;
c = ((v1 & 0x03) << 6| (v2 & 0x3F)) & 0xFF;
(*obuffer)[ocnt++] = c;
}
*olen = ocnt;
return 1;
};