From 633d5cb10fba4d1cddaa462f25c45a14e3726296 Mon Sep 17 00:00:00 2001 From: Uri Baghin Date: Mon, 1 Aug 2022 11:24:40 +1000 Subject: [PATCH] python3Packages.jaxlib: fix darwin build --- .../python-modules/jaxlib/default.nix | 54 ++++++++++++++----- pkgs/top-level/python-packages.nix | 7 ++- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index 0fb0a183ebe43..54b4c36dc5a26 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -8,9 +8,11 @@ , binutils , buildBazelPackage , buildPythonPackage +, cctools , cython , fetchFromGitHub , git +, IOKit , jsoncpp , pybind11 , setuptools @@ -55,8 +57,11 @@ let homepage = "https://github.com/google/jax"; license = licenses.asl20; maintainers = with maintainers; [ ndl ]; - platforms = [ "x86_64-linux" "aarch64-darwin" "x86_64-darwin"]; - hydraPlatforms = ["x86_64-linux" ]; # Don't think anybody is checking the darwin builds + platforms = platforms.unix; + # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136 + # however even with that fix applied, it doesn't work for everyone: + # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129 + broken = stdenv.isAarch64; }; cudatoolkit_joined = symlinkJoin { @@ -117,6 +122,8 @@ let ] ++ lib.optionals cudaSupport [ cudatoolkit cudnn + ] ++ lib.optionals stdenv.isDarwin [ + IOKit ]; postPatch = '' @@ -201,9 +208,7 @@ let # Make sure Bazel knows about our configuration flags during fetching so that the # relevant dependencies can be downloaded. - bazelFetchFlags = bazel-build.bazelBuildFlags; - - bazelBuildFlags = [ + bazelFlags = [ "-c opt" ] ++ lib.optional (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [ "--config=avx_posix" @@ -211,6 +216,11 @@ let "--config=cuda" ] ++ lib.optional mklSupport [ "--config=mkl_open_source_only" + ] ++ lib.optionals stdenv.cc.isClang [ + # bazel depends on the compiler frontend automatically selecting these flags based on file + # extension but our clang doesn't. + # https://github.com/NixOS/nixpkgs/issues/150655 + "--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++" ]; fetchAttrs = { @@ -218,7 +228,7 @@ let if cudaSupport then "sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo=" else - "sha256-6acSbBNcUBw177HMVOmpV7pUfP1aFSe5cP6/zWFdGFo="; + "sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo="; }; buildAttrs = { @@ -226,8 +236,8 @@ let # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. # 1) Fix pybind11 include paths. - # 2) Force static protobuf linkage to prevent crashes on loading multiple extensions - # in the same python program due to duplicate protobuf DBs. + # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on + # loading multiple extensions in the same python program due to duplicate protobuf DBs. # 3) Patch python path in the compiler driver. # 4) Patch tensorflow sources to work with later versions of protobuf. See # https://github.com/google/jax/issues/9534. Note that this should be @@ -236,13 +246,25 @@ let for src in ./jaxlib/*.{cc,h}; do sed -i 's@include/pybind11@pybind11@g' $src done - sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD - sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \ --replace "status.message()" "std::string{status.message()}" + substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \ + --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool" + substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ + --replace "/usr/bin/libtool" "${cctools}/bin/libtool" '' + lib.optionalString cudaSupport '' patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl - ''; + '' + lib.optionalString stdenv.isDarwin '' + # Framework search paths aren't added by bintools hook + # https://github.com/NixOS/nixpkgs/pull/41914 + export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks" + '' + (if stdenv.cc.isGNU then '' + sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD + '' else if stdenv.cc.isClang then '' + sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD + '' else throw "Unsupported stdenv.cc: ${stdenv.cc}"); installPhase = '' ./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch} @@ -251,13 +273,21 @@ let inherit meta; }; + platformTag = + if stdenv.targetPlatform.isLinux then + "manylinux2010_${stdenv.targetPlatform.linuxArch}" + else if stdenv.system == "x86_64-darwin" then + "macosx_10_9_${stdenv.targetPlatform.linuxArch}" + else if stdenv.system == "aarch64-darwin" then + "macosx_11_0_${stdenv.targetPlatform.linuxArch}" + else throw "Unsupported target platform: ${stdenv.targetPlatform}"; in buildPythonPackage { inherit meta pname version; format = "wheel"; - src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl"; + src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-${platformTag}.whl"; # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index e104a8d576869..8c7c4935d2646 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -4558,12 +4558,17 @@ in { cudaPackages = pkgs.cudaPackages_11_6; }; - jaxlib-build = callPackage ../development/python-modules/jaxlib { + jaxlib-build = callPackage ../development/python-modules/jaxlib rec { + inherit (pkgs.darwin) cctools; + buildBazelPackage = pkgs.buildBazelPackage.override { + stdenv = if stdenv.isDarwin then pkgs.darwin.apple_sdk_11_0.stdenv else stdenv; + }; # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. cudaSupport = pkgs.config.cudaSupport or false; # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we # pin to `cudaPackages_11_6` instead. cudaPackages = pkgs.cudaPackages_11_6; + IOKit = pkgs.darwin.apple_sdk_11_0.IOKit; }; jaxlib = self.jaxlib-build;