summaryrefslogtreecommitdiff
path: root/src/utils/common/netio.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/utils/common/netio.c')
-rw-r--r--src/utils/common/netio.c465
1 files changed, 465 insertions, 0 deletions
diff --git a/src/utils/common/netio.c b/src/utils/common/netio.c
new file mode 100644
index 0000000..7ac251d
--- /dev/null
+++ b/src/utils/common/netio.c
@@ -0,0 +1,465 @@
+/* Copyright (C) 2011 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#include <config.h>
+#include "utils/common/netio.h"
+
+#include <stdlib.h> // free
+#include <netdb.h> // addrinfo
+#include <poll.h> // poll
+#include <fcntl.h> // fcntl
+#include <sys/socket.h> // AF_INET (BSD)
+#include <netinet/in.h> // ntohl (BSD)
+#include <arpa/inet.h> // inet_ntop
+#include <unistd.h> // close
+#ifdef HAVE_SYS_UIO_H // struct iovec (OpenBSD)
+#include <sys/uio.h>
+#endif // HAVE_SYS_UIO_H
+
+#include "utils/common/msg.h" // WARN
+#include "common/descriptor.h" // KNOT_CLASS_IN
+#include "common/errcode.h" // KNOT_E
+
+server_t* server_create(const char *name, const char *service)
+{
+ if (name == NULL || service == NULL) {
+ DBG_NULL;
+ return NULL;
+ }
+
+ // Create output structure.
+ server_t *server = calloc(1, sizeof(server_t));
+
+ // Check output.
+ if (server == NULL) {
+ return NULL;
+ }
+
+ // Fill output.
+ server->name = strdup(name);
+ server->service = strdup(service);
+
+ if (server->name == NULL || server->service == NULL) {
+ server_free(server);
+ return NULL;
+ }
+
+ // Return result.
+ return server;
+}
+
+void server_free(server_t *server)
+{
+ if (server == NULL) {
+ DBG_NULL;
+ return;
+ }
+
+ free(server->name);
+ free(server->service);
+ free(server);
+}
+
+int get_iptype(const ip_t ip)
+{
+ switch (ip) {
+ case IP_4:
+ return AF_INET;
+ case IP_6:
+ return AF_INET6;
+ default:
+ return AF_UNSPEC;
+ }
+}
+
+int get_socktype(const protocol_t proto, const uint16_t type)
+{
+ switch (proto) {
+ case PROTO_TCP:
+ return SOCK_STREAM;
+ case PROTO_UDP:
+ return SOCK_DGRAM;
+ default:
+ if (type == KNOT_RRTYPE_AXFR || type == KNOT_RRTYPE_IXFR) {
+ return SOCK_STREAM;
+ } else {
+ return SOCK_DGRAM;
+ }
+ }
+}
+
+const char* get_sockname(const int socktype)
+{
+ const char *proto;
+
+ switch (socktype) {
+ case SOCK_STREAM:
+ proto = "TCP";
+ break;
+ case SOCK_DGRAM:
+ proto = "UDP";
+ break;
+ default:
+ proto = "UNKNOWN";
+ break;
+ }
+
+ return proto;
+}
+
+static int get_addr(const server_t *server,
+ const int iptype,
+ const int socktype,
+ struct addrinfo **info)
+{
+ struct addrinfo hints;
+
+ // Set connection hints.
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = iptype;
+ hints.ai_socktype = socktype;
+
+ // Get connection parameters.
+ if (getaddrinfo(server->name, server->service, &hints, info) != 0) {
+ ERR("can't resolve address %s#%s\n",
+ server->name, server->service);
+ return -1;
+ }
+
+ return 0;
+}
+
+static void get_addr_str(const struct sockaddr_storage *ss,
+ const int socktype,
+ char **dst)
+{
+ char addr[INET6_ADDRSTRLEN] = "NULL";
+ char buf[128] = "NULL";
+ uint16_t port;
+
+ // Get network address string and port number.
+ if (ss->ss_family == AF_INET) {
+ struct sockaddr_in *s = (struct sockaddr_in *)ss;
+ inet_ntop(ss->ss_family, &s->sin_addr, addr, sizeof(addr));
+ port = ntohs(s->sin_port);
+ } else {
+ struct sockaddr_in6 *s = (struct sockaddr_in6 *)ss;
+ inet_ntop(ss->ss_family, &s->sin6_addr, addr, sizeof(addr));
+ port = ntohs(s->sin6_port);
+ }
+
+ // Free previous string if any.
+ free(*dst);
+ *dst = NULL;
+
+ // Write formated information string.
+ int ret = snprintf(buf, sizeof(buf), "%s#%u(%s)", addr, port,
+ get_sockname(socktype));
+ if (ret > 0) {
+ *dst = strdup(buf);
+ } else {
+ *dst = strdup("NULL");
+ }
+}
+
+int net_init(const server_t *local,
+ const server_t *remote,
+ const int iptype,
+ const int socktype,
+ const int wait,
+ net_t *net)
+{
+ if (remote == NULL || net == NULL) {
+ DBG_NULL;
+ return KNOT_EINVAL;
+ }
+
+ // Clean network structure.
+ memset(net, 0, sizeof(*net));
+
+ // Get remote address list.
+ if (get_addr(remote, iptype, socktype, &net->remote_info) != 0) {
+ return KNOT_NET_EADDR;
+ }
+
+ // Set current remote address.
+ net->srv = net->remote_info;
+
+ // Get local address if specified.
+ if (local != NULL) {
+ if (get_addr(local, iptype, socktype, &net->local_info) != 0) {
+ return KNOT_NET_EADDR;
+ }
+ }
+
+ // Store network parameters.
+ net->iptype = iptype;
+ net->socktype = socktype;
+ net->wait = wait;
+ net->local = local;
+ net->remote = remote;
+
+ return KNOT_EOK;
+}
+
+int net_connect(net_t *net)
+{
+ struct pollfd pfd;
+ int sockfd, cs, err = 0;
+ socklen_t err_len = sizeof(err);
+
+ if (net == NULL || net->srv == NULL) {
+ DBG_NULL;
+ return KNOT_EINVAL;
+ }
+
+ // Set remote information string.
+ get_addr_str((struct sockaddr_storage *)net->srv->ai_addr,
+ net->socktype, &net->remote_str);
+
+ // Create socket.
+ sockfd = socket(net->srv->ai_family, net->socktype, 0);
+ if (sockfd == -1) {
+ WARN("can't create socket for %s\n", net->remote_str);
+ return KNOT_NET_ESOCKET;
+ }
+
+ // Initialize poll descriptor structure.
+ pfd.fd = sockfd;
+ pfd.events = POLLOUT;
+ pfd.revents = 0;
+
+ // Set non-blocking socket.
+ if (fcntl(sockfd, F_SETFL, O_NONBLOCK) == -1) {
+ WARN("can't set non-blocking socket for %s\n", net->remote_str);
+ return KNOT_NET_ESOCKET;
+ }
+
+ // Bind address to socket if specified.
+ if (net->local_info != NULL) {
+ // Set local information string.
+ get_addr_str((struct sockaddr_storage *)net->local_info->ai_addr,
+ net->socktype, &net->local_str);
+
+ if (bind(sockfd, net->local_info->ai_addr,
+ net->local_info->ai_addrlen) == -1) {
+ WARN("can't assign address %s\n", net->local_str);
+ return KNOT_NET_ESOCKET;
+ }
+ }
+
+ if (net->socktype == SOCK_STREAM) {
+ // Connect using socket.
+ if (connect(sockfd, net->srv->ai_addr, net->srv->ai_addrlen)
+ == -1 && errno != EINPROGRESS) {
+ WARN("can't connect to %s\n", net->remote_str);
+ close(sockfd);
+ return KNOT_NET_ECONNECT;
+ }
+
+ // Check for connection timeout.
+ if (poll(&pfd, 1, 1000 * net->wait) != 1) {
+ WARN("connection timeout for %s\n", net->remote_str);
+ close(sockfd);
+ return KNOT_NET_ECONNECT;
+ }
+
+ // Check if NB socket is writeable.
+ cs = getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &err, &err_len);
+ if (cs < 0 || err != 0) {
+ WARN("can't connect to %s\n", net->remote_str);
+ close(sockfd);
+ return KNOT_NET_ECONNECT;
+ }
+ }
+
+ // Store socket descriptor.
+ net->sockfd = sockfd;
+
+ return KNOT_EOK;
+}
+
+int net_send(const net_t *net, const uint8_t *buf, const size_t buf_len)
+{
+ if (net == NULL || buf == NULL) {
+ DBG_NULL;
+ return KNOT_EINVAL;
+ }
+
+ if (net->socktype == SOCK_STREAM) {
+ struct iovec iov[2];
+
+ // Leading packet length bytes.
+ uint16_t pktsize = htons(buf_len);
+
+ iov[0].iov_base = &pktsize;
+ iov[0].iov_len = sizeof(pktsize);
+ iov[1].iov_base = (uint8_t *)buf;
+ iov[1].iov_len = buf_len;
+
+ // Compute packet total length.
+ ssize_t total = iov[0].iov_len + iov[1].iov_len;
+
+ // Send data.
+ if (writev(net->sockfd, iov, 2) != total) {
+ WARN("can't send query to %s\n", net->remote_str);
+ return KNOT_NET_ESEND;
+ }
+ } else {
+ // Send data.
+ if (sendto(net->sockfd, buf, buf_len, 0, net->srv->ai_addr,
+ net->srv->ai_addrlen) != (ssize_t)buf_len) {
+ WARN("can't send query to %s\n", net->remote_str);
+ return KNOT_NET_ESEND;
+ }
+ }
+
+ return KNOT_EOK;
+}
+
+int net_receive(const net_t *net, uint8_t *buf, const size_t buf_len)
+{
+ ssize_t ret;
+ struct pollfd pfd;
+
+ if (net == NULL || buf == NULL) {
+ DBG_NULL;
+ return KNOT_EINVAL;
+ }
+
+ // Initialize poll descriptor structure.
+ pfd.fd = net->sockfd;
+ pfd.events = POLLIN;
+ pfd.revents = 0;
+
+ if (net->socktype == SOCK_STREAM) {
+ uint16_t msg_len;
+ uint32_t total = 0;
+
+ // Receive TCP message header.
+ while (total < sizeof(msg_len)) {
+ if (poll(&pfd, 1, 1000 * net->wait) != 1) {
+ WARN("response timeout for %s\n",
+ net->remote_str);
+ return KNOT_NET_ETIMEOUT;
+ }
+
+ // Receive piece of message.
+ ret = recv(net->sockfd, (uint8_t *)&msg_len + total,
+ sizeof(msg_len) - total, 0);
+ if (ret <= 0) {
+ WARN("can't receive reply from %s\n",
+ net->remote_str);
+ return KNOT_NET_ERECV;
+ }
+
+ total += ret;
+ }
+
+ // Convert number to host format.
+ msg_len = ntohs(msg_len);
+
+ total = 0;
+
+ // Receive whole answer message by parts.
+ while (total < msg_len) {
+ if (poll(&pfd, 1, 1000 * net->wait) != 1) {
+ WARN("response timeout for %s\n",
+ net->remote_str);
+ return KNOT_NET_ETIMEOUT;
+ }
+
+ // Receive piece of message.
+ ret = recv(net->sockfd, buf + total, msg_len - total, 0);
+ if (ret <= 0) {
+ WARN("can't receive reply from %s\n",
+ net->remote_str);
+ return KNOT_NET_ERECV;
+ }
+
+ total += ret;
+ }
+
+ return total;
+ } else {
+ struct sockaddr_storage from;
+
+ // Receive replies unless correct reply or timeout.
+ while (true) {
+ socklen_t from_len = sizeof(from);
+
+ // Wait for datagram data.
+ if (poll(&pfd, 1, 1000 * net->wait) != 1) {
+ WARN("response timeout for %s\n",
+ net->remote_str);
+ return KNOT_NET_ETIMEOUT;
+ }
+
+ // Receive whole UDP datagram.
+ ret = recvfrom(net->sockfd, buf, buf_len, 0,
+ (struct sockaddr *)&from, &from_len);
+ if (ret <= 0) {
+ WARN("can't receive reply from %s\n",
+ net->remote_str);
+ return KNOT_NET_ERECV;
+ }
+
+ // Compare reply address with the remote one.
+ if (from_len > sizeof(from) ||
+ memcmp(&from, net->srv->ai_addr, from_len) != 0) {
+ char *src = NULL;
+ get_addr_str(&from, net->socktype, &src);
+ WARN("unexpected reply source %s\n", src);
+ free(src);
+ continue;
+ }
+
+ return ret;
+ }
+ }
+
+ return KNOT_NET_ERECV;
+}
+
+void net_close(net_t *net)
+{
+ if (net == NULL) {
+ DBG_NULL;
+ return;
+ }
+
+ close(net->sockfd);
+ net->sockfd = -1;
+}
+
+void net_clean(net_t *net)
+{
+ if (net == NULL) {
+ DBG_NULL;
+ return;
+ }
+
+ free(net->local_str);
+ free(net->remote_str);
+
+ if (net->local_info != NULL) {
+ freeaddrinfo(net->local_info);
+ }
+
+ if (net->remote_info != NULL) {
+ freeaddrinfo(net->remote_info);
+ }
+}