Skip to content

Commit

Permalink
Implemented comparison operators for pyauditor types
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Aug 30, 2022
1 parent 174a74a commit f207e5d
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 49 deletions.
121 changes: 79 additions & 42 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion auditor/src/domain/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use sqlx::{
Postgres, Type,
};

#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, sqlx::Encode, Clone)]
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, sqlx::Encode, Clone, PartialOrd, Ord)]
#[sqlx(type_name = "component")]
pub struct Component {
pub name: ValidName,
Expand Down
2 changes: 1 addition & 1 deletion auditor/src/domain/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub struct RecordUpdate {
pub stop_time: DateTime<Utc>,
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Record {
pub record_id: String,
pub site_id: Option<String>,
Expand Down
2 changes: 1 addition & 1 deletion auditor/src/domain/validamount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::fmt;
// possible to create this type outside of this module, hence enforcing the use of `parse`. This
// ensures that every string stored in this type satisfies the validation criteria checked by
// `parse`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Decode, sqlx::Encode)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, sqlx::Decode, sqlx::Encode)]
pub struct ValidAmount(i64);

impl ValidAmount {
Expand Down
2 changes: 1 addition & 1 deletion auditor/src/domain/validname.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use unicode_segmentation::UnicodeSegmentation;
// possible to create this type outside of this module, hence enforcing the use of `parse`. This
// ensures that every string stored in this type satisfies the validation criteria checked by
// `parse`.
#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)]
#[sqlx(transparent)]
pub struct ValidName(String);

Expand Down
54 changes: 54 additions & 0 deletions pyauditor/scripts/test_eq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3

from pyauditor import Record, Component, Score
import datetime
import pytz
from tzlocal import get_localzone


def main():
local_tz = get_localzone()
print("LOCAL TIMEZONE: " + str(local_tz))

record_id = "record-1"
site_id = "site-1"
user_id = "user-1"
group_id = "group-1"
score = "score-1"
factor = 12.0
component = "component-1"
amount = 21

# datetimes sent to auditor MUST BE in UTC.
start = datetime.datetime(
2021, 12, 6, 16, 29, 43, 79043, tzinfo=local_tz
).astimezone(pytz.utc)

score1 = Score(score, factor)
score2 = Score(score, factor)
assert score1 == score2

comp1 = Component(component, amount)
comp2 = Component(component, amount)
assert comp1 == comp2

comp1 = Component(component, amount).with_score(score1)
comp2 = Component(component, amount).with_score(score2)
assert comp1 == comp2

record1 = Record(record_id, site_id, user_id, group_id, start)
record2 = Record(record_id, site_id, user_id, group_id, start)
assert record1 == record2

record1 = Record(record_id, site_id, user_id, group_id, start).with_component(comp1)
record2 = Record(record_id, site_id, user_id, group_id, start).with_component(comp1)
assert record1 == record2


if __name__ == "__main__":
import time

s = time.perf_counter()
main()
elapsed = time.perf_counter() - s
print(f"{__file__} executed in {elapsed:0.2f} seconds.")
15 changes: 14 additions & 1 deletion pyauditor/src/domain/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// copied, modified, or distributed except according to those terms.

use crate::domain::Score;
use pyo3::class::basic::{CompareOp, PyObjectProtocol};
use pyo3::prelude::*;

/// Component(name: str, amount: int)
Expand All @@ -22,7 +23,7 @@ use pyo3::prelude::*;
/// :param amount: Amount
/// :type amount: int
#[pyclass]
#[derive(Clone)]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Component {
pub(crate) inner: auditor::domain::Component,
}
Expand Down Expand Up @@ -62,6 +63,18 @@ impl Component {
}
}

#[pyproto]
impl PyObjectProtocol for Component {
fn __richcmp__(&self, other: PyRef<Component>, op: CompareOp) -> Py<PyAny> {
let py = other.py();
match op {
CompareOp::Eq => (self.inner == other.inner).into_py(py),
CompareOp::Ne => (self.inner != other.inner).into_py(py),
_ => py.NotImplemented(),
}
}
}

impl From<auditor::domain::Component> for Component {
fn from(component: auditor::domain::Component) -> Component {
Component { inner: component }
Expand Down
Loading

0 comments on commit f207e5d

Please sign in to comment.