Skip to content

Commit

Permalink
[Bastion] adding rdp file to temp location and adding auth-type for r…
Browse files Browse the repository at this point in the history
…dp (#7006)

* adding rdp file to temp location and adding auth-type for rdp

* fixing some pylint issues

* fixing some pylint issues
  • Loading branch information
aavalang authored Nov 20, 2023
1 parent 38182e8 commit 1d9de52
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 34 deletions.
5 changes: 5 additions & 0 deletions src/bastion/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
Release History
===============
0.2.6
++++++
* Adding auth type aad for RDP to mimic the enable-mfa flag.
* Fixing issue where if powershell is opened in system32 directory, file generation throws error. Files are now dumped in temp folder.

0.2.5
++++++
* Fixing the command `az network bastion rdp` to avoid the `java.lang.NullPointerException` while calling `get_auth_token` function
Expand Down
2 changes: 2 additions & 0 deletions src/bastion/azext_bastion/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def load_arguments(self, _): # pylint: disable=unused-argument
c.argument("configure", help="Flag to configure RDP session.", action="store_true")
c.argument("disable_gateway", help="Flag to disable access through RD gateway.",
arg_type=get_three_state_flag())
c.argument("auth_type", help="Auth type to use for RDP connections.", required=False,
options_list=["--auth-type"])
c.argument('enable_mfa', help='Enable RDS auth for MFA if supported by the target machine.',
arg_type=get_three_state_flag())
with self.argument_context("network bastion tunnel") as c:
Expand Down
68 changes: 44 additions & 24 deletions src/bastion/azext_bastion/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import threading
import time
import json
import uuid

import requests
from azure.cli.core.azclierror import ValidationError, InvalidArgumentValueError, RequiredArgumentMissingError, \
Expand Down Expand Up @@ -148,17 +149,18 @@ def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, reso
if not resource_port:
resource_port = 22

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and \
bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
ip_connect = _is_ipconnect_request(bastion, target_ip_address)
if ip_connect:
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

if ip_connect and int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")
if int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}"
f"/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

_validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
_validate_resourceid(target_resource_id)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port)
Expand Down Expand Up @@ -227,7 +229,7 @@ def _get_rdp_path(rdp_command="mstsc"):


def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name,
resource_port=None, disable_gateway=False, configure=False, enable_mfa=False):
auth_type=None, resource_port=None, disable_gateway=False, configure=False, enable_mfa=False):
import os
from azure.cli.core._profile import Profile
from ._process_helper import launch_and_wait
Expand All @@ -241,17 +243,31 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
if not resource_port:
resource_port = 3389

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and \
bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
ip_connect = _is_ipconnect_request(bastion, target_ip_address)

if auth_type is None:
# do nothing
pass
elif auth_type.lower() == "aad":
enable_mfa = True

if disable_gateway or ip_connect:
raise UnrecognizedArgumentError("AAD login is not supported for Disable Gateway & IP Connect scenarios.")
else:
raise UnrecognizedArgumentError("Unknown auth type, support auth-types: aad. For non aad login, you dont need to provide auth-type flag.")

if ip_connect:
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
if int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")

if ip_connect and int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}"
f"/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

_validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
_validate_resourceid(target_resource_id)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

if platform.system() == "Windows":
Expand All @@ -269,7 +285,8 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
profile = Profile(cli_ctx=cmd.cli_ctx)
access_token = profile.get_raw_token()[0][2].get("accessToken")
logger.debug("Response %s", access_token)
web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}"
web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp"
f"&rdpport={resource_port}&enablerdsaad={enable_mfa}"

headers = {
"Authorization": f"Bearer {access_token}",
Expand All @@ -285,8 +302,10 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
raise ClientRequestError("Request failed with error: " + errorMessage)
raise ClientRequestError("Request to EncodingReservedUnitTypes v2 API endpoint failed.")

_write_to_file(response)
rdpfilepath = os.getcwd() + "/conn.rdp"
tempdir = os.path.realpath(tempfile.gettempdir())
rdpfilepath = os.path.join(tempdir, 'conn_{}.rdp'.format(uuid.uuid4().hex))
_write_to_file(response, rdpfilepath)

command = [_get_rdp_path()]
if configure:
command.append("/edit")
Expand All @@ -296,14 +315,14 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows")


def _is_ipconnect_request(cmd, bastion, target_ip_address):
def _is_ipconnect_request(bastion, target_ip_address):
if 'enableIpConnect' in bastion and bastion['enableIpConnect'] is True and target_ip_address:
return True

return False


def _validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address):
def _validate_resourceid(target_resource_id):
if not is_valid_resource_id(target_resource_id):
err_msg = "Please enter a valid resource ID. If this is not working, " \
"try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID."
Expand All @@ -319,8 +338,8 @@ def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id):
return bastion['dnsName']


def _write_to_file(response):
with open("conn.rdp", "w", encoding="utf-8") as f:
def _write_to_file(response, file_path):
with open(file_path, "w", encoding="utf-8") as f:
for line in response.text.splitlines():
f.write(line + "\n")

Expand Down Expand Up @@ -358,14 +377,15 @@ def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_g
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
ip_connect = _is_ipconnect_request(bastion, target_ip_address)
if ip_connect:
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/"
f"{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

if ip_connect and int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")

_validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
_validate_resourceid(target_resource_id)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, port)
Expand Down
2 changes: 1 addition & 1 deletion src/bastion/azext_bastion/developer_sku_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def _get_data_pod(cmd, resource_port, target_resource_id, bastion):

web_address = f"https://{bastion['dnsName']}/api/connection"
response = requests.post(web_address, json=content, headers=headers,
verify=(not should_disable_connection_verify()))
verify=not should_disable_connection_verify())

return response.content.decode("utf-8")
17 changes: 9 additions & 8 deletions src/bastion/azext_bastion/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@
from contextlib import closing
from datetime import datetime
from threading import Thread
import requests
import urllib3

import websocket
from websocket import create_connection, WebSocket

from msrestazure.azure_exceptions import CloudError
from .BastionServiceConstants import BastionSku

from azure.cli.core._profile import Profile
from azure.cli.core.util import should_disable_connection_verify

import requests
import urllib3

from knack.util import CLIError
from knack.log import get_logger

from .BastionServiceConstants import BastionSku

logger = get_logger(__name__)


Expand Down Expand Up @@ -96,7 +96,7 @@ def _get_auth_token(self):
logger.debug("Content: %s", str(content))
web_address = f"https://{self.bastion_endpoint}/api/tokens"
response = requests.post(web_address, data=content, headers=custom_header,
verify=(not should_disable_connection_verify()))
verify=not should_disable_connection_verify())
response_json = None

if response.content is not None:
Expand All @@ -121,7 +121,8 @@ def _listen(self):
self.client, _address = self.sock.accept()

auth_token = self._get_auth_token()
if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or self.bastion['sku']['name'] == BastionSku.Developer.name:
if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or \
self.bastion['sku']['name'] == BastionSku.Developer.name:
host = f"wss://{self.bastion_endpoint}/omni/webtunnel/{auth_token}"
else:
host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}"
Expand Down Expand Up @@ -204,7 +205,7 @@ def cleanup(self):

web_address = f"https://{self.bastion_endpoint}/api/tokens/{self.last_token}"
response = requests.delete(web_address, headers=custom_header,
verify=(not should_disable_connection_verify()))
verify=not should_disable_connection_verify())
if response.status_code == 404:
logger.info('Session already deleted')
elif response.status_code not in [200, 204]:
Expand Down
2 changes: 1 addition & 1 deletion src/bastion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


# HISTORY.rst entry.
VERSION = '0.2.5'
VERSION = '0.2.6'

# The full list of classifiers is available at
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
Expand Down

0 comments on commit 1d9de52

Please sign in to comment.