adding basic support for SSL. need to add nonblocking

origin
steffen 6 years ago
parent a51d1c9bef
commit eb8b60d2cf

@ -6,13 +6,13 @@ ETCPREFIX=/etc
CXX=g++
CXXFLAGS= -ggdb -fPIC -pg -Wno-write-strings -I./ -std=c++11
LDFLAGS= -lm -lc -pg
LDFLAGS= -lm -lc -pg -lssl -lcrypto
DEFAULT_TCPPORT=6131
DEFAULT_UDPPORT=6131
DEFAULT_SERVER=localhost
OBJLIB=network.o udp.o tcp.o unix.o
OBJLIB=network.o udp.o tcp.o unix.o ssl.o
INCLIB=config.h UDPTCPNetwork.h
OBJLIB_NAME=UDPTCPNetwork
TARGET=lib$(OBJLIB_NAME).so.$(VERSION)
@ -20,14 +20,21 @@ TARGET=lib$(OBJLIB_NAME).so.$(VERSION)
DISTNAME=libUDPTCPNetwork-$(VERSION)
DEPENDFILE=.depend
all: dep $(TARGET) test-udp test-tcp
all: dep $(TARGET) test-udp test-tcp test-ssl
test-tcp: $(TARGET) test-tcp.o config.h
$(CXX) test-tcp.o -o $@ $(LDFLAGS) -lUDPTCPNetwork -L./ -I./
test-ssl: $(TARGET) test-ssl.o config.h
$(CXX) test-ssl.o -o $@ $(LDFLAGS) -lUDPTCPNetwork -L./ -I./
test-udp: $(TARGET) test-udp.o config.h
$(CXX) test-udp.o -o $@ $(LDFLAGS) -lUDPTCPNetwork -L./ -I./
keygen:
# openssl req -nodes -new -newkey rsa:2048 -sha256 -out csr.pem -keyout privkey.pem
openssl req -x509 -sha256 -nodes -days 365 -newkey rsa:2048 -keyout privkey.pem -out cert.pem
install: $(TARGET)
cp -f $(TARGET) $(PREFIX)/lib/
ln -sf $(TARGET) $(PREFIX)/lib/lib$(OBJLIB_NAME).so
@ -51,6 +58,7 @@ dep:
clean:
rm test-tcp -rf
rm test-udp -rf
rm test-ssl -rf
rm -rf gmon.out
rm *.s -rf
rm *.o -rf
@ -61,6 +69,7 @@ clean:
rm -rf *.so
rm -rf *.a
rm -rf *.so.*
rm -rf *.pem -rf
cleanall: clean

@ -5,6 +5,8 @@
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <string>
@ -64,13 +66,10 @@ public:
class TCP {
private:
int sock;
// struct sockaddr_storage localaddr;
// struct sockaddr_storage remoteaddr;
string remote_host;
string remote_port;
int readcnt;
int writecnt;
public:
TCP();
TCP(int s);
@ -105,6 +104,33 @@ public:
};
/************************************************************************
* SSL functions
*/
class SSLSocket {
private:
int readcnt;
int writecnt;
string certfile;
string keyfile;
SSL *ssl;
SSL_CTX *ctx;
int NewServerCTX();
int NewClientCTX();
const string GetSSLErrorText(int err);
public:
SSLSocket();
~SSLSocket();
int SetCertificat(string certf, string keyf);
int Connect(int sockfd);
int Accept(int sockfd);
long int Read(char *buffer, long int len);
long int Write(char *buffer, long int len);
int Close(); // returns socket
};
/************************************************************************
* unix socket related functions
@ -136,5 +162,7 @@ public:
int GetSocket() { return sock; };
};
#endif

223
ssl.cc

@ -0,0 +1,223 @@
/*
*
*/
#include <stdio.h>
#include <unistd.h> /* close() */
#include <stdlib.h>
#include <string.h> /* memset() */
#include <sys/stat.h>
#include <errno.h>
#include <string.h>
#include <errno.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include "UDPTCPNetwork.h"
static int ssl_init = 0;
SSLSocket::SSLSocket() {
readcnt = 0;
writecnt = 0;
certfile = "";
keyfile = "";
ctx = NULL;
ssl = NULL;
if (ssl_init == 0) {
ssl_init = 1;
SSL_library_init();
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
}
};
SSLSocket::~SSLSocket() {
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
if (ssl) SSL_free(ssl);
ssl = NULL;
};
int SSLSocket::NewServerCTX() {
struct stat st;
if (stat (certfile.c_str(), &st)) return 0;
if (stat (keyfile.c_str(), &st)) return 0;
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
ctx = SSL_CTX_new(TLSv1_2_server_method());
if (SSL_CTX_use_certificate_file(ctx, certfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) {
ERR_print_errors_fp(stderr);
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
errno = EPROTO;
return 0;
}
if ( SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM) <= 0 ) {
ERR_print_errors_fp(stderr);
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
errno = EPROTO;
return 0;
}
if (!SSL_CTX_check_private_key(ctx)) {
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
errno = EPROTO;
return 0;
}
return 1;
};
int SSLSocket::NewClientCTX() {
if (ctx) SSL_CTX_free(ctx);
ctx = NULL;
ctx = SSL_CTX_new(TLSv1_2_client_method());
return 1;
};
int SSLSocket::SetCertificat(string certf, string keyf) {
certfile = certf;
keyfile = keyf;
return 1;
};
int SSLSocket::Connect (int sockfd) {
NewClientCTX();
ssl = SSL_new(ctx);
SSL_set_fd (ssl, sockfd);
if (SSL_connect(ssl) == -1 ) {
printf ("ssl connect: error\n");
ERR_print_errors_fp(stderr);
exit (1);
}
return 1;
};
int SSLSocket::Accept (int sockfd) {
int err;
NewServerCTX();
ssl = SSL_new(ctx);
SSL_set_fd (ssl, sockfd);
if (SSL_accept(ssl) == -1 ) {
err = SSL_get_error(ssl, -1);
printf ("%s %s:%d error: %s\n", __FUNCTION__, __FILE__, __LINE__, GetSSLErrorText(err).c_str());
return 0;
}
return 1;
};
const string SSLSocket::GetSSLErrorText(int err) {
string s;
switch (err) {
case SSL_ERROR_NONE:
s = "SSL_ERROR_NONE";
break;
case SSL_ERROR_SSL:
s = "SSL_ERROR_SSL";
break;
case SSL_ERROR_WANT_READ:
s = "SSL_ERROR_WANT_READ";
break;
case SSL_ERROR_WANT_WRITE:
s = "SSL_ERROR_WANT_WRITE";
break;
case SSL_ERROR_SYSCALL:
s = "SSL_ERROR_SYSCALL";
break;
case SSL_ERROR_WANT_CONNECT:
s = "SSL_ERROR_WANT_CONNECT";
break;
case SSL_ERROR_ZERO_RETURN:
s = "SSL_ERROR_ZERO_RETURN";
break;
case SSL_ERROR_WANT_ACCEPT:
s = "SSL_ERROR_WANT_ACCEPT";
break;
default:
s = "SSL_ERROR unknown " + to_string(err);
break;
}
return s;
}
//
// close ssl and return socket.
int SSLSocket::Close () {
int sock = 0;
if (ssl) {
sock = SSL_get_fd(ssl);
SSL_free(ssl);
ssl = NULL;
}
if (ctx) {
SSL_CTX_free(ctx);
ctx = NULL;
}
return sock;
};
//
// need to add timeout currently data is blocking
long int SSLSocket::Read (char *buffer, long int len) {
int ret;
if (!ssl) {
errno = EPROTO;
return -1;
}
ret = SSL_read(ssl, buffer, len);
return ret;
};
//
// need to add timeout currently data is blocking
long int SSLSocket::Write (char *buffer, long int len) {
int ret;
if (!ssl) {
errno = EPROTO;
return -1;
}
ret = SSL_write(ssl, buffer, len);
return ret;
};

