diff --git a/src/Authentication.Abstractions/Models/ConfigKeysForCommon.cs b/src/Authentication.Abstractions/Models/ConfigKeysForCommon.cs index 99412af60e..c38c72f75b 100644 --- a/src/Authentication.Abstractions/Models/ConfigKeysForCommon.cs +++ b/src/Authentication.Abstractions/Models/ConfigKeysForCommon.cs @@ -31,5 +31,6 @@ public static class ConfigKeysForCommon public const string CheckForUpgrade = "CheckForUpgrade"; public const string EnableErrorRecordsPersistence = "EnableErrorRecordsPersistence"; public const string DisplaySecretsWarning = "DisplaySecretsWarning"; + public const string EnablePolicyToken = "EnablePolicyToken"; } } diff --git a/src/Common/AcquirePolicyTokenHandler.cs b/src/Common/AcquirePolicyTokenHandler.cs new file mode 100644 index 0000000000..df8d1f07df --- /dev/null +++ b/src/Common/AcquirePolicyTokenHandler.cs @@ -0,0 +1,234 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using System.Text; +using System.Text.RegularExpressions; +using System.Collections.Generic; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System.Net; +using Microsoft.WindowsAzure.Commands.Utilities.Common; + +namespace Microsoft.WindowsAzure.Commands.Common +{ + /// + /// Delegating handler to acquire an Azure Policy token for change safety feature and attach to outgoing request. + /// Activated when user specifies -AcquirePolicyToken. (ChangeReference deferred to Phase 2.) + /// + public class AcquirePolicyTokenHandler : DelegatingHandler, ICloneable + { + private readonly AzurePSCmdlet _cmdlet; + private const string TokenApiVersion = "2025-03-01"; + private static readonly Regex SubscriptionIdRegex = new Regex(@"/subscriptions/([0-9a-fA-F-]{36})", RegexOptions.IgnoreCase | RegexOptions.Compiled); + private static readonly HashSet _allowedWriteMethods = new HashSet(StringComparer.OrdinalIgnoreCase) + { + HttpMethod.Put.Method, + HttpMethod.Post.Method, + HttpMethod.Delete.Method, + "PATCH" + }; + private const string LogPrefix = "[AcquirePolicyTokenHandler]"; + + public AcquirePolicyTokenHandler(AzurePSCmdlet cmdlet) + { + _cmdlet = cmdlet; + } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + EnqueueDebug($"Intercept {request.Method} {request.RequestUri}"); + + if (!(_cmdlet?.IsPolicyTokenFeatureEnabled() ?? false)) + { + EnqueueDebug("Skip: feature disabled (EnableAcquirePolicyToken config set to false)."); + return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); + } + + bool allowedVerb = _allowedWriteMethods.Contains(request.Method.Method); + if (!allowedVerb) + { + EnqueueDebug("Skip: verb not allowed for token acquisition."); + return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); + } + + bool hasCmdlet = _cmdlet != null; + bool userRequested = hasCmdlet && _cmdlet.ShouldAcquirePolicyToken; + if (!userRequested) + { + EnqueueDebug("Skip: user did not request token (no -AcquirePolicyToken)."); + return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); + } + + var isWhatIf = _cmdlet.MyInvocation?.BoundParameters?.ContainsKey("WhatIf") == true; + if (isWhatIf) + { + EnqueueDebug("Skip: -WhatIf present (dry run)."); + return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); + } + + try + { + var token = await AcquirePolicyTokenAsync(request, cancellationToken).ConfigureAwait(false); + + //Debug token, as is + // EnqueueDebug($"Token: {token}"); + + + if (!string.IsNullOrEmpty(token)) + { + if (request.Headers.Contains("x-ms-policy-external-evaluations")) + { + request.Headers.Remove("x-ms-policy-external-evaluations"); + } + request.Headers.Add("x-ms-policy-external-evaluations", token); + EnqueueDebug("Token acquired and header added."); + } + else + { + EnqueueDebug("No token returned (null/empty)."); + } + } + catch (Exception ex) + { + EnqueueDebug($"Exception: {ex.GetType().Name}: {ex.Message}"); + throw new InvalidOperationException($"Failed to acquire policy token: {ex.Message}", ex); + } + + return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); + } + + private async Task AcquirePolicyTokenAsync(HttpRequestMessage originalRequest, CancellationToken cancellationToken) + { + var subscriptionId = ExtractSubscriptionId(originalRequest.RequestUri); + if (string.IsNullOrEmpty(subscriptionId)) + { + EnqueueDebug("Failed: subscription id not found in URI."); + throw new InvalidOperationException("Unable to determine subscription ID for policy token acquisition."); + } + + var authority = originalRequest.RequestUri.GetLeftPart(UriPartial.Authority); + var relativePath = $"/subscriptions/{subscriptionId}/providers/Microsoft.Authorization/acquirePolicyToken?api-version={TokenApiVersion}"; + var tokenUri = new Uri(authority + relativePath); + + object contentObj = null; + if (originalRequest.Content != null) + { + var body = await originalRequest.Content.ReadAsStringAsync().ConfigureAwait(false); + if (!string.IsNullOrWhiteSpace(body)) + { + try + { + contentObj = JsonConvert.DeserializeObject(body); + } + catch + { + contentObj = body; // leave as raw string if not JSON + } + } + } + + var payload = new + { + operation = new + { + uri = originalRequest.RequestUri.ToString(), + httpMethod = originalRequest.Method.Method, + content = contentObj + } + // Phase 2: reintroduce when ChangeReference parameter is enabled + // ,changeReference = _cmdlet?.CurrentChangeReference + }; + EnqueueDebug("Payload prepared."); + + var payloadJson = JsonConvert.SerializeObject(payload); + var tokenRequest = new HttpRequestMessage(HttpMethod.Post, tokenUri) + { + Content = new StringContent(payloadJson, Encoding.UTF8, "application/json") + }; + tokenRequest.Headers.Add("x-ms-force-sync", "true"); + + // Forward auth headers if present (minimal parity with original request auth context) + if (originalRequest.Headers.Authorization != null) + { + tokenRequest.Headers.Authorization = originalRequest.Headers.Authorization; + } + if (originalRequest.Headers.TryGetValues("x-ms-authorization-auxiliary", out var auxValues)) + { + tokenRequest.Headers.TryAddWithoutValidation("x-ms-authorization-auxiliary", auxValues); + } + + using (var http = new HttpClient()) + { + EnqueueDebug($"POST acquirePolicyToken {tokenUri}"); + var response = await http.SendAsync(tokenRequest, cancellationToken).ConfigureAwait(false); + EnqueueDebug($"Response {(int)response.StatusCode} {response.StatusCode}"); + var responseContent = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + if (response.StatusCode == HttpStatusCode.OK) + { + if (!string.IsNullOrWhiteSpace(responseContent)) + { + var obj = JsonConvert.DeserializeObject(responseContent); + var token = obj?["token"]?.ToString(); + if (string.IsNullOrEmpty(token)) + { + EnqueueDebug("Response OK but token missing."); + throw new InvalidOperationException($"No token returned. Response:{responseContent}"); + } + return token; + } + throw new InvalidOperationException("Empty response body when acquiring policy token."); + } + else if (response.StatusCode == HttpStatusCode.Accepted) + { + EnqueueDebug("202 Accepted received (async not supported)." ); + throw new InvalidOperationException("Asynchronous policy token acquisition (202 Accepted) is not supported."); + } + else + { + EnqueueDebug("Non-success status; will throw."); + throw new InvalidOperationException($"Policy token acquisition failed with {(int)response.StatusCode} {response.StatusCode}: {responseContent}"); + } + } + } + + private static string ExtractSubscriptionId(Uri uri) + { + if (uri == null) return null; + var match = SubscriptionIdRegex.Match(uri.AbsolutePath); + if (match.Success && match.Groups.Count > 1) + { + return match.Groups[1].Value; + } + return null; + } + + public object Clone() + { + return new AcquirePolicyTokenHandler(_cmdlet); + } + + private void EnqueueDebug(string message) + { + try + { + _cmdlet?.DebugMessages?.Enqueue($"{LogPrefix} {message}"); + } + catch { } + } + } +} diff --git a/src/Common/AzurePSCmdlet.cs b/src/Common/AzurePSCmdlet.cs index 646ef40402..ec5a476af7 100644 --- a/src/Common/AzurePSCmdlet.cs +++ b/src/Common/AzurePSCmdlet.cs @@ -26,6 +26,7 @@ using Microsoft.WindowsAzure.Commands.Common.Utilities; using System; using System.Collections.Concurrent; +using System.Collections.ObjectModel; using System.Collections.Generic; using System.Diagnostics; using System.Globalization; @@ -41,7 +42,7 @@ namespace Microsoft.WindowsAzure.Commands.Utilities.Common /// /// Represents base class for all Azure cmdlets. /// - public abstract class AzurePSCmdlet : PSCmdlet, IDisposable + public abstract class AzurePSCmdlet : PSCmdlet, IDisposable, IDynamicParameters { private const string PSVERSION = "PSVersion"; private const string DEFAULT_PSVERSION = "3.0.0.0"; @@ -180,6 +181,91 @@ private IOutputSanitizer OutputSanitizer } } + // PHASE2: ChangeReference temporarily disabled, always null + // internal string CurrentChangeReference + // { + // get + // { + // var bp = this.MyInvocation?.BoundParameters; + // if (bp != null && bp.ContainsKey("ChangeReference")) + // { + // return bp["ChangeReference"] as string; + // } + + // return null; + // } + // } + + /// + /// Determines whether the policy token feature is enabled. + /// Priority 1: environment variable AZ_ENABLE_POLICY_TOKEN overrides (1/0, true/false, yes/no). + /// Priority 2: Config key (EnablePolicyToken) set by Set-AzConfig. + /// Default: disabled (false). + /// + internal bool IsPolicyTokenFeatureEnabled() + { + try + { + var env = Environment.GetEnvironmentVariable("AZ_ENABLE_POLICY_TOKEN"); + if (!string.IsNullOrEmpty(env)) + { + var trimmed = env.Trim(); + + if (bool.TryParse(trimmed, out var b)) return b; + if (string.Equals(trimmed, "1", StringComparison.Ordinal)) return true; + if (string.Equals(trimmed, "0", StringComparison.Ordinal)) return false; + + switch (trimmed.ToLowerInvariant()) + { + case "yes": + case "on": + case "enable": + case "enabled": + return true; + case "no": + case "off": + case "disable": + case "disabled": + return false; + } + } + + // Config fallback (Set-AzConfig -EnablePolicyToken true) + if (AzureSession.Instance.TryGetComponent(nameof(IConfigManager), out var configManager)) + { + try + { + return configManager.GetConfigValue(ConfigKeysForCommon.EnablePolicyToken, MyInvocation); + } + catch { } + } + } + catch { } + + // Default: disabled + return false; + } + + internal bool ShouldAcquirePolicyToken + { + get + { + var bp = this.MyInvocation?.BoundParameters; + if (bp == null) + { + return false; + } + if (!IsPolicyTokenFeatureEnabled()) + { + return false; + } + var acquire = bp.ContainsKey("AcquirePolicyToken") && ((SwitchParameter)bp["AcquirePolicyToken"]).IsPresent; + // PHASE2: ChangeReference disabled; ignore for now + // var changeRef = bp.ContainsKey("ChangeReference") && !string.IsNullOrEmpty(bp["ChangeReference"] as string); + return acquire; // || changeRef; + } + } + /// /// Resolve user submitted paths correctly on all platforms /// @@ -324,7 +410,9 @@ protected virtual void SetupHttpClientPipeline() AzureSession.Instance.ClientFactory.AddUserAgent("AzurePowershell", string.Format("v{0}", AzVersion)); AzureSession.Instance.ClientFactory.AddUserAgent(PSVERSION, string.Format("v{0}", PowerShellVersion)); AzureSession.Instance.ClientFactory.AddUserAgent(ModuleName, this.ModuleVersion); - try { + + try + { string hostEnv = AzurePSCmdlet.getEnvUserAgent(); if (!String.IsNullOrWhiteSpace(hostEnv)) { @@ -335,11 +423,14 @@ protected virtual void SetupHttpClientPipeline() { // ignore if it failed. } + + // Always add the acquire policy token handler; it will internally decide whether to act. + AzureSession.Instance.ClientFactory.AddHandler(new AcquirePolicyTokenHandler(this)); AzureSession.Instance.ClientFactory.AddHandler( new CmdletInfoHandler(this.CommandRuntime.ToString(), this.ParameterSetName, this._clientRequestId)); - + } protected virtual void TearDownHttpClientPipeline() @@ -356,10 +447,82 @@ protected virtual void TearDownHttpClientPipeline() { // ignore if it failed. } + AzureSession.Instance.ClientFactory.RemoveUserAgent(ModuleName); + AzureSession.Instance.ClientFactory.RemoveHandler(typeof(AcquirePolicyTokenHandler)); AzureSession.Instance.ClientFactory.RemoveHandler(typeof(CmdletInfoHandler)); } + /// + /// Dynamic parameters for policy token acquisition and any previously registered dynamic parameters (e.g. AsJob). + /// + public virtual object GetDynamicParameters() + { + var dict = new RuntimeDefinedParameterDictionary(); + + // Preserve existing dynamic parameters if any were registered via RegisterDynamicParameters + if (AsJobDynamicParameters != null) + { + foreach (var kv in AsJobDynamicParameters) + { + dict[kv.Key] = kv.Value; + } + } + + // FEATURE FLAG: only surface parameters when feature enabled + if (!IsPolicyTokenFeatureEnabled()) + { + return dict; + } + + // Do not add parameters for read-only cmdlets (Get-*/List-*/Show-*) + var commandName = this.MyInvocation?.MyCommand?.Name ?? string.Empty; + + if (commandName.StartsWith("Get", StringComparison.OrdinalIgnoreCase) + || commandName.EndsWith("List", StringComparison.OrdinalIgnoreCase) + || commandName.EndsWith("Show", StringComparison.OrdinalIgnoreCase)) + { + return dict; + } + + try + { + var acquireParam = new RuntimeDefinedParameter( + "AcquirePolicyToken", + typeof(SwitchParameter), + new Collection + { + new ParameterAttribute + { + HelpMessage = "Acquire an Azure Policy token automatically for this resource operation.", + ParameterSetName = ParameterAttribute.AllParameterSets + } + }); + dict.Add("AcquirePolicyToken", acquireParam); + + /* PHASE2 (behind same feature flag): ChangeReference dynamic parameter + var changeRefParam = new RuntimeDefinedParameter( + "ChangeReference", + typeof(string), + new Collection + { + new ParameterAttribute + { + HelpMessage = "The related change reference ID for this resource operation.", + ParameterSetName = ParameterAttribute.AllParameterSets + } + }); + dict.Add("ChangeReference", changeRefParam); + */ + } + catch + { + // Ignore dynamic parameter creation issues. + } + + return dict; + } + /// /// Cmdlet begin process. Write to logs, setup Http Tracing and initialize profile /// diff --git a/src/Common/Utilities/GeneralUtilities.cs b/src/Common/Utilities/GeneralUtilities.cs index ee0ca51f3e..41e7e4365d 100644 --- a/src/Common/Utilities/GeneralUtilities.cs +++ b/src/Common/Utilities/GeneralUtilities.cs @@ -37,7 +37,7 @@ public static class GeneralUtilities { private static Assembly assembly = Assembly.GetExecutingAssembly(); - private static List AuthorizationHeaderNames = new List() { "Authorization" }; + private static List AuthorizationHeaderNames = new List() { "Authorization", "x-ms-policy-external-evaluations" }; // this is only used to determine cutoff for streams (not xml or json). private const int StreamCutOffSize = 10 * 1024; //10KB diff --git a/src/ResourceManager/Version2016_09_01/AzureRMCmdlet.cs b/src/ResourceManager/Version2016_09_01/AzureRMCmdlet.cs index 710c52c8f8..44fe7413ef 100644 --- a/src/ResourceManager/Version2016_09_01/AzureRMCmdlet.cs +++ b/src/ResourceManager/Version2016_09_01/AzureRMCmdlet.cs @@ -651,9 +651,9 @@ private void EnqueueDebugSender(object sender, StreamEventArgs args) DebugMessages.Enqueue(args.Message); } - public object GetDynamicParameters() + public override object GetDynamicParameters() { - var parameters = new RuntimeDefinedParameterDictionary(); + var parameters = base.GetDynamicParameters() as RuntimeDefinedParameterDictionary ?? new RuntimeDefinedParameterDictionary(); // add `-SubscriptionId` if the cmdlet has [SupportsSubscriptionId] attribute if (GetType().IsDefined(typeof(SupportsSubscriptionIdAttribute), true))