-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
北词你好
committed
Dec 17, 2024
1 parent
46f9488
commit c733d67
Showing
11 changed files
with
785 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
name: coveralls | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
branches: | ||
- main | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v3 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: '3.9' | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r requirements.txt | ||
|
||
- name: Run tests and generate coverage report | ||
run: | | ||
pip install pytest pytest-cov | ||
pytest --cov=my_project tests/ | ||
|
||
- name: Upload coverage to Coveralls | ||
uses: coverallsapp/github-action@v2 | ||
with: | ||
github-token: ${{ secrets.GITHUB_TOKEN }} |
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .metric import eval_nmi | ||
from .metric import eval_ami | ||
from .metric import eval_ari | ||
from .metric import eval_f1 | ||
from .metric import eval_acc | ||
|
||
__all__ = [ | ||
'eval_nmi', | ||
'eval_ami', | ||
'eval_ari', | ||
'eval_f1', | ||
'eval_acc' | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Metrics used to evaluate the outlier detection performance | ||
""" | ||
# Author: Yingtong Dou <[email protected]>, Kay Liu <[email protected]> | ||
# License: BSD 2 clause | ||
|
||
|
||
from sklearn import metrics | ||
|
||
|
||
|
||
def eval_nmi(ground_truths, predictions): | ||
""" | ||
Normalized Mutual Information (NMI) score for clustering evaluation. | ||
|
||
Parameters | ||
---------- | ||
ground_truths : array-like | ||
Ground truth labels. | ||
predictions : array-like | ||
Predicted cluster labels. | ||
|
||
Returns | ||
------- | ||
nmi : float | ||
Normalized Mutual Information score. | ||
""" | ||
nmi = metrics.normalized_mutual_info_score(ground_truths, predictions) | ||
return nmi | ||
|
||
|
||
def eval_ami(ground_truths, predictions): | ||
""" | ||
Adjusted Mutual Information (AMI) score for clustering evaluation. | ||
|
||
Parameters | ||
---------- | ||
ground_truths : array-like | ||
Ground truth labels. | ||
predictions : array-like | ||
Predicted cluster labels. | ||
|
||
Returns | ||
------- | ||
ami : float | ||
Adjusted Mutual Information score. | ||
""" | ||
ami = metrics.adjusted_mutual_info_score(ground_truths, predictions) | ||
return ami | ||
|
||
|
||
def eval_ari(ground_truths, predictions): | ||
""" | ||
Adjusted Rand Index (ARI) score for clustering evaluation. | ||
|
||
Parameters | ||
---------- | ||
ground_truths : array-like | ||
Ground truth labels. | ||
predictions : array-like | ||
Predicted cluster labels. | ||
|
||
Returns | ||
------- | ||
ari : float | ||
Adjusted Rand Index score. | ||
""" | ||
ari = metrics.adjusted_rand_score(ground_truths, predictions) | ||
return ari | ||
|
||
|
||
def eval_f1(ground_truths, predictions): | ||
""" | ||
F1 score for classification evaluation. | ||
|
||
Parameters | ||
---------- | ||
ground_truths : array-like | ||
Ground truth labels. | ||
predictions : array-like | ||
Predicted labels. | ||
|
||
Returns | ||
------- | ||
f1 : float | ||
F1 score. | ||
""" | ||
f1 = metrics.f1_score(ground_truths, predictions, average='macro') | ||
return f1 | ||
|
||
|
||
def eval_acc(ground_truths, predictions): | ||
""" | ||
Accuracy score for classification evaluation. | ||
|
||
Parameters | ||
---------- | ||
ground_truths : array-like | ||
Ground truth labels. | ||
predictions : array-like | ||
Predicted labels. | ||
|
||
Returns | ||
------- | ||
acc : float | ||
Accuracy score. | ||
""" | ||
acc = metrics.accuracy_score(ground_truths, predictions) | ||
return acc | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
import sys | ||
|
||
# Add parent directory to path so we can import modules | ||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
|
||
# Import test modules | ||
from .testBERT import * | ||
from .testBiLSTM import * | ||
from .testEventX import * | ||
from .testKPGNN import * | ||
from .testRPLMSED import * | ||
from .testword2vec import * | ||
from .testRPLMSED import * | ||
|
||
__all__ = [ | ||
"testBERT", | ||
"testBiLSTM", | ||
"testEventX", | ||
"testKPGNN", | ||
"testRPLMSED", | ||
"testword2vec" | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .utility import * | ||
from .score_converter import to_edge_score, to_graph_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Outlier Score Converters | ||
""" | ||
# Author: Kay Liu <[email protected]> | ||
# License: BSD 2 clause | ||
|
||
|
||
def to_edge_score(score, edge_index): | ||
"""Convert outlier node score to outlier edge score by averaging the | ||
scores of two nodes connected by an edge. | ||
|
||
Parameters | ||
---------- | ||
score : torch.Tensor | ||
The node score. | ||
edge_index : torch.Tensor | ||
The edge index. | ||
|
||
Returns | ||
------- | ||
score : torch.Tensor | ||
The edge score. | ||
""" | ||
score = (score[edge_index[0]] + score[edge_index[1]]) / 2 | ||
return score | ||
|
||
|
||
def to_graph_score(score): | ||
"""Convert outlier node score to outlier graph score by averaging | ||
the scores of all nodes in a graph. | ||
|
||
Parameters | ||
---------- | ||
score : torch.Tensor | ||
The node score. | ||
|
||
Returns | ||
------- | ||
score : torch.Tensor | ||
The graph score. | ||
""" | ||
|
||
return score.mean(dim=-1) |
Oops, something went wrong.