Skip to content

Commit

Permalink
update cli
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouyiqi91 committed Aug 22, 2024
1 parent 702e45f commit ba4a59a
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 1 deletion.
141 changes: 141 additions & 0 deletions sccore/cli/filter_gtf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#!/usr/bin/env python

import collections
import csv
import gzip
import os
import re
import sys

PATTERN = re.compile(r'(\S+?)\s*"(.*?)"')
gtf_row = collections.namedtuple("gtf_row", "seqname source feature start end score strand frame attributes")


def generic_open(file_name, *args, **kwargs):
if file_name.endswith(".gz"):
file_obj = gzip.open(file_name, *args, **kwargs)
else:
file_obj = open(file_name, *args, **kwargs)
return file_obj


class GtfParser:
def __init__(self, gtf_fn):
self.gtf_fn = gtf_fn
self.gene_id = []
self.gene_name = []
self.id_name = {}
self.id_strand = {}

def get_properties_dict(self, properties_str):
"""
allow no space after semicolon
"""

if isinstance(properties_str, dict):
return properties_str

properties = collections.OrderedDict()
attrs = properties_str.split(";")
for attr in attrs:
if attr:
m = re.search(PATTERN, attr)
if m:
key = m.group(1).strip()
value = m.group(2).strip()
properties[key] = value

return properties

def gtf_reader_iter(self):
"""
Yield:
row: list
gtf_row
"""
with generic_open(self.gtf_fn, mode="rt") as f:
reader = csv.reader(f, delimiter="\t")
for i, row in enumerate(reader, start=1):
if len(row) == 0:
continue
if row[0].startswith("#"):
yield row, None
continue

if len(row) != 9:
sys.exit(f"Invalid number of columns in GTF line {i}: {row}\n")

if row[6] not in ["+", "-"]:
sys.exit(f"Invalid strand in GTF line {i}: {row}\n")

seqname = row[0]
source = row[1]
feature = row[2]
# gff/gtf is 1-based, end-inclusive
start = int(row[3])
end = int(row[4])
score = row[5]
strand = row[6]
frame = row[7]
attributes = self.get_properties_dict(row[8])

yield row, gtf_row(seqname, source, feature, start, end, score, strand, frame, attributes)


def filter_gtf(gtf_fn, out_fn, allow):
"""
Filter attributes
Args:
allow: {
"gene_biotype": set("protein_coding", "lncRNA")
}
"""
sys.stderr.write("Writing GTF file...\n")
gp = GtfParser(gtf_fn)
n_filter = 0
no_writer = csv.writer(open("no.gtf", "w"), delimiter="\t", quoting=csv.QUOTE_NONE, quotechar=None)

with open(out_fn, "w") as f:
# quotechar='' is not allowed since python3.11
writer = csv.writer(f, delimiter="\t", quoting=csv.QUOTE_NONE, quotechar=None)
for row, grow in gp.gtf_reader_iter():
if not grow:
writer.writerow(row)
continue

remove = False
if allow:
for key, value in grow.attributes.items():
if key in allow and value not in allow[key]:
remove = True
break

if not remove:
writer.writerow(row)
else:
n_filter += 1
no_writer.writerow(row)
return n_filter


if __name__ == "__main__":
# args: gtf, attributes
gtf_fn = sys.argv[1]
attributes = sys.argv[2]
out_fn = os.path.basename(gtf_fn).replace(".gtf", ".filtered.gtf")

allow = {}
for attr_str in attributes.split(";"):
if attr_str:
attr, val = attr_str.split("=")
val = set(val.split(","))
allow[attr] = val

n_filter = filter_gtf(gtf_fn, out_fn, allow)
sys.stdout.write(f"Filtered {n_filter} lines\n")
log_file = "gtf_filter.log"
with open(log_file, "w") as f:
f.write(f"Filtered lines: {n_filter}\n")
f.write(f"Attributes: {attributes}\n")
f.write(f"Output file: {out_fn}\n")
68 changes: 68 additions & 0 deletions sccore/cli/invalid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!usr/bin/env python3

import argparse
import pyfastx
import pysam
from sccore import parse_protocol, utils


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--fq1")
parser.add_argument("--bam")
parser.add_argument("--assets_dir")
parser.add_argument("--num", default=10**6, type=int)
args = parser.parse_args()
if not (args.fq1 or args.bam):
raise ValueError("Please provide fq1 or bam")

if args.bam:
cnt = 0
read_names = set()
fh = pysam.AlignmentFile(args.bam, "rb")
for read in fh:
bc = read.get_tag("CB")
if bc == "-" and read.query_name not in read_names:
cnt += 1
read_names.add(read.query_name)
if cnt == args.num:
break
print(f"{cnt} Read names extracted")

fq1 = pyfastx.Fastx(args.fq1)
found = 0
invalid_fastq = open("invalid.txt", "wt")
for name1, seq1, qual1 in fq1:
if name1 in read_names:
found += 1
invalid_fastq.write(seq1 + "\n")
if found == 10000:
break

elif args.fq1:
protocol_dict = parse_protocol.get_protocol_dict(args.assets_dir)
v2_dict = protocol_dict["GEXSCOPE-V2"]
v2_raw, v2_mismatch = parse_protocol.get_raw_mismatch(v2_dict["bc"], 1)

invalid_fastq = open("invalid.fastq", "wt")
fq1 = pyfastx.Fastx(args.fq1)
raw = valid_reads = invalid_reads = 0
for name1, seq1, qual1 in fq1:
raw += 1
bc_list = [seq1[x] for x in v2_dict["pattern_dict"]["C"]]
valid, _corrected, res = parse_protocol.check_seq_mismatch(bc_list, v2_raw, v2_mismatch)
if not valid:
invalid_fastq.write(utils.fastq_str(name1, seq1, qual1))
invalid_reads += 1
else:
valid_reads += 1

if raw == args.num:
break
print(f"Total reads: {raw}")
print(f"Valid reads: {valid_reads}")
print(f"Invalid reads: {invalid_reads}")


if __name__ == "__main__":
main()
27 changes: 27 additions & 0 deletions sccore/cli/starsolo_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python

import sys
import pysam


def main():
bam = sys.argv[1]
total = 0
diff_fh = open("diff_gene.tsv", "w")
fh = pysam.AlignmentFile(bam, "rb")
for read in fh:
total += 1
if total % 100000 == 0:
print(f"Processed {total} reads")
if not read.has_tag("XT"):
continue
gene = read.get_tag("XT")
starsolo_gene = read.get_tag("GX")
if gene != starsolo_gene:
diff_fh.write(str(read) + "\n")

print(f"Total reads: {total}")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion sccore/parse_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def check_seq_mismatch(seq_list, raw_list, mismatch_list):
if seq not in raw_list[index]:
if seq not in mismatch_list[index]:
valid = False
res = []
res.append("")
else:
corrected = True
res.append(mismatch_list[index][seq])
Expand Down

0 comments on commit ba4a59a

Please sign in to comment.