lease-server: Allow multiple client connections
[src/drm-lease-manager.git] / drm-lease-manager / lease-server.c
index e05e8e4..c57316e 100644 (file)
 
 #define SOCK_LOCK_SUFFIX ".lock"
 
+/* ACTIVE_CLIENTS
+ * An 'active' client is one that either
+ *  - owns a lease, or
+ *  - is requesting ownership of a lease (which will
+ *    disconnect the current owner if granted)
+ *
+ * There can only be at most one of each kind of client at the same
+ * time. Any other client connections are queued in the
+ * listen() backlog, waiting to be accept()'ed.
+ */
+#define ACTIVE_CLIENTS 2
+
 struct ls_socket {
        int fd;
+       bool is_server;
+       union {
+               struct ls_server *server;
+               struct ls_client *client;
+       };
+};
+
+struct ls_client {
+       struct ls_socket socket;
        struct ls_server *serv;
+       bool is_connected;
 };
 
 struct ls_server {
@@ -43,9 +65,7 @@ struct ls_server {
        int server_socket_lock;
 
        struct ls_socket listen;
-       struct ls_socket client;
-
-       bool is_client_connected;
+       struct ls_client clients[ACTIVE_CLIENTS];
 };
 
 struct ls {
@@ -55,37 +75,42 @@ struct ls {
        int nservers;
 };
 
-static bool client_connect(struct ls *ls, struct ls_server *serv)
+static struct ls_client *client_connect(struct ls *ls, struct ls_server *serv)
 {
        int cfd = accept(serv->listen.fd, NULL, NULL);
        if (cfd < 0) {
                DEBUG_LOG("accept failed on %s: %s\n", serv->address.sun_path,
                          strerror(errno));
-               return false;
+               return NULL;
        }
 
-       if (serv->is_client_connected) {
-               WARN_LOG("Client already connected on %s\n",
-                        serv->address.sun_path);
+       struct ls_client *client = NULL;
+
+       for (int i = 0; i < ACTIVE_CLIENTS; i++) {
+               if (!serv->clients[i].is_connected) {
+                       client = &serv->clients[i];
+                       break;
+               }
+       }
+       if (!client) {
                close(cfd);
-               return false;
+               return NULL;
        }
 
-       serv->client.fd = cfd;
-       serv->client.serv = serv;
+       client->socket.fd = cfd;
 
        struct epoll_event ev = {
            .events = POLLHUP,
-           .data.ptr = &serv->client,
+           .data.ptr = &client->socket,
        };
        if (epoll_ctl(ls->epoll_fd, EPOLL_CTL_ADD, cfd, &ev)) {
                DEBUG_LOG("epoll_ctl add failed: %s\n", strerror(errno));
                close(cfd);
-               return false;
+               return NULL;
        }
 
-       serv->is_client_connected = true;
-       return true;
+       client->is_connected = true;
+       return client;
 }
 
 static int create_socket_lock(struct sockaddr_un *addr)
@@ -161,11 +186,18 @@ static bool server_setup(struct ls *ls, struct ls_server *serv,
                return false;
        }
 
-       serv->is_client_connected = false;
+       for (int i = 0; i < ACTIVE_CLIENTS; i++) {
+               struct ls_client *client = &serv->clients[i];
+               client->serv = serv;
+               client->socket.client = client;
+       }
+
        serv->lease_handle = lease_handle;
        serv->server_socket_lock = socket_lock;
+
        serv->listen.fd = server_socket;
-       serv->listen.serv = serv;
+       serv->listen.server = serv;
+       serv->listen.is_server = true;
 
        struct epoll_event ev = {
            .events = POLLIN,
@@ -193,7 +225,10 @@ static void server_shutdown(struct ls *ls, struct ls_server *serv)
 
        epoll_ctl(ls->epoll_fd, EPOLL_CTL_DEL, serv->listen.fd, NULL);
        close(serv->listen.fd);
-       ls_disconnect_client(ls, serv);
+
+       for (int i = 0; i < ACTIVE_CLIENTS; i++)
+               ls_disconnect_client(ls, &serv->clients[i]);
+
        close(serv->server_socket_lock);
 }
 
@@ -261,32 +296,37 @@ bool ls_get_request(struct ls *ls, struct ls_req *req)
                struct ls_socket *sock = ev.data.ptr;
                assert(sock);
 
-               struct ls_server *server = sock->serv;
-               req->lease_handle = server->lease_handle;
-               req->server = server;
+               struct ls_server *server;
+               struct ls_client *client;
 
-               if (sock == &server->listen) {
+               if (sock->is_server) {
                        if (!(ev.events & POLLIN))
                                continue;
-                       if (client_connect(ls, server))
+
+                       server = sock->server;
+                       client = client_connect(ls, server);
+                       if (client)
                                request = LS_REQ_GET_LEASE;
-               } else if (sock == &server->client) {
+               } else {
                        if (!(ev.events & POLLHUP))
                                continue;
+
+                       client = sock->client;
+                       server = client->serv;
                        request = LS_REQ_RELEASE_LEASE;
-               } else {
-                       ERROR_LOG("Internal error: Invalid socket context\n");
-                       return false;
                }
+
+               req->lease_handle = server->lease_handle;
+               req->client = client;
+               req->type = request;
        }
-       req->type = request;
        return true;
 }
 
-bool ls_send_fd(struct ls *ls, struct ls_server *serv, int fd)
+bool ls_send_fd(struct ls *ls, struct ls_client *client, int fd)
 {
        assert(ls);
-       assert(serv);
+       assert(client);
 
        if (fd < 0)
                return false;
@@ -312,7 +352,9 @@ bool ls_send_fd(struct ls *ls, struct ls_server *serv, int fd)
        cmsg->cmsg_len = CMSG_LEN(sizeof(int));
        *((int *)CMSG_DATA(cmsg)) = fd;
 
-       if (sendmsg(serv->client.fd, &msg, 0) < 0) {
+       struct ls_server *serv = client->serv;
+
+       if (sendmsg(client->socket.fd, &msg, 0) < 0) {
                DEBUG_LOG("sendmsg failed on %s: %s\n", serv->address.sun_path,
                          strerror(errno));
                return false;
@@ -322,14 +364,15 @@ bool ls_send_fd(struct ls *ls, struct ls_server *serv, int fd)
        return true;
 }
 
-void ls_disconnect_client(struct ls *ls, struct ls_server *serv)
+void ls_disconnect_client(struct ls *ls, struct ls_client *client)
 {
        assert(ls);
-       assert(serv);
-       if (!serv->is_client_connected)
+       assert(client);
+
+       if (!client->is_connected)
                return;
 
-       epoll_ctl(ls->epoll_fd, EPOLL_CTL_DEL, serv->client.fd, NULL);
-       close(serv->client.fd);
-       serv->is_client_connected = false;
+       epoll_ctl(ls->epoll_fd, EPOLL_CTL_DEL, client->socket.fd, NULL);
+       close(client->socket.fd);
+       client->is_connected = false;
 }