b56f9ffc80e5282b30fada1c6c4d664ecce7bfd6
[src/xds/xds-agent.git] / lib / session / session.go
1 package session
2
3 import (
4         "encoding/base64"
5         "strconv"
6         "time"
7
8         "github.com/Sirupsen/logrus"
9         "github.com/gin-gonic/gin"
10         "github.com/googollee/go-socket.io"
11         uuid "github.com/satori/go.uuid"
12         "github.com/syncthing/syncthing/lib/sync"
13 )
14
15 const sessionCookieName = "xds-agent-sid"
16 const sessionHeaderName = "XDS-AGENT-SID"
17
18 const sessionMonitorTime = 10 // Time (in seconds) to schedule monitoring session tasks
19
20 const initSessionMaxAge = 10 // Initial session max age in seconds
21 const maxSessions = 100000   // Maximum number of sessions in sessMap map
22
23 const secureCookie = false // TODO: see https://github.com/astaxie/beego/blob/master/session/session.go#L218
24
25 // ClientSession contains the info of a user/client session
26 type ClientSession struct {
27         ID       string
28         WSID     string // only one WebSocket per client/session
29         MaxAge   int64
30         IOSocket *socketio.Socket
31
32         // private
33         expireAt time.Time
34         useCount int64
35 }
36
37 // Sessions holds client sessions
38 type Sessions struct {
39         router       *gin.Engine
40         cookieMaxAge int64
41         sessMap      map[string]ClientSession
42         mutex        sync.Mutex
43         log          *logrus.Logger
44         stop         chan struct{} // signals intentional stop
45 }
46
47 // NewClientSessions .
48 func NewClientSessions(router *gin.Engine, log *logrus.Logger, cookieMaxAge string) *Sessions {
49         ckMaxAge, err := strconv.ParseInt(cookieMaxAge, 10, 0)
50         if err != nil {
51                 ckMaxAge = 0
52         }
53         s := Sessions{
54                 router:       router,
55                 cookieMaxAge: ckMaxAge,
56                 sessMap:      make(map[string]ClientSession),
57                 mutex:        sync.NewMutex(),
58                 log:          log,
59                 stop:         make(chan struct{}),
60         }
61         s.router.Use(s.Middleware())
62
63         // Start monitoring of sessions Map (use to manage expiration and cleanup)
64         go s.monitorSessMap()
65
66         return &s
67 }
68
69 // Stop sessions management
70 func (s *Sessions) Stop() {
71         close(s.stop)
72 }
73
74 // Middleware is used to managed session
75 func (s *Sessions) Middleware() gin.HandlerFunc {
76         return func(c *gin.Context) {
77                 // FIXME Add CSRF management
78
79                 // Get session
80                 sess := s.Get(c)
81                 if sess == nil {
82                         // Allocate a new session key and put in cookie
83                         sess = s.newSession("")
84                 } else {
85                         s.refresh(sess.ID)
86                 }
87
88                 // Set session in cookie and in header
89                 // Do not set Domain to localhost (http://stackoverflow.com/questions/1134290/cookies-on-localhost-with-explicit-domain)
90                 c.SetCookie(sessionCookieName, sess.ID, int(sess.MaxAge), "/", "",
91                         secureCookie, false)
92                 c.Header(sessionHeaderName, sess.ID)
93
94                 // Save session id in gin metadata
95                 c.Set(sessionCookieName, sess.ID)
96
97                 c.Next()
98         }
99 }
100
101 // Get returns the client session for a specific ID
102 func (s *Sessions) Get(c *gin.Context) *ClientSession {
103         var sid string
104
105         // First get from gin metadata
106         v, exist := c.Get(sessionCookieName)
107         if v != nil {
108                 sid = v.(string)
109         }
110
111         // Then look in cookie
112         if !exist || sid == "" {
113                 sid, _ = c.Cookie(sessionCookieName)
114         }
115
116         // Then look in Header
117         if sid == "" {
118                 sid = c.Request.Header.Get(sessionCookieName)
119         }
120         if sid != "" {
121                 s.mutex.Lock()
122                 defer s.mutex.Unlock()
123                 if key, ok := s.sessMap[sid]; ok {
124                         // TODO: return a copy ???
125                         return &key
126                 }
127         }
128         return nil
129 }
130
131 // IOSocketGet Get socketio definition from sid
132 func (s *Sessions) IOSocketGet(sid string) *socketio.Socket {
133         s.mutex.Lock()
134         defer s.mutex.Unlock()
135         sess, ok := s.sessMap[sid]
136         if ok {
137                 return sess.IOSocket
138         }
139         return nil
140 }
141
142 // UpdateIOSocket updates the IO Socket definition for of a session
143 func (s *Sessions) UpdateIOSocket(sid string, so *socketio.Socket) error {
144         s.mutex.Lock()
145         defer s.mutex.Unlock()
146         if _, ok := s.sessMap[sid]; ok {
147                 sess := s.sessMap[sid]
148                 if so == nil {
149                         // Could be the case when socketio is closed/disconnected
150                         sess.WSID = ""
151                 } else {
152                         sess.WSID = (*so).Id()
153                 }
154                 sess.IOSocket = so
155                 s.sessMap[sid] = sess
156         }
157         return nil
158 }
159
160 // nesSession Allocate a new client session
161 func (s *Sessions) newSession(prefix string) *ClientSession {
162         uuid := prefix + uuid.NewV4().String()
163         id := base64.URLEncoding.EncodeToString([]byte(uuid))
164         se := ClientSession{
165                 ID:       id,
166                 WSID:     "",
167                 MaxAge:   initSessionMaxAge,
168                 IOSocket: nil,
169                 expireAt: time.Now().Add(time.Duration(initSessionMaxAge) * time.Second),
170                 useCount: 0,
171         }
172         s.mutex.Lock()
173         defer s.mutex.Unlock()
174
175         s.sessMap[se.ID] = se
176
177         s.log.Debugf("NEW session (%d): %s", len(s.sessMap), id)
178         return &se
179 }
180
181 // refresh Move this session ID to the head of the list
182 func (s *Sessions) refresh(sid string) {
183         s.mutex.Lock()
184         defer s.mutex.Unlock()
185
186         sess := s.sessMap[sid]
187         sess.useCount++
188         if sess.MaxAge < s.cookieMaxAge && sess.useCount > 1 {
189                 sess.MaxAge = s.cookieMaxAge
190                 sess.expireAt = time.Now().Add(time.Duration(sess.MaxAge) * time.Second)
191         }
192
193         // TODO - Add flood detection (like limit_req of nginx)
194         // (delayed request when to much requests in a short period of time)
195
196         s.sessMap[sid] = sess
197 }
198
199 func (s *Sessions) monitorSessMap() {
200         const dbgFullTrace = false // for debugging
201
202         for {
203                 select {
204                 case <-s.stop:
205                         s.log.Debugln("Stop monitorSessMap")
206                         return
207                 case <-time.After(sessionMonitorTime * time.Second):
208                         if dbgFullTrace {
209                                 s.log.Debugf("Sessions Map size: %d", len(s.sessMap))
210                                 s.log.Debugf("Sessions Map : %v", s.sessMap)
211                         }
212
213                         if len(s.sessMap) > maxSessions {
214                                 s.log.Errorln("TOO MUCH sessions, cleanup old ones !")
215                         }
216
217                         s.mutex.Lock()
218                         for _, ss := range s.sessMap {
219                                 if ss.expireAt.Sub(time.Now()) < 0 {
220                                         s.log.Debugf("Delete expired session id: %s", ss.ID)
221                                         delete(s.sessMap, ss.ID)
222                                 }
223                         }
224                         s.mutex.Unlock()
225                 }
226         }
227 }