60b7b8a1bf17fa986399ad58abc371aa072393bd
[src/xds/xds-server.git] / lib / session / session.go
1 package session
2
3 import (
4         "encoding/base64"
5         "strconv"
6         "time"
7
8         "github.com/Sirupsen/logrus"
9         "github.com/gin-gonic/gin"
10         "github.com/googollee/go-socket.io"
11         uuid "github.com/satori/go.uuid"
12         "github.com/syncthing/syncthing/lib/sync"
13 )
14
15 const sessionCookieName = "xds-sid"
16 const sessionHeaderName = "XDS-SID"
17
18 const sessionMonitorTime = 10 // Time (in seconds) to schedule monitoring session tasks
19
20 const initSessionMaxAge = 10 // Initial session max age in seconds
21 const maxSessions = 100000   // Maximum number of sessions in sessMap map
22
23 const secureCookie = false // TODO: see https://github.com/astaxie/beego/blob/master/session/session.go#L218
24
25 // ClientSession contains the info of a user/client session
26 type ClientSession struct {
27         ID       string
28         WSID     string // only one WebSocket per client/session
29         MaxAge   int64
30         IOSocket *socketio.Socket
31
32         // private
33         expireAt time.Time
34         useCount int64
35 }
36
37 // Sessions holds client sessions
38 type Sessions struct {
39         router        *gin.Engine
40         cookieMaxAge  int64
41         sessMap       map[string]ClientSession
42         mutex         sync.Mutex
43         log           *logrus.Logger
44         LogLevelSilly bool
45         stop          chan struct{} // signals intentional stop
46 }
47
48 // NewClientSessions .
49 func NewClientSessions(router *gin.Engine, log *logrus.Logger, cookieMaxAge string, sillyLog bool) *Sessions {
50         ckMaxAge, err := strconv.ParseInt(cookieMaxAge, 10, 0)
51         if err != nil {
52                 ckMaxAge = 0
53         }
54         s := Sessions{
55                 router:        router,
56                 cookieMaxAge:  ckMaxAge,
57                 sessMap:       make(map[string]ClientSession),
58                 mutex:         sync.NewMutex(),
59                 log:           log,
60                 LogLevelSilly: sillyLog,
61                 stop:          make(chan struct{}),
62         }
63         s.router.Use(s.Middleware())
64
65         // Start monitoring of sessions Map (use to manage expiration and cleanup)
66         go s.monitorSessMap()
67
68         return &s
69 }
70
71 // Stop sessions management
72 func (s *Sessions) Stop() {
73         close(s.stop)
74 }
75
76 // Middleware is used to managed session
77 func (s *Sessions) Middleware() gin.HandlerFunc {
78         return func(c *gin.Context) {
79                 // FIXME Add CSRF management
80
81                 // Get session
82                 sess := s.Get(c)
83                 if sess == nil {
84                         // Allocate a new session key and put in cookie
85                         sess = s.newSession("")
86                 } else {
87                         s.refresh(sess.ID)
88                 }
89
90                 // Set session in cookie and in header
91                 // Do not set Domain to localhost (http://stackoverflow.com/questions/1134290/cookies-on-localhost-with-explicit-domain)
92                 c.SetCookie(sessionCookieName, sess.ID, int(sess.MaxAge), "/", "",
93                         secureCookie, false)
94                 c.Header(sessionHeaderName, sess.ID)
95
96                 // Save session id in gin metadata
97                 c.Set(sessionCookieName, sess.ID)
98
99                 c.Next()
100         }
101 }
102
103 // Get returns the client session for a specific ID
104 func (s *Sessions) Get(c *gin.Context) *ClientSession {
105         var sid string
106
107         // First get from gin metadata
108         v, exist := c.Get(sessionCookieName)
109         if v != nil {
110                 sid = v.(string)
111         }
112
113         // Then look in cookie
114         if !exist || sid == "" {
115                 sid, _ = c.Cookie(sessionCookieName)
116         }
117
118         // Then look in Header
119         if sid == "" {
120                 sid = c.Request.Header.Get(sessionCookieName)
121         }
122         if sid != "" {
123                 s.mutex.Lock()
124                 defer s.mutex.Unlock()
125                 if key, ok := s.sessMap[sid]; ok {
126                         // TODO: return a copy ???
127                         return &key
128                 }
129         }
130         return nil
131 }
132
133 // IOSocketGet Get socketio definition from sid
134 func (s *Sessions) IOSocketGet(sid string) *socketio.Socket {
135         s.mutex.Lock()
136         defer s.mutex.Unlock()
137         sess, ok := s.sessMap[sid]
138         if ok {
139                 return sess.IOSocket
140         }
141         return nil
142 }
143
144 // UpdateIOSocket updates the IO Socket definition for of a session
145 func (s *Sessions) UpdateIOSocket(sid string, so *socketio.Socket) error {
146         s.mutex.Lock()
147         defer s.mutex.Unlock()
148         if _, ok := s.sessMap[sid]; ok {
149                 sess := s.sessMap[sid]
150                 if so == nil {
151                         // Could be the case when socketio is closed/disconnected
152                         sess.WSID = ""
153                 } else {
154                         sess.WSID = (*so).Id()
155                 }
156                 sess.IOSocket = so
157                 s.sessMap[sid] = sess
158         }
159         return nil
160 }
161
162 // nesSession Allocate a new client session
163 func (s *Sessions) newSession(prefix string) *ClientSession {
164         uuid := prefix + uuid.NewV4().String()
165         id := base64.URLEncoding.EncodeToString([]byte(uuid))
166         se := ClientSession{
167                 ID:       id,
168                 WSID:     "",
169                 MaxAge:   initSessionMaxAge,
170                 IOSocket: nil,
171                 expireAt: time.Now().Add(time.Duration(initSessionMaxAge) * time.Second),
172                 useCount: 0,
173         }
174         s.mutex.Lock()
175         defer s.mutex.Unlock()
176
177         s.sessMap[se.ID] = se
178
179         s.log.Debugf("NEW session (%d): %s", len(s.sessMap), id)
180         return &se
181 }
182
183 // refresh Move this session ID to the head of the list
184 func (s *Sessions) refresh(sid string) {
185         s.mutex.Lock()
186         defer s.mutex.Unlock()
187
188         sess := s.sessMap[sid]
189         sess.useCount++
190         if sess.MaxAge < s.cookieMaxAge && sess.useCount > 1 {
191                 sess.MaxAge = s.cookieMaxAge
192                 sess.expireAt = time.Now().Add(time.Duration(sess.MaxAge) * time.Second)
193         }
194
195         // TODO - Add flood detection (like limit_req of nginx)
196         // (delayed request when to much requests in a short period of time)
197
198         s.sessMap[sid] = sess
199 }
200
201 func (s *Sessions) monitorSessMap() {
202         for {
203                 select {
204                 case <-s.stop:
205                         s.log.Debugln("Stop monitorSessMap")
206                         return
207                 case <-time.After(sessionMonitorTime * time.Second):
208                         if s.LogLevelSilly {
209                                 s.log.Debugf("Sessions Map size: %d", len(s.sessMap))
210                                 s.log.Debugf("Sessions Map : %v", s.sessMap)
211                         }
212
213                         if len(s.sessMap) > maxSessions {
214                                 s.log.Errorln("TOO MUCH sessions, cleanup old ones !")
215                         }
216
217                         s.mutex.Lock()
218                         for _, ss := range s.sessMap {
219                                 if ss.expireAt.Sub(time.Now()) < 0 {
220                                         s.log.Debugf("Delete expired session id: %s", ss.ID)
221                                         delete(s.sessMap, ss.ID)
222                                 }
223                         }
224                         s.mutex.Unlock()
225                 }
226         }
227 }