-
Notifications
You must be signed in to change notification settings - Fork 250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding statistical metric #1959
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1959 +/- ##
==========================================
- Coverage 91.51% 91.12% -0.38%
==========================================
Files 149 147 -2
Lines 13624 13733 +109
==========================================
+ Hits 12466 12513 +47
- Misses 1158 1220 +62
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! I made some cosmetic comments on the statistical_metric.py, but I think the main ideas are perfectly in place.
@luisfpereira can you have a look at the tests' code?
|
||
class StatisticalMetric(RiemannianMetric): | ||
""" | ||
Defines statistical metric and connection induced from a divergence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Respect docstring conventions laid out here: https://geomstats.github.io/contributing/index.html#writing-docstrings
Eg: Imperativ form for verbs --> Define not Defines
Put verb directly after """
Defines statistical metric and connection induced from a divergence. | ||
|
||
Uses definitions provided in Nielson's An Elementary Introduction to | ||
Information Geometry, Theorem 4 on page 15 (https://arxiv.org/abs/1808.08271) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring missing parameter description (see other classes' docstring in geomstats)
self.dim = dim | ||
self.divergence = self._unpack_tensor(divergence) | ||
|
||
def _unpack_tensor(self, func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please write docstring even for private functions
|
||
return wrapper | ||
|
||
def metric_matrix(self, base_point): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring missing parameter descriptions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give reference of the paper you mention in each docstring, using
References
(see in other places in geomstats for how References is used)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, instead of giving the page where we should look for the equation, just write the equation in the docstring, and start the docstring by r"""
--> on the website, the docstring will show the equation in latex!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the docstring also explain the difference between this metric and the Fisher-Rao metric?
And maybe add a section in the docstring:
See Also
FisherRaoMetric
base_point_pair = gs.concatenate([base_point, base_point]) | ||
return -1 * hess(base_point_pair)[: self.dim, self.dim :] | ||
|
||
def divergence_christoffels(self, base_point): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these the Christoffel symbols associated with the metric_matrix?
- If yes, then they should just be called
christoffels
. - If not, then let's create
class DivergenceConnection(Connection)
in this file, with achristoffels
method using the code that you have here, and haveself.divergence_connection = DivergenceConnection(...)
be an attribute of the StatisticalMetric class.
Doing it like this will allow us to have directly access to the geodesic equation associated with these christoffel symbols! See code below:
https://github.com/geomstats/geomstats/blob/main/geomstats/geometry/connection.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe they are not, they are the Riemannian Christoffell symbols of the two conjugate connections. Agree with the suggested naming convention of class DivergenceConnection(Connection)
, to be changed from DivergenceConjugate(Connection)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I think here the main point is that any of these connections is also an (affine) connection when alone, so we definitely need to break this class into two (e.g. DivergenceConnection
, DualDivergenceConnection
). Then, by using christoffels
as method name, we will have access to all the machinery already implemented in connection (e.g. geodesic_equation
, exp
, log
, etc), as @ninamiolane suggested.
Then, we can always create appropriate objects that combine these (affine) connections (not by inheritance, but by composition). For example, the AlphaConnection
can be created by:
class AlphaConnection(Connection):
def __init__(
self,
space,
primal_connection,
dual_connection,
alpha=0,
):
super().__init__(space)
self.primal_connection = primal_connection
self.dual_connection = dual_connection
self.alpha = alpha
def christoffels(self, point):
# see eq. 53
first_term = (1 + self.alpha) / 2 * self.primal_connection.christoffels(point)
second_term = (1 - self.alpha) / 2 * self.dual_connection.christoffels(point)
return first_term - second_term
Alternatively, we could create the AlphaConnection
by using the concepts of mean connection (which is itself a connection resulting from a combination of connections) and Amari-Chentsov tensor (eq. 52).
We probably still need to iterate on this, but I believe this AlphaConnection
or some kind of ConjugateConnectionPair
is the central object here, as e.g. the Amari-Chentsov tensor or other quantities required to define the statistical manifold can be computed from this pair.
For the divergence case we can always create a DivergenceAlphaConnection
/DivergenceInducedConnection
later (for usability) that creates the necessary connections given a divergence.
The main message is that there's a lot of ways of defining all of this and we probably need to keep the idea of composing objects in the back of our mind all the time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why I thought it might be important for the divergence induced connection and the dual divergence induced connection to be in the same class is that they both need to be induced by the same divergence, in order for them to have the property of being conjugate connections.
But I also see what you mean -- with each of these connections the user should be able to create geodesics and use the other functionality of connections, so I agree it makes sense to split them up. Is there some way of ensuring that when one is instantiated, the other is automatically instantiated with the same divergence, or that they are linked in some way? Maybe have the AlphaConnection
take in a divergence, and then instantiate DivergenceConnection
and DualDivergenceConnection
in the __init__
method of the AlphaConnection
class?
base_point_pair = gs.concatenate([base_point, base_point]) | ||
return -1 * jac_hess(base_point_pair)[:2, :2, 2:] | ||
|
||
def dual_divergence_christoffels(self, base_point): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, a creating a class DualDivergenceConnection(Connection)
for this will make it consistent then.
tests/tests_geomstats/test_information_geometry/test_statistical_metric.py
Show resolved
Hide resolved
…all methods in StatisticalMetric and DivergenceConjugateConnection classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job @dralston78 , it's coming together nicely. Needs some help from @luisfpereira to check the tests pass.
base_point_pair = gs.concatenate([base_point, base_point]) | ||
return -1 * hess(base_point_pair)[: self.dim, self.dim :] | ||
|
||
def divergence_christoffels(self, base_point): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe they are not, they are the Riemannian Christoffell symbols of the two conjugate connections. Agree with the suggested naming convention of class DivergenceConnection(Connection)
, to be changed from DivergenceConjugate(Connection)
.
base_point_pair = gs.concatenate([base_point, base_point]) | ||
return -1 * jac_hess(base_point_pair)[:2, :2, 2:] | ||
|
||
def dual_divergence_christoffels(self, base_point): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, a creating a class DualDivergenceConnection(Connection)
for this will make it consistent then.
|
||
Delegate attribute access to the divergence conjugate connection. | ||
|
||
Instanciated to avoid dimond inheritance problem with RiemannianMetric |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo , Instantiated.
tensor : array-like, shape=[..., dim, dim, dim] | ||
Amari divergence tensor. | ||
""" | ||
divergence_christoffels = self.divergence_christoffels(base_point) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update to new naming conventions, using classes
tests/tests_geomstats/test_information_geometry/test_statistical_metric.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check my comment regarding AlphaConnection
(we can speak about tests after the design is closed). @ninamiolane, @psuarezserrato, please take also a look to this and let me know what you think.
self.dim = space.dim | ||
self.divergence = self._unpack_inputs(divergence) | ||
|
||
def _unpack_inputs(self, func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create it as a function (which also receives dim
) in order to reuse it in StatisticalMetric
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, this should just be a function outside of any class that is used throughout the script?
…herit from Connection. Also corrected confusion about first kind vs second kind christoffel symbols.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do a new iteration on this @dralston78.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this iteration, let's tackle this more involved one @dralston78. (Please let me know if you disagree with something.) Names are not important for now and the ones I'm suggesting need to be improved.
The goal is to tackle the case of the statistical manifold induced by a Bregman generator (I'll be focusing on the potential function, and not in the divergence, though later we can look to things from the divergence perspective - so no need to delete existing objects relying on the divergence).
1 Potential function
I suggest to create a class called PotentialFunction
with methods __call__
, hessian
(eq 63), third_derivatives
(better naming?) (eq 65).
Reasoning:
- allow users to hard-code derivatives (e.g. easy to do for
gs.sum(point**4)
) - avoid making direct calls to autodiff when coding within connections/metrics (derivatives handled in one place)
- easier to test and ensure correctness
By default, we can implement those methods using autodiff
. At init we pass the potential_function
callable so that we don't have to inherit every time we want to create a new potential function for which we don't want to hard-code derivatives.
2 Dummy flat-connection
Create a dummy-flat connection with null (first-kind?) Christoffels (eq 64). (Maybe we don't need to explicitly create this object; living here to recall the potential function-connection is flat)
3 Levi-Civita connection coming from potential function-metric tensor
Similarly to already-existing StatisticalMetric
, create StatisticalConnectionFromPotentialFunction
inheriting from RiemannianMetric
calling potential function hessian to define metric_matrix
and implementing the totally symmetric cubic tensor cubic_tensor
or amari_chentsov_tensor
? if mathematically correct, I would call it amari_chentsov_tensor
; also change in StatisticalMetric
) given by the third derivatives of the potential function (eq 65) (as in the metric tensor case, simply call method from potential function class).
Notice that after this implementation, we have (more than) a fully-fledged Levi-Civita connection with geodesics. We also have flat geodesics from the "dummy" connection (2).
(I guess now is a good time to create first_kind_christoffels
in RiemannianMetric
)
4 Alpha-connection from statistical metric
Create an alpha-connection class that takes a statistical metric (3) and computes the alpha-connection Christoffels using equation (50). Notice that though this is generic, it can be later used to get the conjugate connection of the potential function-connection (thinking about this result
Remarks
All these objects will 1) provide users with different ways of defining similar objects, so that they can choose the best for their use case, 2) give us different independent ways of defining things, which is fantastic for testing.
Do not care much about interfaces: for the inits of these new objects, assume users are passing always the same potential function. We will iterate on this next time.
P.S. I may be missing something, please be critical with the suggestions.
Hi Luis, as I start to get working on this next iteration, here are some comments/questions I have on your suggestions. Overall comments:
1 Potential Function
3 LC connection from potential function
4 Alpha Connection from statistical metric
|
Before merging this one, what I would like to have is the (general) geometric structure induced by any divergence AND a particular implementation for the Bregman-divergence that takes advantage of the fact it results from a potential function. The idea is 2-fold. On one hand, we show we can make computations with general divergences, and, if they have particular structures we can take advantage of, specialized implementations. On the other hand, we can use the specialized implementations to test the general ones. Specialized implementations are also expected to be faster.
At some point our goal is to create an abstract class from which it inherits from. We can start with just a particular class to check if the idea is worth pursuing. (If this works nicely, I think I'll ask you a similar object for the divergence)
Probably. Though we should notice the dual connection in this case comes from the alpha-connection, so we need to be careful instantiating the objects. (But for compatibility with the other
It is a good point. Though I would not say it has a metric structure, but more like it requires the space to be endowed with a metric structure (i.e. it makes use of an existing structure). It is very subtle, but I would ask you soon a change on the computation of the Christoffels that is justified by this. Do you agree with the distinction? |
…gence is Bregman divergence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick review of the PotentialFunction
.
|
||
class PotentialFunction: | ||
def __init__(self, potential_function): | ||
self.potential_function = potential_function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make this private, as this is just a way of creating this object without having to inherit. An alternative would be to directly implement PotentialFunction
as class and override __call__
method.
self.hessian_function = self.hessian() | ||
self.third_derivative_function = self.third_derivative() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete? having the methods is enough.
self.hessian_function = self.hessian() | ||
self.third_derivative_function = self.third_derivative() | ||
|
||
def __call__(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will need to find a better name for x
.
def __call__(self, x): | ||
return self.potential_function(x) | ||
|
||
def hessian(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
receive x
and return result, not function. (Imagine we make an non autodiff-based implementation of the hessian)
def hessian(self): | ||
return gs.autodiff.hessian(self.potential_function) | ||
|
||
def third_derivative(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as for hessian
: don't mind repeating gs.autodiff.hessian(self.potential_function)
again.
…l__ method to be defined. Derivative methods are now functions of points, rather than operators.
Checklist
Description
I added a first attempt at a
StatisticalMetric
class. Given a divergence, a metric tensor, connection, dual connection, and Amari cubic tensor can be induced to create a statistical manifold, a special type of Riemannian manifold with extra structure.Additional context
This is just a first stab at how to code this structure within the Geomstats library, to share with collaborators. Because of this, the PR is marked as a draft. The documentation needs to be reformatted.
Also, for some reason the
StatisticalMetric
class only works with theautograd
back end -- I had a lot of trouble trying to get it to work with thepytorch
backend.Looking forward to discussing here/on Slack!