Commit 46046129 authored by Jason Williams's avatar Jason Williams
Browse files

Simplifies CSRF logic by appending auto- to our username for auth.

This eliminates all CSRF token setting and retrieval, so I removed all that code as well as the retry logic.
parent 94582d28
Loading
Loading
Loading
Loading
+0 −2
Original line number Diff line number Diff line
@@ -47,8 +47,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
                        throw new DevicePortalException(response);
                    }

                    this.RetrieveCsrfToken(response);

                    using (HttpContent content = response.Content)
                    {
                        dataStream = new MemoryStream();
+0 −91
Original line number Diff line number Diff line
@@ -31,11 +31,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// </summary>
        private static readonly string ContentTypeHeaderName = "Content-Type";

        /// <summary>
        /// Header name for a CSRF-Token.
        /// </summary>
        private static readonly string CsrfTokenName = "CSRF-Token";

        /// <summary>
        /// Header name for a User-Agent.
        /// </summary>
@@ -46,39 +41,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// </summary>
        private static readonly string UserAgentValue = "WindowsDevicePortalWrapper";

        /// <summary>
        /// CSRF token retrieved by GET calls and used on subsequent POST/DELETE/PUT calls.
        /// This token is intended to prevent a security vulnerability from cross site forgery.
        /// </summary>
        private string csrfToken = string.Empty;

        /// <summary>
        /// Applies the CSRF token to the HTTP client.
        /// </summary>
        /// <param name="client">The HTTP client on which to have the header set.</param>
        /// <param name="method">The HTTP method (ex: POST) that will be called on the client.</param>
        private void ApplyCSRFHeader(
            HttpClient client, 
            HttpMethods method)
        {
            string headerName = "X-" + CsrfTokenName;
            string headerValue = this.csrfToken;

            if (string.Compare(method.ToString(), "get", true) == 0)
            {
                headerName = CsrfTokenName;
                headerValue = string.IsNullOrEmpty(this.csrfToken) ? "Fetch" : headerValue;
            }

#if WINDOWS_UWP
            HttpRequestHeaderCollection headers = client.DefaultRequestHeaders;
#else
            HttpRequestHeaders headers = client.DefaultRequestHeaders;
#endif // WINDOWS_UWP

            headers.Add(headerName, headerValue);
        }

        /// <summary>
        /// Applies any needed headers to the HTTP client.
        /// </summary>
@@ -88,7 +50,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
            HttpClient client,
            HttpMethods method)
        {
            this.ApplyCSRFHeader(client, method);
            this.ApplyUserAgentHeader(client);
        }

@@ -115,57 +76,5 @@ namespace Microsoft.Tools.WindowsDevicePortal

            headers.Add(UserAgentName, userAgentValue);
        }

        /// <summary>
        /// Retrieves the CSRF token from the HTTP response and stores it.
        /// </summary>
        /// <param name="response">The HTTP response from which to retrieve the header.</param>
        private void RetrieveCsrfToken(HttpResponseMessage response)
        {
            // If the response sets a CSRF token, store that for future requests.
#if WINDOWS_UWP
            string cookie;
            if (response.Headers.TryGetValue("Set-Cookie", out cookie))
            {
                string csrfTokenNameWithEquals = CsrfTokenName + "=";
                if (cookie.StartsWith(csrfTokenNameWithEquals))
                {
                    this.csrfToken = cookie.Substring(csrfTokenNameWithEquals.Length);
                }
            }
#else
            IEnumerable<string> cookies;
            if (response.Headers.TryGetValues("Set-Cookie", out cookies))
            {
                foreach (string cookie in cookies)
                {
                    string csrfTokenNameWithEquals = CsrfTokenName + "=";
                    if (cookie.StartsWith(csrfTokenNameWithEquals))
                    {
                        this.csrfToken = cookie.Substring(csrfTokenNameWithEquals.Length);
                    }
                }
            }
#endif
        }

        /// <summary>
        /// Checks a response to see if it failed due to a bad CSRF token.
        /// </summary>
        /// <param name="exception">The DevicePortalException from the REST call.</param>
        /// <returns>Whether the response failed due to the bad CSRF token.</returns>
        private bool IsBadCsrfToken(DevicePortalException exception)
        {
            return exception.StatusCode == HttpStatusCode.Forbidden && exception.Reason.Equals("CSRF Token Invalid");
        }

        /// <summary>
        /// Makes a simple GET call to refresh the CSRF token.
        /// </summary>
        /// <returns>Task tracking completion of the refresh.</returns>
        private async Task RefreshCsrfToken()
        {
            await this.GetDeviceName();
        }
    }
}
+7 −27
Original line number Diff line number Diff line
@@ -37,12 +37,10 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// <typeparam name="T">The type of the data for the HTTP response body (if present).</typeparam>
        /// <param name="apiPath">The relative portion of the uri path that specifies the API to call.</param>
        /// <param name="payload">The query string portion of the uri path that provides the parameterized data.</param>
        /// <param name="allowRetry">Allow the Delete to be retried after refreshing the CSRF token.</param>
        /// <returns>Task tracking the HTTP completion.</returns>
        private async Task<T> Delete<T>(
            string apiPath,
            string payload = null,
            bool allowRetry = true) where T : new()
            string payload = null) where T : new()
        {
            T data = default(T);

@@ -53,8 +51,6 @@ namespace Microsoft.Tools.WindowsDevicePortal

            DataContractJsonSerializer deserializer = new DataContractJsonSerializer(typeof(T));

            try
            {
            using (Stream dataStream = await this.Delete(uri))
            {
                if ((dataStream != null) &&
@@ -66,22 +62,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
                    data = (T)response;
                }
            }
            }
            catch (DevicePortalException e)
            {
                // If this isn't a retry and it failed due to a bad CSRF
                // token, refresh the token and then retry.
                if (allowRetry && this.IsBadCsrfToken(e))
                {
                    await this.RefreshCsrfToken();
                    return await this.Delete<T>(apiPath, payload, false);
                }
                else
                {
                    throw e;
                }
            }


            return data;
        }
