From 47078b56c0492dfbb6aeb0cbae7dc30b8760630b Mon Sep 17 00:00:00 2001 From: iupaikov-amd Date: Thu, 21 Aug 2025 18:42:16 +0000 Subject: [PATCH 01/19] Added triton perf improvement changes --- torch/_inductor/codegen/triton.py | 6 ++-- torch/_inductor/runtime/hints.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 37 +++++++++++++++++--- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0e763772911c..3bf5e2414494 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1083,11 +1083,11 @@ def relu(x): @staticmethod def minimum(a, b): - return f"triton_helpers.minimum({a}, {b})" + return f"tl.minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"triton_helpers.maximum({a}, {b})" + return f"tl.maximum({a}, {b})" @staticmethod def where(a, b, c): @@ -1273,7 +1273,7 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() def rsqrt(x): - return f"libdevice.rsqrt({x})" + return f"tl.rsqrt({x})" @staticmethod @maybe_upcast_float32() diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 2732b9cecfb2..899b2f2d6d77 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -13,7 +13,7 @@ # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 4096, + "X": 8192, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index bbe9b04243e6..963bc596bf9a 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2071,6 +2071,9 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -2127,9 +2130,11 @@ def triton_config( ): z *= 2 - num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 - ) + # Calculate num_waprs if they are not hard passed to config + if num_warps == None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -2159,7 +2164,15 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: @@ -2207,6 +2220,7 @@ def triton_config_reduction( num_stages=1, num_warps=None, register_intensive=False, + waves_per_eu=None ) -> Config: """ Construct a reduction triton config with some adjustment heuristics @@ -2250,7 +2264,13 @@ def total_numel() -> int: cfg = _get_config({"x": x, **rnumels}) check_max_block(cfg) check_config(cfg, xnumel=size_hints["x"]) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_config(numels: dict[str, int]) -> dict[str, int]: @@ -2388,6 +2408,12 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), + # triton_config_with_settings( + # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2 + # ), + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), *hinted_configs, ] if len(size_hints) == 2: @@ -2624,6 +2650,7 @@ def reduction( configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) + return cached_autotune( size_hints, configs=configs, From 44e07d7256fd72dc637d2c94e6df8534e26b81bf Mon Sep 17 00:00:00 2001 From: iupaikov-amd Date: Fri, 22 Aug 2025 10:08:10 +0000 Subject: [PATCH 02/19] Added a place to put more reduction configs --- torch/_inductor/runtime/triton_heuristics.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 963bc596bf9a..a74ff5e09962 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2537,14 +2537,14 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): ): pass # skip all these cases elif reduction_hint == ReductionHint.INNER: - return [contiguous_config] + result_configs = [contiguous_config] elif reduction_hint == ReductionHint.OUTER: - return [outer_config] + result_configs = [outer_config] elif reduction_hint == ReductionHint.OUTER_TINY: - return [tiny_config] + result_configs = [tiny_config] if disable_pointwise_autotuning(inductor_meta): - return [make_config(32, 128)] - return [ + result_configs = [make_config(32, 128)] + result_configs = [ contiguous_config, outer_config, tiny_config, @@ -2556,6 +2556,12 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): make_config(64, 4, num_warps=8), ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + pass + + return result_configs + def match_target_block_product( size_hints, tiling_scores, target_block_product, min_block_size=1 From 1f90e3f063303191aa8fc52df89650f55a0c88cc Mon Sep 17 00:00:00 2001 From: iupaikov-amd Date: Fri, 22 Aug 2025 10:42:56 +0000 Subject: [PATCH 03/19] added a dummy config for reduction kernels --- torch/_inductor/runtime/triton_heuristics.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a74ff5e09962..ac419f61221f 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2558,7 +2558,15 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): # Additional reduction configs appended for ROCm builds if torch.version.hip: - pass + # New config + result_configs.append(triton_config_reduction( + size_hints, + 8192, + 2048, + num_warps=4, + num_stages=1, + waves_per_eu=2 + )) return result_configs From b41f0e493e8fbf3f8dce77cfc1b8e98af4af3305 Mon Sep 17 00:00:00 2001 From: iupaikov-amd Date: Fri, 22 Aug 2025 18:20:44 +0200 Subject: [PATCH 04/19] Increased the number of spils for kernel configs --- torch/_inductor/config.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e5b5fe224cc8..9f172d8acaa3 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1304,7 +1304,7 @@ class triton: # So far we see a fixed 8 spilled registers for kernels using sin/cos. # Raise the threshold to 16 to be safe. # We should revisit this once we understand more of the source of register spills. - spill_threshold: int = 16 + spill_threshold: int = 32 # Generate code containing the newer tl.make_block_ptr() API for loads/store use_block_ptr = False diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index ac419f61221f..140910d088a2 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -800,7 +800,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs): # for some (complicated) custom Triton kernels, a register-spilling # config may yield the best latency. if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( - "spill_threshold", 16 + "spill_threshold", 32 ): log.debug( "Skip config %s because of register spilling: %d", From 5515cea2829c40e1ee7dfc309f95fa9e63015853 Mon Sep 17 00:00:00 2001 From: iupaikov-amd Date: Fri, 22 Aug 2025 18:22:21 +0200 Subject: [PATCH 05/19] added a better reduction config --- torch/_inductor/runtime/triton_heuristics.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 140910d088a2..a8e7db869445 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2408,9 +2408,6 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), - # triton_config_with_settings( - # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2 - # ), triton_config_with_settings( size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 ), @@ -2561,8 +2558,8 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): # New config result_configs.append(triton_config_reduction( size_hints, - 8192, - 2048, + 1024, + 8, num_warps=4, num_stages=1, waves_per_eu=2 From 4b623331719c28b2233385cb399b2bf46729a022 Mon Sep 17 00:00:00 2001 From: Xiaohu Guo Date: Fri, 22 Aug 2025 19:36:05 +0100 Subject: [PATCH 06/19] enable pipeline --- torch/_inductor/codegen/triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 3bf5e2414494..72fda25781df 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -3599,7 +3599,7 @@ def codegen_body(self): "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" ) self.body.writeline( - f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK, num_stages = 2):" ) with self.body.indent(offset=level + 1): self.iteration_ranges_codegen_header(tree, self.body) From 2d423a39026c4d78554418926271d59e8f45d90a Mon Sep 17 00:00:00 2001 From: Xiaohu Guo Date: Thu, 28 Aug 2025 16:45:00 +0100 Subject: [PATCH 07/19] add one more reduction config to avoid large vgpr spills --- torch/_inductor/runtime/triton_heuristics.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a8e7db869445..3dc49f6275fa 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2564,6 +2564,14 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): num_stages=1, waves_per_eu=2 )) + result_configs.append(triton_config_reduction( + size_hints, + 512, + 8, + num_warps=4, + num_stages=1, + waves_per_eu=1 + )) return result_configs From 8c5880597c5c2ff7d7017d66136fb70497a4563b Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 3 Sep 2025 16:38:26 +0100 Subject: [PATCH 08/19] Bug fix and optimisation for persistent reduction kernel tuning Original PR had incorrect indentation. Updated PR such that autotune will always add tiny configs, otherwise use the hinted configs only. --- torch/_inductor/runtime/triton_heuristics.py | 28 ++++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 3dc49f6275fa..d1d47ed44768 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2780,20 +2780,20 @@ def _persistent_reduction_configs( elif reduction_hint == ReductionHint.OUTER: configs = configs[-1:] - if reduction_hint == ReductionHint.OUTER_TINY: - tiny_configs = [ - triton_config_reduction( - size_hints, - 2 * (256 // rnumel) if rnumel <= 256 else 1, - rnumel, - ) - ] - if max_autotune_enabled: - for tconfig in tiny_configs: - if tconfig not in configs: - configs.append(tconfig) - else: - configs = tiny_configs + tiny_configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + + if max_autotune_enabled: + for conf in tiny_configs: + if conf not in configs: + configs.append(conf) + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = tiny_configs for c in configs: # we don't need Rn_BLOCK for persistent reduction From c8f6b0280a8d9439507d1d9d92010ea86c894556 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 5 Sep 2025 12:28:59 +0100 Subject: [PATCH 09/19] Duplicate config block --- torch/_inductor/runtime/triton_heuristics.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index d1d47ed44768..723b6275304f 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2740,18 +2740,12 @@ def _persistent_reduction_configs( or inductor_meta.get("max_autotune_pointwise") ) - configs = [ - triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) - for xblock in (1, 8, 32, 128) - if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) - ] - if "y" not in size_hints: configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) for xblock in (1, 8, 32, 128) if xblock == 1 - or (rnumel * xblock <= 4096 and xblock <= xnumel) + or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) ] else: configs = [] From 3b996ff706586fb98516bae15d96b556f8a9d366 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Mon, 8 Sep 2025 10:31:20 +0000 Subject: [PATCH 10/19] Fix 1D config being applied to 2D configs --- torch/_inductor/runtime/triton_heuristics.py | 36 +++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 723b6275304f..3aedff0a3f7d 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2532,7 +2532,25 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): elif inductor_meta.get("max_autotune") or inductor_meta.get( "max_autotune_pointwise" ): - pass # skip all these cases + # Extra ROCm tuning + if torch.version.hip: + result_configs.append(triton_config_reduction( + size_hints, + 1024, + 8, + num_warps=4, + num_stages=1, + waves_per_eu=2 + )) + result_configs.append(triton_config_reduction( + size_hints, + 512, + 8, + num_warps=4, + num_stages=1, + waves_per_eu=1 + )) + elif reduction_hint == ReductionHint.INNER: result_configs = [contiguous_config] elif reduction_hint == ReductionHint.OUTER: @@ -2556,22 +2574,6 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): # Additional reduction configs appended for ROCm builds if torch.version.hip: # New config - result_configs.append(triton_config_reduction( - size_hints, - 1024, - 8, - num_warps=4, - num_stages=1, - waves_per_eu=2 - )) - result_configs.append(triton_config_reduction( - size_hints, - 512, - 8, - num_warps=4, - num_stages=1, - waves_per_eu=1 - )) return result_configs From a6e5fe6dbdf9c771ed28765f150be02a2b4fdf65 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Mon, 8 Sep 2025 10:34:18 +0000 Subject: [PATCH 11/19] Remove outdated condition --- torch/_inductor/runtime/triton_heuristics.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 3aedff0a3f7d..8024c07320ef 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2571,10 +2571,6 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): make_config(64, 4, num_warps=8), ] - # Additional reduction configs appended for ROCm builds - if torch.version.hip: - # New config - return result_configs From 7c7ad787780095cf41d5cc4eb41fc5eba582e0d6 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Mon, 8 Sep 2025 17:12:02 +0100 Subject: [PATCH 12/19] Initialise result configs to fix bug --- torch/_inductor/runtime/triton_heuristics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 8024c07320ef..a28ca8593718 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2532,6 +2532,9 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): elif inductor_meta.get("max_autotune") or inductor_meta.get( "max_autotune_pointwise" ): + + result_configs = [] + # Extra ROCm tuning if torch.version.hip: result_configs.append(triton_config_reduction( From d81b7e9b1e59110ab41a4217bfcd3ba985a0228b Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Thu, 11 Sep 2025 10:29:26 +0100 Subject: [PATCH 13/19] [Perf branch] Add support for 2d reduction and bug fix (#2629) --- torch/_inductor/runtime/triton_heuristics.py | 93 +++++++++----------- 1 file changed, 44 insertions(+), 49 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a28ca8593718..d4df02efc9c0 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2282,7 +2282,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: def triton_config_tiled_reduction( - size_hints, x, y, r, num_stages=1, register_intensive=False + size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None ): """ Construct a tile reduction triton config with some adjustment @@ -2319,7 +2319,11 @@ def total_numel() -> int: ) check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) check_max_block(cfg) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + return config def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]): @@ -2469,6 +2473,9 @@ def _reduction_configs( # Convert reductions to 1D, to simplify heuristics. rnumel = get_total_reduction_numel(size_hints) + # Is max autotune enabled + max_autotune = inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + register_intensive = False MAX_R0_BLOCK = 2048 if ( @@ -2491,7 +2498,7 @@ def _reduction_configs( MAX_R0_BLOCK = 1024 register_intensive = True - def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): + def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, waves_per_eu=None): # For 3D case with tiling scores, create an adapted version if "y" in size_hints: assert "tiling_scores" in inductor_meta @@ -2503,6 +2510,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu ) else: # For other cases, use the original function @@ -2513,6 +2521,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu ) contiguous_config = make_config( @@ -2526,54 +2535,38 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) - # For 3d tiling, default to more autotuning initially - if "y" in size_hints: - pass - elif inductor_meta.get("max_autotune") or inductor_meta.get( - "max_autotune_pointwise" - ): - result_configs = [] - - # Extra ROCm tuning - if torch.version.hip: - result_configs.append(triton_config_reduction( - size_hints, - 1024, - 8, - num_warps=4, - num_stages=1, - waves_per_eu=2 - )) - result_configs.append(triton_config_reduction( - size_hints, - 512, - 8, - num_warps=4, - num_stages=1, - waves_per_eu=1 - )) - - elif reduction_hint == ReductionHint.INNER: - result_configs = [contiguous_config] - elif reduction_hint == ReductionHint.OUTER: - result_configs = [outer_config] - elif reduction_hint == ReductionHint.OUTER_TINY: - result_configs = [tiny_config] - if disable_pointwise_autotuning(inductor_meta): - result_configs = [make_config(32, 128)] - result_configs = [ - contiguous_config, - outer_config, - tiny_config, - make_config(64, 64), - make_config(8, 512), - # halve the XBLOCK/Rn_BLOCK compared to outer_config - # TODO: this may only be beneficial when each iteration of the reduction - # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 - make_config(64, 4, num_warps=8), - ] + result_configs = [] + + if not (max_autotune or "y" in size_hints): + if reduction_hint == ReductionHint.INNER: + result_configs = [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + result_configs = [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + result_configs = [tiny_config] + else: + result_configs = [make_config(32, 128)] + else: + result_configs = [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + # Add ROCm-specific configs when autotuning + if torch.version.hip: + result_configs.extend([ + make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1) + ]) + return result_configs @@ -2632,6 +2625,7 @@ def adapt_config_for_tiling( num_stages=1, register_intensive=False, persistent_reduction=False, + waves_per_eu=None ) -> Config: """ Create an adapted configuration based on tiling scores, @@ -2650,6 +2644,7 @@ def adapt_config_for_tiling( block_sizes["r0_"], num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu ) From 4377d1fa5282de878eddbdea969de0809352ef54 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Thu, 11 Sep 2025 11:15:45 +0100 Subject: [PATCH 14/19] Update triton_heuristics.py --- torch/_inductor/runtime/triton_heuristics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index d4df02efc9c0..3a39bc677a1c 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2741,7 +2741,7 @@ def _persistent_reduction_configs( triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) for xblock in (1, 8, 32, 128) if xblock == 1 - or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) + or (xblock <= xnumel and rnumel * xblock <= 4096) ] else: configs = [] From 247c218bfdb96eaf68c88d767fca8213259c6645 Mon Sep 17 00:00:00 2001 From: Sampsa Riikonen Date: Fri, 12 Sep 2025 15:54:24 +0300 Subject: [PATCH 15/19] pointwise autotuning returnz (#2636) removed the (erroneous?) check that disables autotuning for pointwise kernels --- torch/_inductor/runtime/triton_heuristics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 3a39bc677a1c..80efb609a13d 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2419,7 +2419,7 @@ def pointwise( ] if len(size_hints) == 2: if ( - disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE + disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") From 7a77bc4142d254ffaa2a2a64153b949417e94d4d Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Mon, 15 Sep 2025 09:56:01 +0100 Subject: [PATCH 16/19] Expand persistent reduction tuning space --- torch/_inductor/runtime/triton_heuristics.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 80efb609a13d..440aecf5163e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2739,7 +2739,7 @@ def _persistent_reduction_configs( if "y" not in size_hints: configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) - for xblock in (1, 8, 32, 128) + for xblock in (1, 4, 8, 16, 32, 64, 128, 256, 512) if xblock == 1 or (xblock <= xnumel and rnumel * xblock <= 4096) ] @@ -2747,7 +2747,7 @@ def _persistent_reduction_configs( configs = [] assert "tiling_scores" in inductor_meta x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} - for target_block_size in (1, 8, 32, 64, 128): + for target_block_size in (1, 4, 8, 16, 32, 64, 128, 256, 512): if target_block_size * rnumel > 4096: continue @@ -2782,6 +2782,22 @@ def _persistent_reduction_configs( for conf in tiny_configs: if conf not in configs: configs.append(conf) + + # Expand configs to try additional warps + expanded_configs = [] + for conf in configs: + num_warps = conf.num_warps + max_warps = 8 if torch.version.hip else 16 + small_conf = copy.deepcopy(conf) + large_conf = copy.deepcopy(conf) + small_conf.num_warps = max(small_conf.num_warps // 2, 1) + large_conf.num_warps = min(large_conf.num_warps * 2, max_warps) + expanded_configs.append(conf) + expanded_configs.append(small_conf) + expanded_configs.append(large_conf) + + configs = expanded_configs + elif reduction_hint == ReductionHint.OUTER_TINY: configs = tiny_configs From 0e2ed1d7267c12d384dca1ec86db4abd9cd959d0 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:48:02 +0100 Subject: [PATCH 17/19] Fix bug to renable reduction tuning --- torch/_inductor/runtime/triton_heuristics.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 440aecf5163e..e547499bb21f 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2474,7 +2474,10 @@ def _reduction_configs( rnumel = get_total_reduction_numel(size_hints) # Is max autotune enabled - max_autotune = inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ) register_intensive = False MAX_R0_BLOCK = 2048 From fbb2d857b585a2e1810c31be75a29f438ac7d2fc Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:29:51 +0100 Subject: [PATCH 18/19] Fix bug to renable reduction tuning (PART 2) --- torch/_inductor/runtime/triton_heuristics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index e547499bb21f..6d898a4dbf45 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2541,7 +2541,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, wa result_configs = [] - if not (max_autotune or "y" in size_hints): + if not (max_autotune_enabled or "y" in size_hints): if reduction_hint == ReductionHint.INNER: result_configs = [contiguous_config] elif reduction_hint == ReductionHint.OUTER: From a7bac0ac90970ac28175dc76ded5bd937a3a2606 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" <165712832+naromero77amd@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:30:17 -0500 Subject: [PATCH 19/19] [ROCm][inductor] Additional pointwise tunings (#2642) This config improves the performance of a 1D pointwise kernel by 20% as measured on MI350. --- torch/_inductor/runtime/triton_heuristics.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 6d898a4dbf45..f50fe20256c6 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2417,6 +2417,15 @@ def pointwise( ), *hinted_configs, ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + configs.append(triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1 + )) # 20% improvement if len(size_hints) == 2: if ( disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE