Skip to content

[XLA] Different JIT compile behavior from TF2.7 #55610

@wenscarl

Description

@wenscarl

For the customized code below, I have seen such a error at runtime when xla is turned on. This does NOT appear in TF2.7.
2022-04-13 19:49:36.873241: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at xla_ops.cc:436 : INVALID_ARGUMENT: Fail to proof the equality of two dimensions at compile time: %multiply.144 = s32[] multiply(s32[] %constant.142, s32[] %add.1), metadata={op_type="Reshape" op_name="Reshape_3"} vs %add = s32[] add(s32[] %reduce.109, s32[] %constant.17)

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution: Ubuntu 20.04.4 LTS
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version:2.8.0
  • Python version: 3.8
  • GCC/Compiler version (if compiling from source): gcc 10
  • CUDA/cuDNN version: 11.6
  • GPU model and memory: V100, 32G

Standalone code to reproduce the issue

import tensorflow as tf
import numpy as np
display_id_counter = tf.Variable(0, trainable=False, dtype=tf.float64)

@tf.function
def evaluation_step(x, y, predictions):
    dummy_loss = 0.9
    predictions = tf.reshape(predictions, [-1])
    predictions = tf.cast(predictions, tf.float64)
    display_ids = x
    display_ids = tf.reshape(display_ids, [-1])
    labels = tf.reshape(y, [-1])
    sorted_ids = tf.argsort(display_ids)
    display_ids = tf.gather(display_ids, indices=sorted_ids)
    predictions = tf.gather(predictions, indices=sorted_ids)
    labels = tf.gather(labels, indices=sorted_ids)
    _, display_ids_idx, display_ids_ads_count = tf.unique_with_counts(
        display_ids, out_idx=tf.int64)
    pad_length = 30 - tf.reduce_max(display_ids_ads_count)
    preds = tf.RaggedTensor.from_value_rowids(
        predictions, display_ids_idx).to_tensor()
    labels = tf.RaggedTensor.from_value_rowids(
        labels, display_ids_idx).to_tensor()
    labels_mask = tf.math.reduce_max(labels, 1)
    preds_masked = tf.boolean_mask(preds, labels_mask)
    labels_masked = tf.boolean_mask(labels, labels_mask)
    labels_masked = tf.argmax(labels_masked, axis=1, output_type=tf.int32)
    labels_masked = tf.reshape(labels_masked, [-1, 1])

    preds_masked = tf.pad(preds_masked, [(0, 0), (0, pad_length)])
    _, predictions_idx = tf.math.top_k(preds_masked, 12)
    indices = tf.math.equal(predictions_idx, labels_masked)

    shape = tf.cast(tf.shape(indices)[0], tf.float64)
    display_id_counter.assign_add(shape)

DIM = 102400
tf.config.optimizer.set_jit(True)
for step in range(200):
    pre = np.random.random((DIM, 1))
    y_tmp = np.zeros((DIM, 1), dtype=float)

    num_ones = np.random.randint(1, DIM+1, 1)
    id_one = np.random.randint(0, DIM, num_ones)
    for i in id_one:
        y_tmp[i][0] = 1.
    x_tmp = np.random.randint(0, DIM, (DIM, 1), dtype=np.int64)
    evaluation_step(x_tmp, y_tmp, pre)

Tracked down to commit ac4575.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions