d02ea3ce410bf20d5ffd35f8c2a109003fe8b857
[src/xds/xds-agent.git] / lib / agent / sessions.go
1 /*
2  * Copyright (C) 2017 "IoT.bzh"
3  * Author Sebastien Douheret <sebastien@iot.bzh>
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *   http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package agent
19
20 import (
21         "encoding/base64"
22         "strconv"
23         "time"
24
25         "github.com/gin-gonic/gin"
26         "github.com/googollee/go-socket.io"
27         uuid "github.com/satori/go.uuid"
28         "github.com/syncthing/syncthing/lib/sync"
29 )
30
31 const sessionCookieName = "xds-agent-sid"
32 const sessionHeaderName = "XDS-AGENT-SID"
33
34 const sessionMonitorTime = 10 // Time (in seconds) to schedule monitoring session tasks
35
36 const initSessionMaxAge = 10 // Initial session max age in seconds
37 const maxSessions = 100000   // Maximum number of sessions in sessMap map
38
39 const secureCookie = false // TODO: see https://github.com/astaxie/beego/blob/master/session/session.go#L218
40
41 // ClientSession contains the info of a user/client session
42 type ClientSession struct {
43         ID       string
44         WSID     string // only one WebSocket per client/session
45         MaxAge   int64
46         IOSocket *socketio.Socket
47
48         // private
49         expireAt time.Time
50         useCount int64
51 }
52
53 // Sessions holds client sessions
54 type Sessions struct {
55         *Context
56         cookieMaxAge int64
57         sessMap      map[string]ClientSession
58         mutex        sync.Mutex
59         stop         chan struct{} // signals intentional stop
60 }
61
62 // NewClientSessions .
63 func NewClientSessions(ctx *Context, cookieMaxAge string) *Sessions {
64         ckMaxAge, err := strconv.ParseInt(cookieMaxAge, 10, 0)
65         if err != nil {
66                 ckMaxAge = 0
67         }
68         s := Sessions{
69                 Context:      ctx,
70                 cookieMaxAge: ckMaxAge,
71                 sessMap:      make(map[string]ClientSession),
72                 mutex:        sync.NewMutex(),
73                 stop:         make(chan struct{}),
74         }
75         s.webServer.router.Use(s.Middleware())
76
77         // Start monitoring of sessions Map (use to manage expiration and cleanup)
78         go s.monitorSessMap()
79
80         return &s
81 }
82
83 // Stop sessions management
84 func (s *Sessions) Stop() {
85         close(s.stop)
86 }
87
88 // Middleware is used to managed session
89 func (s *Sessions) Middleware() gin.HandlerFunc {
90         return func(c *gin.Context) {
91                 // FIXME Add CSRF management
92
93                 // Get session
94                 sess := s.Get(c)
95                 if sess == nil {
96                         // Allocate a new session key and put in cookie
97                         sess = s.newSession("")
98                 } else {
99                         s.refresh(sess.ID)
100                 }
101
102                 // Set session in cookie and in header
103                 // Do not set Domain to localhost (http://stackoverflow.com/questions/1134290/cookies-on-localhost-with-explicit-domain)
104                 c.SetCookie(sessionCookieName, sess.ID, int(sess.MaxAge), "/", "",
105                         secureCookie, false)
106                 c.Header(sessionHeaderName, sess.ID)
107
108                 // Save session id in gin metadata
109                 c.Set(sessionCookieName, sess.ID)
110
111                 c.Next()
112         }
113 }
114
115 // Get returns the client session for a specific ID
116 func (s *Sessions) Get(c *gin.Context) *ClientSession {
117         var sid string
118
119         // First get from gin metadata
120         v, exist := c.Get(sessionCookieName)
121         if v != nil {
122                 sid = v.(string)
123         }
124
125         // Then look in cookie
126         if !exist || sid == "" {
127                 sid, _ = c.Cookie(sessionCookieName)
128         }
129
130         // Then look in Header
131         if sid == "" {
132                 sid = c.Request.Header.Get(sessionCookieName)
133         }
134         if sid != "" {
135                 s.mutex.Lock()
136                 defer s.mutex.Unlock()
137                 if key, ok := s.sessMap[sid]; ok {
138                         // TODO: return a copy ???
139                         return &key
140                 }
141         }
142         return nil
143 }
144
145 // GetID returns the session or an empty string
146 func (s *Sessions) GetID(c *gin.Context) string {
147         if sess := s.Get(c); sess != nil {
148                 return sess.ID
149         }
150         return ""
151 }
152
153 // IOSocketGet Get socketio definition from sid
154 func (s *Sessions) IOSocketGet(sid string) *socketio.Socket {
155         s.mutex.Lock()
156         defer s.mutex.Unlock()
157         sess, ok := s.sessMap[sid]
158         if ok {
159                 return sess.IOSocket
160         }
161         return nil
162 }
163
164 // UpdateIOSocket updates the IO Socket definition for of a session
165 func (s *Sessions) UpdateIOSocket(sid string, so *socketio.Socket) error {
166         s.mutex.Lock()
167         defer s.mutex.Unlock()
168         if _, ok := s.sessMap[sid]; ok {
169                 sess := s.sessMap[sid]
170                 if so == nil {
171                         // Could be the case when socketio is closed/disconnected
172                         sess.WSID = ""
173                 } else {
174                         sess.WSID = (*so).Id()
175                 }
176                 sess.IOSocket = so
177                 s.sessMap[sid] = sess
178         }
179         return nil
180 }
181
182 // newSession Allocate a new client session
183 func (s *Sessions) newSession(prefix string) *ClientSession {
184         uuid := prefix + uuid.NewV4().String()
185         id := base64.URLEncoding.EncodeToString([]byte(uuid))
186         se := ClientSession{
187                 ID:       id,
188                 WSID:     "",
189                 MaxAge:   initSessionMaxAge,
190                 IOSocket: nil,
191                 expireAt: time.Now().Add(time.Duration(initSessionMaxAge) * time.Second),
192                 useCount: 0,
193         }
194         s.mutex.Lock()
195         defer s.mutex.Unlock()
196
197         s.sessMap[se.ID] = se
198
199         s.Log.Debugf("NEW session (%d): %s", len(s.sessMap), id)
200         return &se
201 }
202
203 // refresh Move this session ID to the head of the list
204 func (s *Sessions) refresh(sid string) {
205         s.mutex.Lock()
206         defer s.mutex.Unlock()
207
208         sess := s.sessMap[sid]
209         sess.useCount++
210         if sess.MaxAge < s.cookieMaxAge && sess.useCount > 1 {
211                 sess.MaxAge = s.cookieMaxAge
212                 sess.expireAt = time.Now().Add(time.Duration(sess.MaxAge) * time.Second)
213         }
214
215         // TODO - Add flood detection (like limit_req of nginx)
216         // (delayed request when to much requests in a short period of time)
217
218         s.sessMap[sid] = sess
219 }
220
221 func (s *Sessions) monitorSessMap() {
222         for {
223                 select {
224                 case <-s.stop:
225                         s.Log.Debugln("Stop monitorSessMap")
226                         return
227                 case <-time.After(sessionMonitorTime * time.Second):
228                         if s.LogLevelSilly {
229                                 s.Log.Debugf("Sessions Map size: %d", len(s.sessMap))
230                                 s.Log.Debugf("Sessions Map : %v", s.sessMap)
231                         }
232
233                         if len(s.sessMap) > maxSessions {
234                                 s.Log.Errorln("TOO MUCH sessions, cleanup old ones !")
235                         }
236
237                         s.mutex.Lock()
238                         for _, ss := range s.sessMap {
239                                 if ss.expireAt.Sub(time.Now()) < 0 {
240                                         if s.LogLevelSilly {
241                                                 s.Log.Debugf("Delete expired session id: %s", ss.ID)
242                                         }
243                                         delete(s.sessMap, ss.ID)
244                                 }
245                         }
246                         s.mutex.Unlock()
247                 }
248         }
249 }