Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new workspace_url format + fix azure test #114

Merged
merged 23 commits into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
{
"name": "databricks-terraform",
"dockerFile": "Dockerfile",
"mounts": [
// Mount go mod cache
"source=databricks-terraform-gomodcache,target=/go/pkg,type=volume",
// Keep command history
"source=databricks-terraform-bashhistory,target=/commandhistory,type=volume",
// Mount docker socket for docker builds
"source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind",
],
"runArgs": [
// Uncomment the next line to use a non-root user. On Linux, this will prevent
// new files getting created as root, but you may need to update the USER_UID
Expand All @@ -12,14 +20,8 @@
"--security-opt",
"seccomp=unconfined",
"--privileged",

"--name", "databricks-terraform-devcontainer",
// Mount go mod cache
"-v", "databricks-terraform-gomodcache:/go/pkg",
// Keep command history
"-v", "databricks-terraform-bashhistory:/root/commandhistory",
// Mount docker socket for docker builds
"-v", "/var/run/docker.sock:/var/run/docker.sock",
"--name",
"databricks-terraform-devcontainer",
// Use host network
"--network=host",
],
Expand Down Expand Up @@ -51,8 +53,9 @@
// "postCreateCommand": "go version",
// Add the IDs of extensions you want installed when the container is created in the array below.
"extensions": [
"ms-vscode.go",
"golang.go",
"mauve.terraform",
"hashicorp.terraform",
"ms-vsliveshare.vsliveshare"
]
}
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ vendor:
# INTEGRATION TESTING WITH AZURE
terraform-acc-azure: fmt
@echo "==> Running Terraform Acceptance Tests for Azure..."
@CLOUD_ENV="azure" TF_ACC=1 gotestsum --format short-verbose --raw-command go test -v -json -tags=azure -short -coverprofile=coverage.out ./...
@/bin/bash integration-environment-azure/run.sh

# INTEGRATION TESTING WITH AWS
terraform-acc-aws: fmt
Expand Down
52 changes: 10 additions & 42 deletions databricks/azure_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package databricks
import (
"encoding/json"
"fmt"
"log"
"net/http"

"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/databrickslabs/databricks-terraform/client/service"
"log"
"net/http"
urlParse "net/url"
)

// List of management information
Expand All @@ -23,6 +23,7 @@ type AzureAuth struct {
AdbWorkspaceResourceID string
AdbAccessToken string
AdbPlatformToken string
AdbWorkspaceUrl string
}

// TokenPayload contains all the auth information for azure sp authentication
Expand Down Expand Up @@ -73,38 +74,6 @@ func (a *AzureAuth) getManagementToken(config *service.DBApiClientConfig) error
return nil
}

func (a *AzureAuth) getWorkspaceID(config *service.DBApiClientConfig) error {
log.Println("[DEBUG] Getting Workspace ID via management token.")
// Escape all the ids
url := fmt.Sprintf("https://management.azure.com/subscriptions/%s/resourceGroups/%s"+
"/providers/Microsoft.Databricks/workspaces/%s",
urlParse.PathEscape(a.TokenPayload.SubscriptionID),
urlParse.PathEscape(a.TokenPayload.ResourceGroup),
urlParse.PathEscape(a.TokenPayload.WorkspaceName))
headers := map[string]string{
"Content-Type": "application/json",
"cache-control": "no-cache",
"Authorization": "Bearer " + a.ManagementToken,
}
type apiVersion struct {
ApiVersion string `url:"api-version"`
}
uriPayload := apiVersion{
ApiVersion: "2018-04-01",
}
var responseMap map[string]interface{}
resp, err := service.PerformQuery(config, http.MethodGet, url, "2.0", headers, false, true, uriPayload, nil)
if err != nil {
return err
}
err = json.Unmarshal(resp, &responseMap)
if err != nil {
return err
}
a.AdbWorkspaceResourceID = responseMap["id"].(string)
return err
}

func (a *AzureAuth) getADBPlatformToken(clientConfig *service.DBApiClientConfig) error {
log.Println("[DEBUG] Creating Azure Databricks management OAuth token.")
platformTokenOAuthCfg, err := adal.NewOAuthConfigWithAPIVersion(azure.PublicCloud.ActiveDirectoryEndpoint,
Expand Down Expand Up @@ -177,11 +146,11 @@ func (a *AzureAuth) initWorkspaceAndGetClient(config *service.DBApiClientConfig)
return err
}

// Get workspace access token
err = a.getWorkspaceID(config)
if err != nil {
return err
}
a.AdbWorkspaceResourceID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s"+
"/providers/Microsoft.Databricks/workspaces/%s",
a.TokenPayload.SubscriptionID,
a.TokenPayload.ResourceGroup,
a.TokenPayload.WorkspaceName)

// Get platform token
err = a.getADBPlatformToken(config)
Expand All @@ -195,8 +164,7 @@ func (a *AzureAuth) initWorkspaceAndGetClient(config *service.DBApiClientConfig)
return err
}

//// TODO: Eventually change this to include new Databricks domain names. May have to add new vars and/or deprecate existing args.
config.Host = "https://" + a.TokenPayload.AzureRegion + ".azuredatabricks.net"
config.Host = a.AdbWorkspaceUrl
config.Token = a.AdbAccessToken

return nil
Expand Down
15 changes: 15 additions & 0 deletions databricks/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ func Provider(version string) terraform.ResourceProvider {
Required: true,
DefaultFunc: schema.EnvDefaultFunc("DATABRICKS_AZURE_WORKSPACE_NAME", nil),
},
"workspace_url": {
Type: schema.TypeString,
Required: true,
DefaultFunc: schema.EnvDefaultFunc("DATABRICKS_AZURE_WORKSPACE_URL", nil),
},
"resource_group": {
Type: schema.TypeString,
Required: true,
Expand Down Expand Up @@ -169,6 +174,8 @@ func providerConfigureAzureClient(d *schema.ResourceData, providerVersion string
azureAuthMap := azureAuth.(map[string]interface{})
//azureAuth AzureAuth{}
tokenPayload := TokenPayload{}
adbWorkspaceURL := ""

// The if else is required for the reason that "azure_auth" schema object is not a block but a map
// Maps do not inherently auto populate defaults from environment variables unless we explicitly assign values
// This makes it very difficult to test
Expand All @@ -192,6 +199,12 @@ func providerConfigureAzureClient(d *schema.ResourceData, providerVersion string
} else if os.Getenv("DATABRICKS_AZURE_WORKSPACE_NAME") != "" {
tokenPayload.WorkspaceName = os.Getenv("DATABRICKS_AZURE_WORKSPACE_NAME")
}
// TODO: Can required field not be set?
if workspaceURL, ok := azureAuthMap["workspace_url"].(string); ok {
EliiseS marked this conversation as resolved.
Show resolved Hide resolved
adbWorkspaceURL = workspaceURL
} else if os.Getenv("DATABRICKS_AZURE_WORKSPACE_URL") != "" {
adbWorkspaceURL = os.Getenv("DATABRICKS_AZURE_WORKSPACE_URL")
}

// This provider takes DATABRICKS_AZURE_* for client ID etc
// The azurerm provider uses ARM_* for the same values
Expand Down Expand Up @@ -228,12 +241,14 @@ func providerConfigureAzureClient(d *schema.ResourceData, providerVersion string
tokenPayload.TenantID = os.Getenv("ARM_TENANT_ID")
}

adbWorkspaceURL = "https://" + adbWorkspaceURL
azureAuthSetup := AzureAuth{
TokenPayload: &tokenPayload,
ManagementToken: "",
AdbWorkspaceResourceID: "",
AdbAccessToken: "",
AdbPlatformToken: "",
AdbWorkspaceUrl: adbWorkspaceURL,
}

// Setup the CustomAuthorizer Function to be called at API invoke rather than client invoke
Expand Down
26 changes: 5 additions & 21 deletions databricks/resource_databricks_azure_adls_gen2_mount_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,34 +42,18 @@ func testAccAzureAdlsGen2Mount_correctly_mounts() string {
clientID := os.Getenv("ARM_CLIENT_ID")
clientSecret := os.Getenv("ARM_CLIENT_SECRET")
tenantID := os.Getenv("ARM_TENANT_ID")
subscriptionID := os.Getenv("ARM_SUBSCRIPTION_ID")
workspaceName := os.Getenv("TEST_WORKSPACE_NAME")
resourceGroupName := os.Getenv("TEST_RESOURCE_GROUP")
managedResourceGroupName := os.Getenv("TEST_MANAGED_RESOURCE_GROUP")
location := os.Getenv("TEST_LOCATION")
gen2AdalName := os.Getenv("TEST_GEN2_ADAL_NAME")

definition := fmt.Sprintf(`
provider "databricks" {
azure_auth = {
client_id = "%[1]s"
client_secret = "%[2]s"
tenant_id = "%[3]s"
subscription_id = "%[4]s"

workspace_name = "%[5]s"
resource_group = "%[6]s"
managed_resource_group = "%[7]s"
azure_region = "%[8]s"
}
}

resource "databricks_cluster" "cluster" {
num_workers = 1
spark_version = "6.4.x-scala2.11"
node_type_id = "Standard_D3_v2"
# Don't spend too much, turn off cluster after 15mins
autotermination_minutes = 15
spark_conf = {
"spark.databricks.delta.preview.enabled": "false"
}
}

resource "databricks_secret_scope" "terraform" {
Expand All @@ -88,7 +72,7 @@ func testAccAzureAdlsGen2Mount_correctly_mounts() string {
resource "databricks_azure_adls_gen2_mount" "mount" {
cluster_id = databricks_cluster.cluster.id
container_name = "dev" # Created by prereqs.tf
storage_account_name = "%[9]s"
storage_account_name = "%[4]s"
directory = ""
mount_name = "localdir${databricks_cluster.cluster.cluster_id}"
tenant_id = "%[3]s"
Expand All @@ -98,7 +82,7 @@ func testAccAzureAdlsGen2Mount_correctly_mounts() string {
initialize_file_system = true
}

`, clientID, clientSecret, tenantID, subscriptionID, workspaceName, resourceGroupName, managedResourceGroupName, location, gen2AdalName)
`, clientID, clientSecret, tenantID, gen2AdalName)
return definition
}

Expand Down
28 changes: 3 additions & 25 deletions databricks/resource_databricks_azure_blob_mount_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,32 +97,10 @@ func testAccAzureBlobMount_mount_exists(n string, azureBlobMount *AzureBlobMount
}

func testAccAzureBlobMount_correctly_mounts() string {
clientID := os.Getenv("ARM_CLIENT_ID")
clientSecret := os.Getenv("ARM_CLIENT_SECRET")
tenantID := os.Getenv("ARM_TENANT_ID")
subscriptionID := os.Getenv("ARM_SUBSCRIPTION_ID")
workspaceName := os.Getenv("TEST_WORKSPACE_NAME")
resourceGroupName := os.Getenv("TEST_RESOURCE_GROUP")
managedResourceGroupName := os.Getenv("TEST_MANAGED_RESOURCE_GROUP")
location := os.Getenv("TEST_LOCATION")
blobAccountKey := os.Getenv("TEST_STORAGE_ACCOUNT_KEY")
blobAccountName := os.Getenv("TEST_STORAGE_ACCOUNT_NAME")

definition := fmt.Sprintf(`
provider "databricks" {
azure_auth = {
client_id = "%[1]s"
client_secret = "%[2]s"
tenant_id = "%[3]s"
subscription_id = "%[4]s"

workspace_name = "%[5]s"
resource_group = "%[6]s"
managed_resource_group = "%[7]s"
azure_region = "%[8]s"
}
}

EliiseS marked this conversation as resolved.
Show resolved Hide resolved
resource "databricks_cluster" "cluster" {
num_workers = 1
spark_version = "6.4.x-scala2.11"
Expand All @@ -140,20 +118,20 @@ func testAccAzureBlobMount_correctly_mounts() string {

resource "databricks_secret" "storage_key" {
key = "blob_storage_key"
string_value = "%[10]s"
string_value = "%[1]s"
scope = databricks_secret_scope.terraform.name
}

resource "databricks_azure_blob_mount" "mount" {
cluster_id = databricks_cluster.cluster.id
container_name = "dev" # Created by prereqs.tf
storage_account_name = "%[9]s"
storage_account_name = "%[2]s"
mount_name = "dev"
auth_type = "ACCESS_KEY"
token_secret_scope = databricks_secret_scope.terraform.name
token_secret_key = databricks_secret.storage_key.key
}

`, clientID, clientSecret, tenantID, subscriptionID, workspaceName, resourceGroupName, managedResourceGroupName, location, blobAccountName, blobAccountKey)
`, blobAccountKey, blobAccountName)
return definition
}
80 changes: 40 additions & 40 deletions databricks/resource_databricks_job_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,49 +95,49 @@ func testAwsJobValuesNewCluster(t *testing.T, job *model.Job) resource.TestCheck
assert.NotNil(t, job.Settings.NotebookTask)
assert.Equal(t, 2, int(job.Settings.NewCluster.Autoscale.MinWorkers))
assert.Equal(t, 3, int(job.Settings.NewCluster.Autoscale.MaxWorkers))
assert.Equal(t, "6.4.x-scala2.11", job.Settings.NewCluster.SparkVersion)
assert.Equal(t, model.AwsAvailability(model.AwsAvailabilitySpot), job.Settings.NewCluster.AwsAttributes.Availability)
assert.Equal(t, "us-east-1a", job.Settings.NewCluster.AwsAttributes.ZoneID)
assert.Equal(t, 100, int(job.Settings.NewCluster.AwsAttributes.SpotBidPricePercent))
assert.Equal(t, 1, int(job.Settings.NewCluster.AwsAttributes.FirstOnDemand))
assert.Equal(t, model.EbsVolumeType(model.EbsVolumeTypeGeneralPurposeSsd), job.Settings.NewCluster.AwsAttributes.EbsVolumeType)
assert.Equal(t, 1, int(job.Settings.NewCluster.AwsAttributes.EbsVolumeCount))
assert.Equal(t, 32, int(job.Settings.NewCluster.AwsAttributes.EbsVolumeSize))
assert.Equal(t, "r3.xlarge", job.Settings.NewCluster.NodeTypeID)
assert.Equal(t, "/Users/[email protected]/my-demo-notebook", job.Settings.NotebookTask.NotebookPath)
assert.Equal(t, "my-demo-notebook", job.Settings.Name)
assert.Equal(t, 3600, int(job.Settings.TimeoutSeconds))
assert.Equal(t, 1, int(job.Settings.MaxRetries))
assert.Equal(t, 1, int(job.Settings.MaxConcurrentRuns))
assert.Equal(t, "6.4.x-scala2.11", job.Settings.NewCluster.SparkVersion)
assert.Equal(t, model.AwsAvailability(model.AwsAvailabilitySpot), job.Settings.NewCluster.AwsAttributes.Availability)
assert.Equal(t, "us-east-1a", job.Settings.NewCluster.AwsAttributes.ZoneID)
assert.Equal(t, 100, int(job.Settings.NewCluster.AwsAttributes.SpotBidPricePercent))
assert.Equal(t, 1, int(job.Settings.NewCluster.AwsAttributes.FirstOnDemand))
assert.Equal(t, model.EbsVolumeType(model.EbsVolumeTypeGeneralPurposeSsd), job.Settings.NewCluster.AwsAttributes.EbsVolumeType)
assert.Equal(t, 1, int(job.Settings.NewCluster.AwsAttributes.EbsVolumeCount))
assert.Equal(t, 32, int(job.Settings.NewCluster.AwsAttributes.EbsVolumeSize))
assert.Equal(t, "r3.xlarge", job.Settings.NewCluster.NodeTypeID)
assert.Equal(t, "/Users/[email protected]/my-demo-notebook", job.Settings.NotebookTask.NotebookPath)
assert.Equal(t, "my-demo-notebook", job.Settings.Name)
assert.Equal(t, 3600, int(job.Settings.TimeoutSeconds))
assert.Equal(t, 1, int(job.Settings.MaxRetries))
assert.Equal(t, 1, int(job.Settings.MaxConcurrentRuns))
return nil
}
}

func testAwsJobResourceNewCluster() string {
return fmt.Sprintf(`
resource "databricks_job" "my_job" {
new_cluster {
autoscale {
min_workers = 2
max_workers = 3
}
spark_version = "6.4.x-scala2.11"
aws_attributes {
availability = "SPOT"
zone_id = "us-east-1a"
spot_bid_price_percent = "100"
first_on_demand = 1
ebs_volume_type = "GENERAL_PURPOSE_SSD"
ebs_volume_count = 1
ebs_volume_size = 32
}
node_type_id = "r3.xlarge"
}
notebook_path = "/Users/[email protected]/my-demo-notebook"
name = "my-demo-notebook"
timeout_seconds = 3600
max_retries = 1
max_concurrent_runs = 1
}
`)
return `
resource "databricks_job" "my_job" {
new_cluster {
autoscale {
min_workers = 2
max_workers = 3
}
spark_version = "6.4.x-scala2.11"
aws_attributes {
availability = "SPOT"
zone_id = "us-east-1a"
spot_bid_price_percent = "100"
first_on_demand = 1
ebs_volume_type = "GENERAL_PURPOSE_SSD"
ebs_volume_count = 1
ebs_volume_size = 32
}
node_type_id = "r3.xlarge"
}
notebook_path = "/Users/[email protected]/my-demo-notebook"
name = "my-demo-notebook"
timeout_seconds = 3600
max_retries = 1
max_concurrent_runs = 1
}
`
}
Loading