Switch from codegangsta/cli to urfave/cli package.
[src/xds/xds-agent.git] / lib / agent / sessions.go
1 package agent
2
3 import (
4         "encoding/base64"
5         "strconv"
6         "time"
7
8         "github.com/gin-gonic/gin"
9         "github.com/googollee/go-socket.io"
10         uuid "github.com/satori/go.uuid"
11         "github.com/syncthing/syncthing/lib/sync"
12 )
13
14 const sessionCookieName = "xds-agent-sid"
15 const sessionHeaderName = "XDS-AGENT-SID"
16
17 const sessionMonitorTime = 10 // Time (in seconds) to schedule monitoring session tasks
18
19 const initSessionMaxAge = 10 // Initial session max age in seconds
20 const maxSessions = 100000   // Maximum number of sessions in sessMap map
21
22 const secureCookie = false // TODO: see https://github.com/astaxie/beego/blob/master/session/session.go#L218
23
24 // ClientSession contains the info of a user/client session
25 type ClientSession struct {
26         ID       string
27         WSID     string // only one WebSocket per client/session
28         MaxAge   int64
29         IOSocket *socketio.Socket
30
31         // private
32         expireAt time.Time
33         useCount int64
34 }
35
36 // Sessions holds client sessions
37 type Sessions struct {
38         *Context
39         cookieMaxAge int64
40         sessMap      map[string]ClientSession
41         mutex        sync.Mutex
42         stop         chan struct{} // signals intentional stop
43 }
44
45 // NewClientSessions .
46 func NewClientSessions(ctx *Context, cookieMaxAge string) *Sessions {
47         ckMaxAge, err := strconv.ParseInt(cookieMaxAge, 10, 0)
48         if err != nil {
49                 ckMaxAge = 0
50         }
51         s := Sessions{
52                 Context:      ctx,
53                 cookieMaxAge: ckMaxAge,
54                 sessMap:      make(map[string]ClientSession),
55                 mutex:        sync.NewMutex(),
56                 stop:         make(chan struct{}),
57         }
58         s.webServer.router.Use(s.Middleware())
59
60         // Start monitoring of sessions Map (use to manage expiration and cleanup)
61         go s.monitorSessMap()
62
63         return &s
64 }
65
66 // Stop sessions management
67 func (s *Sessions) Stop() {
68         close(s.stop)
69 }
70
71 // Middleware is used to managed session
72 func (s *Sessions) Middleware() gin.HandlerFunc {
73         return func(c *gin.Context) {
74                 // FIXME Add CSRF management
75
76                 // Get session
77                 sess := s.Get(c)
78                 if sess == nil {
79                         // Allocate a new session key and put in cookie
80                         sess = s.newSession("")
81                 } else {
82                         s.refresh(sess.ID)
83                 }
84
85                 // Set session in cookie and in header
86                 // Do not set Domain to localhost (http://stackoverflow.com/questions/1134290/cookies-on-localhost-with-explicit-domain)
87                 c.SetCookie(sessionCookieName, sess.ID, int(sess.MaxAge), "/", "",
88                         secureCookie, false)
89                 c.Header(sessionHeaderName, sess.ID)
90
91                 // Save session id in gin metadata
92                 c.Set(sessionCookieName, sess.ID)
93
94                 c.Next()
95         }
96 }
97
98 // Get returns the client session for a specific ID
99 func (s *Sessions) Get(c *gin.Context) *ClientSession {
100         var sid string
101
102         // First get from gin metadata
103         v, exist := c.Get(sessionCookieName)
104         if v != nil {
105                 sid = v.(string)
106         }
107
108         // Then look in cookie
109         if !exist || sid == "" {
110                 sid, _ = c.Cookie(sessionCookieName)
111         }
112
113         // Then look in Header
114         if sid == "" {
115                 sid = c.Request.Header.Get(sessionCookieName)
116         }
117         if sid != "" {
118                 s.mutex.Lock()
119                 defer s.mutex.Unlock()
120                 if key, ok := s.sessMap[sid]; ok {
121                         // TODO: return a copy ???
122                         return &key
123                 }
124         }
125         return nil
126 }
127
128 // IOSocketGet Get socketio definition from sid
129 func (s *Sessions) IOSocketGet(sid string) *socketio.Socket {
130         s.mutex.Lock()
131         defer s.mutex.Unlock()
132         sess, ok := s.sessMap[sid]
133         if ok {
134                 return sess.IOSocket
135         }
136         return nil
137 }
138
139 // UpdateIOSocket updates the IO Socket definition for of a session
140 func (s *Sessions) UpdateIOSocket(sid string, so *socketio.Socket) error {
141         s.mutex.Lock()
142         defer s.mutex.Unlock()
143         if _, ok := s.sessMap[sid]; ok {
144                 sess := s.sessMap[sid]
145                 if so == nil {
146                         // Could be the case when socketio is closed/disconnected
147                         sess.WSID = ""
148                 } else {
149                         sess.WSID = (*so).Id()
150                 }
151                 sess.IOSocket = so
152                 s.sessMap[sid] = sess
153         }
154         return nil
155 }
156
157 // newSession Allocate a new client session
158 func (s *Sessions) newSession(prefix string) *ClientSession {
159         uuid := prefix + uuid.NewV4().String()
160         id := base64.URLEncoding.EncodeToString([]byte(uuid))
161         se := ClientSession{
162                 ID:       id,
163                 WSID:     "",
164                 MaxAge:   initSessionMaxAge,
165                 IOSocket: nil,
166                 expireAt: time.Now().Add(time.Duration(initSessionMaxAge) * time.Second),
167                 useCount: 0,
168         }
169         s.mutex.Lock()
170         defer s.mutex.Unlock()
171
172         s.sessMap[se.ID] = se
173
174         s.Log.Debugf("NEW session (%d): %s", len(s.sessMap), id)
175         return &se
176 }
177
178 // refresh Move this session ID to the head of the list
179 func (s *Sessions) refresh(sid string) {
180         s.mutex.Lock()
181         defer s.mutex.Unlock()
182
183         sess := s.sessMap[sid]
184         sess.useCount++
185         if sess.MaxAge < s.cookieMaxAge && sess.useCount > 1 {
186                 sess.MaxAge = s.cookieMaxAge
187                 sess.expireAt = time.Now().Add(time.Duration(sess.MaxAge) * time.Second)
188         }
189
190         // TODO - Add flood detection (like limit_req of nginx)
191         // (delayed request when to much requests in a short period of time)
192
193         s.sessMap[sid] = sess
194 }
195
196 func (s *Sessions) monitorSessMap() {
197         for {
198                 select {
199                 case <-s.stop:
200                         s.Log.Debugln("Stop monitorSessMap")
201                         return
202                 case <-time.After(sessionMonitorTime * time.Second):
203                         if s.LogLevelSilly {
204                                 s.Log.Debugf("Sessions Map size: %d", len(s.sessMap))
205                                 s.Log.Debugf("Sessions Map : %v", s.sessMap)
206                         }
207
208                         if len(s.sessMap) > maxSessions {
209                                 s.Log.Errorln("TOO MUCH sessions, cleanup old ones !")
210                         }
211
212                         s.mutex.Lock()
213                         for _, ss := range s.sessMap {
214                                 if ss.expireAt.Sub(time.Now()) < 0 {
215                                         if s.LogLevelSilly {
216                                                 s.Log.Debugf("Delete expired session id: %s", ss.ID)
217                                         }
218                                         delete(s.sessMap, ss.ID)
219                                 }
220                         }
221                         s.mutex.Unlock()
222                 }
223         }
224 }