diff options
author | Antonio Sanchez <cantonios@google.com> | 2023-01-03 14:22:31 -0800 |
---|---|---|
committer | TensorFlow Release Automation <jenkins@tensorflow.org> | 2023-02-03 20:10:00 +0000 |
commit | 4c49ff8d07cb0851e63d6ee4a1e344a0922cee50 (patch) | |
tree | d82d04800f949615bbf683a50235e38f99448d77 | |
parent | 5968b6b37ee986ad563d9bae2a995aae8c9f6bea (diff) | |
download | tensorflow-upstream-r2.11-0bf8d115393.tar.gz |
Fix sparse tensor to CSR batch index OOB error.upstream-r2.11-0bf8d115393
Added a check for the batch index.
PiperOrigin-RevId: 499315915
-rw-r--r-- | tensorflow/core/kernels/sparse/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/sparse/kernels.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/sparse/kernels_test.cc | 25 |
3 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/sparse/BUILD b/tensorflow/core/kernels/sparse/BUILD index 8be732152fa..96aba2faef9 100644 --- a/tensorflow/core/kernels/sparse/BUILD +++ b/tensorflow/core/kernels/sparse/BUILD @@ -106,6 +106,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:testlib", + "//tensorflow/core/platform:status_matchers", "//third_party/eigen3", ], ) diff --git a/tensorflow/core/kernels/sparse/kernels.cc b/tensorflow/core/kernels/sparse/kernels.cc index 4cf375cc317..5be8a0dba3f 100644 --- a/tensorflow/core/kernels/sparse/kernels.cc +++ b/tensorflow/core/kernels/sparse/kernels.cc @@ -75,6 +75,12 @@ Status SparseTensorToCSRSparseMatrixCPUFunctor::operator()( } else { // rank == 3 for (int64_t i = 0; i < total_nnz; ++i) { const int cur_batch = indices(i, 0); + if (cur_batch < 0 || cur_batch >= batch_size) { + return errors::InvalidArgument("Batch index ", cur_batch, + " is outside of valid batch range [", 0, + ", ", batch_size, ")"); + } + // For now, the rows pointers store the corresponding row counts. csr_row_ptr(cur_batch * (num_rows + 1) + indices(i, 1) + 1) += 1; csr_col_ind(i) = indices(i, 2); diff --git a/tensorflow/core/kernels/sparse/kernels_test.cc b/tensorflow/core/kernels/sparse/kernels_test.cc index 39ab6fa011b..d7aee439215 100644 --- a/tensorflow/core/kernels/sparse/kernels_test.cc +++ b/tensorflow/core/kernels/sparse/kernels_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/status_matchers.h" namespace tensorflow { namespace { @@ -73,6 +75,29 @@ TEST(SparseTensorToCSRSparseMatrix, BatchConversion) { test::ExpectTensorEqual<int32>(csr_col_ind, test::AsTensor<int32>({0, 3, 1})); } +TEST(SparseTensorToCSRSparseMatrix, InvalidBatchThrowsIllegalArgument) { + const auto indices = + test::AsTensor<int64_t>({0, 0, 0, // + 4, 2, 3, // Batch out of bounds. + 2, 0, 1}, + TensorShape({3, 3})); + Tensor batch_ptr(DT_INT32, {4}); + Tensor csr_col_ind(DT_INT32, {3}); + // row pointers have size = batch_size * (num_rows + 1) = 3 * 4 = 12 + Tensor csr_row_ptr(DT_INT32, {12}); + test::FillFn<int32>(&csr_row_ptr, [](int unused) { return 0; }); + + functor::SparseTensorToCSRSparseMatrixCPUFunctor coo_to_csr; + EXPECT_THAT( + coo_to_csr(3 /* batch_size */, 3 /* num_rows */, + indices.template matrix<int64_t>(), batch_ptr.vec<int32>(), + csr_row_ptr.vec<int32>(), csr_col_ind.vec<int32>()), + tsl::testing::StatusIs( + tsl::error::Code::INVALID_ARGUMENT, + ::testing::ContainsRegex( + "Batch index .* is outside of valid batch range"))); +} + } // namespace } // namespace tensorflow |