Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding statistical metric #1959

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

dralston78
Copy link
Contributor

Checklist

  • My pull request has a clear and explanatory title.
  • If necessary, my code is vectorized.
  • I added appropriate unit tests.
  • I made sure the code passes all unit tests. (refer to comment below)
  • My PR follows PEP8 guidelines. (refer to comment below)
  • My PR follows geomstats coding style and API.
  • My code is properly documented and I made sure the documentation renders properly. (Link)
  • I linked to issues and PRs that are relevant to this PR.

Description

I added a first attempt at a StatisticalMetric class. Given a divergence, a metric tensor, connection, dual connection, and Amari cubic tensor can be induced to create a statistical manifold, a special type of Riemannian manifold with extra structure.

Additional context

This is just a first stab at how to code this structure within the Geomstats library, to share with collaborators. Because of this, the PR is marked as a draft. The documentation needs to be reformatted.

Also, for some reason the StatisticalMetric class only works with the autograd back end -- I had a lot of trouble trying to get it to work with the pytorch backend.

Looking forward to discussing here/on Slack!

Copy link

codecov bot commented Feb 17, 2024

Codecov Report

Attention: Patch coverage is 31.81818% with 30 lines in your changes are missing coverage. Please review.

Project coverage is 91.12%. Comparing base (ef340b9) to head (3d35ab2).
Report is 211 commits behind head on main.

❗ Current head 3d35ab2 differs from pull request most recent head d9e22a3. Consider uploading reports for the commit d9e22a3 to get more accurate results

Files Patch % Lines
...omstats/information_geometry/statistical_metric.py 31.82% 30 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1959      +/-   ##
==========================================
- Coverage   91.51%   91.12%   -0.38%     
==========================================
  Files         149      147       -2     
  Lines       13624    13733     +109     
==========================================
+ Hits        12466    12513      +47     
- Misses       1158     1220      +62     
Flag Coverage Δ
autograd ?
numpy 89.11% <31.82%> (-1.09%) ⬇️
pytorch 85.59% <31.82%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@ninamiolane ninamiolane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! I made some cosmetic comments on the statistical_metric.py, but I think the main ideas are perfectly in place.

@luisfpereira can you have a look at the tests' code?


class StatisticalMetric(RiemannianMetric):
"""
Defines statistical metric and connection induced from a divergence.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Respect docstring conventions laid out here: https://geomstats.github.io/contributing/index.html#writing-docstrings

Eg: Imperativ form for verbs --> Define not Defines
Put verb directly after """

Defines statistical metric and connection induced from a divergence.

Uses definitions provided in Nielson's An Elementary Introduction to
Information Geometry, Theorem 4 on page 15 (https://arxiv.org/abs/1808.08271)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring missing parameter description (see other classes' docstring in geomstats)

self.dim = dim
self.divergence = self._unpack_tensor(divergence)

def _unpack_tensor(self, func):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please write docstring even for private functions


return wrapper

def metric_matrix(self, base_point):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring missing parameter descriptions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give reference of the paper you mention in each docstring, using

References

(see in other places in geomstats for how References is used)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, instead of giving the page where we should look for the equation, just write the equation in the docstring, and start the docstring by r"""

--> on the website, the docstring will show the equation in latex!

Eg. https://geomstats.github.io/api/geomstats.information_geometry.html#geomstats.information_geometry.fisher_rao_metric.FisherRaoMetric

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the docstring also explain the difference between this metric and the Fisher-Rao metric?

And maybe add a section in the docstring:

See Also

FisherRaoMetric

base_point_pair = gs.concatenate([base_point, base_point])
return -1 * hess(base_point_pair)[: self.dim, self.dim :]

def divergence_christoffels(self, base_point):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Collaborator

@ninamiolane ninamiolane Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the Christoffel symbols associated with the metric_matrix?

  • If yes, then they should just be called christoffels.
  • If not, then let's create class DivergenceConnection(Connection) in this file, with a christoffels method using the code that you have here, and have self.divergence_connection = DivergenceConnection(...) be an attribute of the StatisticalMetric class.

Doing it like this will allow us to have directly access to the geodesic equation associated with these christoffel symbols! See code below:

https://github.com/geomstats/geomstats/blob/main/geomstats/geometry/connection.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe they are not, they are the Riemannian Christoffell symbols of the two conjugate connections. Agree with the suggested naming convention of class DivergenceConnection(Connection), to be changed from DivergenceConjugate(Connection).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I think here the main point is that any of these connections is also an (affine) connection when alone, so we definitely need to break this class into two (e.g. DivergenceConnection, DualDivergenceConnection). Then, by using christoffels as method name, we will have access to all the machinery already implemented in connection (e.g. geodesic_equation, exp, log, etc), as @ninamiolane suggested.

Then, we can always create appropriate objects that combine these (affine) connections (not by inheritance, but by composition). For example, the AlphaConnection can be created by:

class AlphaConnection(Connection):
    def __init__(
        self,
        space,
        primal_connection,
        dual_connection,
        alpha=0,
    ):
        super().__init__(space)
        self.primal_connection = primal_connection
        self.dual_connection = dual_connection
        self.alpha = alpha 

    def christoffels(self, point):
        # see eq. 53
        first_term = (1 + self.alpha) / 2 * self.primal_connection.christoffels(point)
        second_term = (1 - self.alpha) / 2 * self.dual_connection.christoffels(point)

        return first_term - second_term

Alternatively, we could create the AlphaConnection by using the concepts of mean connection (which is itself a connection resulting from a combination of connections) and Amari-Chentsov tensor (eq. 52).

We probably still need to iterate on this, but I believe this AlphaConnection or some kind of ConjugateConnectionPair is the central object here, as e.g. the Amari-Chentsov tensor or other quantities required to define the statistical manifold can be computed from this pair.

For the divergence case we can always create a DivergenceAlphaConnection/DivergenceInducedConnection later (for usability) that creates the necessary connections given a divergence.

The main message is that there's a lot of ways of defining all of this and we probably need to keep the idea of composing objects in the back of our mind all the time.

Copy link
Contributor Author

@dralston78 dralston78 Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I thought it might be important for the divergence induced connection and the dual divergence induced connection to be in the same class is that they both need to be induced by the same divergence, in order for them to have the property of being conjugate connections.

But I also see what you mean -- with each of these connections the user should be able to create geodesics and use the other functionality of connections, so I agree it makes sense to split them up. Is there some way of ensuring that when one is instantiated, the other is automatically instantiated with the same divergence, or that they are linked in some way? Maybe have the AlphaConnection take in a divergence, and then instantiate DivergenceConnection and DualDivergenceConnection in the __init__ method of the AlphaConnection class?

base_point_pair = gs.concatenate([base_point, base_point])
return -1 * jac_hess(base_point_pair)[:2, :2, 2:]

def dual_divergence_christoffels(self, base_point):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, a creating a class DualDivergenceConnection(Connection) for this will make it consistent then.

Copy link

@psuarezserrato psuarezserrato left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job @dralston78 , it's coming together nicely. Needs some help from @luisfpereira to check the tests pass.

base_point_pair = gs.concatenate([base_point, base_point])
return -1 * hess(base_point_pair)[: self.dim, self.dim :]

def divergence_christoffels(self, base_point):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe they are not, they are the Riemannian Christoffell symbols of the two conjugate connections. Agree with the suggested naming convention of class DivergenceConnection(Connection), to be changed from DivergenceConjugate(Connection).

base_point_pair = gs.concatenate([base_point, base_point])
return -1 * jac_hess(base_point_pair)[:2, :2, 2:]

def dual_divergence_christoffels(self, base_point):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, a creating a class DualDivergenceConnection(Connection) for this will make it consistent then.


Delegate attribute access to the divergence conjugate connection.

Instanciated to avoid dimond inheritance problem with RiemannianMetric

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo , Instantiated.

tensor : array-like, shape=[..., dim, dim, dim]
Amari divergence tensor.
"""
divergence_christoffels = self.divergence_christoffels(base_point)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update to new naming conventions, using classes

Copy link
Collaborator

@luisfpereira luisfpereira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check my comment regarding AlphaConnection (we can speak about tests after the design is closed). @ninamiolane, @psuarezserrato, please take also a look to this and let me know what you think.

self.dim = space.dim
self.divergence = self._unpack_inputs(divergence)

def _unpack_inputs(self, func):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create it as a function (which also receives dim) in order to reuse it in StatisticalMetric.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, this should just be a function outside of any class that is used throughout the script?

Copy link
Collaborator

@luisfpereira luisfpereira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do a new iteration on this @dralston78.

geomstats/information_geometry/statistical_metric.py Outdated Show resolved Hide resolved
geomstats/information_geometry/statistical_metric.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@luisfpereira luisfpereira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this iteration, let's tackle this more involved one @dralston78. (Please let me know if you disagree with something.) Names are not important for now and the ones I'm suggesting need to be improved.

The goal is to tackle the case of the statistical manifold induced by a Bregman generator (I'll be focusing on the potential function, and not in the divergence, though later we can look to things from the divergence perspective - so no need to delete existing objects relying on the divergence).

1 Potential function

I suggest to create a class called PotentialFunction with methods __call__, hessian (eq 63), third_derivatives (better naming?) (eq 65).

Reasoning:

  • allow users to hard-code derivatives (e.g. easy to do for gs.sum(point**4))
  • avoid making direct calls to autodiff when coding within connections/metrics (derivatives handled in one place)
  • easier to test and ensure correctness

By default, we can implement those methods using autodiff. At init we pass the potential_function callable so that we don't have to inherit every time we want to create a new potential function for which we don't want to hard-code derivatives.

2 Dummy flat-connection

Create a dummy-flat connection with null (first-kind?) Christoffels (eq 64). (Maybe we don't need to explicitly create this object; living here to recall the potential function-connection is flat)

3 Levi-Civita connection coming from potential function-metric tensor

Similarly to already-existing StatisticalMetric, create StatisticalConnectionFromPotentialFunction inheriting from RiemannianMetric calling potential function hessian to define metric_matrix and implementing the totally symmetric cubic tensor $C$ (should we call it cubic_tensor or amari_chentsov_tensor? if mathematically correct, I would call it amari_chentsov_tensor; also change in StatisticalMetric) given by the third derivatives of the potential function (eq 65) (as in the metric tensor case, simply call method from potential function class).

Notice that after this implementation, we have (more than) a fully-fledged Levi-Civita connection with geodesics. We also have flat geodesics from the "dummy" connection (2).

(I guess now is a good time to create first_kind_christoffels in RiemannianMetric)

4 Alpha-connection from statistical metric

Create an alpha-connection class that takes a statistical metric (3) and computes the alpha-connection Christoffels using equation (50). Notice that though this is generic, it can be later used to get the conjugate connection of the potential function-connection (thinking about this result $\left({ }^F \nabla\right)^*={ }^F \nabla^1$). Since it is flat, we can use it to validate our implementation.

Remarks

All these objects will 1) provide users with different ways of defining similar objects, so that they can choose the best for their use case, 2) give us different independent ways of defining things, which is fantastic for testing.

Do not care much about interfaces: for the inits of these new objects, assume users are passing always the same potential function. We will iterate on this next time.

P.S. I may be missing something, please be critical with the suggestions.

@dralston78
Copy link
Contributor Author

Hi Luis, as I start to get working on this next iteration, here are some comments/questions I have on your suggestions.

Overall comments:

  • When I start to make these changes, should they be in a separate pull request and a separate branch? Do we want to merge the current PR and then start a fresh one for the next iteration?

  • In this next iteration we will only consider Bregman-divergence generated metrics/connections. Does this mean that we will return to a more general structure where we consider geometric structure from any divergence? Or for the rest of the project will we only consider Bregman-divergence? I think that the structure from a general divergence is much more powerful: for example, I have a test showing that we can recover the fisher information metric from the KL divergence. Also, I was just reading about something called a Log-divergence, the structure induced by this divergence can have geometric consequences in portfolio theory.

1 Potential Function

  • Should this class inherit from any base class?

3 LC connection from potential function

  • Should we also include a dual_connection method in this class? Note, this dual_connection might not be flat (with the same coordinate system on the manifold)!

4 Alpha Connection from statistical metric

  • Very minor, but it seems that really this will be not only a connection structure, but more broadly a metric structure. What if we call it AlphaMetric?

@luisfpereira
Copy link
Collaborator

@dralston78,

When I start to make these changes, should they be in a separate pull request and a separate branch?

In this next iteration we will only consider Bregman-divergence generated metrics/connections. Does this mean that we will return to a more general structure where we consider geometric structure from any divergence? Or for the rest of the project will we only consider Bregman-divergence?

Before merging this one, what I would like to have is the (general) geometric structure induced by any divergence AND a particular implementation for the Bregman-divergence that takes advantage of the fact it results from a potential function.

The idea is 2-fold. On one hand, we show we can make computations with general divergences, and, if they have particular structures we can take advantage of, specialized implementations. On the other hand, we can use the specialized implementations to test the general ones. Specialized implementations are also expected to be faster.
In this particular case, I would like to test the statistical manifold implemented using the potential function against the statistical manifold implemented using the divergence.

Should this class [Potential function] inherit from any base class?

At some point our goal is to create an abstract class from which it inherits from. We can start with just a particular class to check if the idea is worth pursuing. (If this works nicely, I think I'll ask you a similar object for the divergence)

Should we also include a dual_connection method in this class [LC connection from potential function]?

Probably. Though we should notice the dual connection in this case comes from the alpha-connection, so we need to be careful instantiating the objects. (But for compatibility with the other StatisticalManifold, we should probably have it; we can add it later though)

Very minor, but it seems that really this will be not only a connection structure, but more broadly a metric structure. What if we call it AlphaMetric?

It is a good point. Though I would not say it has a metric structure, but more like it requires the space to be endowed with a metric structure (i.e. it makes use of an existing structure). It is very subtle, but I would ask you soon a change on the computation of the Christoffels that is justified by this. Do you agree with the distinction?

Copy link
Collaborator

@luisfpereira luisfpereira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick review of the PotentialFunction.


class PotentialFunction:
def __init__(self, potential_function):
self.potential_function = potential_function
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make this private, as this is just a way of creating this object without having to inherit. An alternative would be to directly implement PotentialFunction as class and override __call__ method.

Comment on lines 430 to 431
self.hessian_function = self.hessian()
self.third_derivative_function = self.third_derivative()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete? having the methods is enough.

self.hessian_function = self.hessian()
self.third_derivative_function = self.third_derivative()

def __call__(self, x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will need to find a better name for x.

def __call__(self, x):
return self.potential_function(x)

def hessian(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

receive x and return result, not function. (Imagine we make an non autodiff-based implementation of the hessian)

def hessian(self):
return gs.autodiff.hessian(self.potential_function)

def third_derivative(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for hessian: don't mind repeating gs.autodiff.hessian(self.potential_function) again.

dralston78 and others added 2 commits April 23, 2024 09:43
…l__ method to be defined. Derivative methods are now functions of points, rather than operators.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants