Skip to content

Commit

Permalink
[task][vm] add remote host to ssh known hosts (#1342)
Browse files Browse the repository at this point in the history
  • Loading branch information
pducolin authored Jan 7, 2025
1 parent e8f27fd commit 4372e33
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ Pulumi.*.yaml
# Build files
dist/main
mem.pprof

# Python cache files
**/__pycache__/
4 changes: 2 additions & 2 deletions integration-tests/invoke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func testAzureInvokeVM(t *testing.T, tmpConfigFile string, workingDirectory stri
stackName = sanitizeStackName(stackName)

t.Log("creating vm")
createCmd := exec.Command("invoke", "az.create-vm", "--no-interactive", "--stack-name", stackName, "--config-path", tmpConfigFile)
createCmd := exec.Command("invoke", "az.create-vm", "--no-interactive", "--stack-name", stackName, "--config-path", tmpConfigFile, "--no-add-known-host")
createCmd.Dir = workingDirectory
createOutput, err := createCmd.Output()
assert.NoError(t, err, "Error found creating vm: %s", string(createOutput))
Expand All @@ -103,7 +103,7 @@ func testAwsInvokeVM(t *testing.T, tmpConfigFile string, workingDirectory string
stackName = sanitizeStackName(stackName)

t.Log("creating vm")
createCmd := exec.Command("invoke", "aws.create-vm", "--no-interactive", "--stack-name", stackName, "--config-path", tmpConfigFile, "--use-fakeintake")
createCmd := exec.Command("invoke", "aws.create-vm", "--no-interactive", "--stack-name", stackName, "--config-path", tmpConfigFile, "--use-fakeintake", "--no-add-known-host")
createCmd.Dir = workingDirectory
createOutput, err := createCmd.Output()
assert.NoError(t, err, "Error found creating vm: %s", string(createOutput))
Expand Down
9 changes: 8 additions & 1 deletion tasks/aws/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from tasks.aws.deploy import deploy
from tasks.destroy import destroy
from tasks.tool import add_known_host as add_known_host_func
from tasks.tool import clean_known_hosts as clean_known_hosts_func
from tasks.tool import get_host, notify, show_connection_message

Expand Down Expand Up @@ -43,6 +44,7 @@
"no_verify": doc.no_verify,
"ssh_user": doc.ssh_user,
"os_version": doc.os_version,
"add_known_host": doc.add_known_host,
}
)
def create_vm(
Expand All @@ -64,6 +66,7 @@ def create_vm(
instance_type: Optional[str] = None,
no_verify: Optional[bool] = False,
ssh_user: Optional[str] = None,
add_known_host: Optional[bool] = True,
) -> None:
"""
Create a new virtual machine on aws.
Expand Down Expand Up @@ -111,6 +114,10 @@ def create_vm(
if interactive:
notify(ctx, "Your VM is now created")

if add_known_host:
host = get_host(ctx, remote_hostname, scenario_name, stack_name)
add_known_host_func(ctx, host)

show_connection_message(ctx, remote_hostname, full_stack_name, interactive)


Expand Down Expand Up @@ -138,7 +145,7 @@ def destroy_vm(
stack=stack_name,
)
if clean_known_hosts:
clean_known_hosts_func(host)
clean_known_hosts_func(ctx, host)


def _get_os_family(os_family: Optional[str]) -> str:
Expand Down
11 changes: 9 additions & 2 deletions tasks/azure/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tasks.config import get_full_profile_path
from tasks.deploy import deploy
from tasks.destroy import destroy
from tasks.tool import add_known_host as add_known_host_func
from tasks.tool import clean_known_hosts as clean_known_hosts_func
from tasks.tool import get_host, show_connection_message

Expand All @@ -38,6 +39,7 @@
"architecture": azure_doc.architecture,
"instance_type": azure_doc.instance_type,
"os_version": doc.os_version,
"add_known_host": doc.add_known_host,
}
)
def create_vm(
Expand All @@ -58,6 +60,7 @@ def create_vm(
deploy_job: Optional[str] = None,
no_verify: Optional[bool] = False,
use_fakeintake: Optional[bool] = False,
add_known_host: Optional[bool] = True,
) -> None:
"""
Create a new virtual machine on azure.
Expand Down Expand Up @@ -105,6 +108,10 @@ def create_vm(
if interactive:
tool.notify(ctx, "Your VM is now created")

if add_known_host:
host = get_host(ctx, remote_hostname, scenario_name, stack_name)
add_known_host_func(ctx, host)

show_connection_message(ctx, remote_hostname, full_stack_name, interactive)


Expand All @@ -124,15 +131,15 @@ def destroy_vm(
"""
Destroy a new virtual machine on azure.
"""
host = get_host(ctx, remote_hostname, scenario_name, stack_name)
destroy(
ctx,
scenario_name=scenario_name,
config_path=config_path,
stack=stack_name,
)
if clean_known_hosts:
clean_known_hosts_func(host)
host = get_host(ctx, remote_hostname, scenario_name, stack_name)
clean_known_hosts_func(ctx, host)


def _get_os_information(os_family: Optional[str], arch: Optional[str]) -> Tuple[str, Optional[str]]:
Expand Down
1 change: 1 addition & 0 deletions tasks/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
os_version: str = "The version of the OS to use (default will be selected depending on the OS family). See https://github.com/DataDog/test-infra-definitions/blob/main/components/os/linux_descriptors.go for a list of version available for a given OS (https://github.com/DataDog/test-infra-definitions/blob/main/components/os/windows_descriptors.go for Windows)"
full_image_path: str = "The full image path (registry:tag) of the Agent image to deploy"
cluster_agent_full_image_path: str = "The full image path (registry:tag) of the Cluster Agent image to deploy"
add_known_host: str = "Add the host to the known_hosts file (default True)"
22 changes: 18 additions & 4 deletions tasks/gcp/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@
get_deploy_job,
get_os_families,
)
from tasks.tool import clean_known_hosts as clean_known_hosts_func
from tasks.tool import get_host, show_connection_message
from tasks.tool import (
add_known_host as add_known_hosts_func,
)
from tasks.tool import (
clean_known_hosts as clean_known_hosts_func,
)
from tasks.tool import (
get_host,
show_connection_message,
)

scenario_name = "gcp/vm"
remote_hostname = "gcp-vm"
Expand All @@ -38,6 +46,7 @@
"architecture": gcp_doc.architecture,
"instance_type": gcp_doc.instance_type,
"os_version": doc.os_version,
"add_known_host": doc.add_known_host,
}
)
def create_vm(
Expand All @@ -58,6 +67,7 @@ def create_vm(
deploy_job: Optional[str] = None,
no_verify: Optional[bool] = False,
use_fakeintake: Optional[bool] = False,
add_known_host: Optional[bool] = True,
) -> None:
"""
Create a new virtual machine on gcp.
Expand Down Expand Up @@ -105,6 +115,10 @@ def create_vm(
if interactive:
tool.notify(ctx, "Your VM is now created")

if add_known_host:
host = get_host(ctx, remote_hostname, scenario_name, stack_name)
add_known_hosts_func(ctx, host)

show_connection_message(ctx, remote_hostname, full_stack_name, interactive)


Expand All @@ -124,15 +138,15 @@ def destroy_vm(
"""
Destroy a new virtual machine on gcp.
"""
host = get_host(ctx, remote_hostname, scenario_name, stack_name)
destroy(
ctx,
scenario_name=scenario_name,
config_path=config_path,
stack=stack_name,
)
if clean_known_hosts:
clean_known_hosts_func(host)
host = get_host(ctx, remote_hostname, scenario_name, stack_name)
clean_known_hosts_func(ctx, host)


def _get_os_information(os_family: Optional[str], arch: Optional[str]) -> Tuple[str, Optional[str]]:
Expand Down
15 changes: 14 additions & 1 deletion tasks/localpodman/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from tasks import config, doc
from tasks.deploy import deploy
from tasks.destroy import destroy
from tasks.tool import notify, show_connection_message
from tasks.tool import add_known_host as add_known_host_func
from tasks.tool import clean_known_hosts as clean_known_hosts_func
from tasks.tool import get_host, notify, show_connection_message

scenario_name = "localpodman/vm"
remote_hostname = "local-podman-vm"
Expand All @@ -24,6 +26,7 @@
"debug": doc.debug,
"use_fakeintake": doc.fakeintake,
"interactive": doc.interactive,
"add_known_host": doc.add_known_host,
}
)
def create_vm(
Expand All @@ -36,6 +39,7 @@ def create_vm(
debug: Optional[bool] = False,
use_fakeintake: Optional[bool] = False,
interactive: Optional[bool] = True,
add_known_host: Optional[bool] = True,
) -> None:
"""
Create a new virtual machine on local podman.
Expand Down Expand Up @@ -69,19 +73,25 @@ def create_vm(
if interactive:
notify(ctx, "Your VM is now created")

if add_known_host:
host = get_host(ctx, remote_hostname, scenario_name, stack_name)
add_known_host_func(ctx, host)

show_connection_message(ctx, remote_hostname, full_stack_name, interactive)


@task(
help={
"config_path": doc.config_path,
"stack_name": doc.stack_name,
"clean_known_hosts": doc.clean_known_hosts,
}
)
def destroy_vm(
ctx: Context,
config_path: Optional[str] = None,
stack_name: Optional[str] = None,
clean_known_hosts: Optional[bool] = True,
):
"""
Destroy a new virtual machine on aws.
Expand All @@ -92,3 +102,6 @@ def destroy_vm(
config_path=config_path,
stack=stack_name,
)
if clean_known_hosts:
host = get_host(ctx, remote_hostname, scenario_name, stack_name)
clean_known_hosts_func(ctx, host)
24 changes: 16 additions & 8 deletions tasks/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,17 +262,25 @@ def show_connection_message(
pyperclip.copy(command)


def clean_known_hosts(host: str) -> None:
def add_known_host(ctx: Context, host: str) -> None:
"""
Remove the host from the known_hosts file.
Add the host to the known_hosts file.
"""
home = os.environ.get("HOME", f"/Users/{getpass.getuser()}")
with open(f"{home}/.ssh/known_hosts") as f:
lines = f.readlines()
# remove the host if it already exists
clean_known_hosts(ctx, host)
result = ctx.run(f"ssh-keyscan {host}", hide=True)
if result and result.ok:
home = pathlib.Path.home()
filtered_hosts = '\n'.join([line for line in result.stdout.splitlines() if not line.startswith("#")])
with open(os.path.join(home, ".ssh", "known_hosts"), "a") as f:
f.write(filtered_hosts)


filtered_lines = [line for line in lines if not line.startswith(host)]
with open(f"{home}/.ssh/known_hosts", "w") as f:
f.writelines(filtered_lines)
def clean_known_hosts(ctx: Context, host: str) -> None:
"""
Remove the host from the known_hosts file.
"""
ctx.run(f"ssh-keygen -R {host}", hide=True)


def get_host(ctx: Context, remote_host_name: str, scenario_name: str, stack_name: Optional[str] = None) -> str:
Expand Down

0 comments on commit 4372e33

Please sign in to comment.