Commit ee6bd445 authored by Jason Williams's avatar Jason Williams Committed by GitHub
Browse files

Merge pull request #170 from WilliamsJason/master

Refresh CSRF tokens and retry when we have issues 
parents 4699d1fb 0e30f50b
Loading
Loading
Loading
Loading
+13 −6
Original line number Diff line number Diff line
@@ -106,6 +106,8 @@ namespace SampleWdpClient.UniversalWindows
                        }
                    };

                    try
                    {
                        // If the user wants to allow untrusted connections, make a call to GetRootDeviceCertificate
                        // with acceptUntrustedCerts set to true. This will enable untrusted connections for the
                        // remainder of this session.
@@ -114,6 +116,11 @@ namespace SampleWdpClient.UniversalWindows
                            await portal.GetRootDeviceCertificate(true);
                        }
                        await portal.Connect(manualCertificate: this.certificate);
                    }
                    catch (Exception exception)
                    {
                        sb.AppendLine(exception.Message);
                    }

                    this.MarshalUpdateCommandOutput(sb.ToString());
                });
+26 −4
Original line number Diff line number Diff line
@@ -6,12 +6,15 @@

#if !WINDOWS_UWP
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Reflection;
using System.Threading.Tasks;
#else
using System.Reflection;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using Windows.Web.Http;
using Windows.Web.Http.Headers;
#endif // !WINDOWS_UWP
@@ -54,7 +57,7 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// </summary>
        /// <param name="client">The HTTP client on which to have the header set.</param>
        /// <param name="method">The HTTP method (ex: POST) that will be called on the client.</param>
        public void ApplyCSRFHeader(
        private void ApplyCSRFHeader(
            HttpClient client, 
            HttpMethods method)
        {
@@ -81,7 +84,7 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// </summary>
        /// <param name="client">The HTTP client on which to have the headers set.</param>
        /// <param name="method">The HTTP method (ex: POST) that will be called on the client.</param>
        public void ApplyHttpHeaders(
        private void ApplyHttpHeaders(
            HttpClient client,
            HttpMethods method)
        {
@@ -94,7 +97,7 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// as coming from the WDPW Open Source project.
        /// </summary>
        /// <param name="client">The HTTP client on which to have the header set.</param>
        public void ApplyUserAgentHeader(HttpClient client)
        private void ApplyUserAgentHeader(HttpClient client)
        {
            string userAgentValue = UserAgentValue;

@@ -117,7 +120,7 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// Retrieves the CSRF token from the HTTP response and stores it.
        /// </summary>
        /// <param name="response">The HTTP response from which to retrieve the header.</param>
        public void RetrieveCsrfToken(HttpResponseMessage response)
        private void RetrieveCsrfToken(HttpResponseMessage response)
        {
            // If the response sets a CSRF token, store that for future requests.
#if WINDOWS_UWP
@@ -145,5 +148,24 @@ namespace Microsoft.Tools.WindowsDevicePortal
            }
#endif
        }

        /// <summary>
        /// Checks a response to see if it failed due to a bad CSRF token.
        /// </summary>
        /// <param name="response">The response from the REST call.</param>
        /// <returns>Whether the response failed due to the bad CSRF token.</returns>
        private bool IsBadCsrfToken(HttpResponseMessage response)
        {
            return response.StatusCode == HttpStatusCode.Forbidden && response.ReasonPhrase.Equals("CSRF Token Invalid");
        }

        /// <summary>
        /// Makes a simple GET call to refresh the CSRF token.
        /// </summary>
        /// <returns>Task tracking completion of the refresh.</returns>
        private async Task RefreshCsrfToken()
        {
            await this.GetDeviceName();
        }
    }
}
+10 −10
Original line number Diff line number Diff line
@@ -20,16 +20,6 @@ namespace Microsoft.Tools.WindowsDevicePortal
    /// </content>
    public partial class DevicePortal
    {
        /// <summary>
        /// Sets the manual certificate.
        /// </summary>
        /// <param name="cert">Manual certificate</param>
        private void SetManualCertificate(Certificate cert)
        {
            CertificateStore store = CertificateStores.TrustedRootCertificationAuthorities;
            store.Add(cert);
        }

        /// <summary>
        /// Gets the root certificate from the device.
        /// </summary>
@@ -80,5 +70,15 @@ namespace Microsoft.Tools.WindowsDevicePortal
            return certificate;
        }
#pragma warning restore 1998

        /// <summary>
        /// Sets the manual certificate.
        /// </summary>
        /// <param name="cert">Manual certificate</param>
        private void SetManualCertificate(Certificate cert)
        {
            CertificateStore store = CertificateStores.TrustedRootCertificationAuthorities;
            store.Add(cert);
        }
    }
}
+12 −1
Original line number Diff line number Diff line
@@ -26,9 +26,10 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// Submits the http delete request to the specified uri.
        /// </summary>
        /// <param name="uri">The uri to which the delete request will be issued.</param>
        /// <param name="allowRetry">Allow the Post to be retried after issuing a Get call. Currently used for CSRF failures.</param>
        /// <returns>Task tracking HTTP completion</returns>
#pragma warning disable 1998
        private async Task<Stream> Delete(Uri uri)
        private async Task<Stream> Delete(Uri uri, bool allowRetry = true)
        {
            IBuffer dataBuffer = null;

@@ -56,9 +57,19 @@ namespace Microsoft.Tools.WindowsDevicePortal
                {
                    if (!response.IsSuccessStatusCode)
                    {
                        // If this isn't a retry and it failed due to a bad CSRF
                        // token, issue a GET to refresh the token and then retry.
                        if (allowRetry && this.IsBadCsrfToken(response))
                        {
                            await this.RefreshCsrfToken();
                            return await this.Delete(uri, false);
                        }

                        throw new DevicePortalException(response);
                    }

                    this.RetrieveCsrfToken(response);

                    if (response.Content != null)
                    {
                        using (IHttpContent messageContent = response.Content)
+13 −1
Original line number Diff line number Diff line
@@ -28,12 +28,14 @@ namespace Microsoft.Tools.WindowsDevicePortal
        /// <param name="uri">The uri to which the post request will be issued.</param>
        /// <param name="requestStream">Optional stream containing data for the request body.</param>
        /// <param name="requestStreamContentType">The type of that request body data.</param>
        /// <param name="allowRetry">Allow the Post to be retried after issuing a Get call. Currently used for CSRF failures.</param>
        /// <returns>Task tracking the completion of the POST request</returns>
#pragma warning disable 1998
        private async Task<Stream> Post(
            Uri uri,
            Stream requestStream = null,
            string requestStreamContentType = null)
            string requestStreamContentType = null,
            bool allowRetry = true)
        {
            HttpStreamContent requestContent = null;
            IBuffer dataBuffer = null;
@@ -69,9 +71,19 @@ namespace Microsoft.Tools.WindowsDevicePortal
                {
                    if (!response.IsSuccessStatusCode)
                    {
                        // If this isn't a retry and it failed due to a bad CSRF
                        // token, issue a GET to refresh the token and then retry.
                        if (allowRetry && this.IsBadCsrfToken(response))
                        {
                            await this.RefreshCsrfToken();
                            return await this.Post(uri, requestStream, requestStreamContentType, false);
                        }

                        throw new DevicePortalException(response);
                    }

                    this.RetrieveCsrfToken(response);

                    if (response.Content != null)
                    {
                        using (IHttpContent messageContent = response.Content)
Loading