Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plan command to the CLI #31

Merged
merged 7 commits into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/changeset/change_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Optional
from pydantic import BaseModel
import string


class Change(BaseModel):
"""Describes what a given change is and hot to apply it."""
identifier: str
type: str
action: str
sync_function: object
parameters: dict

def __str__(self):
return f"{self.action.upper()} {string.capwords(self.type)} {self.identifier}"
Comment on lines +14 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's room in the future for this to be a little more detailed (e.g. what was the value before, and what is it now?). This is a great start 👍🏻


def apply(self):
self.sync_function(**self.parameters)


class ChangeSet(BaseModel):
"""Store the set of changes to be displayed or applied."""
__root__: Optional[list[Change]] = []

def __iter__(self):
return iter(self.__root__)

def append(self, change: Change):
self.__root__.append(change)

def __str__(self):
list_str = [str(change) for change in self.__root__]
return "\n".join(list_str)

def __len__(self):
return len(self.__root__)

def apply(self):
for change in self.__root__:
change.apply()
48 changes: 12 additions & 36 deletions src/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CustomEnvironmentVariablePayload,
)
from schemas.job import JobDefinition
from schemas import check_env_var_same


class DBTCloud:
Expand Down Expand Up @@ -65,7 +66,7 @@ def update_job(self, job: JobDefinition) -> JobDefinition:
if response.status_code >= 400:
logger.error(response.json())

logger.success("Updated successfully.")
logger.success("Job updated successfully.")

return JobDefinition(**(response.json()["data"]), identifier=job.identifier)

Expand All @@ -83,7 +84,7 @@ def create_job(self, job: JobDefinition) -> JobDefinition:
if response.status_code >= 400:
logger.error(response.json())

logger.success("Created successfully.")
logger.success("Job created successfully.")

return JobDefinition(**(response.json()["data"]), identifier=job.identifier)

Expand All @@ -100,7 +101,7 @@ def delete_job(self, job: JobDefinition) -> None:
if response.status_code >= 400:
logger.error(response.json())

logger.warning("Deleted successfully.")
logger.success("Job deleted successfully.")

def get_jobs(self) -> List[JobDefinition]:
"""Return a list of Jobs for all the dbt Cloud jobs in an environment."""
Expand Down Expand Up @@ -206,41 +207,16 @@ def create_env_var(
return response.json()["data"]

def update_env_var(
self, custom_env_var: CustomEnvironmentVariable, project_id: int, job_id: int
) -> Optional[CustomEnvironmentVariablePayload]:
self, custom_env_var: CustomEnvironmentVariable, project_id: int, job_id: int, env_var_id: int, yml_job_identifier: str = None) -> Optional[CustomEnvironmentVariablePayload]:
"""Update env vars job overwrite in dbt Cloud."""

self._check_for_creds()

all_env_vars = self.get_env_vars(project_id, job_id)

if custom_env_var.name not in all_env_vars:
raise Exception(
f"Custom environment variable {custom_env_var.name} not found in dbt Cloud, "
f"you need to create it first."
)

env_var_id: Optional[int]

# TODO: Move this logic out of the client layer, and move it into
# at least one layer higher up. We want the dbt Cloud client to be
# as naive as possible.
if custom_env_var.name not in all_env_vars:
return self.create_env_var(
CustomEnvironmentVariablePayload(
account_id=self.account_id,
project_id=project_id,
**custom_env_var.dict(),
)
)

if all_env_vars[custom_env_var.name].value == custom_env_var.value:
logger.debug(
f"The env var {custom_env_var.name} is already up to date for the job {job_id}."
)
return None

env_var_id: int = all_env_vars[custom_env_var.name].id
# handle the case where the job was not created when we queued the function call
if yml_job_identifier and not job_id:
mapping_job_identifier_job_id = self.build_mapping_job_identifier_job_id()
job_id = mapping_job_identifier_job_id[yml_job_identifier]
custom_env_var.job_definition_id = job_id

