Skip to content

WIP: Add normalize parameter to sliced_wasserstein_distance#808

Open
Harguna wants to merge 1 commit into
PythonOT:masterfrom
Harguna:feature/normalize-sliced-wasserstein
Open

WIP: Add normalize parameter to sliced_wasserstein_distance#808
Harguna wants to merge 1 commit into
PythonOT:masterfrom
Harguna:feature/normalize-sliced-wasserstein

Conversation

@Harguna
Copy link
Copy Markdown

@Harguna Harguna commented Apr 29, 2026

Types of changes

  • New Feature

Motivation and context / Related issue

Addresses #807.

Sliced Wasserstein Distance is sensitive to feature scale: features with larger numerical ranges dominate the random projections, drowning out meaningful differences in smaller-scale features. Users often don't realize this is happening and, when they do, the manual fix (preprocessing inputs with a scaler) is verbose and easy to get wrong — fitting each distribution independently silently corrupts the distance.

This PR adds optional normalize and normalize_mode parameters to sliced_wasserstein_distance and max_sliced_wasserstein_distance to handle this cleanly inside the function. Default behavior (normalize=None) is unchanged, so the change is fully backward-compatible.

This is a [WIP] skeleton PR - it establishes the API surface, signatures, docstrings, and a helper function so the design can be reviewed before the full implementation lands. The actual normalization math, edge case handling, behavioral tests, and example script will follow in subsequent commits on this same branch.

How has this been tested (if it applies)

In this skeleton:

  • Existing test/test_sliced.py test suite continues to pass (verifies the new keyword parameters didn't break anything).
  • pre-commit run --all-files passes locally.

Tests related to the new feature will be added with the full implementation in the subsequent commits.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@rflamary
Copy link
Copy Markdown
Collaborator

Thanks for this PR, we are a bit busy at te moement and will have more time to give some feedback after the neurips deadline in two weeks.

@Harguna
Copy link
Copy Markdown
Author

Harguna commented May 3, 2026

Sounds good, best of luck with the NeurIPS.

@rflamary
Copy link
Copy Markdown
Collaborator

rflamary commented May 12, 2026

Hello @Harguna , thanks for the PR.

We had a look with @clbonet and we are not really comfortable with having a normalization inside the sliced wasserstein function. While this might make sens in some applications it also means that for instance when optimizing the SWD, the loss between two optimization steps or minibtach is not comparable (since normalized locally) which poses a practical problem because it is an intuitive behavior and leads to different minimizers.

But we agree with you that normalization should be easier to handle. So we propose to handle it in a slightly different way as follows :

scaler = ot.utils.DataScaler(norm='standard').fit([X_s,X_t]) # can take a tensor or a list for joint normalization
swd = ot.sliced_wasserstein_distance(X_s, X_t, scaler=scaler)

this means that the normalization is fitted outside on a class (compatible with sklearn with a fit and transform function but that handles backends). The scaler parameter should also accept a function (detected with __call__ and can apply it so this would allow pytorch pre-processing pipeline or models). I thinks we need a helper function

def apply_scaler(X_s, X_t, scaler=None)

that handles the preprocessing of the data (or not if scaler=None) so that we can add this API to other functions in POT such as ot.solve_sample.

Would you be OK with implementing our suggestions?

@Harguna
Copy link
Copy Markdown
Author

Harguna commented May 13, 2026

Hello @rflamary,

Thanks for the detailed feedback, this makes sense. I had accounted for the relative shift between X_s and X_t between optimization steps for the same batch by fitting normalization statistics jointly on concat(X_s, X_t), but you're right that this doesn't address inter-batch variations, which would destabilize the objective during gradient-based training.

I agree with your suggested design which decouples the fitting step from the distance computation. I'm happy to implement DataScaler with backend compatibility and apply_scaler as a standalone helper so it can be reused across other POT functions. I will get started on that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants