Skip to content

Conversation

ymwangg
Copy link
Contributor

@ymwangg ymwangg commented Dec 10, 2022

We recently found the torch_xla lowering of torch.sigmoid is not numerically stable on GPU. One common use-case of torch.sigmoid is to force the output value to be within [0,1].
For example, the following code failed with nan loss because x = -5.9604645e-08.

x = torch.sigmoid(torch.tensor([-16.740633],device=device))
y = torch.tensor([1.0],device=device)
print(torch.nn.functional.binary_cross_entropy(x,y)) # print tensor(nan, device='xla:1')

Are there any special reasons for torch_xla to use sigmoid(x) = 0.5+0.5*tanh(0.5*x) instead of sigmoid(x) = 1 / (1 + exp(-x))?

@ymwangg
Copy link
Contributor Author

ymwangg commented Dec 12, 2022

@ymwangg ymwangg changed the title [Draft] Improve numerical stability of torch.sigmoid Improve numerical stability of torch.sigmoid Dec 12, 2022
@JackCaoG
Copy link
Collaborator

I have a feeling that it might be because sigmoid(x) = 0.5+0.5*tanh(0.5*x) is faster.. let me double check.

@ymwangg
Copy link
Contributor Author

ymwangg commented Dec 13, 2022

Yes, the tanh implementation is slightly faster on GPU.
Using the following script:

x = torch.rand(1000000000,device=device)
xm.mark_step()
t0 = time.time()
for _ in range(100):
    for _ in range(100):
        y = torch.sigmoid(x)
    xm.mark_step()
t1 = time.time()
print(t1-t0)

I'm getting 1.2621409893035889 with tanh implementation (with clamp) and 1.301847219467163 with normal implementation.

If we want to keep the tanh implementation, one way is to wrap it with xla::Clamp(zero, half + half * xla::Tanh(half * input), one).

@JackCaoG
Copy link
Collaborator

I talked with Blake. Speed was the main reason we used tanh and TPU does not have this numerical instability issue. He suggested us to lower sigmod using XlaOp Logistic(XlaOp operand); which will have different TPU and GPU implementation in the backend to handle the subtle difference in accelerators.

@ymwangg
Copy link
Contributor Author

ymwangg commented Dec 14, 2022

Updated and thanks for the info. I just realize xla::Logistic is equivalent to torch.sigmoid.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@JackCaoG JackCaoG added the lowering ATen Operation lowering label Dec 14, 2022
@JackCaoG JackCaoG merged commit 453aa65 into pytorch:master Dec 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
lowering ATen Operation lowering xla:gpu
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants