Skip to content

Commit

Permalink
go: analyze imports paths by module to enable multiple go_mod targets (
Browse files Browse the repository at this point in the history
…pantsbuild#16386)

The package mapping from import path to addresses (`ImportPathToPackages`) was global to the entire repository and not split by Go module. This prevented using multiple Go modules in a single repository. This PR solves the issue by introducing `GoImportPathMappingRequest` which allows the package mapping to be requested per module.

That per-module mapping relies on a repository-wide mapping available as `AllGoModuleImportPathsMappings`. The `GoModuleImportPathsMappingsHook` union allows plugins to provide their own import path mappings. For example, this support is now used by the protobuf/go codegen backend to supply import paths from generated protobuf code, meaning that the Go backend is able to infer dependencies between Go code and protobuf code automatically.

Fixes pantsbuild#13114.

[ci skip-rust]
  • Loading branch information
Tom Dyas committed Sep 8, 2022
1 parent 7270296 commit 857371b
Show file tree
Hide file tree
Showing 13 changed files with 787 additions and 278 deletions.
180 changes: 85 additions & 95 deletions src/python/pants/backend/codegen/protobuf/go/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,18 @@
AllProtobufTargets,
ProtobufGrpcToggleField,
ProtobufSourceField,
ProtobufSourcesGeneratorTarget,
ProtobufSourceTarget,
)
from pants.backend.go import target_type_rules
from pants.backend.go.target_type_rules import ImportPathToPackages
from pants.backend.go.target_types import GoPackageSourcesField
from pants.backend.go.dependency_inference import (
GoImportPathsMappingAddressSet,
GoModuleImportPathsMapping,
GoModuleImportPathsMappings,
GoModuleImportPathsMappingsHook,
)
from pants.backend.go.target_type_rules import GoImportPathMappingRequest
from pants.backend.go.target_types import GoOwningGoModAddressField, GoPackageSourcesField
from pants.backend.go.util_rules import (
assembly,
build_pkg,
Expand All @@ -36,10 +44,8 @@
BuildGoPackageTargetRequest,
GoCodegenBuildRequest,
)
from pants.backend.go.util_rules.first_party_pkg import (
FallibleFirstPartyPkgAnalysis,
FirstPartyPkgAnalysisRequest,
)
from pants.backend.go.util_rules.first_party_pkg import FallibleFirstPartyPkgAnalysis
from pants.backend.go.util_rules.go_mod import OwningGoMod, OwningGoModRequest
from pants.backend.go.util_rules.pkg_analyzer import PackageAnalyzerSetup
from pants.backend.go.util_rules.sdk import GoSdkProcess
from pants.backend.python.util_rules import pex
Expand All @@ -65,13 +71,10 @@
from pants.engine.process import FallibleProcessResult, Process, ProcessResult
from pants.engine.rules import collect_rules, rule
from pants.engine.target import (
FieldSet,
GeneratedSources,
GenerateSourcesRequest,
HydratedSources,
HydrateSourcesRequest,
InferDependenciesRequest,
InferredDependencies,
SourcesPaths,
SourcesPathsRequest,
TransitiveTargets,
Expand Down Expand Up @@ -118,17 +121,15 @@ def parse_go_package_option(content_raw: bytes) -> str | None:
return None


@dataclass(frozen=True)
class GoProtobufImportPathMapping:
"""Maps import paths of Go Protobuf packages to the addresses."""

mapping: FrozenDict[str, tuple[Address, ...]]
class ProtobufGoModuleImportPathsMappingsHook(GoModuleImportPathsMappingsHook):
pass


@rule(desc="Map import paths for all Go Protobuf targets.", level=LogLevel.DEBUG)
async def map_import_paths_of_all_go_protobuf_targets(
targets: AllProtobufTargets,
) -> GoProtobufImportPathMapping:
_request: ProtobufGoModuleImportPathsMappingsHook,
all_protobuf_targets: AllProtobufTargets,
) -> GoModuleImportPathsMappings:
sources = await MultiGet(
Get(
HydratedSources,
Expand All @@ -138,28 +139,57 @@ async def map_import_paths_of_all_go_protobuf_targets(
enable_codegen=True,
),
)
for tgt in targets
for tgt in all_protobuf_targets
)

all_contents = await MultiGet(
Get(DigestContents, Digest, source.snapshot.digest) for source in sources
)

go_protobuf_targets: dict[str, set[Address]] = defaultdict(set)
for tgt, contents in zip(targets, all_contents):
go_protobuf_mapping_metadata = []
owning_go_mod_gets = []
for tgt, contents in zip(all_protobuf_targets, all_contents):
if not contents:
continue
if len(contents) > 1:
raise AssertionError(
f"Protobuf target `{tgt.address}` mapped to more than one source file."
)

import_path = parse_go_package_option(contents[0].content)
if not import_path:
continue
go_protobuf_targets[import_path].add(tgt.address)

return GoProtobufImportPathMapping(
FrozenDict({ip: tuple(addrs) for ip, addrs in go_protobuf_targets.items()})
owning_go_mod_gets.append(Get(OwningGoMod, OwningGoModRequest(tgt.address)))
go_protobuf_mapping_metadata.append((import_path, tgt.address))

owning_go_mod_targets = await MultiGet(owning_go_mod_gets)

import_paths_by_module: dict[Address, dict[str, set[Address]]] = defaultdict(
lambda: defaultdict(set)
)

for owning_go_mod, (import_path, address) in zip(
owning_go_mod_targets, go_protobuf_mapping_metadata
):
import_paths_by_module[owning_go_mod.address][import_path].add(address)

return GoModuleImportPathsMappings(
FrozenDict(
{
go_mod_addr: GoModuleImportPathsMapping(
mapping=FrozenDict(
{
import_path: GoImportPathsMappingAddressSet(
addresses=tuple(sorted(addresses)), infer_all=True
)
for import_path, addresses in import_path_mapping.items()
}
),
)
for go_mod_addr, import_path_mapping in import_paths_by_module.items()
}
)
)


Expand All @@ -182,8 +212,6 @@ async def setup_full_package_build_request(
request: _SetupGoProtobufPackageBuildRequest,
protoc: Protoc,
go_protoc_plugin: _SetupGoProtocPlugin,
package_mapping: ImportPathToPackages,
go_protobuf_mapping: GoProtobufImportPathMapping,
analyzer: PackageAnalyzerSetup,
) -> FallibleBuildGoPackageRequest:
output_dir = "_generated_files"
Expand All @@ -196,6 +224,11 @@ async def setup_full_package_build_request(
Get(Digest, CreateDigest([Directory(output_dir)])),
)

go_mod_addr = await Get(OwningGoMod, OwningGoModRequest(transitive_targets.roots[0].address))
package_mapping = await Get(
GoModuleImportPathsMapping, GoImportPathMappingRequest(go_mod_addr.address)
)

all_sources = await Get(
SourceFiles,
SourceFilesRequest(
Expand Down Expand Up @@ -317,25 +350,23 @@ async def setup_full_package_build_request(
candidate_addresses = package_mapping.mapping.get(dep_import_path)
if candidate_addresses:
# TODO: Use explicit dependencies to disambiguate? This should never happen with Go backend though.
if len(candidate_addresses) > 1:
return FallibleBuildGoPackageRequest(
request=None,
import_path=request.import_path,
exit_code=result.exit_code,
stderr=textwrap.dedent(
f"""
Multiple addresses match import of `{dep_import_path}`.
addresses: {', '.join(str(a) for a in candidate_addresses)}
"""
).strip(),
)
dep_build_request_addrs.extend(candidate_addresses)

# Infer dependencies on other generated Go sources.
go_protobuf_candidate_addresses = go_protobuf_mapping.mapping.get(dep_import_path)
if go_protobuf_candidate_addresses:
dep_build_request_addrs.extend(go_protobuf_candidate_addresses)
if candidate_addresses.infer_all:
dep_build_request_addrs.extend(candidate_addresses.addresses)
else:
if len(candidate_addresses.addresses) > 1:
return FallibleBuildGoPackageRequest(
request=None,
import_path=request.import_path,
exit_code=result.exit_code,
stderr=textwrap.dedent(
f"""
Multiple addresses match import of `{dep_import_path}`.
addresses: {', '.join(str(a) for a in candidate_addresses.addresses)}
"""
).strip(),
)
dep_build_request_addrs.extend(candidate_addresses.addresses)

dep_build_requests = await MultiGet(
Get(BuildGoPackageRequest, BuildGoPackageTargetRequest(addr))
Expand All @@ -359,7 +390,6 @@ async def setup_full_package_build_request(
@rule
async def setup_build_go_package_request_for_protobuf(
request: GoCodegenBuildProtobufRequest,
protobuf_package_mapping: GoProtobufImportPathMapping,
) -> FallibleBuildGoPackageRequest:
# Hydrate the protobuf source to parse for the Go import path.
sources = await Get(HydratedSources, HydrateSourcesRequest(request.target[ProtobufSourceField]))
Expand All @@ -374,10 +404,15 @@ async def setup_build_go_package_request_for_protobuf(
stderr=f"No import path was set in Protobuf file via `option go_package` directive for {request.target.address}.",
)

go_mod_addr = await Get(OwningGoMod, OwningGoModRequest(request.target.address))
package_mapping = await Get(
GoModuleImportPathsMapping, GoImportPathMappingRequest(go_mod_addr.address)
)

# Request the full build of the package. This indirection is necessary so that requests for two or more
# Protobuf files in the same Go package result in a single cacheable rule invocation.
protobuf_target_addrs_for_import_path = protobuf_package_mapping.mapping.get(import_path)
if not protobuf_target_addrs_for_import_path:
protobuf_target_addrs_set_for_import_path = package_mapping.mapping.get(import_path)
if not protobuf_target_addrs_set_for_import_path:
return FallibleBuildGoPackageRequest(
request=None,
import_path=import_path,
Expand All @@ -393,7 +428,7 @@ async def setup_build_go_package_request_for_protobuf(
return await Get(
FallibleBuildGoPackageRequest,
_SetupGoProtobufPackageBuildRequest(
addresses=protobuf_target_addrs_for_import_path,
addresses=protobuf_target_addrs_set_for_import_path.addresses,
import_path=import_path,
),
)
Expand Down Expand Up @@ -589,59 +624,14 @@ async def setup_go_protoc_plugin(platform: Platform) -> _SetupGoProtocPlugin:
return _SetupGoProtocPlugin(plugin_digest)


@dataclass(frozen=True)
class GoProtobufDependenciesInferenceFieldSet(FieldSet):
required_fields = (GoPackageSourcesField,)

sources: GoPackageSourcesField


class InferGoProtobufDependenciesRequest(InferDependenciesRequest):
infer_from = GoProtobufDependenciesInferenceFieldSet


@rule(
desc="Infer dependencies on Protobuf sources for first-party Go packages", level=LogLevel.DEBUG
)
async def infer_go_dependencies(
request: InferGoProtobufDependenciesRequest,
go_protobuf_mapping: GoProtobufImportPathMapping,
) -> InferredDependencies:
address = request.field_set.address
maybe_pkg_analysis = await Get(
FallibleFirstPartyPkgAnalysis, FirstPartyPkgAnalysisRequest(address)
)
if maybe_pkg_analysis.analysis is None:
_logger.error(
softwrap(
f"""
Failed to analyze {maybe_pkg_analysis.import_path} for dependency inference:
{maybe_pkg_analysis.stderr}
"""
)
)
return InferredDependencies([])
pkg_analysis = maybe_pkg_analysis.analysis

inferred_dependencies: list[Address] = []
for import_path in (
*pkg_analysis.imports,
*pkg_analysis.test_imports,
*pkg_analysis.xtest_imports,
):
candidate_addresses = go_protobuf_mapping.mapping.get(import_path, ())
inferred_dependencies.extend(candidate_addresses)

return InferredDependencies(inferred_dependencies)


def rules():
return (
*collect_rules(),
UnionRule(GenerateSourcesRequest, GenerateGoFromProtobufRequest),
UnionRule(GoCodegenBuildRequest, GoCodegenBuildProtobufRequest),
UnionRule(InferDependenciesRequest, InferGoProtobufDependenciesRequest),
UnionRule(GoModuleImportPathsMappingsHook, ProtobufGoModuleImportPathsMappingsHook),
ProtobufSourcesGeneratorTarget.register_plugin_field(GoOwningGoModAddressField),
ProtobufSourceTarget.register_plugin_field(GoOwningGoModAddressField),
# Rules needed for this to pass src/python/pants/init/load_backends_integration_test.py:
*assembly.rules(),
*build_pkg.rules(),
Expand Down
2 changes: 2 additions & 0 deletions src/python/pants/backend/experimental/go/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from pants.backend.go.util_rules import (
assembly,
binary,
build_pkg,
build_pkg_target,
coverage,
Expand All @@ -38,6 +39,7 @@ def target_types():
def rules():
return [
*assembly.rules(),
*binary.rules(),
*build_pkg.rules(),
*build_pkg_target.rules(),
*check.rules(),
Expand Down
49 changes: 49 additions & 0 deletions src/python/pants/backend/go/dependency_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from __future__ import annotations

from dataclasses import dataclass

from pants.build_graph.address import Address
from pants.engine.environment import EnvironmentName
from pants.engine.unions import union
from pants.util.frozendict import FrozenDict


@union(in_scope_types=[EnvironmentName])
@dataclass(frozen=True)
class GoModuleImportPathsMappingsHook:
"""An entry point for a specific implementation of mapping Go import paths to owning targets.
All implementations will be merged together. The core Go dependency inference rules will request
the `GoModuleImportPathsMappings` type using implementations of this union.
"""


@dataclass(frozen=True)
class GoImportPathsMappingAddressSet:
addresses: tuple[Address, ...]
infer_all: bool


@dataclass(frozen=True)
class GoModuleImportPathsMapping:
"""Maps import paths (as strings) to one or more addresses of targets providing those import
path(s) for a single Go module."""

mapping: FrozenDict[str, GoImportPathsMappingAddressSet]


@dataclass(frozen=True)
class GoModuleImportPathsMappings:
"""Import path mappings for all Go modules in the repository.
This type is requested from plugins which provide implementations for the GoCodegenBuildRequest
union and then merged.
"""

modules: FrozenDict[Address, GoModuleImportPathsMapping]


class AllGoModuleImportPathsMappings(GoModuleImportPathsMappings):
pass
Loading

0 comments on commit 857371b

Please sign in to comment.