Skip to content

Incorrect branch handling during derivative computation for Max function #445

@akhadke-bdai

Description

@akhadke-bdai

Describe the bug
I am attempting to generate PyTorch code for computing the derivative/jacobian for the following function $w = \sqrt{\max(0, x^2 - z^2)}$. The expected derivative/jacobian for this function is as follows

$\frac{\partial w}{\partial x}$ if $x^2 < z^2$
undefined if $x^2 = z^2$
$\frac{\partial w}{\partial x} = \frac{x}{\sqrt{x^2 - z^2}}$ if $x^2 > z^2$

However, symforce codegen computes the following
$\left(\frac{\partial w}{\partial x}\right)_{\text{codegen}} = \frac{x}{2} \left(\frac{sign(x^2 - z^2) + 1}{\sqrt{\max(0, x^2 - z^2)}}\right)$
which leads to nan values when evaluating any $x: x^2 < z^2$

To Reproduce
My script to generate the symbolic code

import os
import symforce

symforce.set_symbolic_api("symengine")
symforce.set_log_level("warning")
symforce.set_epsilon_to_symbol()

import symforce.symbolic as se
from symforce import codegen
from symforce.codegen.backends.pytorch.pytorch_config import PyTorchConfig
from symforce.values import Values

def gen_py(inputs, outputs, name):
    gen = codegen.Codegen(
        inputs=inputs,
        outputs=outputs,
        config=PyTorchConfig(),
        name=name,
    )
    data = gen.generate_function()

    # Print what we generated
    print("Files generated in {}:\n".format(data.output_dir))
    for f in data.generated_files:
        print("  |- {}".format(os.path.relpath(f, data.output_dir)))

def main():
    x = se.Symbol('x')
    z = se.Symbol('z')
    y = se.Max(se.Scalar(0.0), (x**2 - z**2))
    w = se.sqrt(y)

    dwdx = [se.diff(w, x), se.diff(w, z)]

    inputs = Values(x=se.Matrix([x, z]))
    outputs = Values(dwdx=se.Matrix(dwdx))
    gen_py(inputs=inputs, outputs=outputs, name='dwdx')

if __name__=='__main__':
    main()

Expected behavior
See issue description

Environment (please complete the following information):

  • OS and version: Ubuntu 22.04.5 LTS
  • Python version 3.10.12
  • SymForce Version 0.10.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions