Commit 7f096085 authored by David Kline's avatar David Kline Committed by GitHub
Browse files

Merge pull request #176 from WilliamsJason/master

Simplify CSRF logic by appending username with -auto
parents ee6bd445 46046129
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -379,6 +379,19 @@ namespace XboxWdpDriver
                    }
                }
            }
            catch (AggregateException e)
            {
                if (e.InnerException is DevicePortalException)
                {
                    DevicePortalException innerException = e.InnerException as DevicePortalException;

                    Console.WriteLine(string.Format("Exception encountered: {0}, hr = 0x{1:X} : {2}", innerException.StatusCode, innerException.HResult, innerException.Reason));
                }
                else
                {
                    Console.WriteLine(string.Format("Unexpected exception encountered: {0}", e.Message));
                }
            }
            catch (Exception e)
            {
                Console.WriteLine(e.Message);
+0 −2
Original line number Diff line number Diff line
@@ -47,8 +47,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
                        throw new DevicePortalException(response);
                    }

                    this.RetrieveCsrfToken(response);

                    using (HttpContent content = response.Content)
                    {
                        dataStream = new MemoryStream();
+0 −91
Original line number Diff line number Diff line
@@ -31,11 +31,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// </summary>
        private static readonly string ContentTypeHeaderName = "Content-Type";

        /// <summary>
        /// Header name for a CSRF-Token.
        /// </summary>
        private static readonly string CsrfTokenName = "CSRF-Token";

        /// <summary>
        /// Header name for a User-Agent.
        /// </summary>
@@ -46,39 +41,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// </summary>
        private static readonly string UserAgentValue = "WindowsDevicePortalWrapper";

        /// <summary>
        /// CSRF token retrieved by GET calls and used on subsequent POST/DELETE/PUT calls.
        /// This token is intended to prevent a security vulnerability from cross site forgery.
        /// </summary>
        private string csrfToken = string.Empty;

        /// <summary>
        /// Applies the CSRF token to the HTTP client.
        /// </summary>
        /// <param name="client">The HTTP client on which to have the header set.</param>
        /// <param name="method">The HTTP method (ex: POST) that will be called on the client.</param>
        private void ApplyCSRFHeader(
            HttpClient client, 
            HttpMethods method)
        {
            string headerName = "X-" + CsrfTokenName;
            string headerValue = this.csrfToken;

            if (string.Compare(method.ToString(), "get", true) == 0)
            {
                headerName = CsrfTokenName;
                headerValue = string.IsNullOrEmpty(this.csrfToken) ? "Fetch" : headerValue;
            }

#if WINDOWS_UWP
            HttpRequestHeaderCollection headers = client.DefaultRequestHeaders;
#else
            HttpRequestHeaders headers = client.DefaultRequestHeaders;
#endif // WINDOWS_UWP

            headers.Add(headerName, headerValue);
        }

        /// <summary>
        /// Applies any needed headers to the HTTP client.
        /// </summary>
@@ -88,7 +50,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
            HttpClient client,
            HttpMethods method)
        {
            this.ApplyCSRFHeader(client, method);
            this.ApplyUserAgentHeader(client);
        }

@@ -115,57 +76,5 @@ namespace Microsoft.Tools.WindowsDevicePortal

            headers.Add(UserAgentName, userAgentValue);
        }

        /// <summary>
        /// Retrieves the CSRF token from the HTTP response and stores it.
        /// </summary>
        /// <param name="response">The HTTP response from which to retrieve the header.</param>
        private void RetrieveCsrfToken(HttpResponseMessage response)
        {
            // If the response sets a CSRF token, store that for future requests.
#if WINDOWS_UWP
            string cookie;
            if (response.Headers.TryGetValue("Set-Cookie", out cookie))
            {
                string csrfTokenNameWithEquals = CsrfTokenName + "=";
                if (cookie.StartsWith(csrfTokenNameWithEquals))
                {
                    this.csrfToken = cookie.Substring(csrfTokenNameWithEquals.Length);
                }
            }
#else
            IEnumerable<string> cookies;
            if (response.Headers.TryGetValues("Set-Cookie", out cookies))
            {
                foreach (string cookie in cookies)
                {
                    string csrfTokenNameWithEquals = CsrfTokenName + "=";
                    if (cookie.StartsWith(csrfTokenNameWithEquals))
                    {
                        this.csrfToken = cookie.Substring(csrfTokenNameWithEquals.Length);
                    }
                }
            }
#endif
        }

        /// <summary>
        /// Checks a response to see if it failed due to a bad CSRF token.
        /// </summary>
        /// <param name="response">The response from the REST call.</param>
        /// <returns>Whether the response failed due to the bad CSRF token.</returns>
        private bool IsBadCsrfToken(HttpResponseMessage response)
        {
            return response.StatusCode == HttpStatusCode.Forbidden && response.ReasonPhrase.Equals("CSRF Token Invalid");
        }

        /// <summary>
        /// Makes a simple GET call to refresh the CSRF token.
        /// </summary>
        /// <returns>Task tracking completion of the refresh.</returns>
        private async Task RefreshCsrfToken()
        {
            await this.GetDeviceName();
        }
    }
}
+0 −2
Original line number Diff line number Diff line
@@ -52,8 +52,6 @@ namespace Microsoft.Tools.WindowsDevicePortal

                using (HttpResponseMessage response = responseOperation.GetResults())
                {
                    this.RetrieveCsrfToken(response);

                    using (IHttpContent messageContent = response.Content)
                    {
                        IAsyncOperationWithProgress<IBuffer, ulong> bufferOperation = messageContent.ReadAsBufferAsync();
+2 −1
Original line number Diff line number Diff line
@@ -32,7 +32,8 @@ namespace Microsoft.Tools.WindowsDevicePortal
            string password)
        {
            this.Connection = new Uri(address);
            this.Credentials = new NetworkCredential(userName, password);
            // append auto- to the credentials to bypass CSRF token requirement on non-Get requests.
            this.Credentials = new NetworkCredential(string.Format("auto-{0}", userName), password);
        }

        /// <summary>
Loading