Set SUB_VERSION as current commit ID.
[src/xds/xds-agent.git] / lib / common / httpclient.go
1 package common
2
3 import (
4         "bytes"
5         "crypto/tls"
6         "encoding/json"
7         "errors"
8         "fmt"
9         "io/ioutil"
10         "net/http"
11         "strings"
12
13         "github.com/Sirupsen/logrus"
14 )
15
16 type HTTPClient struct {
17         httpClient http.Client
18         endpoint   string
19         apikey     string
20         username   string
21         password   string
22         id         string
23         csrf       string
24         conf       HTTPClientConfig
25         logger     *logrus.Logger
26 }
27
28 type HTTPClientConfig struct {
29         URLPrefix           string
30         HeaderAPIKeyName    string
31         Apikey              string
32         HeaderClientKeyName string
33         CsrfDisable         bool
34 }
35
36 const (
37         logError   = 1
38         logWarning = 2
39         logInfo    = 3
40         logDebug   = 4
41 )
42
43 // Inspired by syncthing/cmd/cli
44
45 const insecure = false
46
47 // HTTPNewClient creates a new HTTP client to deal with Syncthing
48 func HTTPNewClient(baseURL string, cfg HTTPClientConfig) (*HTTPClient, error) {
49
50         // Create w new Http client
51         httpClient := http.Client{
52                 Transport: &http.Transport{
53                         TLSClientConfig: &tls.Config{
54                                 InsecureSkipVerify: insecure,
55                         },
56                 },
57         }
58         client := HTTPClient{
59                 httpClient: httpClient,
60                 endpoint:   baseURL,
61                 apikey:     cfg.Apikey,
62                 conf:       cfg,
63                 /* TODO - add user + pwd support
64                 username:   c.GlobalString("username"),
65                 password:   c.GlobalString("password"),
66                 */
67         }
68
69         if client.apikey == "" {
70                 if err := client.getCidAndCsrf(); err != nil {
71                         return nil, err
72                 }
73         }
74         return &client, nil
75 }
76
77 // SetLogger Define the logger to use
78 func (c *HTTPClient) SetLogger(log *logrus.Logger) {
79         c.logger = log
80 }
81
82 func (c *HTTPClient) log(level int, format string, args ...interface{}) {
83         if c.logger != nil {
84                 switch level {
85                 case logError:
86                         c.logger.Errorf(format, args...)
87                         break
88                 case logWarning:
89                         c.logger.Warningf(format, args...)
90                         break
91                 case logInfo:
92                         c.logger.Infof(format, args...)
93                         break
94                 default:
95                         c.logger.Debugf(format, args...)
96                         break
97                 }
98         }
99 }
100
101 // Send request to retrieve Client id and/or CSRF token
102 func (c *HTTPClient) getCidAndCsrf() error {
103         request, err := http.NewRequest("GET", c.endpoint, nil)
104         if err != nil {
105                 return err
106         }
107         if _, err := c.handleRequest(request); err != nil {
108                 return err
109         }
110         if c.id == "" {
111                 return errors.New("Failed to get device ID")
112         }
113         if !c.conf.CsrfDisable && c.csrf == "" {
114                 return errors.New("Failed to get CSRF token")
115         }
116         return nil
117 }
118
119 // GetClientID returns the id
120 func (c *HTTPClient) GetClientID() string {
121         return c.id
122 }
123
124 // formatURL Build full url by concatenating all parts
125 func (c *HTTPClient) formatURL(endURL string) string {
126         url := c.endpoint
127         if !strings.HasSuffix(url, "/") {
128                 url += "/"
129         }
130         url += strings.TrimLeft(c.conf.URLPrefix, "/")
131         if !strings.HasSuffix(url, "/") {
132                 url += "/"
133         }
134         return url + strings.TrimLeft(endURL, "/")
135 }
136
137 // HTTPGet Send a Get request to client and return an error object
138 func (c *HTTPClient) HTTPGet(url string, data *[]byte) error {
139         _, err := c.HTTPGetWithRes(url, data)
140         return err
141 }
142
143 // HTTPGetWithRes Send a Get request to client and return both response and error
144 func (c *HTTPClient) HTTPGetWithRes(url string, data *[]byte) (*http.Response, error) {
145         request, err := http.NewRequest("GET", c.formatURL(url), nil)
146         if err != nil {
147                 return nil, err
148         }
149         res, err := c.handleRequest(request)
150         if err != nil {
151                 return res, err
152         }
153         if res.StatusCode != 200 {
154                 return res, errors.New(res.Status)
155         }
156
157         *data = c.responseToBArray(res)
158
159         return res, nil
160 }
161
162 // HTTPPost Send a POST request to client and return an error object
163 func (c *HTTPClient) HTTPPost(url string, body string) error {
164         _, err := c.HTTPPostWithRes(url, body)
165         return err
166 }
167
168 // HTTPPostWithRes Send a POST request to client and return both response and error
169 func (c *HTTPClient) HTTPPostWithRes(url string, body string) (*http.Response, error) {
170         request, err := http.NewRequest("POST", c.formatURL(url), bytes.NewBufferString(body))
171         if err != nil {
172                 return nil, err
173         }
174         res, err := c.handleRequest(request)
175         if err != nil {
176                 return res, err
177         }
178         if res.StatusCode != 200 {
179                 return res, errors.New(res.Status)
180         }
181         return res, nil
182 }
183
184 func (c *HTTPClient) responseToBArray(response *http.Response) []byte {
185         defer response.Body.Close()
186         bytes, err := ioutil.ReadAll(response.Body)
187         if err != nil {
188                 // TODO improved error reporting
189                 fmt.Println("ERROR: " + err.Error())
190         }
191         return bytes
192 }
193
194 func (c *HTTPClient) handleRequest(request *http.Request) (*http.Response, error) {
195         if c.conf.HeaderAPIKeyName != "" && c.apikey != "" {
196                 request.Header.Set(c.conf.HeaderAPIKeyName, c.apikey)
197         }
198         if c.conf.HeaderClientKeyName != "" && c.id != "" {
199                 request.Header.Set(c.conf.HeaderClientKeyName, c.id)
200         }
201         if c.username != "" || c.password != "" {
202                 request.SetBasicAuth(c.username, c.password)
203         }
204         if c.csrf != "" {
205                 request.Header.Set("X-CSRF-Token-"+c.id[:5], c.csrf)
206         }
207
208         c.log(logDebug, "HTTP %s %v", request.Method, request.URL)
209
210         response, err := c.httpClient.Do(request)
211         if err != nil {
212                 return nil, err
213         }
214
215         // Detect client ID change
216         cid := response.Header.Get(c.conf.HeaderClientKeyName)
217         if cid != "" && c.id != cid {
218                 c.id = cid
219         }
220
221         // Detect CSR token change
222         for _, item := range response.Cookies() {
223                 if item.Name == "CSRF-Token-"+c.id[:5] {
224                         c.csrf = item.Value
225                         goto csrffound
226                 }
227         }
228         // OK CSRF found
229 csrffound:
230
231         if response.StatusCode == 404 {
232                 return nil, errors.New("Invalid endpoint or API call")
233         } else if response.StatusCode == 401 {
234                 return nil, errors.New("Invalid username or password")
235         } else if response.StatusCode == 403 {
236                 if c.apikey == "" {
237                         // Request a new Csrf for next requests
238                         c.getCidAndCsrf()
239                         return nil, errors.New("Invalid CSRF token")
240                 }
241                 return nil, errors.New("Invalid API key")
242         } else if response.StatusCode != 200 {
243                 data := make(map[string]interface{})
244                 // Try to decode error field of APIError struct
245                 json.Unmarshal(c.responseToBArray(response), &data)
246                 if err, found := data["error"]; found {
247                         return nil, fmt.Errorf(err.(string))
248                 } else {
249                         body := strings.TrimSpace(string(c.responseToBArray(response)))
250                         if body != "" {
251                                 return nil, fmt.Errorf(body)
252                         }
253                 }
254                 return nil, errors.New("Unknown HTTP status returned: " + response.Status)
255         }
256         return response, nil
257 }