From 8a9fed3044a55b3d639f12d3f0db2f3d89bb1294 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 30 Aug 2024 13:58:30 -0400 Subject: [PATCH] feat: start adding interface to override (#834) --- crates/rattler-bin/src/commands/create.rs | 18 +- .../src/commands/virtual_packages.rs | 2 +- crates/rattler_virtual_packages/src/lib.rs | 238 +++++++++++++----- crates/rattler_virtual_packages/src/libc.rs | 2 +- py-rattler/examples/solve_and_install.py | 2 +- py-rattler/rattler/__init__.py | 4 +- .../rattler/virtual_package/__init__.py | 4 +- .../virtual_package/virtual_package.py | 173 ++++++++++++- py-rattler/src/lib.rs | 4 +- py-rattler/src/virtual_package.rs | 135 +++++++++- py-rattler/tests/unit/test_override.py | 30 +++ 11 files changed, 535 insertions(+), 77 deletions(-) create mode 100644 py-rattler/tests/unit/test_override.py diff --git a/crates/rattler-bin/src/commands/create.rs b/crates/rattler-bin/src/commands/create.rs index 5b3dcf305..7c840a77a 100644 --- a/crates/rattler-bin/src/commands/create.rs +++ b/crates/rattler-bin/src/commands/create.rs @@ -193,14 +193,16 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { }) .collect::>>()?) } else { - rattler_virtual_packages::VirtualPackage::current() - .map(|vpkgs| { - vpkgs - .iter() - .map(|vpkg| GenericVirtualPackage::from(vpkg.clone())) - .collect::>() - }) - .map_err(anyhow::Error::from) + rattler_virtual_packages::VirtualPackage::detect( + &rattler_virtual_packages::VirtualPackageOverrides::default(), + ) + .map(|vpkgs| { + vpkgs + .iter() + .map(|vpkg| GenericVirtualPackage::from(vpkg.clone())) + .collect::>() + }) + .map_err(anyhow::Error::from) } })?; diff --git a/crates/rattler-bin/src/commands/virtual_packages.rs b/crates/rattler-bin/src/commands/virtual_packages.rs index 719e62970..5e18b3445 100644 --- a/crates/rattler-bin/src/commands/virtual_packages.rs +++ b/crates/rattler-bin/src/commands/virtual_packages.rs @@ -4,7 +4,7 @@ use rattler_conda_types::GenericVirtualPackage; pub struct Opt {} pub fn virtual_packages(_opt: Opt) -> anyhow::Result<()> { - let virtual_packages = rattler_virtual_packages::VirtualPackage::current()?; + let virtual_packages = rattler_virtual_packages::VirtualPackage::detect(&Default::default())?; for package in virtual_packages { println!("{}", GenericVirtualPackage::from(package.clone())); } diff --git a/crates/rattler_virtual_packages/src/lib.rs b/crates/rattler_virtual_packages/src/lib.rs index c9d5219bc..52349a680 100644 --- a/crates/rattler_virtual_packages/src/lib.rs +++ b/crates/rattler_virtual_packages/src/lib.rs @@ -34,7 +34,6 @@ pub mod linux; pub mod osx; use archspec::cpu::Microarchitecture; -use once_cell::sync::OnceCell; use rattler_conda_types::{ GenericVirtualPackage, PackageName, ParseVersionError, Platform, Version, }; @@ -48,35 +47,82 @@ use libc::DetectLibCError; use linux::ParseLinuxVersionError; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +/// Configure the overrides used in in this crate. +#[derive(Clone, Debug, PartialEq, Default)] +pub enum Override { + /// Use the default override env var name + #[default] + DefaultEnvVar, + /// Use custom env var name + EnvVar(String), + /// Use a custom override directly + String(String), + /// Disable overrides + None, +} + /// Traits for overridable virtual packages -/// Use as `Cuda::from_default_env_var.unwrap_or(Cuda::current().into()).unwrap()` +/// Use as `Cuda::detect(override)` pub trait EnvOverride: Sized { /// Parse `env_var_value` - fn from_env_var_name_with_var( - env_var_name: &str, - env_var_value: &str, - ) -> Result; + fn parse_version(value: &str) -> Result; - /// Read the environment variable and if it exists, try to parse it with [`EnvOverride::from_env_var_name_with_var`] + /// Helper to convert the output of `parse_version` and handling empty strings. + fn parse_version_opt(value: &str) -> Result, DetectVirtualPackageError> { + if value.is_empty() { + Ok(None) + } else { + Ok(Some(Self::parse_version(value)?)) + } + } + + /// Read the environment variable and if it exists, try to parse it with [`EnvOverride::parse_version`] /// If the output is: /// - `None`, then the environment variable did not exist, /// - `Some(Err(None))`, then the environment variable exist but was set to zero, so the package should be disabled /// - `Some(Ok(pkg))`, then the override was for the package. - fn from_env_var_name(env_var_name: &str) -> Option>> { - let var = env::var(env_var_name).ok()?; - if var.is_empty() { - Some(Err(None)) - } else { - Some(Self::from_env_var_name_with_var(env_var_name, &var).map_err(Some)) + fn from_env_var_name_or( + env_var_name: &str, + fallback: F, + ) -> Result, DetectVirtualPackageError> + where + F: FnOnce() -> Result, DetectVirtualPackageError>, + { + match env::var(env_var_name) { + Ok(var) => Self::parse_version_opt(&var), + Err(env::VarError::NotPresent) => fallback(), + Err(e) => Err(DetectVirtualPackageError::VarError(e)), } } /// Default name of the environment variable that overrides the virtual package. const DEFAULT_ENV_NAME: &'static str; - /// Shortcut for `EnvOverride::from_env_var_name(EnvOverride::DEFAULT_ENV_NAME)`. - fn from_default_env_var() -> Option>> { - Self::from_env_var_name(Self::DEFAULT_ENV_NAME) + /// Detect the virutal package for the current system. + /// This method is here so that `::current` always returns the same error type. + /// `current` may return different types of errors depending on the virtual package. This one always returns + /// `DetectVirtualPackageError`. + fn detect_from_host() -> Result, DetectVirtualPackageError>; + + /// Apply the override to the current virtual package. If the override is `None` then use the fallback + fn detect_with_fallback( + ov: &Override, + fallback: F, + ) -> Result, DetectVirtualPackageError> + where + F: FnOnce() -> Result, DetectVirtualPackageError>, + { + match ov { + Override::None => fallback(), + Override::String(str) => Self::parse_version_opt(str), + Override::DefaultEnvVar => Self::from_env_var_name_or(Self::DEFAULT_ENV_NAME, fallback), + Override::EnvVar(name) => Self::from_env_var_name_or(name, fallback), + } + } + + /// Shortcut for `Self::detect_with_fallback` with `Self::detect_from_host` as fallback + fn detect(ov: &Override) -> Result, DetectVirtualPackageError> { + Self::detect_with_fallback(ov, Self::detect_from_host) } } @@ -86,19 +132,19 @@ pub enum VirtualPackage { /// Available on windows Win, - /// Available on unix based platforms + /// Available on `Unix` based platforms Unix, - /// Available when running on Linux + /// Available when running on `Linux`` Linux(Linux), - /// Available when running on OSX + /// Available when running on `OSX` Osx(Osx), - /// Available LibC family and version + /// Available `LibC` family and version LibC(LibC), - /// Available Cuda version + /// Available `Cuda` version Cuda(Cuda), /// The CPU architecture @@ -130,11 +176,19 @@ impl From for GenericVirtualPackage { impl VirtualPackage { /// Returns virtual packages detected for the current system or an error if the versions could /// not be properly detected. - pub fn current() -> Result<&'static [Self], DetectVirtualPackageError> { - static DETECTED_VIRTUAL_PACKAGES: OnceCell> = OnceCell::new(); - DETECTED_VIRTUAL_PACKAGES - .get_or_try_init(try_detect_virtual_packages) - .map(Vec::as_slice) + #[deprecated( + since = "1.0.4", + note = "Use `Self::detect(&VirtualPackageOverrides::none())` instead." + )] + pub fn current() -> Result, DetectVirtualPackageError> { + try_detect_virtual_packages_with_overrides(&VirtualPackageOverrides::none()) + } + + /// Detect the virtual packages of the current system with the given overrides. + pub fn detect( + overrides: &VirtualPackageOverrides, + ) -> Result, DetectVirtualPackageError> { + try_detect_virtual_packages_with_overrides(overrides) } } @@ -150,10 +204,39 @@ pub enum DetectVirtualPackageError { #[error(transparent)] DetectLibC(#[from] DetectLibCError), + + #[error(transparent)] + VarError(#[from] env::VarError), + + #[error(transparent)] + VersionParseError(#[from] ParseVersionError), +} +/// Configure the overrides used in this crate. +#[derive(Default, Clone, Debug)] +pub struct VirtualPackageOverrides { + /// The override for the osx virtual package + pub osx: Override, + /// The override for the libc virtual package + pub libc: Override, + /// The override for the cuda virtual package + pub cuda: Override, +} + +impl VirtualPackageOverrides { + /// Disable all overrides + pub fn none() -> Self { + Self { + osx: Override::None, + libc: Override::None, + cuda: Override::None, + } + } } // Detect the available virtual packages on the system -fn try_detect_virtual_packages() -> Result, DetectVirtualPackageError> { +fn try_detect_virtual_packages_with_overrides( + overrides: &VirtualPackageOverrides, +) -> Result, DetectVirtualPackageError> { let mut result = Vec::new(); let platform = Platform::current(); @@ -169,18 +252,18 @@ fn try_detect_virtual_packages() -> Result, DetectVirtualPac if let Some(linux_version) = Linux::current()? { result.push(linux_version.into()); } - if let Some(libc) = LibC::current()? { + if let Some(libc) = LibC::detect(&overrides.libc)? { result.push(libc.into()); } } if platform.is_osx() { - if let Some(osx) = Osx::current()? { + if let Some(osx) = Osx::detect(&overrides.osx)? { result.push(osx.into()); } } - if let Some(cuda) = Cuda::current() { + if let Some(cuda) = Cuda::detect(&overrides.cuda)? { result.push(cuda.into()); } @@ -233,7 +316,7 @@ impl From for Linux { /// `LibC` virtual package description #[derive(Clone, Eq, PartialEq, Hash, Debug, Deserialize)] pub struct LibC { - /// The family of LibC. This could be glibc for instance. + /// The family of `LibC`. This could be glibc for instance. pub family: String, /// The version of the libc distribution. @@ -274,15 +357,16 @@ impl From for VirtualPackage { impl EnvOverride for LibC { const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_GLIBC"; - fn from_env_var_name_with_var( - _env_var_name: &str, - env_var_value: &str, - ) -> Result { + fn parse_version(env_var_value: &str) -> Result { Version::from_str(env_var_value).map(|version| Self { family: "glibc".into(), version, }) } + + fn detect_from_host() -> Result, DetectVirtualPackageError> { + Ok(Self::current()?) + } } /// Cuda virtual package description @@ -306,13 +390,12 @@ impl From for Cuda { } impl EnvOverride for Cuda { - fn from_env_var_name_with_var( - _env_var_name: &str, - env_var_value: &str, - ) -> Result { + fn parse_version(env_var_value: &str) -> Result { Version::from_str(env_var_value).map(|version| Self { version }) } - + fn detect_from_host() -> Result, DetectVirtualPackageError> { + Ok(Self::current()) + } const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_CUDA"; } @@ -483,13 +566,12 @@ impl From for Osx { } impl EnvOverride for Osx { - fn from_env_var_name_with_var( - _env_var_name: &str, - env_var_value: &str, - ) -> Result { + fn parse_version(env_var_value: &str) -> Result { Version::from_str(env_var_value).map(|version| Self { version }) } - + fn detect_from_host() -> Result, DetectVirtualPackageError> { + Ok(Self::current()?) + } const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_OSX"; } @@ -504,11 +586,12 @@ mod test { use crate::EnvOverride; use crate::LibC; use crate::Osx; + use crate::Override; use crate::VirtualPackage; #[test] fn doesnt_crash() { - let virtual_packages = VirtualPackage::current().unwrap(); + let virtual_packages = VirtualPackage::detect(&Default::default()).unwrap(); println!("{virtual_packages:?}"); } #[test] @@ -518,12 +601,32 @@ mod test { version: Version::from_str(v).unwrap(), family: "glibc".into(), }; - env::set_var(LibC::DEFAULT_ENV_NAME, v); - assert_eq!(LibC::from_default_env_var(), Some(Ok(res))); - env::set_var(LibC::DEFAULT_ENV_NAME, ""); - assert_eq!(LibC::from_default_env_var(), Some(Err(None))); - env::remove_var(LibC::DEFAULT_ENV_NAME); - assert_eq!(LibC::from_default_env_var(), None); + let env_var_name = format!("{}_{}", LibC::DEFAULT_ENV_NAME, "12345511231"); + env::set_var(env_var_name.clone(), v); + assert_eq!( + LibC::detect(&Override::EnvVar(env_var_name.clone())) + .unwrap() + .unwrap(), + res + ); + env::set_var(env_var_name.clone(), ""); + assert_eq!( + LibC::detect(&Override::EnvVar(env_var_name.clone())).unwrap(), + None + ); + env::remove_var(env_var_name.clone()); + assert_eq!( + LibC::detect_with_fallback(&Override::DefaultEnvVar, || Ok(Some(res.clone()))) + .unwrap() + .unwrap(), + res + ); + assert_eq!( + LibC::detect_with_fallback(&Override::String(v.to_string()), || Ok(None)) + .unwrap() + .unwrap(), + res + ); } #[test] @@ -532,8 +635,25 @@ mod test { let res = Cuda { version: Version::from_str(v).unwrap(), }; - env::set_var(Cuda::DEFAULT_ENV_NAME, v); - assert_eq!(Cuda::from_default_env_var(), Some(Ok(res))); + let env_var_name = format!("{}_{}", Cuda::DEFAULT_ENV_NAME, "12345511231"); + env::set_var(env_var_name.clone(), v); + assert_eq!( + Cuda::detect(&Override::EnvVar(env_var_name.clone())) + .unwrap() + .unwrap(), + res + ); + assert_eq!( + Cuda::detect(&Override::None).map_err(|_x| 1), + ::detect_from_host().map_err(|_x| 1) + ); + env::remove_var(env_var_name.clone()); + assert_eq!( + Cuda::detect(&Override::String(v.to_string())) + .unwrap() + .unwrap(), + res + ); } #[test] @@ -542,7 +662,13 @@ mod test { let res = Osx { version: Version::from_str(v).unwrap(), }; - env::set_var(Osx::DEFAULT_ENV_NAME, v); - assert_eq!(Osx::from_default_env_var(), Some(Ok(res))); + let env_var_name = format!("{}_{}", Osx::DEFAULT_ENV_NAME, "12345511231"); + env::set_var(env_var_name.clone(), v); + assert_eq!( + Osx::detect(&Override::EnvVar(env_var_name.clone())) + .unwrap() + .unwrap(), + res + ); } } diff --git a/crates/rattler_virtual_packages/src/libc.rs b/crates/rattler_virtual_packages/src/libc.rs index e3248df34..30b8b4fb2 100644 --- a/crates/rattler_virtual_packages/src/libc.rs +++ b/crates/rattler_virtual_packages/src/libc.rs @@ -37,7 +37,7 @@ pub enum DetectLibCError { #[cfg(unix)] fn try_detect_libc_version() -> Result, DetectLibCError> { // Run `ldd --version` to detect the libc version and family on the system. - // `ldd` is shipped with libc so if an error occured during its execution we + // `ldd` is shipped with libc so if an error occurred during its execution we // can assume no libc is available on the system. let output = match std::process::Command::new("ldd").arg("--version").output() { Err(e) => { diff --git a/py-rattler/examples/solve_and_install.py b/py-rattler/examples/solve_and_install.py index 5cd7640c1..3f84c6f18 100644 --- a/py-rattler/examples/solve_and_install.py +++ b/py-rattler/examples/solve_and_install.py @@ -16,7 +16,7 @@ async def main() -> None: # The specs to solve for specs=["python ~=3.12.0", "pip", "requests 2.31.0"], # Virtual packages define the specifications of the environment - virtual_packages=VirtualPackage.current(), + virtual_packages=VirtualPackage.detect(), ) print("solved required dependencies") diff --git a/py-rattler/rattler/__init__.py b/py-rattler/rattler/__init__.py index 131aeadf0..11fae0431 100644 --- a/py-rattler/rattler/__init__.py +++ b/py-rattler/rattler/__init__.py @@ -11,7 +11,7 @@ ) from rattler.channel import Channel, ChannelConfig, ChannelPriority from rattler.networking import AuthenticatedClient, fetch_repo_data -from rattler.virtual_package import GenericVirtualPackage, VirtualPackage +from rattler.virtual_package import GenericVirtualPackage, VirtualPackage, VirtualPackageOverrides, Override from rattler.package import ( PackageName, AboutJson, @@ -58,6 +58,8 @@ "fetch_repo_data", "GenericVirtualPackage", "VirtualPackage", + "VirtualPackageOverrides", + "Override", "PackageName", "PrefixRecord", "PrefixPaths", diff --git a/py-rattler/rattler/virtual_package/__init__.py b/py-rattler/rattler/virtual_package/__init__.py index 2e96eca51..2ce68f629 100644 --- a/py-rattler/rattler/virtual_package/__init__.py +++ b/py-rattler/rattler/virtual_package/__init__.py @@ -1,4 +1,4 @@ from rattler.virtual_package.generic import GenericVirtualPackage -from rattler.virtual_package.virtual_package import VirtualPackage +from rattler.virtual_package.virtual_package import VirtualPackage, VirtualPackageOverrides, Override -__all__ = ["GenericVirtualPackage", "VirtualPackage"] +__all__ = ["GenericVirtualPackage", "VirtualPackage", "VirtualPackageOverrides", "Override"] diff --git a/py-rattler/rattler/virtual_package/virtual_package.py b/py-rattler/rattler/virtual_package/virtual_package.py index 5d8544438..a37da3452 100644 --- a/py-rattler/rattler/virtual_package/virtual_package.py +++ b/py-rattler/rattler/virtual_package/virtual_package.py @@ -1,11 +1,169 @@ from __future__ import annotations - -from rattler.rattler import PyVirtualPackage from typing import List +import warnings + +from rattler.rattler import PyVirtualPackage, PyOverride, PyVirtualPackageOverrides from rattler.virtual_package.generic import GenericVirtualPackage +class Override: + """ + Represents an override for a virtual package. + An override can be build using + - `Override.default_env_var()` for overriding the detection with the default environment variable, + - `Override.env_var(str)` for overriding the detection with a custom environment variable, + - `Override.string(str)` for passing the version directly, or + - `Override.none()` for disabling the override process all together. + """ + + _override: PyOverride + + @classmethod + def _from_py_override(cls, py_override: PyOverride) -> Override: + """Construct Rattler Override from FFI PyOverride object.""" + override = cls.__new__(cls) + override._override = py_override + return override + + @classmethod + def default_env_var(cls) -> Override: + """ + Returns a new instance to indicate that the default environment variable should overwrite the detected information from the host if specified. + """ + return cls._from_py_override(PyOverride.default_env_var()) + + @classmethod + def env_var(cls, env_var: str) -> Override: + """ + Returns the environment variable override for the given environment variable. + """ + return cls._from_py_override(PyOverride.env_var(env_var)) + + @classmethod + def string(cls, override: str) -> Override: + """ + Returns the override for the given string. + """ + return cls._from_py_override(PyOverride.string(override)) + + @classmethod + def none(cls) -> Override: + """ + Returns the override for None. + """ + return cls._from_py_override(PyOverride.none()) + + def __str__(self) -> str: + """ + Returns string representation of the Override. + """ + return self._override.as_str() + + def __repr__(self) -> str: + """ + Returns a representation of the Override. + """ + return f"Override({self._override.as_str()})" + + def __eq__(self, other: object) -> bool: + """ + Returns True if the Overrides are equal, False otherwise. + """ + if not isinstance(other, Override): + return NotImplemented + return self._override == other._override + + +class VirtualPackageOverrides: + _overrides: PyVirtualPackageOverrides + + @classmethod + def _from_py_virtual_package_overrides( + cls, py_virtual_package_overrides: PyVirtualPackageOverrides + ) -> VirtualPackageOverrides: + """Construct Rattler VirtualPackageOverrides from FFI PyVirtualPackageOverrides object.""" + virtual_package_overrides = cls.__new__(cls) + virtual_package_overrides._overrides = py_virtual_package_overrides + return virtual_package_overrides + + def __init__(self, osx: Override | None = None, libc: Override | None = None, cuda: Override | None = None) -> None: + """ + Returns the default virtual package overrides. + """ + self._overrides = PyVirtualPackageOverrides.default() + if osx is not None: + self.osx = osx + if libc is not None: + self.libc = libc + if cuda is not None: + self.cuda = cuda + + @classmethod + def none(cls) -> VirtualPackageOverrides: + """ + Returns the virtual package overrides for None. + """ + return cls._from_py_virtual_package_overrides(PyVirtualPackageOverrides.none()) + + @property + def osx(self) -> Override: + """ + Returns the OSX override. + """ + return Override._from_py_override(self._overrides.osx) + + @osx.setter + def osx(self, override: Override) -> VirtualPackageOverrides: + """ + Sets the OSX override. + """ + self._overrides.osx = override._override + return self._overrides.osx + + @property + def libc(self) -> Override: + """ + Returns the libc override. + """ + return Override._from_py_override(self._overrides.libc) + + @libc.setter + def libc(self, override: Override) -> VirtualPackageOverrides: + """ + Sets the libc override. + """ + self._overrides.libc = override._override + return self._overrides.libc + + @property + def cuda(self) -> Override: + """ + Returns the CUDA override. + """ + return Override._from_py_override(self._overrides.cuda) + + @cuda.setter + def cuda(self, override: Override) -> VirtualPackageOverrides: + """ + Sets the CUDA override. + """ + self._overrides.cuda = override._override + return self._overrides.cuda + + def __str__(self) -> str: + """ + Returns string representation of the VirtualPackageOverrides. + """ + return self._overrides.as_str() + + def __repr__(self) -> str: + """ + Returns a representation of the VirtualPackageOverrides. + """ + return f"VirtualPackageOverrides({self._overrides.as_str()})" + + class VirtualPackage: _virtual_package: PyVirtualPackage @@ -22,7 +180,16 @@ def current() -> List[VirtualPackage]: Returns virtual packages detected for the current system or an error if the versions could not be properly detected. """ - return [VirtualPackage._from_py_virtual_package(vp) for vp in PyVirtualPackage.current()] + warnings.warn("Use `detect` instead") + return VirtualPackage.detect(VirtualPackageOverrides.none()) + + @staticmethod + def detect(overrides: VirtualPackageOverrides | None = None) -> List[VirtualPackage]: + """ + Returns virtual packages detected for the current system with the given overrides. + """ + _overrides: VirtualPackageOverrides = overrides or VirtualPackageOverrides() + return [VirtualPackage._from_py_virtual_package(vp) for vp in PyVirtualPackage.detect(_overrides._overrides)] def into_generic(self) -> GenericVirtualPackage: """ diff --git a/py-rattler/src/lib.rs b/py-rattler/src/lib.rs index 570cf8cc4..3cfe48f3b 100644 --- a/py-rattler/src/lib.rs +++ b/py-rattler/src/lib.rs @@ -65,7 +65,7 @@ use run_exports_json::PyRunExportsJson; use shell::{PyActivationResult, PyActivationVariables, PyActivator, PyShellEnum}; use solver::{py_solve, py_solve_with_sparse_repodata}; use version::PyVersion; -use virtual_package::PyVirtualPackage; +use virtual_package::{PyOverride, PyVirtualPackage, PyVirtualPackageOverrides}; use crate::error::GatewayException; @@ -115,6 +115,8 @@ fn rattler(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(py_fetch_repo_data, m).unwrap()) .unwrap(); m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap(); diff --git a/py-rattler/src/virtual_package.rs b/py-rattler/src/virtual_package.rs index 6f2b2d59c..51975c142 100644 --- a/py-rattler/src/virtual_package.rs +++ b/py-rattler/src/virtual_package.rs @@ -1,8 +1,131 @@ use pyo3::{pyclass, pymethods, PyResult}; -use rattler_virtual_packages::VirtualPackage; +use rattler_virtual_packages::{Override, VirtualPackage, VirtualPackageOverrides}; use crate::{error::PyRattlerError, generic_virtual_package::PyGenericVirtualPackage}; +#[pyclass] +#[repr(transparent)] +#[derive(Clone, Default, PartialEq)] +pub struct PyOverride { + pub(crate) inner: Override, +} + +impl From for PyOverride { + fn from(value: Override) -> Self { + Self { inner: value } + } +} + +impl From for Override { + fn from(value: PyOverride) -> Self { + value.inner + } +} + +#[pymethods] +impl PyOverride { + #[staticmethod] + pub fn none() -> Self { + Self { + inner: Override::None, + } + } + + #[staticmethod] + pub fn default_env_var() -> Self { + Self { + inner: Override::DefaultEnvVar, + } + } + + #[staticmethod] + pub fn env_var(name: &str) -> Self { + Self { + inner: Override::EnvVar(name.to_string()), + } + } + + #[staticmethod] + pub fn string(value: &str) -> Self { + Self { + inner: Override::String(value.to_string()), + } + } + + pub fn as_str(&self) -> String { + format!("{:?}", self.inner) + } + + pub fn __eq__(&self, other: &Self) -> bool { + self.inner == other.inner + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyVirtualPackageOverrides { + pub(crate) inner: VirtualPackageOverrides, +} + +impl From for PyVirtualPackageOverrides { + fn from(value: VirtualPackageOverrides) -> Self { + Self { inner: value } + } +} + +impl From for VirtualPackageOverrides { + fn from(value: PyVirtualPackageOverrides) -> Self { + value.inner + } +} + +#[pymethods] +impl PyVirtualPackageOverrides { + #[staticmethod] + pub fn default() -> Self { + Self { + inner: VirtualPackageOverrides::default(), + } + } + + #[staticmethod] + pub fn none() -> Self { + Self { + inner: VirtualPackageOverrides::none(), + } + } + + pub fn as_str(&self) -> String { + format!("{:?}", self.inner) + } + + #[getter] + pub fn get_osx(&self) -> PyOverride { + self.inner.osx.clone().into() + } + #[setter] + pub fn set_osx(&mut self, value: PyOverride) { + self.inner.osx = value.into(); + } + #[getter] + pub fn get_cuda(&self) -> PyOverride { + self.inner.cuda.clone().into() + } + #[setter] + pub fn set_cuda(&mut self, value: PyOverride) { + self.inner.cuda = value.into(); + } + #[getter] + pub fn get_libc(&self) -> PyOverride { + self.inner.libc.clone().into() + } + #[setter] + pub fn set_libc(&mut self, value: PyOverride) { + self.inner.libc = value.into(); + } +} + #[pyclass] #[repr(transparent)] #[derive(Clone)] @@ -21,14 +144,20 @@ impl From for VirtualPackage { value.inner } } - #[pymethods] impl PyVirtualPackage { /// Returns virtual packages detected for the current system or an error if the versions could /// not be properly detected. + // marking this as depreacted causes a warning when building the code, + // we just warn directly from python. #[staticmethod] pub fn current() -> PyResult> { - Ok(VirtualPackage::current() + Self::detect(&PyVirtualPackageOverrides::none()) + } + + #[staticmethod] + pub fn detect(overrides: &PyVirtualPackageOverrides) -> PyResult> { + Ok(VirtualPackage::detect(&overrides.clone().into()) .map(|vp| vp.iter().map(|v| v.clone().into()).collect::>()) .map_err(PyRattlerError::from)?) } diff --git a/py-rattler/tests/unit/test_override.py b/py-rattler/tests/unit/test_override.py new file mode 100644 index 000000000..731daa295 --- /dev/null +++ b/py-rattler/tests/unit/test_override.py @@ -0,0 +1,30 @@ +from rattler import VirtualPackage, VirtualPackageOverrides, Override, Version, PackageName + + +def test_overrides() -> None: + overrides = VirtualPackageOverrides.none() + print(overrides.osx, Override.none()) + assert overrides.osx == Override.none() + assert overrides.libc == Override.none() + assert overrides.cuda == Override.none() + overrides = VirtualPackageOverrides() + assert overrides.osx == Override.default_env_var() + assert overrides.libc == Override.default_env_var() + assert overrides.cuda == Override.default_env_var() + + overrides.osx = Override.string("123.45") + overrides.libc = Override.string("123.457") + overrides.cuda = Override.string("123.4578") + + r = [i.into_generic() for i in VirtualPackage.detect(overrides)] + + def find(name: str, ver: str, must_find: bool = True) -> None: + for i in r: + if i.name == PackageName(name): + assert i.version == Version(ver) + return + assert not must_find + + find("__cuda", "123.4578") + find("__libc", "123.4578", False) + find("__osx", "123.45", False)