# the endpoint is different for updating an overwrite vs creating one
if env_var_id:
Expand All @@ -266,7 +242,7 @@ def update_env_var(

self._clear_env_var_cache(job_definition_id=payload.job_definition_id)

logger.info(f"Updated the env_var {custom_env_var.name} for job {job_id}")
logger.success(f"Updated the env_var {custom_env_var.name} for job {job_id}")
return CustomEnvironmentVariablePayload(**(response.json()["data"]))

def delete_env_var(self, project_id: int, env_var_id: int) -> None:
Expand All @@ -282,7 +258,7 @@ def delete_env_var(self, project_id: int, env_var_id: int) -> None:
if response.status_code >= 400:
logger.error(response.json())

logger.warning("Deleted successfully.")
logger.success("Env Var Job Overwrite deleted successfully.")

def get_environments(self) -> Dict:
"""Return a list of Environments for all the dbt Cloud jobs in an account"""
Expand Down
180 changes: 142 additions & 38 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@
from client import DBTCloud
from loader.load import load_job_configuration
from schemas import check_job_mapping_same
from changeset.change_set import Change, ChangeSet
from schemas import check_env_var_same


@click.group()
def cli():
pass


@cli.command()
@click.argument("config", type=click.File("r"))
def sync(config):
"""Synchronize a dbt Cloud job config file against dbt Cloud.
def build_change_set(config):
"""Compares the config of YML files versus dbt Cloud.
Depending on the value of no_update, it will either update the dbt Cloud config or not.

CONFIG is the path to your jobs.yml config file.
"""
Expand All @@ -35,6 +31,8 @@ def sync(config):
job.identifier: job for job in cloud_jobs if job.identifier is not None
}

dbt_cloud_change_set = ChangeSet()

# Use sets to find jobs for different operations
shared_jobs = set(defined_jobs.keys()).intersection(set(tracked_jobs.keys()))
created_jobs = set(defined_jobs.keys()) - set(tracked_jobs.keys())
Expand All @@ -47,18 +45,39 @@ def sync(config):
if not check_job_mapping_same(
source_job=defined_jobs[identifier], dest_job=tracked_jobs[identifier]
):
dbt_cloud_change = Change(
identifier=identifier,
type="job",
action="update",
sync_function=dbt_cloud.update_job,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really clever approach! I appreciate how configurable this callback function approach will be 🚀

parameters={"job": defined_jobs[identifier]},
)
dbt_cloud_change_set.append(dbt_cloud_change)
defined_jobs[identifier].id = tracked_jobs[identifier].id
dbt_cloud.update_job(job=defined_jobs[identifier])

# Create new jobs
logger.info("Detected {count} new jobs.", count=len(created_jobs))
for identifier in created_jobs:
dbt_cloud.create_job(job=defined_jobs[identifier])
dbt_cloud_change = Change(
identifier=identifier,
type="job",
action="create",
sync_function=dbt_cloud.create_job,
parameters={"job": defined_jobs[identifier]},
)
dbt_cloud_change_set.append(dbt_cloud_change)

# Remove Deleted Jobs
logger.warning("Detected {count} deleted jobs.", count=len(deleted_jobs))
for identifier in deleted_jobs:
dbt_cloud.delete_job(job=tracked_jobs[identifier])
dbt_cloud_change = Change(
identifier=identifier,
type="job",
action="delete",
sync_function=dbt_cloud.delete_job,
parameters={"job": tracked_jobs[identifier]},
)
dbt_cloud_change_set.append(dbt_cloud_change)

# -- ENV VARS --
# Now that we have replicated all jobs we can get their IDs for further API calls
Expand All @@ -67,37 +86,122 @@ def sync(config):

# Replicate the env vars from the YML to dbt Cloud
for job in defined_jobs.values():
job_id = mapping_job_identifier_job_id[job.identifier]
for env_var_yml in job.custom_environment_variables:
env_var_yml.job_definition_id = job_id
updated_env_vars = dbt_cloud.update_env_var(
project_id=job.project_id, job_id=job_id, custom_env_var=env_var_yml

if job.identifier in mapping_job_identifier_job_id: # the job already exists
job_id = mapping_job_identifier_job_id[job.identifier]
all_env_vars_for_job = dbt_cloud.get_env_vars(
project_id=job.project_id, job_id=job_id
)
for env_var_yml in job.custom_environment_variables:
env_var_yml.job_definition_id = job_id
same_env_var, env_var_id = check_env_var_same(
source_env_var=env_var_yml, dest_env_vars=all_env_vars_for_job
)
if not same_env_var:
dbt_cloud_change = Change(
identifier=f"{job.identifier}:{env_var_yml.name}",
type="env var overwrite",
action="update",
sync_function=dbt_cloud.update_env_var,
parameters={
"project_id": job.project_id,
"job_id": job_id,
"custom_env_var": env_var_yml,
"env_var_id": env_var_id,
},
)
dbt_cloud_change_set.append(dbt_cloud_change)

else: # the job doesn't exist yet so it doesn't have an ID
for env_var_yml in job.custom_environment_variables:
dbt_cloud_change = Change(
identifier=f"{job.identifier}:{env_var_yml.name}",
type="env var overwrite",
action="create",
sync_function=dbt_cloud.update_env_var,
parameters={
"project_id": job.project_id,
"job_id": None,
"custom_env_var": env_var_yml,
"env_var_id": None,
"yml_job_identifier": job.identifier,
},
)
dbt_cloud_change_set.append(dbt_cloud_change)

# Delete the env vars from dbt Cloud that are not in the yml
for job in defined_jobs.values():
job_id = mapping_job_identifier_job_id[job.identifier]

# We get the env vars from dbt Cloud, now that the YML ones have been replicated
env_var_dbt_cloud = dbt_cloud.get_env_vars(
project_id=job.project_id, job_id=job_id
)
# we only delete env var overwrite if the job already exists
if job.identifier in mapping_job_identifier_job_id:
job_id = mapping_job_identifier_job_id[job.identifier]

# And we get the list of env vars defined for a given job in the YML
env_vars_for_job = [
env_var.name for env_var in job.custom_environment_variables
]
# We get the env vars from dbt Cloud, now that the YML ones have been replicated
env_var_dbt_cloud = dbt_cloud.get_env_vars(
project_id=job.project_id, job_id=job_id
)

for env_var, env_var_val in env_var_dbt_cloud.items():
# If the env var is not in the YML but is defined at the "job" level in dbt Cloud, we delete it
if env_var not in env_vars_for_job and env_var_val.id:
logger.info(f"{env_var} not in the YML file but in the dbt Cloud job")
dbt_cloud.delete_env_var(
project_id=job.project_id, env_var_id=env_var_val.id
)
logger.info(
f"Deleted the env_var {env_var} for the job {job.identifier}"
)
# And we get the list of env vars defined for a given job in the YML
env_vars_for_job = [
env_var.name for env_var in job.custom_environment_variables
]

for env_var, env_var_val in env_var_dbt_cloud.items():
# If the env var is not in the YML but is defined at the "job" level in dbt Cloud, we delete it
if env_var not in env_vars_for_job and env_var_val.id:
logger.info(
f"{env_var} not in the YML file but in the dbt Cloud job"
)
dbt_cloud_change = Change(
identifier=f"{job.identifier}:{env_var_yml.name}",
type="env var overwrite",
action="delete",
sync_function=dbt_cloud.delete_env_var,
parameters={
"project_id": job.project_id,
"env_var_id": env_var_val.id,
},
)
dbt_cloud_change_set.append(dbt_cloud_change)

return dbt_cloud_change_set


@click.group()
def cli():
pass


@cli.command()
@click.argument("config", type=click.File("r"))
def sync(config):
"""Synchronize a dbt Cloud job config file against dbt Cloud.

CONFIG is the path to your jobs.yml config file.
"""
change_set = build_change_set(config)
if len(change_set) == 0:
logger.success("-- PLAN -- No changes detected.")
else:
logger.warning("-- PLAN -- {count} changes detected.", count=len(change_set))
print(change_set)
logger.info("-- SYNC --")
change_set.apply()


@cli.command()
@click.argument("config", type=click.File("r"))
def plan(config):
"""Check the difference between a local file and dbt Cloud without updating dbt Cloud.

CONFIG is the path to your jobs.yml config file.
"""
change_set = build_change_set(config)
if len(change_set) == 0:
logger.success("-- PLAN -- No changes detected.")
else:
logger.warning("-- PLAN -- {count} changes detected.", count=len(change_set))
print(change_set)


@cli.command()
Expand All @@ -121,11 +225,11 @@ def validate(config, online):
if not online:
return

# Retrive the list of Project IDs and Environment IDs from the config file
# Retrieve the list of Project IDs and Environment IDs from the config file
config_project_ids = set([job.project_id for job in defined_jobs])
config_environment_ids = set([job.environment_id for job in defined_jobs])

# Retrieve the list of Project IDs and Environment IDs from dbt Cloudby calling the environment API endpoint
# Retrieve the list of Project IDs and Environment IDs from dbt Cloud by calling the environment API endpoint
dbt_cloud = DBTCloud(
account_id=list(defined_jobs)[0].account_id,
api_key=os.environ.get("API_KEY"),
Expand Down
Loading