Skip to content

Commit 38ad0c3

Browse files
authored
Merge pull request #85 from Jorghi12/transpile_device_math
Automatically handle transpilations inside device code only
2 parents 347b277 + 7bd98d2 commit 38ad0c3

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@
1212
supported in ROCm/HIP yet.
1313
"""
1414

15+
# List of math functions that should be replaced inside device code only.
16+
MATH_TRANSPILATIONS = {
17+
"std::max": ("::max"),
18+
"std::min": ("::min"),
19+
"std::ceil": ("::ceil"),
20+
"std::floor": ("::floor"),
21+
"std::exp": ("::exp"),
22+
"std::log": ("::log"),
23+
"std::pow": ("::pow"),
24+
"std::fabs": ("::fabs"),
25+
"std::fmod": ("::fmod"),
26+
"std::remainder": ("::remainder"),
27+
}
28+
29+
1530
CUDA_TYPE_NAME_MAP = {
1631
"CUresult": ("hipError_t", CONV_TYPE, API_DRIVER),
1732
"cudaError_t": ("hipError_t", CONV_TYPE, API_RUNTIME),
@@ -2138,5 +2153,5 @@
21382153
"hipblasGetStream": ("rocblas_get_stream", API_CAFFE2),
21392154
}
21402155

2141-
CUDA_TO_HIP_MAPPINGS = [CUDA_IDENTIFIER_MAP, CUDA_TYPE_NAME_MAP,
2156+
CUDA_TO_HIP_MAPPINGS = [CUDA_IDENTIFIER_MAP, CUDA_TYPE_NAME_MAP,
21422157
CUDA_INCLUDE_MAP, CUDA_SPARSE_MAP, PYTORCH_SPECIFIC_MAPPINGS, CAFFE2_SPECIFIC_MAPPINGS]

tools/amd_build/pyHIPIFY/hipify-python.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from functools import reduce
3838
from enum import Enum
3939
from cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
40+
from cuda_to_hip_mappings import MATH_TRANSPILATIONS
4041

4142
# Hardcode the PyTorch template map
4243
"""This dictionary provides the mapping from PyTorch kernel template types
@@ -426,21 +427,27 @@ def find_kernel_bounds(string):
426427
return output_string
427428

428429

429-
def find_parenthesis_end(input_string, start):
430+
def find_closure_group(input_string, start, group):
431+
"""Generalization for finding a balancing closure group
432+
433+
e.g. if group = ["(", ")"], then finds the first balanced parantheses.
434+
if group = ["{", "}"], then finds the first balanced bracket.
435+
"""
436+
430437
inside_parenthesis = False
431438
parens = 0
432439
pos = start
433440
p_start, p_end = -1, -1
434441

435442
while pos < len(input_string):
436-
if input_string[pos] == "(":
443+
if input_string[pos] == group[0]:
437444
if inside_parenthesis is False:
438445
inside_parenthesis = True
439446
parens = 1
440447
p_start = pos
441448
else:
442449
parens += 1
443-
elif input_string[pos] == ")" and inside_parenthesis:
450+
elif input_string[pos] == group[1] and inside_parenthesis:
444451
parens -= 1
445452

446453
if parens == 0:
@@ -451,14 +458,24 @@ def find_parenthesis_end(input_string, start):
451458
return None, None
452459

453460

461+
def find_bracket_group(input_string, start):
462+
"""Finds the first balanced parantheses."""
463+
return find_closure_group(input_string, start, group=["{", "}"])
464+
465+
466+
def find_parentheses_group(input_string, start):
467+
"""Finds the first balanced bracket."""
468+
return find_closure_group(input_string, start, group=["(", ")"])
469+
470+
454471
def disable_asserts(input_string):
455472
""" Disables regular assert statements
456473
e.g. "assert(....)" -> "/*assert(....)*/"
457474
"""
458475
output_string = input_string
459476
asserts = list(re.finditer(r"\bassert[ ]*\(", input_string))
460477
for assert_item in asserts:
461-
p_start, p_end = find_parenthesis_end(input_string, assert_item.end() - 1)
478+
p_start, p_end = find_parentheses_group(input_string, assert_item.end() - 1)
462479
start = assert_item.start()
463480
output_string = output_string.replace(input_string[start:p_end + 1], "")
464481
return output_string
@@ -712,6 +729,9 @@ def preprocessor(filepath, stats, hipify_caffe2):
712729
# Replace std:: with non-std:: versions
713730
output_source = replace_math_functions(output_source)
714731

732+
# Replace std:: with non-std:: versions
733+
output_source = transpile_device_math(output_source)
734+
715735
# Replace __forceinline__ with inline
716736
output_source = replace_forceinline(output_source)
717737

@@ -888,6 +908,31 @@ def disable_module(input_file):
888908
f.truncate()
889909

890910

911+
def transpile_device_math(input_string):
912+
""" Temporarily replace std:: invocations of math functions with non-std:: versions."""
913+
# Extract device code positions
914+
get_kernel_definitions = [k for k in re.finditer( r"(template[ ]*<(.*)>\n.*\n?)?(__global__|__device__) void[\n| ](\w+(\(.*\))?)\(", input_string)]
915+
916+
# Prepare output
917+
output_string = input_string
918+
919+
# Iterate through each kernel definition
920+
for kernel in get_kernel_definitions:
921+
# Find the final paranthesis that closes this kernel function definition.
922+
_, paranth_end = find_bracket_group(input_string, kernel.end() - 1)
923+
924+
# Replace all std:: math functions within range [start...ending]
925+
selection = input_string[kernel.start():paranth_end + 1]
926+
selection_transpiled = selection
927+
for func in MATH_TRANSPILATIONS:
928+
selection_transpiled = selection_transpiled.replace(func, MATH_TRANSPILATIONS[func])
929+
930+
# Perform replacements inside the output_string
931+
output_string = output_string.replace(selection, selection_transpiled)
932+
933+
return output_string
934+
935+
891936
def extract_arguments(start, string):
892937
""" Return the list of arguments in the upcoming function parameter closure.
893938
Example:
@@ -1212,7 +1257,8 @@ def main():
12121257
f.truncate()
12131258

12141259
all_files = list(matched_files_iter(args.output_directory, includes=args.includes,
1215-
ignores=args.ignores, extensions=args.extensions, hipify_caffe2=args.hipify_caffe2))
1260+
ignores=args.ignores, extensions=args.extensions,
1261+
hipify_caffe2=args.hipify_caffe2))
12161262

12171263
# Start Preprocessor
12181264
preprocess(

0 commit comments

Comments
 (0)