Skip to content

Commit a3b3df3

Browse files
committed
pytorch 2.6 weights_only compatibility
1 parent 039593b commit a3b3df3

3 files changed

Lines changed: 10 additions & 6 deletions

File tree

src/simba_plus/model_prox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def training_step(self, batch, batch_idx):
622622
t0 = time.time()
623623
if "peak" in batch.node_types:
624624
pid = batch["peak"].n_id.cpu()
625-
herit_loss_value = self.herit_loss_lam * self.herit_loss(
625+
herit_loss_value = self.herit_loss(
626626
mu_dict["peak"],
627627
pid,
628628
)

src/simba_plus/post_training/factors.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,10 @@ def summarize_enrichments(gene_enrichment, gene_bot_enrichment):
144144

145145
def main(args, logger=None):
146146
if not logger:
147-
logger = setup_logging(
148-
"simba+factor", log_dir=os.path.dirname(args.adata_prefix)
149-
)
147+
logger = setup_logging("simba+factor", log_dir=args.adata_prefix)
150148
# Loading pretrained results as `simba+ heritability` takes long time to run
151149
if args.output_dir is None:
152-
args.output_dir = f"{os.path.dirname(args.adata_prefix)}/factors/"
150+
args.output_dir = f"{args.adata_prefix}/factors/"
153151
os.makedirs(args.output_dir, exist_ok=True)
154152

155153
adata_C = sc.read_h5ad(f"{args.adata_prefix}/adata_C{args.version_suffix}.h5ad")
@@ -174,6 +172,7 @@ def main(args, logger=None):
174172
plt.close(fig_bot)
175173
sc.set_figure_params(vector_friendly=True)
176174
fig = sc.pl.umap(adata_C, color=args.cell_type_label, return_fig=True)
175+
fig.axes[0].set_box_aspect(1)
177176
pdf.savefig(fig, bbox_inches="tight")
178177
plt.close(fig)
179178
factor_labels = adata_G.uns["factor_enrichments_summary"]
@@ -182,6 +181,9 @@ def main(args, logger=None):
182181
return_fig=True,
183182
factor_labels=factor_labels,
184183
)
184+
for i, ax in enumerate(factor_umap.axes):
185+
if i % 2 == 0:
186+
ax.set_box_aspect(1)
185187
pdf.savefig(factor_umap, bbox_inches="tight")
186188
plt.close(factor_umap)
187189
logger.info(f"Generated factor report in {output_filename}.")

src/simba_plus/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def setup_logging(checkpoint_dir):
5454

5555
# Create timestamped log file
5656
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
57-
log_file = os.path.join(checkpoint_dir, f"train_{timestamp}.log")
57+
log_file = os.path.join(checkpoint_dir, f"simba+train_{timestamp}.log")
5858

5959
# Configure root logger
6060
logger = logging.getLogger()
@@ -329,6 +329,8 @@ def train(
329329
if load_checkpoint
330330
else None
331331
),
332+
weights_only=False,
333+
332334
)
333335
return checkpoint_callback.last_model_path
334336

0 commit comments

Comments
 (0)