diff options
author | Ehsan Nasiri <ehsann@google.com> | 2016-11-24 15:37:22 -0500 |
---|---|---|
committer | David Neto <dneto@google.com> | 2016-11-28 11:44:10 -0500 |
commit | bd5b0bfca1bf067563c9e03a7c30c3590189b5c0 (patch) | |
tree | 999983b1cfece234b7ed8c504bcbf086605183cf | |
parent | f72189c249ba143c6a89a4cf1e7d53337b2ddd40 (diff) | |
download | spirv-tools-bd5b0bfca1bf067563c9e03a7c30c3590189b5c0.tar.gz |
Checks that result IDs are within the ID bound specified in the SPIR-V header
This is described in Section 2.17 of the SPIR-V Spec.
* Updated existing unit test 'SemanticsIdIsAnIdNotALiteral' to pass by
manipulating the ID bound in its binary header.
* Fixed boundary check in the code.
* Added unit test to check the case that the largest ID is equal to the
ID bound.
-rw-r--r-- | source/val/validation_state.cpp | 3 | ||||
-rw-r--r-- | source/val/validation_state.h | 9 | ||||
-rw-r--r-- | source/validate.cpp | 7 | ||||
-rw-r--r-- | source/validate_instruction.cpp | 18 | ||||
-rw-r--r-- | test/val/CMakeLists.txt | 6 | ||||
-rw-r--r-- | test/val/val_capability_test.cpp | 6 | ||||
-rw-r--r-- | test/val/val_fixtures.cpp | 8 | ||||
-rw-r--r-- | test/val/val_fixtures.h | 6 | ||||
-rw-r--r-- | test/val/val_limits_test.cpp | 78 |
9 files changed, 137 insertions, 4 deletions
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 16622ff6..c69749cb 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -425,4 +425,7 @@ void ValidationState_t::RegisterSampledImageConsumer(uint32_t sampled_image_id, sampled_image_consumers_[sampled_image_id].push_back(consumer_id); } +uint32_t ValidationState_t::getIdBound() const { return id_bound_; } + +void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; } } /// namespace libspirv diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 0647e7ef..7e69a9d0 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -78,6 +78,12 @@ class ValidationState_t { /// the OpName instruction std::string getIdName(uint32_t id) const; + /// Accessor function for ID bound. + uint32_t getIdBound() const; + + /// Mutator function for ID bound. + void setIdBound(uint32_t bound); + /// Like getIdName but does not display the id if the \p id has a name std::string getIdOrName(uint32_t id) const; @@ -227,6 +233,9 @@ class ValidationState_t { /// IDs that are entry points, ie, arguments to OpEntryPoint. std::vector<uint32_t> entry_points_; + /// ID Bound from the Header + uint32_t id_bound_; + AssemblyGrammar grammar_; SpvAddressingModel addressing_model_; diff --git a/source/validate.cpp b/source/validate.cpp index 402bb109..45ae5a45 100644 --- a/source/validate.cpp +++ b/source/validate.cpp @@ -69,14 +69,15 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInsts, namespace { // TODO(umar): Validate header -// TODO(umar): The Id bound should be validated also. But you can only do that -// after you've seen all the instructions in the module. // TODO(umar): The binary parser validates the magic word, and the length of the // header, but nothing else. spv_result_t setHeader(void* user_data, spv_endianness_t endian, uint32_t magic, uint32_t version, uint32_t generator, uint32_t id_bound, uint32_t reserved) { - (void)user_data; + // Record the ID bound so that the validator can ensure no ID is out of bound. + ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data)); + _.setIdBound(id_bound); + (void)endian; (void)magic; (void)version; diff --git a/source/validate_instruction.cpp b/source/validate_instruction.cpp index bcab49ad..976980eb 100644 --- a/source/validate_instruction.cpp +++ b/source/validate_instruction.cpp @@ -131,6 +131,17 @@ spv_result_t CapCheck(ValidationState_t& _, return SPV_SUCCESS; } +// Checks that the Resuld <id> is within the valid bound. +spv_result_t LimitCheckIdBound(ValidationState_t& _, + const spv_parsed_instruction_t* inst) { + if (inst->result_id >= _.getIdBound()) { + return _.diag(SPV_ERROR_INVALID_BINARY) + << "Result <id> '" << inst->result_id + << "' must be less than the ID bound '" << _.getIdBound() << "'."; + } + return SPV_SUCCESS; +} + spv_result_t InstructionPass(ValidationState_t& _, const spv_parsed_instruction_t* inst) { const SpvOp opcode = static_cast<SpvOp>(inst->opcode); @@ -169,6 +180,11 @@ spv_result_t InstructionPass(ValidationState_t& _, } } } - return CapCheck(_, inst); + + if (auto error = CapCheck(_, inst)) return error; + if (auto error = LimitCheckIdBound(_, inst)) return error; + + // All instruction checks have passed. + return SPV_SUCCESS; } } // namespace libspirv diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt index b00abc35..5ac13025 100644 --- a/test/val/CMakeLists.txt +++ b/test/val/CMakeLists.txt @@ -69,3 +69,9 @@ add_spvtools_unittest(TARGET val_data LIBS ${SPIRV_TOOLS} ) +add_spvtools_unittest(TARGET val_limits + SRCS val_limits_test.cpp + ${VAL_TEST_COMMON_SRCS} + LIBS ${SPIRV_TOOLS} +) + diff --git a/test/val/val_capability_test.cpp b/test/val/val_capability_test.cpp index f590ba86..2604187e 100644 --- a/test/val/val_capability_test.cpp +++ b/test/val/val_capability_test.cpp @@ -1227,6 +1227,12 @@ OpFunctionEnd )"; CompileSuccessfully(str); + + // Since we are forcing usage of <id> 64, the "id bound" in the binary header + // must be overwritten so that <id> 64 is considered within bound. + // ID Bound is at index 3 of the binary. Set it to 65. + OverwriteAssembledBinary(3, 65); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } diff --git a/test/val/val_fixtures.cpp b/test/val/val_fixtures.cpp index 3845b30e..2db99a89 100644 --- a/test/val/val_fixtures.cpp +++ b/test/val/val_fixtures.cpp @@ -54,6 +54,14 @@ void ValidateBase<T>::CompileSuccessfully(std::string code, } template <typename T> +void ValidateBase<T>::OverwriteAssembledBinary(uint32_t index, uint32_t word) { + ASSERT_TRUE(index < binary_->wordCount) + << "OverwriteAssembledBinary: The given index is larger than the binary " + "word count."; + binary_->code[index] = word; +} + +template <typename T> spv_result_t ValidateBase<T>::ValidateInstructions(spv_target_env env) { return spvValidate(ScopedContext(env).context, get_const_binary(), &diagnostic_); diff --git a/test/val/val_fixtures.h b/test/val/val_fixtures.h index faab41a7..bb4fe187 100644 --- a/test/val/val_fixtures.h +++ b/test/val/val_fixtures.h @@ -35,6 +35,12 @@ class ValidateBase : public ::testing::Test, void CompileSuccessfully(std::string code, spv_target_env env = SPV_ENV_UNIVERSAL_1_0); + // Overwrites the word at index 'index' with the given word. + // For testing purposes, it is often useful to be able to manipulate the + // assembled binary before running the validator on it. + // This function overwrites the word at the given index with a new word. + void OverwriteAssembledBinary(uint32_t index, uint32_t word); + // Performs validation on the SPIR-V code and compares the result of the // spvValidate function spv_result_t ValidateInstructions(spv_target_env env = SPV_ENV_UNIVERSAL_1_0); diff --git a/test/val/val_limits_test.cpp b/test/val/val_limits_test.cpp new file mode 100644 index 00000000..97bca1df --- /dev/null +++ b/test/val/val_limits_test.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for Universal Limits. (Section 2.17 of the SPIR-V Spec) + +#include <sstream> +#include <string> +#include <utility> + +#include "gmock/gmock.h" +#include "unit_spirv.h" +#include "val_fixtures.h" + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; + +using std::string; + +using ValidateLimits = spvtest::ValidateBase<bool>; + +string header = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 +)"; + +TEST_F(ValidateLimits, idLargerThanBoundBad) { + string str = header + R"( +; %i32 has ID 1 +%i32 = OpTypeInt 32 1 +%c = OpConstant %i32 100 + +; Fake an instruction with 64 as the result id. +; !64 = OpConstantNull %i32 +!0x3002e !1 !64 +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Result <id> '64' must be less than the ID bound '3'.")); +} + +TEST_F(ValidateLimits, idEqualToBoundBad) { + string str = header + R"( +; %i32 has ID 1 +%i32 = OpTypeInt 32 1 +%c = OpConstant %i32 100 + +; Fake an instruction with 64 as the result id. +; !64 = OpConstantNull %i32 +!0x3002e !1 !64 +)"; + + CompileSuccessfully(str); + + // The largest ID used in this program is 64. Let's overwrite the ID bound in + // the header to be 64. This should result in an error because all IDs must + // satisfy: 0 < id < bound. + OverwriteAssembledBinary(3, 64); + + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Result <id> '64' must be less than the ID bound '64'.")); +} + |