aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Delorimier <mdel@google.com>2024-02-08 13:02:44 -0800
committerTensorFlower Gardener <gardener@tensorflow.org>2024-02-08 13:10:22 -0800
commit080a510e51891aa3058b6b260ad549d44e8c55a0 (patch)
tree1d00c2fb7676f73e56d4b701a31fcd870f8b11fb
parente3e054ec763eca896d732c9499fb5d1e1d9194d4 (diff)
downloadtensorflow-080a510e51891aa3058b6b260ad549d44e8c55a0.tar.gz
Hoist reads that are broadcast into replicas. This pass moves ReadVariable above tf_device.replicate. Only ReadVariables that have device type CPU are hoisted, because we know these are broadcasts. Reads after writes are not hoisted.
PiperOrigin-RevId: 605403556
-rw-r--r--tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir68
-rw-r--r--tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc2
-rw-r--r--tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc2
-rw-r--r--tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD36
-rw-r--r--tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h5
-rw-r--r--tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td39
-rw-r--r--tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc154
7 files changed, 305 insertions, 1 deletions
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir b/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir
new file mode 100644
index 00000000000..c88310443ed
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir
@@ -0,0 +1,68 @@
+// RUN: tf-opt %s -split-input-file -tf-hoist-broadcast-read | FileCheck %s
+
+// The read should be hoisted.
+
+// CHECK-LABEL: func @hoist_cpu
+func.func @hoist_cpu(%arg0: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () {
+ // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"
+ // CHECK-NEXT: tf_device.replicate
+ // CHECK-NEXT: "tf.OpA"(%[[READ]])
+ tf_device.replicate {n = 2 : i32} {
+ %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
+ "tf.OpA"(%0) : (tensor<f32>) -> ()
+ }
+ func.return
+}
+
+// -----
+
+// The read should not be hoisted because the resource does not have device type CPU.
+
+// CHECK-LABEL: func @only_hoist_cpu
+func.func @only_hoist_cpu(%arg0: tensor<*x!tf_type.resource<tensor<f32>>>) -> () {
+ // CHECK: tf_device.replicate
+ // CHECK-NEXT: "tf.ReadVariableOp"
+ tf_device.replicate {n = 2 : i32} {
+ %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
+ "tf.OpA"(%0) : (tensor<f32>) -> ()
+ }
+ func.return
+}
+
+// -----
+
+// The read should not be hoisted because it follows a write.
+
+// CHECK-LABEL: func @skip_read_after_write
+func.func @skip_read_after_write(%arg0: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () {
+ // CHECK: tf_device.replicate
+ // CHECK: "tf.AssignVariableOp"
+ // CHECK-NEXT: "tf.ReadVariableOp"
+ tf_device.replicate {n = 2 : i32} {
+ %0 = "tf.OpA"() : () -> tensor<f32>
+ "tf.AssignVariableOp"(%arg0, %0) : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
+ %1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
+ "tf.OpB"(%1) : (tensor<f32>) -> ()
+ }
+ func.return
+}
+
+// -----
+
+// Check that hoisting preserves read order.
+
+// CHECK-LABEL: func @order_preserved
+func.func @order_preserved(%arg0: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource<tensor<f32>>>, %arg2: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () {
+ // CHECK: %[[READ0:.*]] = "tf.ReadVariableOp"(%arg0)
+ // CHECK-NEXT: %[[READ2:.*]] = "tf.ReadVariableOp"(%arg2)
+ // CHECK-NEXT: tf_device.replicate
+ // CHECK-NEXT: %[[READ1:.*]] = "tf.ReadVariableOp"(%arg1)
+ // CHECK-NEXT: "tf.OpA"(%[[READ0]], %[[READ1]], %[[READ2]])
+ tf_device.replicate {n = 2 : i32} {
+ %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
+ %1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
+ %2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
+ "tf.OpA"(%0, %1, %2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+ }
+ func.return
+}
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc
index ed04adc6a39..f00df151321 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc
@@ -143,6 +143,8 @@ void AddReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm,
pm.addNestedPass<FuncOp>(mlir::TFDevice::CreateClusterConstantSinkingPass());
pm.addPass(mlir::TF::CreateResourceDeviceInferencePass());
+ pm.addNestedPass<FuncOp>(
+ tensorflow::tf2xla::internal::CreateHoistBroadcastReadPass());
pm.addPass(mlir::TFDevice::CreateClusterOutliningPass());
pm.addPass(mlir::TFTPU::CreateTPUResourceReadForWritePass());
pm.addPass(mlir::TFDevice::CreateMarkInputOutputAliasesPass());
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc
index c9cc5a4d1df..756ec42f312 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc
@@ -28,7 +28,7 @@ TEST(ClusteringBridgePassesTest, AddsBridgePasses) {
OpPassManager pass_manager;
AddReplicatedBridgeClusteringPipelinePasses(pass_manager);
- EXPECT_EQ(pass_manager.size(), 43);
+ EXPECT_EQ(pass_manager.size(), 44);
}
TEST(ClusteringBridgePassesTest, AddsNonTPUBridgePasses) {
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD
index 6250f2cf0ca..df25fdd9ffe 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD
@@ -26,6 +26,7 @@ cc_library(
deps = [
":extract_head_tail_outside_compilation",
":extract_outside_compilation",
+ ":hoist_broadcast_read",
":mark_ops_for_outside_compilation",
":tpu_cluster_formation",
":verify_clustering_pass",
@@ -351,6 +352,41 @@ cc_library(
],
)
+cc_library(
+ name = "hoist_broadcast_read",
+ srcs = ["hoist_broadcast_read.cc"],
+ textual_hdrs = [
+ "clustering_passes.h.inc",
+ ],
+ deps = [
+ ":clustering_passes_inc_gen",
+ "//tensorflow/compiler/mlir/tensorflow",
+ "//tensorflow/compiler/mlir/tensorflow:attribute_utils",
+ "//tensorflow/compiler/mlir/tensorflow:string_util",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
+ "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util",
+ "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib",
+ "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen",
+ "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
+ "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config",
+ "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/strings",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Rewrite",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+)
+
tf_cc_test(
name = "tpu_cluster_formation_test",
srcs = ["tpu_cluster_formation_test.cc"],
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h
index ea6187a2309..3ccf990a4b5 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h
@@ -56,6 +56,11 @@ CreateXlaOutlineEntryFunctionsPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateMarkOpsForOutsideCompilationPass();
+// Creates a pass that hoists reads out of a replicate that are on a variable
+// whose value is broacast to all replicas.
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateHoistBroadcastReadPass();
+
#define GEN_PASS_REGISTRATION
#define GEN_PASS_DECL_MARKOPSFOROUTSIDECOMPILATIONPASS
#define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td
index c219c35842c..90d2e962bc9 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td
@@ -349,3 +349,42 @@ def MarkOpsForOutsideCompilationPass : Pass<"tf-mark-ops-for-outside-compilation
let constructor = "tensorflow::tf2xla::internal::CreateMarkOpsForOutsideCompilationPass()";
}
+
+def HoistBroadcastReadPass : Pass<"tf-hoist-broadcast-read", "mlir::func::FuncOp"> {
+ let summary = "Hoist reads out of a replicate that are on a resource that is broacast to all replicas.";
+
+ let description = [{
+ Some `ReadVariableOp`s that are within a `tf_device.replicate` read the same
+ value across all replicas. These reads can be hoisted out of the
+ `tf_device.replicate` so there's one read for all replicas, and each replica
+ depends on the result of the read. This transform enables the
+ xla-broadcast-pass to optimize the broadcast value.
+
+ For example, the following:
+
+ ```mlir
+ tf_device.replicate {n = 2 : i32} {
+ %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
+ "tf.OpA"(%0) : (tensor<f32>) -> ()
+ }
+ ```
+
+ will be transformed into:
+
+ ``mlir
+ %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
+ tf_device.replicate {n = 2 : i32} {
+ "tf.OpA"(%0) : (tensor<f32>) -> ()
+ }
+ ```
+
+ We must ensure that there is a single underlying resource that not
+ distributed across replicas. There is a single underlying resource when the
+ resource device type is CPU, so we cautiously only apply in this case.
+
+ To be cautious we never hoist a read that comes after a write to the same
+ resource.
+ }];
+
+ let constructor = "tensorflow::tf2xla::internal::CreateHoistBroadcastReadPass()";
+}
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc
new file mode 100644
index 00000000000..732bae8c67b
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc
@@ -0,0 +1,154 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
+#include "mlir/IR/Operation.h" // from @llvm-project
+#include "mlir/IR/Value.h" // from @llvm-project
+#include "mlir/IR/Visitors.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace tf2xla {
+namespace internal {
+
+namespace {
+
+using mlir::BlockArgument;
+using mlir::failure;
+using mlir::LogicalResult;
+using mlir::Operation;
+using mlir::OperationPass;
+using mlir::OpOperand;
+using mlir::StringAttr;
+using mlir::success;
+using mlir::Value;
+using mlir::WalkResult;
+using mlir::func::FuncOp;
+using mlir::TF::ReadVariableOp;
+using mlir::tf_device::ReplicateOp;
+
+#define GEN_PASS_DEF_HOISTBROADCASTREADPASS
+#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc"
+
+constexpr char kFuncDeviceAttr[] = "tf.device";
+constexpr char kCpuDeviceType[] = "CPU";
+
+struct HoistBroadcastRead
+ : public impl::HoistBroadcastReadPassBase<HoistBroadcastRead> {
+ void runOnOperation() override;
+};
+
+// Get the ancestor of `descendant` that is a direct child of `ancestor`.
+Operation* GetAncestorBelow(Operation* descendant, Operation* ancestor) {
+ Operation* parent = descendant->getParentOp();
+ if (!parent) return nullptr;
+ if (parent == ancestor) return descendant;
+ return GetAncestorBelow(parent, ancestor);
+}
+
+// `is_cpu_read` is set to `true` iff `read` is on a resource with device type
+// CPU.
+LogicalResult IsCpuRead(FuncOp func, ReadVariableOp read, bool& is_cpu_read) {
+ if (auto arg = read->getOperand(0).dyn_cast<BlockArgument>()) {
+ if (arg.getOwner() != &(func.front())) {
+ is_cpu_read = false;
+ return success();
+ }
+ if (auto attr = func.getArgAttrOfType<StringAttr>(arg.getArgNumber(),
+ kFuncDeviceAttr)) {
+ std::string device = attr.getValue().str();
+ tensorflow::DeviceNameUtils::ParsedName parsed_name;
+ if (!tensorflow::DeviceNameUtils::ParseFullName(device, &parsed_name)) {
+ return read->emitOpError() << "invalid device '" << device << "'";
+ }
+ is_cpu_read = parsed_name.type == kCpuDeviceType;
+ return success();
+ }
+ }
+ is_cpu_read = false;
+ return success();
+}
+
+// Get the reads to hoist in the `replicate`.
+LogicalResult GetReads(FuncOp func, ReplicateOp replicate,
+ llvm::SmallVector<ReadVariableOp, 4>& reads) {
+ for (Operation& op : replicate.getBody().front()) {
+ if (auto read = llvm::dyn_cast<ReadVariableOp>(&op)) {
+ bool is_cpu_read;
+ if (failed(IsCpuRead(func, read, is_cpu_read))) return failure();
+ if (is_cpu_read) reads.push_back(read);
+ }
+ }
+ return success();
+}
+
+// Move reads above the `replicate`. Skip reads that come after a write to the
+// same resource.
+void MoveReads(ReplicateOp replicate,
+ llvm::SmallVector<ReadVariableOp, 4>& reads) {
+ for (ReadVariableOp read : reads) {
+ Value res = read.getResource();
+ Operation* scope = res.getParentBlock()->getParentOp();
+ if (!scope->isProperAncestor(replicate)) continue;
+ bool has_conflicting_write = false;
+ for (OpOperand& use : res.getUses()) {
+ Operation* using_op = use.getOwner();
+ if (using_op == read) continue;
+ if (!replicate->isProperAncestor(using_op)) continue;
+ Operation* peer = GetAncestorBelow(using_op, replicate);
+ if (read->isBeforeInBlock(peer)) continue;
+ if (llvm::isa<ReadVariableOp>(peer)) continue;
+ has_conflicting_write = true;
+ }
+ if (has_conflicting_write) continue;
+ read->moveBefore(replicate);
+ }
+}
+
+// Hoist `ReadVariableOp`s above the `tf_device.replicate`s.
+void HoistBroadcastRead::runOnOperation() {
+ FuncOp func = getOperation();
+
+ auto result = func.walk([&](ReplicateOp replicate) {
+ llvm::SmallVector<ReadVariableOp, 4> reads;
+ if (failed(GetReads(func, replicate, reads)))
+ return WalkResult::interrupt();
+ MoveReads(replicate, reads);
+ return WalkResult::advance();
+ });
+
+ if (result.wasInterrupted()) return signalPassFailure();
+}
+
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> CreateHoistBroadcastReadPass() {
+ return std::make_unique<HoistBroadcastRead>();
+}
+
+} // namespace internal
+} // namespace tf2xla
+} // namespace tensorflow