diff --git a/test/system/test_quick_fingerprinting.py b/test/system/test_quick_fingerprinting.py index ad4e7be3..fb7d61d0 100644 --- a/test/system/test_quick_fingerprinting.py +++ b/test/system/test_quick_fingerprinting.py @@ -40,12 +40,12 @@ def test_quick_fingerprinting(tmpdir): ] ) - output = f'{tmpdir}/quick_fingerprinting_results.txt' - import sys + output = f'{tmpdir}/quick_fingerprinting_results.csv' with open(output) as out: lines = out.readlines() - last_line = lines[-1] - arr = last_line.split(' ') - key, val = arr[-1].split('=') - assert key == 'hit_fraction' - assert float(val) > 0.99, 'hit fraction of HG001 vs itself is less than 0.99' \ No newline at end of file + # Skip header line + data_line = lines[1] + # CSV format: full_path,cram_filename,sample_id,ground_truth_id,hit_fraction,best_match,mean_depth + fields = data_line.strip().split(',') + hit_fraction = float(fields[4]) + assert hit_fraction > 0.99, f'hit fraction of HG001 vs itself is {hit_fraction}, less than 0.99' \ No newline at end of file diff --git a/ugbio_utils b/ugbio_utils index d01f4462..4b708880 160000 --- a/ugbio_utils +++ b/ugbio_utils @@ -1 +1 @@ -Subproject commit d01f4462342bc8f9a24d4f3e898b9799103433dc +Subproject commit 4b708880a03159f2985f6a43b62140449f67b327 diff --git a/ugvc/comparison/quick_fingerprinter.py b/ugvc/comparison/quick_fingerprinter.py index 3492ba82..f53de300 100644 --- a/ugvc/comparison/quick_fingerprinter.py +++ b/ugvc/comparison/quick_fingerprinter.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +from joblib import Parallel, delayed from simppl.simple_pipeline import SimplePipeline from ugbio_cloud_utils.cloud_sync import optional_cloud_sync @@ -24,7 +25,14 @@ def __init__( # pylint: disable=too-many-arguments add_aws_auth_command: bool, out_dir: str, sp: SimplePipeline, + csv_name: str = "quick_fingerprinting_results.csv", + n_jobs: int = -1, ): + """ + Initialize the QuickFingerprinter with sample CRAMs, ground truth VCFs, HCRs, reference, region, + filtering parameters, output directory, and pipeline object. + Prepares ground truth files for comparison. + """ self.crams = sample_crams self.ground_truth_vcfs = ground_truth_vcfs self.hcrs = hcrs @@ -36,23 +44,36 @@ def __init__( # pylint: disable=too-many-arguments self.min_hit_fraction_target = min_hit_fraction_target self.sp = sp self.add_aws_auth_command = add_aws_auth_command + self.n_jobs = n_jobs + # Variant caller for hit fraction calculation self.vc = VariantHitFractionCaller(self.ref, self.out_dir, self.sp, self.min_af_snps, region) self.vpu = VcfUtils(self.sp) + # Output file for results (now CSV, with configurable name) + self.output_file = open(os.path.join(self.out_dir, csv_name), "w") + os.makedirs(out_dir, exist_ok=True) + # Prepare ground truth VCFs for each sample self.ground_truths_to_check = self.prepare_ground_truth() def prepare_ground_truth(self): + """ + For each sample, prepare ground truth VCFs restricted to HCR and region. + Returns a dict mapping sample_id to the processed ground truth VCF. + """ ground_truths_to_check = {} + # Create a BED file for the region of interest self.sp.print_and_run(f"echo {self.region} | sed 's/:/\t/' | sed 's/-/\t/' > {self.out_dir}/region.bed") for sample_id in self.ground_truth_vcfs: + # Sync ground truth VCF and HCR files from cloud if needed ground_truth_vcf = optional_cloud_sync(self.ground_truth_vcfs[sample_id], self.out_dir) hcr = optional_cloud_sync(self.hcrs[sample_id], self.out_dir) ground_truth_in_hcr = f"{self.out_dir}/{sample_id}_ground_truth_snps_in_hcr.vcf.gz" ground_truth_to_check_vcf = f"{self.out_dir}/{sample_id}_ground_truth_snps_to_check.vcf.gz" hcr_in_region = f"{self.out_dir}/{sample_id}_hcr_in_region.bed" + # Intersect ground truth VCF with HCR, keep only SNPs self.sp.print_and_run( f"bedtools intersect -a {ground_truth_vcf} -b {hcr} -header | " + f"bcftools view --type snps -Oz -o {ground_truth_in_hcr}" @@ -63,6 +84,7 @@ def prepare_ground_truth(self): ) self.vpu.index_vcf(ground_truth_to_check_vcf) + # Prepare HCR BED file restricted to region if self.region != "": self.sp.print_and_run( f"bedtools intersect -a {hcr} -b {self.out_dir}/region.bed | " @@ -75,61 +97,58 @@ def prepare_ground_truth(self): return ground_truths_to_check def print(self, msg: str): + """ + Write a message to the output file. + """ self.output_file.write(msg + "\n") + def _process_cram(self, sample_id: str, cram: str) -> list[str]: + max_hit_fraction = 0 + best_match = None + cram_base_name = os.path.basename(cram) + called_vcf = f"{self.out_dir}/{cram_base_name}.calls.vcf.gz" + local_bam = f"{self.out_dir}/{cram_base_name}.bam" + vc = VariantHitFractionCaller(self.ref, self.out_dir, self.sp, self.min_af_snps, self.region) + + if self.add_aws_auth_command: + self.sp.print_and_run( + f"eval $(aws configure export-credentials --format env-no-export) " + f"samtools view {cram} -T {self.ref} {self.region} -b -o {local_bam}" + ) + else: + self.sp.print_and_run(f"samtools view {cram} -T {self.ref} {self.region} -b -o {local_bam}") + + vc.call_variants(local_bam, called_vcf, self.region, min_af=self.min_af_snps) + mean_depth = vc.get_mean_depth(called_vcf) + + hit_fraction_dict = {} + for ground_truth_id, ground_truth_to_check_vcf in self.ground_truths_to_check.items(): + hit_fraction, _, _ = vc.calc_hit_fraction(called_vcf, ground_truth_to_check_vcf) + hit_fraction_dict[ground_truth_id] = hit_fraction + if hit_fraction > max_hit_fraction: + max_hit_fraction = hit_fraction + best_match = ground_truth_id + + return [ + f"{cram},{cram_base_name},{sample_id},{ground_truth_id},{hit_fraction},{best_match},{mean_depth}" + for ground_truth_id, hit_fraction in hit_fraction_dict.items() + ] + def check(self): - errors = [] - with open(f"{self.out_dir}/quick_fingerprinting_results.txt", "w", encoding="utf-8") as of: - self.output_file = of - for sample_id in self.crams: - self.print(f"Check consistency for {sample_id}:") - crams = self.crams[sample_id] - self.print(" crams = \n\t" + "\n\t".join(self.crams[sample_id])) - self.print(f" hcrs = {self.hcrs}") - self.print(f" ground_truth_vcfs = {self.ground_truth_vcfs}") - - for cram in crams: - # Validate that each cram correlates to the ground-truth - self.print("") - hit_fractions = [] - max_hit_fraction = 0 - best_match = None - match_to_expected_truth = None - cram_base_name = os.path.basename(cram) - - called_vcf = f"{self.out_dir}/{cram_base_name}.calls.vcf.gz" - local_bam = f"{self.out_dir}/{cram_base_name}.bam" - if self.add_aws_auth_command: - self.sp.print_and_run( - f"eval $(aws configure export-credentials --format env-no-export) \ - samtools view {cram} -T {self.ref} {self.region} -b -o {local_bam}" - ) - else: - self.sp.print_and_run(f"samtools view {cram} -T {self.ref} {self.region} -b -o {local_bam}") - - self.vc.call_variants(local_bam, called_vcf, self.region, min_af=self.min_af_snps) - - potential_error = f"{cram} - {sample_id} " - for ground_truth_id, ground_truth_to_check_vcf in self.ground_truths_to_check.items(): - hit_fraction, _, _ = self.vc.calc_hit_fraction(called_vcf, ground_truth_to_check_vcf) - if hit_fraction > max_hit_fraction: - max_hit_fraction = hit_fraction - best_match = ground_truth_id - hit_fractions.append(hit_fraction) - if sample_id == ground_truth_id and hit_fraction < self.min_hit_fraction_target: - match_to_expected_truth = hit_fraction - potential_error += f"does not match it's ground truth: hit_fraction={hit_fraction} " - elif sample_id != ground_truth_id and hit_fraction > self.min_hit_fraction_target: - potential_error += ( - f"matched ground truth of {ground_truth_id}: hit_fraction={hit_fraction} " - ) - self.print(f"{cram} - {sample_id} vs. {ground_truth_id} hit_fraction={hit_fraction}") - if best_match != sample_id: - if match_to_expected_truth is None: - self.print(f"{cram} best_match={best_match} hit_fraction={max_hit_fraction}") - else: - potential_error += f"max_hit_fraction = {max(hit_fractions)}" - if potential_error != f"{cram} - {sample_id} ": - errors.append(potential_error) - if len(errors) > 0: - raise RuntimeError("\n".join(errors)) + """ + For each sample and each CRAM, call variants and compare to all ground truth VCFs. + Print a CSV table with cram filename, sample_id, ground_truth_id, hit_fraction, and best_match. + """ + # Print CSV header + self.print("full_path,cram_filename,sample_id,ground_truth_id,hit_fraction,best_match,mean_depth") + + for sample_id in self.crams: + crams = self.crams[sample_id] + rows_by_cram = Parallel(n_jobs=self.n_jobs, prefer="threads")( + delayed(self._process_cram)(sample_id, cram) for cram in crams + ) + for rows in rows_by_cram: + for row in rows: + self.print(row) + + self.output_file.close() diff --git a/ugvc/comparison/variant_hit_fraction_caller.py b/ugvc/comparison/variant_hit_fraction_caller.py index 813c010c..418f9982 100644 --- a/ugvc/comparison/variant_hit_fraction_caller.py +++ b/ugvc/comparison/variant_hit_fraction_caller.py @@ -32,6 +32,9 @@ def calc_hit_fraction(self, called_vcf: str, ground_truth_vcf: str) -> tuple[flo gt_base_name = os.path.basename(ground_truth_vcf).replace(".vcf.gz", "") ground_truth_variants = get_vcf_df(ground_truth_vcf) called_variants = get_vcf_df(called_vcf) + if "dp" not in called_variants.columns: + called_variants["dp"] = pd.NA + called_variants["dp"] = pd.to_numeric(called_variants["dp"], errors="coerce") # keep only major alt called_variants['major_alt'] = called_variants.apply(lambda x: x["alleles"][1], axis=1) ground_truth_variants['major_alt'] = ground_truth_variants.apply(lambda x: x["alleles"][1], axis=1) @@ -48,6 +51,16 @@ def calc_hit_fraction(self, called_vcf: str, ground_truth_vcf: str) -> tuple[flo return hit_fraction, hit_count, ground_truth_count + def get_mean_depth(self, called_vcf: str) -> float: + called_variants = get_vcf_df(called_vcf) + if "dp" not in called_variants.columns: + called_variants["dp"] = pd.NA + called_variants["dp"] = pd.to_numeric(called_variants["dp"], errors="coerce") + mean_depth = called_variants["dp"].mean(skipna=True) + if pd.isna(mean_depth): + return 0.0 + return round(float(mean_depth), 1) + def count_lines(self, in_file: str, out_file: str): self.sp.print_and_run(f"wc -l {in_file} " + "| awk '{print $1}' " + f" > {out_file}") diff --git a/ugvc/pipelines/comparison/quick_fingerprinting.py b/ugvc/pipelines/comparison/quick_fingerprinting.py index 5213453a..8b42411c 100644 --- a/ugvc/pipelines/comparison/quick_fingerprinting.py +++ b/ugvc/pipelines/comparison/quick_fingerprinting.py @@ -12,6 +12,10 @@ def __get_parser() -> argparse.ArgumentParser: + """ + Create and return the argument parser for the script. + Adds arguments for configuration, region, AWS auth, output directory, and variant caller options. + """ parser = argparse.ArgumentParser( prog="quick fingerprinting (finding sample identity of crams, " "given a known list of ground-truth files)", description=run.__doc__, @@ -32,21 +36,36 @@ def __get_parser() -> argparse.ArgumentParser: action="store_true", help="add aws auth command to samtools commands" ) - + # Add arguments from VariantHitFractionCaller VariantHitFractionCaller.add_args_to_parser(parser) parser.add_argument("--out_dir", type=str, required=True, help="output directory") + parser.add_argument( + "--results_file_name", + type=str, + default="quick_fingerprinting_results.csv", + help="output CSV filename (default: quick_fingerprinting_results.csv)" + ) + parser.add_argument( + "--n_jobs", + type=int, + default=-1, + help="number of parallel jobs for CRAM processing (-1 uses all available cores)", + ) return parser def run(argv): """quick fingerprinting to identify known samples in crams""" parser = __get_parser() + # Add pipeline-specific arguments SimplePipeline.add_parse_args(parser) args = parser.parse_args(argv[1:]) + # Load configuration from JSON file with open(args.json_conf, encoding="utf-8") as fh: conf = json.load(fh) + # Sync reference files from cloud if needed ref = optional_cloud_sync(conf["references"]["ref_fasta"], args.out_dir) optional_cloud_sync(conf["references"]["ref_dict"], args.out_dir) optional_cloud_sync(conf["references"]["ref_fasta_index"], args.out_dir) @@ -54,15 +73,18 @@ def run(argv): ground_truth_vcf_files = conf["ground_truth_vcf_files"] # dict sample-id -> bed hcr_files = conf["ground_truth_hcr_files"] # dict sample-id -> bed + # Extract region and variant calling parameters region = args.region_str min_af_snps = args.min_af_snps min_af_germline_snps = args.min_af_germline_snps min_hit_fraction_target = args.min_hit_fraction_target + # Initialize the pipeline sp = SimplePipeline(args.fc, args.lc, debug=args.d) os.makedirs(args.out_dir, exist_ok=True) errors = [] + # Run the quick fingerprinting check QuickFingerprinter( cram_files_list, ground_truth_vcf_files, @@ -74,9 +96,12 @@ def run(argv): min_hit_fraction_target, args.add_aws_auth_command, args.out_dir, - sp + sp, + csv_name=args.results_file_name, + n_jobs=args.n_jobs, ).check() + # Raise errors if any occurred if len(errors) > 0: raise RuntimeError("\n".join(errors))