From 1d9de52a7e3069ce0423ff13a23304e55755c7ed Mon Sep 17 00:00:00 2001 From: aavalang <56377848+aavalang@users.noreply.github.com> Date: Sun, 19 Nov 2023 18:29:14 -0800 Subject: [PATCH] [Bastion] adding rdp file to temp location and adding auth-type for rdp (#7006) * adding rdp file to temp location and adding auth-type for rdp * fixing some pylint issues * fixing some pylint issues --- src/bastion/HISTORY.rst | 5 ++ src/bastion/azext_bastion/_params.py | 2 + src/bastion/azext_bastion/custom.py | 68 ++++++++++++------- .../azext_bastion/developer_sku_helper.py | 2 +- src/bastion/azext_bastion/tunnel.py | 17 ++--- src/bastion/setup.py | 2 +- 6 files changed, 62 insertions(+), 34 deletions(-) diff --git a/src/bastion/HISTORY.rst b/src/bastion/HISTORY.rst index e804dc77772..c99ab5ca3ad 100644 --- a/src/bastion/HISTORY.rst +++ b/src/bastion/HISTORY.rst @@ -2,6 +2,11 @@ Release History =============== +0.2.6 +++++++ +* Adding auth type aad for RDP to mimic the enable-mfa flag. +* Fixing issue where if powershell is opened in system32 directory, file generation throws error. Files are now dumped in temp folder. + 0.2.5 ++++++ * Fixing the command `az network bastion rdp` to avoid the `java.lang.NullPointerException` while calling `get_auth_token` function diff --git a/src/bastion/azext_bastion/_params.py b/src/bastion/azext_bastion/_params.py index 565bf887d25..3ccf7f6010f 100644 --- a/src/bastion/azext_bastion/_params.py +++ b/src/bastion/azext_bastion/_params.py @@ -38,6 +38,8 @@ def load_arguments(self, _): # pylint: disable=unused-argument c.argument("configure", help="Flag to configure RDP session.", action="store_true") c.argument("disable_gateway", help="Flag to disable access through RD gateway.", arg_type=get_three_state_flag()) + c.argument("auth_type", help="Auth type to use for RDP connections.", required=False, + options_list=["--auth-type"]) c.argument('enable_mfa', help='Enable RDS auth for MFA if supported by the target machine.', arg_type=get_three_state_flag()) with self.argument_context("network bastion tunnel") as c: diff --git a/src/bastion/azext_bastion/custom.py b/src/bastion/azext_bastion/custom.py index b9644f9cca6..068bd54f662 100644 --- a/src/bastion/azext_bastion/custom.py +++ b/src/bastion/azext_bastion/custom.py @@ -15,6 +15,7 @@ import threading import time import json +import uuid import requests from azure.cli.core.azclierror import ValidationError, InvalidArgumentValueError, RequiredArgumentMissingError, \ @@ -148,17 +149,18 @@ def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, reso if not resource_port: resource_port = 22 - if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True: + if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and \ + bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') - ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address) + ip_connect = _is_ipconnect_request(bastion, target_ip_address) if ip_connect: - target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" - - if ip_connect and int(resource_port) not in [22, 3389]: - raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.") + if int(resource_port) not in [22, 3389]: + raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.") + target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}" + f"/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" - _validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address) + _validate_resourceid(target_resource_id) bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port) @@ -227,7 +229,7 @@ def _get_rdp_path(rdp_command="mstsc"): def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, - resource_port=None, disable_gateway=False, configure=False, enable_mfa=False): + auth_type=None, resource_port=None, disable_gateway=False, configure=False, enable_mfa=False): import os from azure.cli.core._profile import Profile from ._process_helper import launch_and_wait @@ -241,17 +243,31 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ if not resource_port: resource_port = 3389 - if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True: + if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and \ + bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') - ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address) + ip_connect = _is_ipconnect_request(bastion, target_ip_address) + + if auth_type is None: + # do nothing + pass + elif auth_type.lower() == "aad": + enable_mfa = True + + if disable_gateway or ip_connect: + raise UnrecognizedArgumentError("AAD login is not supported for Disable Gateway & IP Connect scenarios.") + else: + raise UnrecognizedArgumentError("Unknown auth type, support auth-types: aad. For non aad login, you dont need to provide auth-type flag.") + if ip_connect: - target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" + if int(resource_port) not in [22, 3389]: + raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.") - if ip_connect and int(resource_port) not in [22, 3389]: - raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.") + target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}" + f"/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" - _validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address) + _validate_resourceid(target_resource_id) bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) if platform.system() == "Windows": @@ -269,7 +285,8 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ profile = Profile(cli_ctx=cmd.cli_ctx) access_token = profile.get_raw_token()[0][2].get("accessToken") logger.debug("Response %s", access_token) - web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" + web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp" + f"&rdpport={resource_port}&enablerdsaad={enable_mfa}" headers = { "Authorization": f"Bearer {access_token}", @@ -285,8 +302,10 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ raise ClientRequestError("Request failed with error: " + errorMessage) raise ClientRequestError("Request to EncodingReservedUnitTypes v2 API endpoint failed.") - _write_to_file(response) - rdpfilepath = os.getcwd() + "/conn.rdp" + tempdir = os.path.realpath(tempfile.gettempdir()) + rdpfilepath = os.path.join(tempdir, 'conn_{}.rdp'.format(uuid.uuid4().hex)) + _write_to_file(response, rdpfilepath) + command = [_get_rdp_path()] if configure: command.append("/edit") @@ -296,14 +315,14 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows") -def _is_ipconnect_request(cmd, bastion, target_ip_address): +def _is_ipconnect_request(bastion, target_ip_address): if 'enableIpConnect' in bastion and bastion['enableIpConnect'] is True and target_ip_address: return True return False -def _validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address): +def _validate_resourceid(target_resource_id): if not is_valid_resource_id(target_resource_id): err_msg = "Please enter a valid resource ID. If this is not working, " \ "try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID." @@ -319,8 +338,8 @@ def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id): return bastion['dnsName'] -def _write_to_file(response): - with open("conn.rdp", "w", encoding="utf-8") as f: +def _write_to_file(response, file_path): + with open(file_path, "w", encoding="utf-8") as f: for line in response.text.splitlines(): f.write(line + "\n") @@ -358,14 +377,15 @@ def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_g if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') - ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address) + ip_connect = _is_ipconnect_request(bastion, target_ip_address) if ip_connect: - target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" + target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/" + f"{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" if ip_connect and int(resource_port) not in [22, 3389]: raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.") - _validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address) + _validate_resourceid(target_resource_id) bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, port) diff --git a/src/bastion/azext_bastion/developer_sku_helper.py b/src/bastion/azext_bastion/developer_sku_helper.py index e35124ea135..3407681b28a 100644 --- a/src/bastion/azext_bastion/developer_sku_helper.py +++ b/src/bastion/azext_bastion/developer_sku_helper.py @@ -24,6 +24,6 @@ def _get_data_pod(cmd, resource_port, target_resource_id, bastion): web_address = f"https://{bastion['dnsName']}/api/connection" response = requests.post(web_address, json=content, headers=headers, - verify=(not should_disable_connection_verify())) + verify=not should_disable_connection_verify()) return response.content.decode("utf-8") diff --git a/src/bastion/azext_bastion/tunnel.py b/src/bastion/azext_bastion/tunnel.py index 1e86f74837c..9106526499c 100644 --- a/src/bastion/azext_bastion/tunnel.py +++ b/src/bastion/azext_bastion/tunnel.py @@ -19,21 +19,21 @@ from contextlib import closing from datetime import datetime from threading import Thread +import requests +import urllib3 import websocket from websocket import create_connection, WebSocket from msrestazure.azure_exceptions import CloudError -from .BastionServiceConstants import BastionSku - from azure.cli.core._profile import Profile from azure.cli.core.util import should_disable_connection_verify -import requests -import urllib3 - from knack.util import CLIError from knack.log import get_logger + +from .BastionServiceConstants import BastionSku + logger = get_logger(__name__) @@ -96,7 +96,7 @@ def _get_auth_token(self): logger.debug("Content: %s", str(content)) web_address = f"https://{self.bastion_endpoint}/api/tokens" response = requests.post(web_address, data=content, headers=custom_header, - verify=(not should_disable_connection_verify())) + verify=not should_disable_connection_verify()) response_json = None if response.content is not None: @@ -121,7 +121,8 @@ def _listen(self): self.client, _address = self.sock.accept() auth_token = self._get_auth_token() - if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or self.bastion['sku']['name'] == BastionSku.Developer.name: + if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or \ + self.bastion['sku']['name'] == BastionSku.Developer.name: host = f"wss://{self.bastion_endpoint}/omni/webtunnel/{auth_token}" else: host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" @@ -204,7 +205,7 @@ def cleanup(self): web_address = f"https://{self.bastion_endpoint}/api/tokens/{self.last_token}" response = requests.delete(web_address, headers=custom_header, - verify=(not should_disable_connection_verify())) + verify=not should_disable_connection_verify()) if response.status_code == 404: logger.info('Session already deleted') elif response.status_code not in [200, 204]: diff --git a/src/bastion/setup.py b/src/bastion/setup.py index e7923cd974f..6972629f468 100644 --- a/src/bastion/setup.py +++ b/src/bastion/setup.py @@ -10,7 +10,7 @@ # HISTORY.rst entry. -VERSION = '0.2.5' +VERSION = '0.2.6' # The full list of classifiers is available at # https://pypi.python.org/pypi?%3Aaction=list_classifiers