Skip to content

Commit

Permalink
vcf updates
Browse files Browse the repository at this point in the history
  • Loading branch information
kdm9 committed Nov 5, 2024
1 parent 62453ed commit 15eec1b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 27 deletions.
19 changes: 11 additions & 8 deletions blsl/vcfparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from tqdm import tqdm
from cyvcf2 import VCF
from natsort import natsorted

from sys import stdin, stdout, stderr
import shutil
Expand All @@ -18,12 +19,10 @@
import re


def parallel_regions(vcf, cores=1, npercore=10):
def parallel_regions(vcf, chunks=10):
V = VCF(vcf)
# 10 chunks per chrom per core
chunks = len(V.seqlens)*npercore*cores
for cname, clen in zip(V.seqnames, V.seqlens):
chunk = int(max(min(clen, 1000), math.ceil(clen/chunks)))
chunk = int(max(min(clen, 1_000_000), math.ceil(clen/chunks)))
for start in range(0, clen, chunk):
s = start+1
e = min(clen, start+chunk+1)
Expand All @@ -47,7 +46,7 @@ def chunkwise_pipeline(args):
with ProcessPoolExecutor(args.threads) as exc:
jobs = []
if args.regions is None:
regions = parallel_regions(args.vcf)
regions = parallel_regions(args.vcf, chunks=args.chunks)
else:
regions = set()
for line in args.regions:
Expand All @@ -67,7 +66,7 @@ def merge_one(files, prefix, threads=1, merge_type="fast"):
fofn = f"{prefix}fofn.txt"
output = f"{prefix}output.bcf"
with open(fofn, "w") as fh:
for file in sorted(files):
for file in natsorted(files):
print(file, file=fh)
merge = "--allow-overlaps --rm-dup all" if merge_type == "slow" else ""
cmd = f"(bcftools concat --file-list {fofn} {merge} --threads {threads} -Ob0 --write-index -o {output}) >{output}.log 2>&1"
Expand Down Expand Up @@ -97,7 +96,7 @@ def merge_results(args, filestomerge):
with ProcessPoolExecutor(args.threads) as exc:
jobs = []
for i, files in enumerate(groups):
jobs.append(exc.submit(merge_one, files, f"{args.temp_prefix}merge_group_{i}_", args.merge_type))
jobs.append(exc.submit(merge_one, files, f"{args.temp_prefix}merge_group_{i:09d}_", args.merge_type))
for job in tqdm(as_completed(jobs), total=len(jobs), unit="group"):
ofile = job.result()
final_merge.append(ofile)
Expand All @@ -106,7 +105,7 @@ def merge_results(args, filestomerge):

fofn = f"{args.temp_prefix}final_fofn.txt"
with open(fofn, "w") as fh:
for file in sorted(final_merge):
for file in natsorted(final_merge):
print(file, file=fh)
index = "--write-index" if re.match(r"[zb]", args.outformat) else ""
merge = "--allow-overlaps --rm-dup all" if args.merge_type == "slow" else ""
Expand All @@ -131,6 +130,8 @@ def main(argv=None):
help="Use slow bcftools merging (with --allow-overlaps and --remove-duplicates)")
ap.add_argument("-f", "--filter", default="", type=str,
help="bcftools view arguments for variant filtering")
ap.add_argument("-C", "--chunks", type=int,
help="Number of chunks each chromosome is broken into. Default: 10*--threads")
ap.add_argument("-c", "--commands", default="", type=str,
help="command(s) to operate. Must take uncompressed bcf on stdin and yield bcf (i.e -Ob0) on stdout. Can use | and other shell features.")
ap.add_argument("-M", "--merge-with-cat", action="store_true",
Expand All @@ -140,6 +141,8 @@ def main(argv=None):
ap.add_argument("vcf")
args = ap.parse_args(argv)

if args.chunks is None:
args.chunks = args.threads * 10
if args.temp_prefix is None:
args.temp_prefix = tempfile.gettempdir() + "/bcffilter"
tp = Path(args.temp_prefix)
Expand Down
68 changes: 49 additions & 19 deletions blsl/vcfreport.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,26 @@ def variant2dict(v, fields=None):
if "FORMAT_GT" in fields:
dat["FORMAT_GT"] = [v[0] + v[1] if v[0] != -1 and v[1] != -1 else float('nan') for v in v.genotypes ]
if "FORMAT_DP" in fields:
dat["FORMAT_DP"] = v.format('DP')
dat["FORMAT_DP"] = [x[0] for x in v.format('DP').tolist()]
return dat


def res2json(gres, base=True):
res = {}
for key, val in gres.items():
if isinstance(val, dict):
val = res2json(val, base=False)
elif isinstance(val, physt.histogram1d.Histogram1D):
val={ "bins": val.bins.tolist(), "frequencies": val.frequencies.tolist()}
elif isinstance(val, np.ndarray):
val=val.tolist()
if base and key == "subset_variants":
continue # skip subset variants
#val = {x: [v[x] for v in val] for x in ["FORMAT_DP", "FORMAT_GT"]}
res[key] = val
return res


def one_chunk_stats(vcf, chunk, fill=True, fields=None, min_maf=0.01, min_call=0.7, subsample=0.01):
cmd=f"bcftools view -r {str(chunk)} {vcf} -Ou"
if fill:
Expand Down Expand Up @@ -147,41 +163,48 @@ def chunkwise_bcftools_stats(vcf, threads=8, chunksize=1_000_000):
for job in as_completed(jobs):
pbar.update(1)
update_result(global_res, job.result())

genomat = np.vstack([x["FORMAT_GT"] for x in global_res["subset_variants"]])
dpmat = np.vstack([x["FORMAT_DP"] for x in global_res["subset_variants"]])
global_res["subset_variants"] = {
"FORMAT_GT": genomat,
"FORMAT_DP": dpmat,
}
global_res["sample_missingness"] = sample_missingness(global_res)
global_res["sample_depth"] = sample_depth(global_res)
x, pc, varex = genotype_pca(global_res)
global_res["pca"] = {"pc": pc, "varex": varex}
return global_res

def sample_missingness(variants):
genomat = np.vstack([x["FORMAT_GT"] for x in variants])
def sample_missingness(res):
genomat = res["subset_variants"]["FORMAT_GT"]
missing = np.sum(np.isnan(genomat), axis=0) / np.shape(genomat)[0]
#return {s:m for s, m in zip(res["samples"], missing)}
return histogram(missing, 30)
return {s:m for s, m in zip(res["samples"], missing)}

def sample_depth(variants):
dpmat = np.hstack([x["FORMAT_DP"] for x in variants])
meandepth = np.nansum(dpmat, axis=1) / np.shape(dpmat)[1]
#return {s:m for s, m in zip(samples, meandepth)}
return histogram(meandepth, 30)
def sample_depth(res):
dpmat = res["subset_variants"]["FORMAT_DP"]
meandepth = np.nansum(dpmat, axis=0) / np.shape(dpmat)[0]
return {s:m for s, m in zip(res["samples"], meandepth)}

def genotype_pca(res):
genomat = np.vstack([x["FORMAT_GT"] for x in res])
genomat = res["subset_variants"]["FORMAT_GT"]
imp = sklearn.impute.SimpleImputer()
genomat = imp.fit_transform(genomat)
pc = sklearn.decomposition.PCA()
gpc = pc.fit_transform(genomat.T)
return genomat, gpc, pc.explained_variance_ratio_


def generate_report(vcf, threads, chunksize=1_000_000):
gres = chunkwise_bcftools_stats(vcf, threads=threads, chunksize=chunksize)
def generate_report(gres, args):
figwidth=1000
figheight=750

fig_smis = sample_missingness(gres["subset_variants"]).plot()
fig_smis = histogram(list(gres["sample_missingness"].values()), 30).plot()
fig_smis.update_xaxes(title_text="Missing Rate (sample)")
fig_smis.update_yaxes(title_text="# Samples")
fig_smis.update_layout(title_text="Sample Missing Rate", width=figwidth, height=figheight)
SMIS_CODE=fig_smis.to_html(full_html=False)

fig_sdp = sample_depth(gres["subset_variants"]).plot()
fig_sdp = histogram(list(gres["sample_depth"].values()), 30).plot()
fig_sdp.update_xaxes(title_text="Mean Depth (sample)")
fig_sdp.update_yaxes(title_text="# Samples")
fig_sdp.update_layout(title_text="Sample Mean Depths", width=figwidth, height=figheight)
Expand Down Expand Up @@ -224,7 +247,8 @@ def generate_report(vcf, threads, chunksize=1_000_000):
QUAL_CODE = fig_qual.to_html(full_html=False)


x, pc, varex = genotype_pca(gres["subset_variants"])
pc = gres["pca"]["pc"]
varex = gres["pca"]["varex"]
axtitle={"xy"[i]: f"PC {i+1} ({varex[i]*100:.1f}%)" for i in range(2)}
pc_fig = px.scatter(
x=pc[:,0],
Expand Down Expand Up @@ -268,7 +292,7 @@ def generate_report(vcf, threads, chunksize=1_000_000):
<body>
<h1>VCF Statistics</h1>
<table>
<tr><td>File</td> <td align="right">{vcf}</td> </tr>
<tr><td>File</td> <td align="right">{args.vcf}</td> </tr>
<tr><td>Total # SNPs</td> <td align="right">{gres['n_snps']:,}</td> </tr>
<tr><td># Samples</td> <td align="right">{len(gres['samples']):,}</td> </tr>
</table>
Expand Down Expand Up @@ -318,6 +342,8 @@ def main(argv=None):
ap = argparse.ArgumentParser("blsl vcfreport")
ap.add_argument("--output", "-o", type=argparse.FileType("w"), required=True,
help="Output html file")
ap.add_argument("--json", "-O", type=argparse.FileType("w"),
help="Output JSON data dump file")
ap.add_argument("--threads", "-t", type=int, default=2,
help="Parallel threads")
ap.add_argument("--chunksize", "-c", type=int, default=1_000_000,
Expand All @@ -326,7 +352,11 @@ def main(argv=None):
help="VCF input file (must be indexed)")
args=ap.parse_args(argv)

html = generate_report(args.vcf, threads=args.threads, chunksize=args.chunksize)
gres = chunkwise_bcftools_stats(args.vcf, threads=args.threads, chunksize=args.chunksize)
if args.json:
jres = res2json(gres)
json.dump(jres, args.json)
html = generate_report(gres, args)
args.output.write(html)
args.output.flush()

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"scikit-learn",
"physt",
"plotly",
"natsort",
]
dynamic = [
"version",
Expand Down

0 comments on commit 15eec1b

Please sign in to comment.