aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAntonio Sanchez <cantonios@google.com>2023-01-03 14:22:31 -0800
committerTensorFlow Release Automation <jenkins@tensorflow.org>2023-02-03 20:10:00 +0000
commit4c49ff8d07cb0851e63d6ee4a1e344a0922cee50 (patch)
treed82d04800f949615bbf683a50235e38f99448d77
parent5968b6b37ee986ad563d9bae2a995aae8c9f6bea (diff)
downloadtensorflow-upstream-r2.11-0bf8d115393.tar.gz
Fix sparse tensor to CSR batch index OOB error.upstream-r2.11-0bf8d115393
Added a check for the batch index. PiperOrigin-RevId: 499315915
-rw-r--r--tensorflow/core/kernels/sparse/BUILD1
-rw-r--r--tensorflow/core/kernels/sparse/kernels.cc6
-rw-r--r--tensorflow/core/kernels/sparse/kernels_test.cc25
3 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/sparse/BUILD b/tensorflow/core/kernels/sparse/BUILD
index 8be732152fa..96aba2faef9 100644
--- a/tensorflow/core/kernels/sparse/BUILD
+++ b/tensorflow/core/kernels/sparse/BUILD
@@ -106,6 +106,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:testlib",
+ "//tensorflow/core/platform:status_matchers",
"//third_party/eigen3",
],
)
diff --git a/tensorflow/core/kernels/sparse/kernels.cc b/tensorflow/core/kernels/sparse/kernels.cc
index 4cf375cc317..5be8a0dba3f 100644
--- a/tensorflow/core/kernels/sparse/kernels.cc
+++ b/tensorflow/core/kernels/sparse/kernels.cc
@@ -75,6 +75,12 @@ Status SparseTensorToCSRSparseMatrixCPUFunctor::operator()(
} else { // rank == 3
for (int64_t i = 0; i < total_nnz; ++i) {
const int cur_batch = indices(i, 0);
+ if (cur_batch < 0 || cur_batch >= batch_size) {
+ return errors::InvalidArgument("Batch index ", cur_batch,
+ " is outside of valid batch range [", 0,
+ ", ", batch_size, ")");
+ }
+
// For now, the rows pointers store the corresponding row counts.
csr_row_ptr(cur_batch * (num_rows + 1) + indices(i, 1) + 1) += 1;
csr_col_ind(i) = indices(i, 2);
diff --git a/tensorflow/core/kernels/sparse/kernels_test.cc b/tensorflow/core/kernels/sparse/kernels_test.cc
index 39ab6fa011b..d7aee439215 100644
--- a/tensorflow/core/kernels/sparse/kernels_test.cc
+++ b/tensorflow/core/kernels/sparse/kernels_test.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tsl/platform/errors.h"
+#include "tensorflow/tsl/platform/status_matchers.h"
namespace tensorflow {
namespace {
@@ -73,6 +75,29 @@ TEST(SparseTensorToCSRSparseMatrix, BatchConversion) {
test::ExpectTensorEqual<int32>(csr_col_ind, test::AsTensor<int32>({0, 3, 1}));
}
+TEST(SparseTensorToCSRSparseMatrix, InvalidBatchThrowsIllegalArgument) {
+ const auto indices =
+ test::AsTensor<int64_t>({0, 0, 0, //
+ 4, 2, 3, // Batch out of bounds.
+ 2, 0, 1},
+ TensorShape({3, 3}));
+ Tensor batch_ptr(DT_INT32, {4});
+ Tensor csr_col_ind(DT_INT32, {3});
+ // row pointers have size = batch_size * (num_rows + 1) = 3 * 4 = 12
+ Tensor csr_row_ptr(DT_INT32, {12});
+ test::FillFn<int32>(&csr_row_ptr, [](int unused) { return 0; });
+
+ functor::SparseTensorToCSRSparseMatrixCPUFunctor coo_to_csr;
+ EXPECT_THAT(
+ coo_to_csr(3 /* batch_size */, 3 /* num_rows */,
+ indices.template matrix<int64_t>(), batch_ptr.vec<int32>(),
+ csr_row_ptr.vec<int32>(), csr_col_ind.vec<int32>()),
+ tsl::testing::StatusIs(
+ tsl::error::Code::INVALID_ARGUMENT,
+ ::testing::ContainsRegex(
+ "Batch index .* is outside of valid batch range")));
+}
+
} // namespace
} // namespace tensorflow