Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions test/system/test_quick_fingerprinting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def test_quick_fingerprinting(tmpdir):
]
)

output = f'{tmpdir}/quick_fingerprinting_results.txt'
import sys
output = f'{tmpdir}/quick_fingerprinting_results.csv'
with open(output) as out:
lines = out.readlines()
last_line = lines[-1]
arr = last_line.split(' ')
key, val = arr[-1].split('=')
assert key == 'hit_fraction'
assert float(val) > 0.99, 'hit fraction of HG001 vs itself is less than 0.99'
# Skip header line
data_line = lines[1]
# CSV format: full_path,cram_filename,sample_id,ground_truth_id,hit_fraction,best_match,mean_depth
fields = data_line.strip().split(',')
hit_fraction = float(fields[4])
assert hit_fraction > 0.99, f'hit fraction of HG001 vs itself is {hit_fraction}, less than 0.99'
2 changes: 1 addition & 1 deletion ugbio_utils
Submodule ugbio_utils updated 141 files
129 changes: 74 additions & 55 deletions ugvc/comparison/quick_fingerprinter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from joblib import Parallel, delayed

from simppl.simple_pipeline import SimplePipeline
from ugbio_cloud_utils.cloud_sync import optional_cloud_sync
Expand All @@ -24,7 +25,14 @@ def __init__( # pylint: disable=too-many-arguments
add_aws_auth_command: bool,
out_dir: str,
sp: SimplePipeline,
csv_name: str = "quick_fingerprinting_results.csv",
n_jobs: int = -1,
):
"""
Initialize the QuickFingerprinter with sample CRAMs, ground truth VCFs, HCRs, reference, region,
filtering parameters, output directory, and pipeline object.
Prepares ground truth files for comparison.
"""
self.crams = sample_crams
self.ground_truth_vcfs = ground_truth_vcfs
self.hcrs = hcrs
Expand All @@ -36,23 +44,36 @@ def __init__( # pylint: disable=too-many-arguments
self.min_hit_fraction_target = min_hit_fraction_target
self.sp = sp
self.add_aws_auth_command = add_aws_auth_command
self.n_jobs = n_jobs
# Variant caller for hit fraction calculation
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instance attribute self.vc is now unused dead code

Low Severity

self.vc (a VariantHitFractionCaller instance) is created in __init__ but never referenced anywhere in the new code. The refactored _process_cram method creates its own local vc instance instead. This is leftover dead code from the old check() method that previously used self.vc.

Fix in Cursor Fix in Web

self.vc = VariantHitFractionCaller(self.ref, self.out_dir, self.sp, self.min_af_snps, region)
self.vpu = VcfUtils(self.sp)
# Output file for results (now CSV, with configurable name)
self.output_file = open(os.path.join(self.out_dir, csv_name), "w")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output file not closed on failures

Low Severity

self.output_file is opened in __init__ and only closed at the end of check(). If any command in prepare_ground_truth() or check() raises, cleanup is skipped and the file handle remains open. This removed the previous context-managed lifecycle and can leak descriptors across repeated runs.

Additional Locations (1)

Fix in Cursor Fix in Web


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output file opened before directory is created

Medium Severity

self.output_file = open(...) on line 52 is called before os.makedirs(out_dir, exist_ok=True) on line 54. If the output directory doesn't already exist, the open() call will raise a FileNotFoundError. The os.makedirs call needs to come first. The current single caller happens to pre-create the directory, masking this bug, but the class itself is broken for any other caller.

Fix in Cursor Fix in Web

os.makedirs(out_dir, exist_ok=True)

# Prepare ground truth VCFs for each sample
self.ground_truths_to_check = self.prepare_ground_truth()

def prepare_ground_truth(self):
"""
For each sample, prepare ground truth VCFs restricted to HCR and region.
Returns a dict mapping sample_id to the processed ground truth VCF.
"""
ground_truths_to_check = {}
# Create a BED file for the region of interest
self.sp.print_and_run(f"echo {self.region} | sed 's/:/\t/' | sed 's/-/\t/' > {self.out_dir}/region.bed")

for sample_id in self.ground_truth_vcfs:
# Sync ground truth VCF and HCR files from cloud if needed
ground_truth_vcf = optional_cloud_sync(self.ground_truth_vcfs[sample_id], self.out_dir)
hcr = optional_cloud_sync(self.hcrs[sample_id], self.out_dir)
ground_truth_in_hcr = f"{self.out_dir}/{sample_id}_ground_truth_snps_in_hcr.vcf.gz"
ground_truth_to_check_vcf = f"{self.out_dir}/{sample_id}_ground_truth_snps_to_check.vcf.gz"
hcr_in_region = f"{self.out_dir}/{sample_id}_hcr_in_region.bed"

# Intersect ground truth VCF with HCR, keep only SNPs
self.sp.print_and_run(
f"bedtools intersect -a {ground_truth_vcf} -b {hcr} -header | "
+ f"bcftools view --type snps -Oz -o {ground_truth_in_hcr}"
Expand All @@ -63,6 +84,7 @@ def prepare_ground_truth(self):
)
self.vpu.index_vcf(ground_truth_to_check_vcf)

# Prepare HCR BED file restricted to region
if self.region != "":
self.sp.print_and_run(
f"bedtools intersect -a {hcr} -b {self.out_dir}/region.bed | "
Expand All @@ -75,61 +97,58 @@ def prepare_ground_truth(self):
return ground_truths_to_check

def print(self, msg: str):
"""
Write a message to the output file.
"""
self.output_file.write(msg + "\n")

def _process_cram(self, sample_id: str, cram: str) -> list[str]:
max_hit_fraction = 0
best_match = None
cram_base_name = os.path.basename(cram)
called_vcf = f"{self.out_dir}/{cram_base_name}.calls.vcf.gz"
local_bam = f"{self.out_dir}/{cram_base_name}.bam"
vc = VariantHitFractionCaller(self.ref, self.out_dir, self.sp, self.min_af_snps, self.region)

if self.add_aws_auth_command:
self.sp.print_and_run(
f"eval $(aws configure export-credentials --format env-no-export) "
f"samtools view {cram} -T {self.ref} {self.region} -b -o {local_bam}"
)
else:
self.sp.print_and_run(f"samtools view {cram} -T {self.ref} {self.region} -b -o {local_bam}")

vc.call_variants(local_bam, called_vcf, self.region, min_af=self.min_af_snps)
mean_depth = vc.get_mean_depth(called_vcf)

hit_fraction_dict = {}
for ground_truth_id, ground_truth_to_check_vcf in self.ground_truths_to_check.items():
hit_fraction, _, _ = vc.calc_hit_fraction(called_vcf, ground_truth_to_check_vcf)
hit_fraction_dict[ground_truth_id] = hit_fraction
if hit_fraction > max_hit_fraction:
max_hit_fraction = hit_fraction
best_match = ground_truth_id

return [
f"{cram},{cram_base_name},{sample_id},{ground_truth_id},{hit_fraction},{best_match},{mean_depth}"
for ground_truth_id, hit_fraction in hit_fraction_dict.items()
]

def check(self):
errors = []
with open(f"{self.out_dir}/quick_fingerprinting_results.txt", "w", encoding="utf-8") as of:
self.output_file = of
for sample_id in self.crams:
self.print(f"Check consistency for {sample_id}:")
crams = self.crams[sample_id]
self.print(" crams = \n\t" + "\n\t".join(self.crams[sample_id]))
self.print(f" hcrs = {self.hcrs}")
self.print(f" ground_truth_vcfs = {self.ground_truth_vcfs}")

for cram in crams:
# Validate that each cram correlates to the ground-truth
self.print("")
hit_fractions = []
max_hit_fraction = 0
best_match = None
match_to_expected_truth = None
cram_base_name = os.path.basename(cram)

called_vcf = f"{self.out_dir}/{cram_base_name}.calls.vcf.gz"
local_bam = f"{self.out_dir}/{cram_base_name}.bam"
if self.add_aws_auth_command:
self.sp.print_and_run(
f"eval $(aws configure export-credentials --format env-no-export) \
samtools view {cram} -T {self.ref} {self.region} -b -o {local_bam}"
)
else:
self.sp.print_and_run(f"samtools view {cram} -T {self.ref} {self.region} -b -o {local_bam}")

self.vc.call_variants(local_bam, called_vcf, self.region, min_af=self.min_af_snps)

potential_error = f"{cram} - {sample_id} "
for ground_truth_id, ground_truth_to_check_vcf in self.ground_truths_to_check.items():
hit_fraction, _, _ = self.vc.calc_hit_fraction(called_vcf, ground_truth_to_check_vcf)
if hit_fraction > max_hit_fraction:
max_hit_fraction = hit_fraction
best_match = ground_truth_id
hit_fractions.append(hit_fraction)
if sample_id == ground_truth_id and hit_fraction < self.min_hit_fraction_target:
match_to_expected_truth = hit_fraction
potential_error += f"does not match it's ground truth: hit_fraction={hit_fraction} "
elif sample_id != ground_truth_id and hit_fraction > self.min_hit_fraction_target:
potential_error += (
f"matched ground truth of {ground_truth_id}: hit_fraction={hit_fraction} "
)
self.print(f"{cram} - {sample_id} vs. {ground_truth_id} hit_fraction={hit_fraction}")
if best_match != sample_id:
if match_to_expected_truth is None:
self.print(f"{cram} best_match={best_match} hit_fraction={max_hit_fraction}")
else:
potential_error += f"max_hit_fraction = {max(hit_fractions)}"
if potential_error != f"{cram} - {sample_id} ":
errors.append(potential_error)
if len(errors) > 0:
raise RuntimeError("\n".join(errors))
"""
For each sample and each CRAM, call variants and compare to all ground truth VCFs.
Print a CSV table with cram filename, sample_id, ground_truth_id, hit_fraction, and best_match.
"""
# Print CSV header
self.print("full_path,cram_filename,sample_id,ground_truth_id,hit_fraction,best_match,mean_depth")

for sample_id in self.crams:
crams = self.crams[sample_id]
rows_by_cram = Parallel(n_jobs=self.n_jobs, prefer="threads")(
delayed(self._process_cram)(sample_id, cram) for cram in crams
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shared SimplePipeline used unsafely from parallel threads

High Severity

_process_cram is executed in parallel threads via joblib.Parallel, but each thread calls self.sp.print_and_run() on the same shared SimplePipeline instance. SimplePipeline maintains an internal command index for sequential execution with fc/lc range control and is not designed for concurrent thread access. The library provides its own run_parallel method for parallelism. Concurrent print_and_run calls will corrupt the command counter, potentially causing commands to be skipped or misordered.

Additional Locations (1)

Fix in Cursor Fix in Web

for rows in rows_by_cram:
for row in rows:
self.print(row)

self.output_file.close()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fingerprint check no longer fails mismatches

High Severity

check() now only writes CSV rows and never evaluates min_hit_fraction_target or raises errors for incorrect matches. As a result, sample/ground-truth mismatches no longer fail the run, so quick_fingerprinter.py can report successful completion even when fingerprinting is inconsistent.

Fix in Cursor Fix in Web

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a functionality I want

13 changes: 13 additions & 0 deletions ugvc/comparison/variant_hit_fraction_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def calc_hit_fraction(self, called_vcf: str, ground_truth_vcf: str) -> tuple[flo
gt_base_name = os.path.basename(ground_truth_vcf).replace(".vcf.gz", "")
ground_truth_variants = get_vcf_df(ground_truth_vcf)
called_variants = get_vcf_df(called_vcf)
if "dp" not in called_variants.columns:
called_variants["dp"] = pd.NA
called_variants["dp"] = pd.to_numeric(called_variants["dp"], errors="coerce")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused dp column processing in calc_hit_fraction

Low Severity

The dp column handling (checking for existence, adding if missing, converting to numeric) was added to calc_hit_fraction but the dp column is never used within this method. This logic is only meaningful in get_mean_depth, where an identical block appears. It looks like copy-paste dead code.

Fix in Cursor Fix in Web

# keep only major alt
called_variants['major_alt'] = called_variants.apply(lambda x: x["alleles"][1], axis=1)
ground_truth_variants['major_alt'] = ground_truth_variants.apply(lambda x: x["alleles"][1], axis=1)
Expand All @@ -48,6 +51,16 @@ def calc_hit_fraction(self, called_vcf: str, ground_truth_vcf: str) -> tuple[flo

return hit_fraction, hit_count, ground_truth_count

def get_mean_depth(self, called_vcf: str) -> float:
called_variants = get_vcf_df(called_vcf)
if "dp" not in called_variants.columns:
called_variants["dp"] = pd.NA
called_variants["dp"] = pd.to_numeric(called_variants["dp"], errors="coerce")
mean_depth = called_variants["dp"].mean(skipna=True)
if pd.isna(mean_depth):
return 0.0
return round(float(mean_depth), 1)

def count_lines(self, in_file: str, out_file: str):
self.sp.print_and_run(f"wc -l {in_file} " + "| awk '{print $1}' " + f" > {out_file}")

Expand Down
29 changes: 27 additions & 2 deletions ugvc/pipelines/comparison/quick_fingerprinting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@


def __get_parser() -> argparse.ArgumentParser:
"""
Create and return the argument parser for the script.
Adds arguments for configuration, region, AWS auth, output directory, and variant caller options.
"""
parser = argparse.ArgumentParser(
prog="quick fingerprinting (finding sample identity of crams, " "given a known list of ground-truth files)",
description=run.__doc__,
Expand All @@ -32,37 +36,55 @@ def __get_parser() -> argparse.ArgumentParser:
action="store_true",
help="add aws auth command to samtools commands"
)

# Add arguments from VariantHitFractionCaller
VariantHitFractionCaller.add_args_to_parser(parser)
parser.add_argument("--out_dir", type=str, required=True, help="output directory")
parser.add_argument(
"--results_file_name",
type=str,
default="quick_fingerprinting_results.csv",
help="output CSV filename (default: quick_fingerprinting_results.csv)"
)
parser.add_argument(
"--n_jobs",
type=int,
default=-1,
help="number of parallel jobs for CRAM processing (-1 uses all available cores)",
)
return parser


def run(argv):
"""quick fingerprinting to identify known samples in crams"""
parser = __get_parser()
# Add pipeline-specific arguments
SimplePipeline.add_parse_args(parser)
args = parser.parse_args(argv[1:])

# Load configuration from JSON file
with open(args.json_conf, encoding="utf-8") as fh:
conf = json.load(fh)

# Sync reference files from cloud if needed
ref = optional_cloud_sync(conf["references"]["ref_fasta"], args.out_dir)
optional_cloud_sync(conf["references"]["ref_dict"], args.out_dir)
optional_cloud_sync(conf["references"]["ref_fasta_index"], args.out_dir)
cram_files_list = conf["cram_files"]
ground_truth_vcf_files = conf["ground_truth_vcf_files"] # dict sample-id -> bed
hcr_files = conf["ground_truth_hcr_files"] # dict sample-id -> bed

# Extract region and variant calling parameters
region = args.region_str
min_af_snps = args.min_af_snps
min_af_germline_snps = args.min_af_germline_snps
min_hit_fraction_target = args.min_hit_fraction_target

# Initialize the pipeline
sp = SimplePipeline(args.fc, args.lc, debug=args.d)
os.makedirs(args.out_dir, exist_ok=True)
errors = []

# Run the quick fingerprinting check
QuickFingerprinter(
cram_files_list,
ground_truth_vcf_files,
Expand All @@ -74,9 +96,12 @@ def run(argv):
min_hit_fraction_target,
args.add_aws_auth_command,
args.out_dir,
sp
sp,
csv_name=args.results_file_name,
n_jobs=args.n_jobs,
).check()

# Raise errors if any occurred
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead errors list never populated in run function

Low Severity

The errors list on line 85 is initialized but never appended to. The old check() method internally raised RuntimeError on fingerprinting mismatches, but the new check() only outputs CSV without any error reporting. The if len(errors) > 0 check on line 105 is dead code that gives a false impression that errors are being caught.

Fix in Cursor Fix in Web

if len(errors) > 0:
raise RuntimeError("\n".join(errors))

Expand Down
Loading