+7 −28
Original line number Diff line number Diff line
@@ -80,14 +80,12 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// <param name="payload">The query string portion of the uri path that provides the parameterized data.</param>
        /// <param name="requestStream">Optional stream containing data for the request body.</param>
        /// <param name="requestStreamContentType">The type of that request body data.</param>
        /// <param name="allowRetry">Allow the Post to be retried after issuing a Get call. Currently used for CSRF failures.</param>
        /// <returns>Task tracking the POST completion.</returns>
        private async Task<T> Post<T>(
            string apiPath,
            string payload = null,
            Stream requestStream = null,
            string requestStreamContentType = null,
            bool allowRetry = true) where T : new()
            string requestStreamContentType = null) where T : new()
        {
            T data = default(T);

@@ -98,8 +96,6 @@ namespace Microsoft.Tools.WindowsDevicePortal

            DataContractJsonSerializer deserializer = new DataContractJsonSerializer(typeof(T));

            try
            {
            using (Stream dataStream = await this.Post(uri, requestStream, requestStreamContentType))
            {
                if ((dataStream != null) &&
@@ -111,23 +107,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
                    data = (T)response;
                }
            }
            }
            catch (DevicePortalException e)
            {
                // If this isn't a retry and it failed due to a bad CSRF
                // token, refresh the token and then retry. 
                // Note: due to the stream already being disposed, we can't
                // retry a POST unless a body stream isn't being provided.
                if (allowRetry && this.IsBadCsrfToken(e) && requestStream == null)
                {
                    await this.RefreshCsrfToken();
                    return await this.Post<T>(apiPath, payload, null, null, false);
                }
                else
                {
                    throw e;
                }
            }

            return data;
        }
+7 −26
Original line number Diff line number Diff line
@@ -57,13 +57,11 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// <param name="apiPath">The relative portion of the uri path that specifies the API to call.</param>
        /// <param name="bodyData">The data to be used for the HTTP request body.</param>
        /// <param name="payload">The query string portion of the uri path that provides the parameterized data.</param>
        /// <param name="allowRetry">Allow the Put to be retried after refreshing the CSRF token.</param>
        /// <returns>Task tracking the PUT completion, optional response body.</returns>
        private async Task<T> Put<T, K>(
            string apiPath,
            K bodyData = null,
            string payload = null,
            bool allowRetry = true) where T : new()
            string payload = null) where T : new()
                                   where K : class
        {
            T data = default(T);
@@ -98,8 +96,6 @@ namespace Microsoft.Tools.WindowsDevicePortal

            DataContractJsonSerializer deserializer = new DataContractJsonSerializer(typeof(T));

            try
            {
            using (Stream dataStream = await this.Put(uri, streamContent))
            {
                if ((dataStream != null) &&
@@ -111,21 +107,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
                    data = (T)response;
                }
            }
            }
            catch (DevicePortalException e)
            {
                // If this isn't a retry and it failed due to a bad CSRF
                // token, refresh the token and then retry.
                if (allowRetry && this.IsBadCsrfToken(e))
                {
                    await this.RefreshCsrfToken();
                    return await this.Put<T, K>(apiPath, bodyData, payload, false);
                }
                else
                {
                    throw e;
                }
            }

            return data;
        }
Loading