From 0ac1da2d25e3a0b7dedeb5f2ec9df7f1ae6c983d Mon Sep 17 00:00:00 2001 From: Alessandro Pasqui Date: Wed, 13 May 2026 15:30:50 +0200 Subject: [PATCH] Add Sinkhorn implementation for vertex relabeling after T1s --- src/vertax/opt.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/vertax/opt.py b/src/vertax/opt.py index cecf1f4..d28fbd6 100644 --- a/src/vertax/opt.py +++ b/src/vertax/opt.py @@ -383,6 +383,54 @@ def inner_opt( return (vt_f, ht_f, ft_f), final_L_list +# ----------------------- # +# Sinkhorn implementation # +# ----------------------- # + +# def _periodic_sq_cost_matrix( +# vertTable_after: Array, +# vertTable_target: Array, +# width: float, +# height: float, +# ) -> Array: +# """Pairwise periodic squared distance, shape (n, n).""" +# sim = vertTable_after[:, :2] +# tgt = vertTable_target[:, :2] +# diff = sim[:, None, :] - tgt[None, :, :] # (n, n, 2) +# shifts = jnp.array( +# [ +# [0.0, 0.0], +# [-width, 0.0], [width, 0.0], +# [0.0, -height], [0.0, height], +# [-width, -height], [-width, height], +# [width, -height], [width, height], +# ] +# ) # (9, 2) +# shifted = diff[:, :, None, :] - shifts[None, None, :, :] # (n, n, 9, 2) +# sq = jnp.sum(shifted * shifted, axis=-1) # (n, n, 9) +# return jnp.min(sq, axis=-1) # (n, n) +# +# from ott.geometry import geometry +# from ott.problems.linear import linear_problem +# from ott.solvers.linear import sinkhorn +# def _build_t1_repair_perm_sinkhorn( +# vertTable_after: Array, +# vertTable_target: Array, +# width: float, +# height: float, +# epsilon: float = 1e-3, +# threshold: float = 1e-3, +# max_iterations: int = 1000, +# ) -> Array: +# """Vertex relabeling by entropic OT (ott-jax Sinkhorn).""" +# C = _periodic_sq_cost_matrix(vertTable_after, vertTable_target, width, height) +# geom = geometry.Geometry(cost_matrix=C, epsilon=epsilon) +# prob = linear_problem.LinearProblem(geom) +# solver = sinkhorn.Sinkhorn(threshold=threshold, max_iterations=max_iterations) +# out = solver(prob) +# return jnp.argmax(out.matrix, axis=1).astype(jnp.int32) + + def _periodic_sq_dist(p: Array, q: Array, width: float, height: float) -> Array: """Squared distance under PBC between two 2D points (cols 0,1).