From eb8b60d2cfcbc8c04a290e141f62709a1b381d4c Mon Sep 17 00:00:00 2001 From: steffen Date: Sun, 5 May 2019 20:05:26 +0000 Subject: [PATCH] adding basic support for SSL. need to add nonblocking --- Makefile | 15 +++- UDPTCPNetwork.h | 34 +++++++- ssl.cc | 223 ++++++++++++++++++++++++++++++++++++++++++++++++ test-ssl.cc | 149 ++++++++++++++++++++++++++++++++ 4 files changed, 415 insertions(+), 6 deletions(-) create mode 100644 ssl.cc create mode 100644 test-ssl.cc diff --git a/Makefile b/Makefile index 8398197..7de214c 100644 --- a/Makefile +++ b/Makefile @@ -6,13 +6,13 @@ ETCPREFIX=/etc CXX=g++ CXXFLAGS= -ggdb -fPIC -pg -Wno-write-strings -I./ -std=c++11 -LDFLAGS= -lm -lc -pg +LDFLAGS= -lm -lc -pg -lssl -lcrypto DEFAULT_TCPPORT=6131 DEFAULT_UDPPORT=6131 DEFAULT_SERVER=localhost -OBJLIB=network.o udp.o tcp.o unix.o +OBJLIB=network.o udp.o tcp.o unix.o ssl.o INCLIB=config.h UDPTCPNetwork.h OBJLIB_NAME=UDPTCPNetwork TARGET=lib$(OBJLIB_NAME).so.$(VERSION) @@ -20,14 +20,21 @@ TARGET=lib$(OBJLIB_NAME).so.$(VERSION) DISTNAME=libUDPTCPNetwork-$(VERSION) DEPENDFILE=.depend -all: dep $(TARGET) test-udp test-tcp +all: dep $(TARGET) test-udp test-tcp test-ssl test-tcp: $(TARGET) test-tcp.o config.h $(CXX) test-tcp.o -o $@ $(LDFLAGS) -lUDPTCPNetwork -L./ -I./ +test-ssl: $(TARGET) test-ssl.o config.h + $(CXX) test-ssl.o -o $@ $(LDFLAGS) -lUDPTCPNetwork -L./ -I./ + test-udp: $(TARGET) test-udp.o config.h $(CXX) test-udp.o -o $@ $(LDFLAGS) -lUDPTCPNetwork -L./ -I./ +keygen: + # openssl req -nodes -new -newkey rsa:2048 -sha256 -out csr.pem -keyout privkey.pem + openssl req -x509 -sha256 -nodes -days 365 -newkey rsa:2048 -keyout privkey.pem -out cert.pem + install: $(TARGET) cp -f $(TARGET) $(PREFIX)/lib/ ln -sf $(TARGET) $(PREFIX)/lib/lib$(OBJLIB_NAME).so @@ -51,6 +58,7 @@ dep: clean: rm test-tcp -rf rm test-udp -rf + rm test-ssl -rf rm -rf gmon.out rm *.s -rf rm *.o -rf @@ -61,6 +69,7 @@ clean: rm -rf *.so rm -rf *.a rm -rf *.so.* + rm -rf *.pem -rf cleanall: clean diff --git a/UDPTCPNetwork.h b/UDPTCPNetwork.h index 1617ee7..19d3061 100644 --- a/UDPTCPNetwork.h +++ b/UDPTCPNetwork.h @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include @@ -64,13 +66,10 @@ public: class TCP { private: int sock; -// struct sockaddr_storage localaddr; -// struct sockaddr_storage remoteaddr; string remote_host; string remote_port; int readcnt; int writecnt; - public: TCP(); TCP(int s); @@ -105,6 +104,33 @@ public: }; +/************************************************************************ + * SSL functions + */ +class SSLSocket { +private: + int readcnt; + int writecnt; + string certfile; + string keyfile; + SSL *ssl; + SSL_CTX *ctx; + int NewServerCTX(); + int NewClientCTX(); + const string GetSSLErrorText(int err); +public: + SSLSocket(); + ~SSLSocket(); + + int SetCertificat(string certf, string keyf); + int Connect(int sockfd); + int Accept(int sockfd); + long int Read(char *buffer, long int len); + long int Write(char *buffer, long int len); + int Close(); // returns socket +}; + + /************************************************************************ * unix socket related functions @@ -136,5 +162,7 @@ public: int GetSocket() { return sock; }; }; + + #endif diff --git a/ssl.cc b/ssl.cc new file mode 100644 index 0000000..047715a --- /dev/null +++ b/ssl.cc @@ -0,0 +1,223 @@ +/* + * + */ + +#include +#include /* close() */ +#include +#include /* memset() */ +#include +#include +#include +#include +#include +#include + +#include "UDPTCPNetwork.h" + +static int ssl_init = 0; + +SSLSocket::SSLSocket() { + readcnt = 0; + writecnt = 0; + certfile = ""; + keyfile = ""; + ctx = NULL; + ssl = NULL; + if (ssl_init == 0) { + ssl_init = 1; + SSL_library_init(); + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); + } +}; + + +SSLSocket::~SSLSocket() { + if (ctx) SSL_CTX_free(ctx); + ctx = NULL; + if (ssl) SSL_free(ssl); + ssl = NULL; +}; + + +int SSLSocket::NewServerCTX() { + struct stat st; + + if (stat (certfile.c_str(), &st)) return 0; + if (stat (keyfile.c_str(), &st)) return 0; + + if (ctx) SSL_CTX_free(ctx); + ctx = NULL; + + ctx = SSL_CTX_new(TLSv1_2_server_method()); + + if (SSL_CTX_use_certificate_file(ctx, certfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) { + ERR_print_errors_fp(stderr); + if (ctx) SSL_CTX_free(ctx); + ctx = NULL; + errno = EPROTO; + return 0; + } + + if ( SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) { + ERR_print_errors_fp(stderr); + if (ctx) SSL_CTX_free(ctx); + ctx = NULL; + errno = EPROTO; + return 0; + } + + if (!SSL_CTX_check_private_key(ctx)) { + if (ctx) SSL_CTX_free(ctx); + ctx = NULL; + errno = EPROTO; + return 0; + } + + return 1; +}; + + +int SSLSocket::NewClientCTX() { + if (ctx) SSL_CTX_free(ctx); + ctx = NULL; + ctx = SSL_CTX_new(TLSv1_2_client_method()); + return 1; +}; + + +int SSLSocket::SetCertificat(string certf, string keyf) { + certfile = certf; + keyfile = keyf; + + return 1; +}; + + +int SSLSocket::Connect (int sockfd) { + NewClientCTX(); + + ssl = SSL_new(ctx); + SSL_set_fd (ssl, sockfd); + if (SSL_connect(ssl) == -1 ) { + printf ("ssl connect: error\n"); + ERR_print_errors_fp(stderr); + exit (1); + } + + return 1; +}; + + + +int SSLSocket::Accept (int sockfd) { + int err; + NewServerCTX(); + + ssl = SSL_new(ctx); + SSL_set_fd (ssl, sockfd); + if (SSL_accept(ssl) == -1 ) { + err = SSL_get_error(ssl, -1); + printf ("%s %s:%d error: %s\n", __FUNCTION__, __FILE__, __LINE__, GetSSLErrorText(err).c_str()); + return 0; + } + + return 1; +}; + + +const string SSLSocket::GetSSLErrorText(int err) { + string s; + + switch (err) { + case SSL_ERROR_NONE: + s = "SSL_ERROR_NONE"; + break; + + case SSL_ERROR_SSL: + s = "SSL_ERROR_SSL"; + break; + + case SSL_ERROR_WANT_READ: + s = "SSL_ERROR_WANT_READ"; + break; + + case SSL_ERROR_WANT_WRITE: + s = "SSL_ERROR_WANT_WRITE"; + break; + + case SSL_ERROR_SYSCALL: + s = "SSL_ERROR_SYSCALL"; + break; + + case SSL_ERROR_WANT_CONNECT: + s = "SSL_ERROR_WANT_CONNECT"; + break; + + case SSL_ERROR_ZERO_RETURN: + s = "SSL_ERROR_ZERO_RETURN"; + break; + + case SSL_ERROR_WANT_ACCEPT: + s = "SSL_ERROR_WANT_ACCEPT"; + break; + default: + s = "SSL_ERROR unknown " + to_string(err); + break; + } + + return s; +} + + +// +// close ssl and return socket. +int SSLSocket::Close () { + int sock = 0; + + if (ssl) { + sock = SSL_get_fd(ssl); + SSL_free(ssl); + ssl = NULL; + } + + if (ctx) { + SSL_CTX_free(ctx); + ctx = NULL; + } + + return sock; +}; + + +// +// need to add timeout currently data is blocking +long int SSLSocket::Read (char *buffer, long int len) { + int ret; + + if (!ssl) { + errno = EPROTO; + return -1; + } + + ret = SSL_read(ssl, buffer, len); + + return ret; +}; + + +// +// need to add timeout currently data is blocking +long int SSLSocket::Write (char *buffer, long int len) { + int ret; + + if (!ssl) { + errno = EPROTO; + return -1; + } + + ret = SSL_write(ssl, buffer, len); + + return ret; +}; diff --git a/test-ssl.cc b/test-ssl.cc new file mode 100644 index 0000000..a36f54b --- /dev/null +++ b/test-ssl.cc @@ -0,0 +1,149 @@ + +#include +#include + +#include "UDPTCPNetwork.h" + +#define DEFAULT_PORT 12345 + + +void server () { + TCP tcpserver; + TCP *connection; + SSLSocket ssl; + int i, timeout; + pid_t pid; + char buffer[NET_BUFFERSIZE]; + + // + // start the server + if (tcpserver.Listen(DEFAULT_PORT) != 1) { + printf ("cloud not start the tcp server\n"); + exit (1); + } + + // + // init SSL + if (ssl.SetCertificat("cert.pem", "privkey.pem") != 1) { + printf ("SetCertificat error:%s\n", strerror(errno)); + exit (1); + } + + // + // check for connections + for (timeout = 10; timeout > 0; timeout--) { + connection = tcpserver.Accept(); + if (connection != NULL) { + // + // someone connected - create new process + // take care of parallel processing (parent is always the server) + // + printf (" server: got a connection forking new process\n"); + pid = fork(); + if (pid == 0) { + // + // child process - always close server since it will handeled + // by the parent process. Make sure the client exits and never + // returns. + + tcpserver.Close(); + if (ssl.Accept(connection->GetSocket()) != 1) { + printf ("could not establish SSL connection:%s\n", strerror(errno)); + exit (1); + } + i = ssl.Read(buffer, NET_BUFFERSIZE); + if (i > 0) { + int c; + + printf (" server: got: '%s'\n", buffer); + for (c = 0; c < i; c++) buffer[c] = toupper(buffer[c]); + ssl.Write(buffer, i); + } + // + // just delete the class object, it will close the client connection + ssl.Close(); + delete (connection); + + // + // exit child process + exit (1); + } + else { + // + // parent process - just close the client connection + // it will be handeled by the child process. + delete (connection); + } + } + sleep (1); + } +}; + + +void client () { + TCP tcpclient; + SSLSocket ssl; + + char buffer[NET_BUFFERSIZE]; + int i; + + sleep (1); // wait one second to start the server + + // + // connect to the server + if (tcpclient.Connect ("localhost", DEFAULT_PORT) != 1) { + printf ("cloud not connect to server\n"); + exit (1); + } + + if (ssl.Connect(tcpclient.GetSocket()) != 1) { + printf ("could not establish SSL connection:%s\n", strerror(errno)); + exit (1); + } + + // + // send some data + snprintf (buffer, NET_BUFFERSIZE, "nur ein kleiner Test."); + printf ("client:send '%s' to the server.\n", buffer); + if (ssl.Write(buffer, strlen (buffer)) != strlen (buffer)) { + printf ("could not send all data.\n"); + exit (1); + } + + // + // read some data (wait maximum 10x1000ms) + for (i = 10; i > 0; i--) + if (ssl.Read(buffer, NET_BUFFERSIZE) > 0) { + printf ("client:got '%s' from server.\n", buffer); + break; + } + + // + // close connection + ssl.Close(); + tcpclient.Close(); +}; + + + +int main (int argc, char **argv) { + pid_t pid; + + pid = fork(); + if (pid == 0) { // child process + printf ("start client\n"); + client(); + printf ("start client\n"); + client(); + printf ("start client\n"); + client(); + printf ("start client\n"); + client(); + } + else { // parent process + server(); + } + + return 0; +}; +