Move lease fd send/receive to dlm-protocol
[src/drm-lease-manager.git] / drm-lease-manager / lease-server.c
index e05e8e4..c81d363 100644 (file)
@@ -14,6 +14,8 @@
  */
 
 #include "lease-server.h"
+
+#include "dlm-protocol.h"
 #include "log.h"
 #include "socket-path.h"
 
 
 #define SOCK_LOCK_SUFFIX ".lock"
 
+/* ACTIVE_CLIENTS
+ * An 'active' client is one that either
+ *  - owns a lease, or
+ *  - is requesting ownership of a lease (which will
+ *    disconnect the current owner if granted)
+ *
+ * There can only be at most one of each kind of client at the same
+ * time. Any other client connections are queued in the
+ * listen() backlog, waiting to be accept()'ed.
+ */
+#define ACTIVE_CLIENTS 2
+
 struct ls_socket {
        int fd;
+       bool is_server;
+       union {
+               struct ls_server *server;
+               struct ls_client *client;
+       };
+};
+
+struct ls_client {
+       struct ls_socket socket;
        struct ls_server *serv;
+       bool is_connected;
 };
 
 struct ls_server {
@@ -43,9 +67,7 @@ struct ls_server {
        int server_socket_lock;
 
        struct ls_socket listen;
-       struct ls_socket client;
-
-       bool is_client_connected;
+       struct ls_client clients[ACTIVE_CLIENTS];
 };
 
 struct ls {
@@ -55,37 +77,63 @@ struct ls {
        int nservers;
 };
 
-static bool client_connect(struct ls *ls, struct ls_server *serv)
+static void client_connect(struct ls *ls, struct ls_server *serv)
 {
        int cfd = accept(serv->listen.fd, NULL, NULL);
        if (cfd < 0) {
                DEBUG_LOG("accept failed on %s: %s\n", serv->address.sun_path,
                          strerror(errno));
-               return false;
+               return;
        }
 
-       if (serv->is_client_connected) {
-               WARN_LOG("Client already connected on %s\n",
-                        serv->address.sun_path);
+       struct ls_client *client = NULL;
+
+       for (int i = 0; i < ACTIVE_CLIENTS; i++) {
+               if (!serv->clients[i].is_connected) {
+                       client = &serv->clients[i];
+                       break;
+               }
+       }
+       if (!client) {
                close(cfd);
-               return false;
+               return;
        }
 
-       serv->client.fd = cfd;
-       serv->client.serv = serv;
+       client->socket.fd = cfd;
 
        struct epoll_event ev = {
-           .events = POLLHUP,
-           .data.ptr = &serv->client,
+           .events = POLLIN,
+           .data.ptr = &client->socket,
        };
        if (epoll_ctl(ls->epoll_fd, EPOLL_CTL_ADD, cfd, &ev)) {
                DEBUG_LOG("epoll_ctl add failed: %s\n", strerror(errno));
                close(cfd);
-               return false;
+               return;
        }
 
-       serv->is_client_connected = true;
-       return true;
+       client->is_connected = true;
+}
+
+static int parse_client_request(struct ls_socket *client)
+{
+       int ret = -1;
+       struct dlm_client_request hdr;
+       if (!receive_dlm_client_request(client->fd, &hdr))
+               return ret;
+
+       switch (hdr.opcode) {
+       case DLM_GET_LEASE:
+               ret = LS_REQ_GET_LEASE;
+               break;
+       case DLM_RELEASE_LEASE:
+               ret = LS_REQ_RELEASE_LEASE;
+               break;
+       default:
+               ERROR_LOG("Unexpected client request received\n");
+               break;
+       };
+
+       return ret;
 }
 
 static int create_socket_lock(struct sockaddr_un *addr)
@@ -140,7 +188,7 @@ static bool server_setup(struct ls *ls, struct ls_server *serv,
 
        address->sun_family = AF_UNIX;
 
-       int server_socket = socket(PF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0);
+       int server_socket = socket(PF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK, 0);
        if (server_socket < 0) {
                DEBUG_LOG("Socket creation failed: %s\n", strerror(errno));
                return false;
@@ -161,11 +209,18 @@ static bool server_setup(struct ls *ls, struct ls_server *serv,
                return false;
        }
 
-       serv->is_client_connected = false;
+       for (int i = 0; i < ACTIVE_CLIENTS; i++) {
+               struct ls_client *client = &serv->clients[i];
+               client->serv = serv;
+               client->socket.client = client;
+       }
+
        serv->lease_handle = lease_handle;
        serv->server_socket_lock = socket_lock;
+
        serv->listen.fd = server_socket;
-       serv->listen.serv = serv;
+       serv->listen.server = serv;
+       serv->listen.is_server = true;
 
        struct epoll_event ev = {
            .events = POLLIN,
@@ -193,7 +248,10 @@ static void server_shutdown(struct ls *ls, struct ls_server *serv)
 
        epoll_ctl(ls->epoll_fd, EPOLL_CTL_DEL, serv->listen.fd, NULL);
        close(serv->listen.fd);
-       ls_disconnect_client(ls, serv);
+
+       for (int i = 0; i < ACTIVE_CLIENTS; i++)
+               ls_disconnect_client(ls, &serv->clients[i]);
+
        close(serv->server_socket_lock);
 }
 
@@ -261,75 +319,60 @@ bool ls_get_request(struct ls *ls, struct ls_req *req)
                struct ls_socket *sock = ev.data.ptr;
                assert(sock);
 
-               struct ls_server *server = sock->serv;
-               req->lease_handle = server->lease_handle;
-               req->server = server;
+               if (sock->is_server) {
+                       if (ev.events & POLLIN)
+                               client_connect(ls, sock->server);
+                       continue;
+               }
 
-               if (sock == &server->listen) {
-                       if (!(ev.events & POLLIN))
-                               continue;
-                       if (client_connect(ls, server))
-                               request = LS_REQ_GET_LEASE;
-               } else if (sock == &server->client) {
-                       if (!(ev.events & POLLHUP))
-                               continue;
+               if (ev.events & POLLIN)
+                       request = parse_client_request(sock);
+
+               if (request < 0 && (ev.events & POLLHUP))
                        request = LS_REQ_RELEASE_LEASE;
-               } else {
-                       ERROR_LOG("Internal error: Invalid socket context\n");
-                       return false;
-               }
+
+               struct ls_client *client = sock->client;
+               struct ls_server *server = client->serv;
+
+               req->lease_handle = server->lease_handle;
+               req->client = client;
+               req->type = request;
        }
-       req->type = request;
        return true;
 }
 
-bool ls_send_fd(struct ls *ls, struct ls_server *serv, int fd)
+bool ls_send_fd(struct ls *ls, struct ls_client *client, int fd)
 {
        assert(ls);
-       assert(serv);
+       assert(client);
+
+       struct ls_server *serv = client->serv;
 
        if (fd < 0)
                return false;
 
-       char data[1];
-       struct iovec iov = {
-           .iov_base = data,
-           .iov_len = sizeof(data),
-       };
-
-       char ctrl_buf[CMSG_SPACE(sizeof(int))] = {0};
-
-       struct msghdr msg = {
-           .msg_iov = &iov,
-           .msg_iovlen = 1,
-           .msg_controllen = sizeof(ctrl_buf),
-           .msg_control = ctrl_buf,
-       };
-
-       struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
-       cmsg->cmsg_level = SOL_SOCKET;
-       cmsg->cmsg_type = SCM_RIGHTS;
-       cmsg->cmsg_len = CMSG_LEN(sizeof(int));
-       *((int *)CMSG_DATA(cmsg)) = fd;
-
-       if (sendmsg(serv->client.fd, &msg, 0) < 0) {
+       if (!send_lease_fd(client->socket.fd, fd)) {
                DEBUG_LOG("sendmsg failed on %s: %s\n", serv->address.sun_path,
                          strerror(errno));
                return false;
        }
 
-       INFO_LOG("Lease request granted on %s\n", serv->address.sun_path);
+       if (fd > 0)
+               INFO_LOG("Lease request granted on %s\n",
+                        serv->address.sun_path);
+
        return true;
 }
 
-void ls_disconnect_client(struct ls *ls, struct ls_server *serv)
+void ls_disconnect_client(struct ls *ls, struct ls_client *client)
 {
        assert(ls);
-       assert(serv);
-       if (!serv->is_client_connected)
+       assert(client);
+
+       if (!client->is_connected)
                return;
 
-       epoll_ctl(ls->epoll_fd, EPOLL_CTL_DEL, serv->client.fd, NULL);
-       close(serv->client.fd);
-       serv->is_client_connected = false;
+       epoll_ctl(ls->epoll_fd, EPOLL_CTL_DEL, client->socket.fd, NULL);
+       close(client->socket.fd);
+       client->is_connected = false;
 }