Skip to content

Commit

Permalink
Merge pull request #184395 from uri-canva/jaxlib-darwin
Browse files Browse the repository at this point in the history
python3Packages.jaxlib: fix darwin build
  • Loading branch information
samuela authored Aug 18, 2022
2 parents 802ea45 + 633d5cb commit 5b71d23
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
54 changes: 42 additions & 12 deletions pkgs/development/python-modules/jaxlib/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
, binutils
, buildBazelPackage
, buildPythonPackage
, cctools
, cython
, fetchFromGitHub
, git
, IOKit
, jsoncpp
, pybind11
, setuptools
Expand Down Expand Up @@ -55,8 +57,11 @@ let
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ ndl ];
platforms = [ "x86_64-linux" "aarch64-darwin" "x86_64-darwin"];
hydraPlatforms = ["x86_64-linux" ]; # Don't think anybody is checking the darwin builds
platforms = platforms.unix;
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
# however even with that fix applied, it doesn't work for everyone:
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
broken = stdenv.isAarch64;
};

cudatoolkit_joined = symlinkJoin {
Expand Down Expand Up @@ -117,6 +122,8 @@ let
] ++ lib.optionals cudaSupport [
cudatoolkit
cudnn
] ++ lib.optionals stdenv.isDarwin [
IOKit
];

postPatch = ''
Expand Down Expand Up @@ -201,33 +208,36 @@ let

# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
bazelFetchFlags = bazel-build.bazelBuildFlags;

bazelBuildFlags = [
bazelFlags = [
"-c opt"
] ++ lib.optional (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
"--config=avx_posix"
] ++ lib.optional cudaSupport [
"--config=cuda"
] ++ lib.optional mklSupport [
"--config=mkl_open_source_only"
] ++ lib.optionals stdenv.cc.isClang [
# bazel depends on the compiler frontend automatically selecting these flags based on file
# extension but our clang doesn't.
# https://github.com/NixOS/nixpkgs/issues/150655
"--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
];

fetchAttrs = {
sha256 =
if cudaSupport then
"sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo="
else
"sha256-6acSbBNcUBw177HMVOmpV7pUfP1aFSe5cP6/zWFdGFo=";
"sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo=";
};

buildAttrs = {
outputs = [ "out" ];

# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
# 1) Fix pybind11 include paths.
# 2) Force static protobuf linkage to prevent crashes on loading multiple extensions
# in the same python program due to duplicate protobuf DBs.
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
# 3) Patch python path in the compiler driver.
# 4) Patch tensorflow sources to work with later versions of protobuf. See
# https://github.com/google/jax/issues/9534. Note that this should be
Expand All @@ -236,13 +246,25 @@ let
for src in ./jaxlib/*.{cc,h}; do
sed -i 's@include/pybind11@pybind11@g' $src
done
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
--replace "status.message()" "std::string{status.message()}"
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
--replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
--replace "/usr/bin/libtool" "${cctools}/bin/libtool"
'' + lib.optionalString cudaSupport ''
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'';
'' + lib.optionalString stdenv.isDarwin ''
# Framework search paths aren't added by bintools hook
# https://github.com/NixOS/nixpkgs/pull/41914
export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
'' + (if stdenv.cc.isGNU then ''
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
'' else if stdenv.cc.isClang then ''
sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
'' else throw "Unsupported stdenv.cc: ${stdenv.cc}");

installPhase = ''
./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch}
Expand All @@ -251,13 +273,21 @@ let

inherit meta;
};
platformTag =
if stdenv.targetPlatform.isLinux then
"manylinux2010_${stdenv.targetPlatform.linuxArch}"
else if stdenv.system == "x86_64-darwin" then
"macosx_10_9_${stdenv.targetPlatform.linuxArch}"
else if stdenv.system == "aarch64-darwin" then
"macosx_11_0_${stdenv.targetPlatform.linuxArch}"
else throw "Unsupported target platform: ${stdenv.targetPlatform}";

in
buildPythonPackage {
inherit meta pname version;
format = "wheel";

src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-${platformTag}.whl";

# Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
Expand Down
7 changes: 6 additions & 1 deletion pkgs/top-level/python-packages.nix
Original file line number Diff line number Diff line change
Expand Up @@ -4594,12 +4594,17 @@ in {
cudaPackages = pkgs.cudaPackages_11_6;
};

jaxlib-build = callPackage ../development/python-modules/jaxlib {
jaxlib-build = callPackage ../development/python-modules/jaxlib rec {
inherit (pkgs.darwin) cctools;
buildBazelPackage = pkgs.buildBazelPackage.override {
stdenv = if stdenv.isDarwin then pkgs.darwin.apple_sdk_11_0.stdenv else stdenv;
};
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
cudaSupport = pkgs.config.cudaSupport or false;
# At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
# pin to `cudaPackages_11_6` instead.
cudaPackages = pkgs.cudaPackages_11_6;
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
};

jaxlib = self.jaxlib-build;
Expand Down

0 comments on commit 5b71d23

Please sign in to comment.