diff options
author | Michael Delorimier <mdel@google.com> | 2024-02-08 13:02:44 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2024-02-08 13:10:22 -0800 |
commit | 080a510e51891aa3058b6b260ad549d44e8c55a0 (patch) | |
tree | 1d00c2fb7676f73e56d4b701a31fcd870f8b11fb | |
parent | e3e054ec763eca896d732c9499fb5d1e1d9194d4 (diff) | |
download | tensorflow-080a510e51891aa3058b6b260ad549d44e8c55a0.tar.gz |
Hoist reads that are broadcast into replicas. This pass moves ReadVariable above tf_device.replicate. Only ReadVariables that have device type CPU are hoisted, because we know these are broadcasts. Reads after writes are not hoisted.
PiperOrigin-RevId: 605403556
7 files changed, 305 insertions, 1 deletions
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir b/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir new file mode 100644 index 00000000000..c88310443ed --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir @@ -0,0 +1,68 @@ +// RUN: tf-opt %s -split-input-file -tf-hoist-broadcast-read | FileCheck %s + +// The read should be hoisted. + +// CHECK-LABEL: func @hoist_cpu +func.func @hoist_cpu(%arg0: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () { + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp" + // CHECK-NEXT: tf_device.replicate + // CHECK-NEXT: "tf.OpA"(%[[READ]]) + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32> + "tf.OpA"(%0) : (tensor<f32>) -> () + } + func.return +} + +// ----- + +// The read should not be hoisted because the resource does not have device type CPU. + +// CHECK-LABEL: func @only_hoist_cpu +func.func @only_hoist_cpu(%arg0: tensor<*x!tf_type.resource<tensor<f32>>>) -> () { + // CHECK: tf_device.replicate + // CHECK-NEXT: "tf.ReadVariableOp" + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32> + "tf.OpA"(%0) : (tensor<f32>) -> () + } + func.return +} + +// ----- + +// The read should not be hoisted because it follows a write. + +// CHECK-LABEL: func @skip_read_after_write +func.func @skip_read_after_write(%arg0: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () { + // CHECK: tf_device.replicate + // CHECK: "tf.AssignVariableOp" + // CHECK-NEXT: "tf.ReadVariableOp" + tf_device.replicate {n = 2 : i32} { + %0 = "tf.OpA"() : () -> tensor<f32> + "tf.AssignVariableOp"(%arg0, %0) : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<f32>) -> () + %1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32> + "tf.OpB"(%1) : (tensor<f32>) -> () + } + func.return +} + +// ----- + +// Check that hoisting preserves read order. + +// CHECK-LABEL: func @order_preserved +func.func @order_preserved(%arg0: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource<tensor<f32>>>, %arg2: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () { + // CHECK: %[[READ0:.*]] = "tf.ReadVariableOp"(%arg0) + // CHECK-NEXT: %[[READ2:.*]] = "tf.ReadVariableOp"(%arg2) + // CHECK-NEXT: tf_device.replicate + // CHECK-NEXT: %[[READ1:.*]] = "tf.ReadVariableOp"(%arg1) + // CHECK-NEXT: "tf.OpA"(%[[READ0]], %[[READ1]], %[[READ2]]) + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32> + %1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32> + %2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32> + "tf.OpA"(%0, %1, %2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> () + } + func.return +} diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index ed04adc6a39..f00df151321 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -143,6 +143,8 @@ void AddReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm, pm.addNestedPass<FuncOp>(mlir::TFDevice::CreateClusterConstantSinkingPass()); pm.addPass(mlir::TF::CreateResourceDeviceInferencePass()); + pm.addNestedPass<FuncOp>( + tensorflow::tf2xla::internal::CreateHoistBroadcastReadPass()); pm.addPass(mlir::TFDevice::CreateClusterOutliningPass()); pm.addPass(mlir::TFTPU::CreateTPUResourceReadForWritePass()); pm.addPass(mlir::TFDevice::CreateMarkInputOutputAliasesPass()); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc index c9cc5a4d1df..756ec42f312 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc @@ -28,7 +28,7 @@ TEST(ClusteringBridgePassesTest, AddsBridgePasses) { OpPassManager pass_manager; AddReplicatedBridgeClusteringPipelinePasses(pass_manager); - EXPECT_EQ(pass_manager.size(), 43); + EXPECT_EQ(pass_manager.size(), 44); } TEST(ClusteringBridgePassesTest, AddsNonTPUBridgePasses) { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index 6250f2cf0ca..df25fdd9ffe 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -26,6 +26,7 @@ cc_library( deps = [ ":extract_head_tail_outside_compilation", ":extract_outside_compilation", + ":hoist_broadcast_read", ":mark_ops_for_outside_compilation", ":tpu_cluster_formation", ":verify_clustering_pass", @@ -351,6 +352,41 @@ cc_library( ], ) +cc_library( + name = "hoist_broadcast_read", + srcs = ["hoist_broadcast_read.cc"], + textual_hdrs = [ + "clustering_passes.h.inc", + ], + deps = [ + ":clustering_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:string_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", + "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + tf_cc_test( name = "tpu_cluster_formation_test", srcs = ["tpu_cluster_formation_test.cc"], diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h index ea6187a2309..3ccf990a4b5 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h @@ -56,6 +56,11 @@ CreateXlaOutlineEntryFunctionsPass(); std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateMarkOpsForOutsideCompilationPass(); +// Creates a pass that hoists reads out of a replicate that are on a variable +// whose value is broacast to all replicas. +std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> +CreateHoistBroadcastReadPass(); + #define GEN_PASS_REGISTRATION #define GEN_PASS_DECL_MARKOPSFOROUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td index c219c35842c..90d2e962bc9 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td @@ -349,3 +349,42 @@ def MarkOpsForOutsideCompilationPass : Pass<"tf-mark-ops-for-outside-compilation let constructor = "tensorflow::tf2xla::internal::CreateMarkOpsForOutsideCompilationPass()"; } + +def HoistBroadcastReadPass : Pass<"tf-hoist-broadcast-read", "mlir::func::FuncOp"> { + let summary = "Hoist reads out of a replicate that are on a resource that is broacast to all replicas."; + + let description = [{ + Some `ReadVariableOp`s that are within a `tf_device.replicate` read the same + value across all replicas. These reads can be hoisted out of the + `tf_device.replicate` so there's one read for all replicas, and each replica + depends on the result of the read. This transform enables the + xla-broadcast-pass to optimize the broadcast value. + + For example, the following: + + ```mlir + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32> + "tf.OpA"(%0) : (tensor<f32>) -> () + } + ``` + + will be transformed into: + + ``mlir + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32> + tf_device.replicate {n = 2 : i32} { + "tf.OpA"(%0) : (tensor<f32>) -> () + } + ``` + + We must ensure that there is a single underlying resource that not + distributed across replicas. There is a single underlying resource when the + resource device type is CPU, so we cautiously only apply in this case. + + To be cautious we never hoist a read that comes after a write to the same + resource. + }]; + + let constructor = "tensorflow::tf2xla::internal::CreateHoistBroadcastReadPass()"; +} diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc new file mode 100644 index 00000000000..732bae8c67b --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc @@ -0,0 +1,154 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <string> + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +namespace { + +using mlir::BlockArgument; +using mlir::failure; +using mlir::LogicalResult; +using mlir::Operation; +using mlir::OperationPass; +using mlir::OpOperand; +using mlir::StringAttr; +using mlir::success; +using mlir::Value; +using mlir::WalkResult; +using mlir::func::FuncOp; +using mlir::TF::ReadVariableOp; +using mlir::tf_device::ReplicateOp; + +#define GEN_PASS_DEF_HOISTBROADCASTREADPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" + +constexpr char kFuncDeviceAttr[] = "tf.device"; +constexpr char kCpuDeviceType[] = "CPU"; + +struct HoistBroadcastRead + : public impl::HoistBroadcastReadPassBase<HoistBroadcastRead> { + void runOnOperation() override; +}; + +// Get the ancestor of `descendant` that is a direct child of `ancestor`. +Operation* GetAncestorBelow(Operation* descendant, Operation* ancestor) { + Operation* parent = descendant->getParentOp(); + if (!parent) return nullptr; + if (parent == ancestor) return descendant; + return GetAncestorBelow(parent, ancestor); +} + +// `is_cpu_read` is set to `true` iff `read` is on a resource with device type +// CPU. +LogicalResult IsCpuRead(FuncOp func, ReadVariableOp read, bool& is_cpu_read) { + if (auto arg = read->getOperand(0).dyn_cast<BlockArgument>()) { + if (arg.getOwner() != &(func.front())) { + is_cpu_read = false; + return success(); + } + if (auto attr = func.getArgAttrOfType<StringAttr>(arg.getArgNumber(), + kFuncDeviceAttr)) { + std::string device = attr.getValue().str(); + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(device, &parsed_name)) { + return read->emitOpError() << "invalid device '" << device << "'"; + } + is_cpu_read = parsed_name.type == kCpuDeviceType; + return success(); + } + } + is_cpu_read = false; + return success(); +} + +// Get the reads to hoist in the `replicate`. +LogicalResult GetReads(FuncOp func, ReplicateOp replicate, + llvm::SmallVector<ReadVariableOp, 4>& reads) { + for (Operation& op : replicate.getBody().front()) { + if (auto read = llvm::dyn_cast<ReadVariableOp>(&op)) { + bool is_cpu_read; + if (failed(IsCpuRead(func, read, is_cpu_read))) return failure(); + if (is_cpu_read) reads.push_back(read); + } + } + return success(); +} + +// Move reads above the `replicate`. Skip reads that come after a write to the +// same resource. +void MoveReads(ReplicateOp replicate, + llvm::SmallVector<ReadVariableOp, 4>& reads) { + for (ReadVariableOp read : reads) { + Value res = read.getResource(); + Operation* scope = res.getParentBlock()->getParentOp(); + if (!scope->isProperAncestor(replicate)) continue; + bool has_conflicting_write = false; + for (OpOperand& use : res.getUses()) { + Operation* using_op = use.getOwner(); + if (using_op == read) continue; + if (!replicate->isProperAncestor(using_op)) continue; + Operation* peer = GetAncestorBelow(using_op, replicate); + if (read->isBeforeInBlock(peer)) continue; + if (llvm::isa<ReadVariableOp>(peer)) continue; + has_conflicting_write = true; + } + if (has_conflicting_write) continue; + read->moveBefore(replicate); + } +} + +// Hoist `ReadVariableOp`s above the `tf_device.replicate`s. +void HoistBroadcastRead::runOnOperation() { + FuncOp func = getOperation(); + + auto result = func.walk([&](ReplicateOp replicate) { + llvm::SmallVector<ReadVariableOp, 4> reads; + if (failed(GetReads(func, replicate, reads))) + return WalkResult::interrupt(); + MoveReads(replicate, reads); + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) return signalPassFailure(); +} + +} // namespace + +std::unique_ptr<OperationPass<FuncOp>> CreateHoistBroadcastReadPass() { + return std::make_unique<HoistBroadcastRead>(); +} + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow |