Skip to content

Commit

Permalink
Merge pull request #14 from TaskeHAMANO/feature/median_filter
Browse files Browse the repository at this point in the history
[DEV] Changed median filter that uses large window size
  • Loading branch information
Shinya SUZUKI authored May 8, 2018
2 parents dc76e16 + 04f1018 commit 65cbb69
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 15 deletions.
4 changes: 2 additions & 2 deletions sphere/sphere_dplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
import numpy as np
from sphere.sphere_utils import load_depth_file
from sphere.sphere_utils import compress_depth
from sphere.sphere_utils import segment_depth
from sphere.sphere_utils import get_logger
from matplotlib import gridspec
try:
Expand Down Expand Up @@ -48,7 +48,7 @@ def main(args, logger):
y = df["depth"].values
t1 = np.arange(0, 2*np.pi, 2*np.pi/args["np"])
t2 = np.arange(0, 2*np.pi, 2*np.pi/I)
y_f = compress_depth(y, args["np"])
y_f = segment_depth(y, args["np"])
width = 2 * np.pi / (args["np"]+10)

fig = plt.figure(figsize=(20, 20))
Expand Down
18 changes: 13 additions & 5 deletions sphere/sphere_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Created: 2017-11-01

from sphere.sphere_utils import compress_depth
from sphere.sphere_utils import compress_length
from sphere.sphere_utils import load_depth_file
from sphere.sphere_utils import get_logger
import argparse
Expand All @@ -19,22 +20,29 @@ def argument_parse(argv=None):
parser.add_argument("output_dest",
type=str,
help="destination of output tsv file")
parser.add_argument("-cl", "--compressedlength",
dest="cl",
parser.add_argument("-s", "--stride_length",
dest="s",
nargs="?",
default=100,
type=int,
help="Stride length of filter (default: 100)")
parser.add_argument("-w", "--window_length",
dest="w",
nargs="?",
default=10000,
type=int,
help="Compressed length of genome (default: 10000)")
help="Window length of filter (default: 10000)")
args = parser.parse_args(argv)
return vars(args)


def main(args, logger):
df = load_depth_file(args["depth_file_path"])
cl = compress_length(df["depth"].size, s=args["s"], w=args["w"])

genome_name = df["genome"].unique()[0]
position = np.arange(1, args["cl"]+1, 1)
c_depth = compress_depth(df["depth"], args["cl"])
position = np.arange(1, cl + 1, 1)
c_depth = compress_depth(df["depth"], s=args["s"], w=args["w"])
c_df = pd.DataFrame({"position": position, "depth": c_depth})
c_df["genome"] = genome_name
c_df = c_df[["genome", "position", "depth"]]
Expand Down
15 changes: 14 additions & 1 deletion sphere/sphere_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,20 @@ def load_multiple_depth_file(depth_file_path: list):
return c_df


def compress_depth(v: np.ndarray, cl: int):
def compress_depth(d: pd.Series, s: int=None, w: int=None) -> pd.Series:
dr = d.rolling(window=w).median().dropna().reset_index(drop=True)
dr = dr[list(range(0, dr.size, s))].reset_index(drop=True)
dr = dr.round().astype(int)
return dr


def compress_length(dl: int, s: int, w: int) -> int:
cl = (dl - w) / s + 1
cl = int(cl)
return cl


def segment_depth(v: np.ndarray, cl: int) -> np.ndarray:
I = v.size
w1 = window_length(I, cl)
w2 = w1 + 1
Expand Down
5 changes: 5 additions & 0 deletions tests/data/test_sphere_filter/answer2.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
genome_name 1 4
genome_name 2 4
genome_name 3 6
genome_name 4 6
genome_name 5 8
4 changes: 4 additions & 0 deletions tests/data/test_sphere_filter/answer3.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
genome_name 1 2
genome_name 2 4
genome_name 3 6
genome_name 4 8
File renamed without changes.
10 changes: 10 additions & 0 deletions tests/data/test_sphere_filter/input2.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
genome_name 1 1
genome_name 2 2
genome_name 3 3
genome_name 4 4
genome_name 5 5
genome_name 6 6
genome_name 7 7
genome_name 8 8
genome_name 9 9
genome_name 10 10
10 changes: 10 additions & 0 deletions tests/data/test_sphere_filter/input3.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
genome_name 1 1
genome_name 2 2
genome_name 3 3
genome_name 4 4
genome_name 5 5
genome_name 6 6
genome_name 7 7
genome_name 8 8
genome_name 9 9
genome_name 10 10
37 changes: 30 additions & 7 deletions tests/test_sphere_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import unittest
import os
import filecmp
from sphere import sphere_filter
from sphere.sphere_utils import get_logger

Expand All @@ -16,35 +17,57 @@ class SphereFilterTest(unittest.TestCase):
def setUp(self):
self.maxDiff = None
d_dir = os.path.dirname(__file__) + "/data/test_sphere_filter"
self.__input = d_dir + "/input.tsv"
self.__input1 = d_dir + "/input1.tsv"
self.__input2 = d_dir + "/input2.tsv"
self.__input3 = d_dir + "/input3.tsv"
self.__output = d_dir + "/output.tsv"
self.__answer2 = d_dir + "/answer2.tsv"
self.__answer3 = d_dir + "/answer3.tsv"

def tearDown(self):
if os.path.exists(self.__output):
os.remove(self.__output)

def test_sphere_filter_main(self):
args = {
"depth_file_path": self.__input,
"depth_file_path": self.__input1,
"output_dest": self.__output,
"cl": 100,
"s": 10,
"w": 10
}
sphere_filter.main(args, SphereFilterTest.logger)

def test_sphere_filter_main_non_devidable(self):
args = {
"depth_file_path": self.__input,
"depth_file_path": self.__input1,
"output_dest": self.__output,
"cl": 19,
"s": 7,
"w": 11
}
sphere_filter.main(args, SphereFilterTest.logger)

def test_sphere_filter_command(self):
argv_str = "{0} {1} -cl 100".format(self.__input, self.__output)
def test_sphere_filter_command1(self):
argv_str = "{0} {1} -s 10 -w 10".format(self.__input1, self.__output)
argv = argv_str.split()
args = sphere_filter.argument_parse(argv)
sphere_filter.main(args, SphereFilterTest.logger)

def test_sphere_filter_command2(self):
argv_str = "{0} {1} -s 1 -w 6".format(self.__input2, self.__output)
argv = argv_str.split()
args = sphere_filter.argument_parse(argv)
sphere_filter.main(args, SphereFilterTest.logger)
result = filecmp.cmp(self.__output, self.__answer2)
self.assertTrue(result)

def test_sphere_filter_command3(self):
argv_str = "{0} {1} -s 2 -w 3".format(self.__input3, self.__output)
argv = argv_str.split()
args = sphere_filter.argument_parse(argv)
sphere_filter.main(args, SphereFilterTest.logger)
result = filecmp.cmp(self.__output, self.__answer3)
self.assertTrue(result)


if __name__ == '__main__':
unittest.main()

0 comments on commit 65cbb69

Please sign in to comment.