Skip to content

Commit fcc90d4

Browse files
committed
Merge branch 'main' into hops
2 parents 91c30a8 + fea4f8f commit fcc90d4

1 file changed

Lines changed: 15 additions & 7 deletions

File tree

accelforge/mapper/FFM/_join_pmappings/join_pmappings.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def prune_with_tolerance(
149149
objective_tolerance: float,
150150
resource_usage_tolerance: float,
151151
print_progress: bool = True,
152+
is_last: bool = False,
152153
):
153154
if objective_tolerance == 0 and resource_usage_tolerance == 0:
154155
return pmappings
@@ -175,6 +176,9 @@ def prune(einsum_name: EinsumName, pg: PmappingGroup):
175176
result[einsum_name].append(pg)
176177

177178
new_n = sum(len(pg.mappings) for p in result.values() for pg in p)
179+
if new_n == prev_n and not is_last:
180+
return None
181+
178182
if print_progress:
179183
print(f"Dirty joining uses {new_n / prev_n * 100:.2f}% of the pmappings")
180184

@@ -215,7 +219,11 @@ def join_strategy_2(
215219
objective_tolerance=threshold,
216220
resource_usage_tolerance=resource_usage_tolerance,
217221
print_progress=print_progress,
222+
is_last=i == len(thresholds) - 1,
218223
)
224+
if cur_compressed is None:
225+
continue
226+
219227
joined = join_pmappings(
220228
cur_compressed,
221229
spec,
@@ -940,13 +948,13 @@ def no_match_lookahead_error(
940948
# f"\tCombining {sum(len(s) for s in left.values())}({len(left)}) x {sum(len(s) for s in right.values())}({len(right)}) -> {len(combined)}"
941949
# )
942950

943-
# nmappings = sum(len(s.mappings.data) for s in combined)
944-
# for_einsum_text = f"for Einsum {right_einsum}"
945-
# logger.info(f"\tNumber of groups {for_einsum_text}: {len(combined)}")
946-
# # for c in combined:
947-
# # print(f"\t\t{c.compatibility}")
948-
# logger.info(f"\tNumber of mappings {for_einsum_text}: {nmappings}")
949-
# logger.info(
951+
nmappings = sum(len(s.mappings.data) for s in combined)
952+
for_einsum_text = f"for Einsum {right_einsum}"
953+
# print(f"\tNumber of groups {for_einsum_text}: {len(combined)}")
954+
# for c in combined:
955+
# print(f"\t\t{c.compatibility}")
956+
# print(f"\tNumber of mappings {for_einsum_text}: {nmappings}")
957+
# print(
950958
# f"\tMappings per group {for_einsum_text}: {nmappings / len(combined)}"
951959
# )
952960
# logger.info(

0 commit comments

Comments
 (0)