@ -0,0 +1,149 @@
#include <string.h>
#include <unistd.h>
#include "UDPTCPNetwork.h"
#define DEFAULT_PORT 12345
void server () {
TCP tcpserver;
TCP *connection;
SSLSocket ssl;
int i, timeout;
pid_t pid;
char buffer[NET_BUFFERSIZE];
//
// start the server
if (tcpserver.Listen(DEFAULT_PORT) != 1) {
printf ("cloud not start the tcp server\n");
exit (1);
}
//
// init SSL
if (ssl.SetCertificat("cert.pem", "privkey.pem") != 1) {
printf ("SetCertificat error:%s\n", strerror(errno));
exit (1);
}
//
// check for connections
for (timeout = 10; timeout > 0; timeout--) {
connection = tcpserver.Accept();
if (connection != NULL) {
//
// someone connected - create new process
// take care of parallel processing (parent is always the server)
//
printf (" server: got a connection forking new process\n");
pid = fork();
if (pid == 0) {
//
// child process - always close server since it will handeled
// by the parent process. Make sure the client exits and never
// returns.
tcpserver.Close();
if (ssl.Accept(connection->GetSocket()) != 1) {
printf ("could not establish SSL connection:%s\n", strerror(errno));
exit (1);
}
i = ssl.Read(buffer, NET_BUFFERSIZE);
if (i > 0) {
int c;
printf (" server: got: '%s'\n", buffer);
for (c = 0; c < i; c++) buffer[c] = toupper(buffer[c]);
ssl.Write(buffer, i);
}
//
// just delete the class object, it will close the client connection
ssl.Close();
delete (connection);
//
// exit child process
exit (1);
}
else {
//
// parent process - just close the client connection
// it will be handeled by the child process.
delete (connection);
}
}
sleep (1);
}
};
void client () {
TCP tcpclient;
SSLSocket ssl;
char buffer[NET_BUFFERSIZE];
int i;
sleep (1); // wait one second to start the server
//
// connect to the server
if (tcpclient.Connect ("localhost", DEFAULT_PORT) != 1) {
printf ("cloud not connect to server\n");
exit (1);
}
if (ssl.Connect(tcpclient.GetSocket()) != 1) {
printf ("could not establish SSL connection:%s\n", strerror(errno));
exit (1);
}
//
// send some data
snprintf (buffer, NET_BUFFERSIZE, "nur ein kleiner Test.");
printf ("client:send '%s' to the server.\n", buffer);
if (ssl.Write(buffer, strlen (buffer)) != strlen (buffer)) {
printf ("could not send all data.\n");
exit (1);
}
//
// read some data (wait maximum 10x1000ms)
for (i = 10; i > 0; i--)
if (ssl.Read(buffer, NET_BUFFERSIZE) > 0) {
printf ("client:got '%s' from server.\n", buffer);
break;
}
//
// close connection
ssl.Close();
tcpclient.Close();
};
int main (int argc, char **argv) {
pid_t pid;
pid = fork();
if (pid == 0) { // child process
printf ("start client\n");
client();
printf ("start client\n");
client();
printf ("start client\n");
client();
printf ("start client\n");
client();
}
else { // parent process
server();
}
return 0;
};
Loading…
Cancel
Save