37
37
from functools import reduce
38
38
from enum import Enum
39
39
from cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
40
+ from cuda_to_hip_mappings import MATH_TRANSPILATIONS
40
41
41
42
# Hardcode the PyTorch template map
42
43
"""This dictionary provides the mapping from PyTorch kernel template types
@@ -426,21 +427,27 @@ def find_kernel_bounds(string):
426
427
return output_string
427
428
428
429
429
- def find_parenthesis_end (input_string , start ):
430
+ def find_closure_group (input_string , start , group ):
431
+ """Generalization for finding a balancing closure group
432
+
433
+ e.g. if group = ["(", ")"], then finds the first balanced parantheses.
434
+ if group = ["{", "}"], then finds the first balanced bracket.
435
+ """
436
+
430
437
inside_parenthesis = False
431
438
parens = 0
432
439
pos = start
433
440
p_start , p_end = - 1 , - 1
434
441
435
442
while pos < len (input_string ):
436
- if input_string [pos ] == "(" :
443
+ if input_string [pos ] == group [ 0 ] :
437
444
if inside_parenthesis is False :
438
445
inside_parenthesis = True
439
446
parens = 1
440
447
p_start = pos
441
448
else :
442
449
parens += 1
443
- elif input_string [pos ] == ")" and inside_parenthesis :
450
+ elif input_string [pos ] == group [ 1 ] and inside_parenthesis :
444
451
parens -= 1
445
452
446
453
if parens == 0 :
@@ -451,14 +458,24 @@ def find_parenthesis_end(input_string, start):
451
458
return None , None
452
459
453
460
461
+ def find_bracket_group (input_string , start ):
462
+ """Finds the first balanced parantheses."""
463
+ return find_closure_group (input_string , start , group = ["{" , "}" ])
464
+
465
+
466
+ def find_parentheses_group (input_string , start ):
467
+ """Finds the first balanced bracket."""
468
+ return find_closure_group (input_string , start , group = ["(" , ")" ])
469
+
470
+
454
471
def disable_asserts (input_string ):
455
472
""" Disables regular assert statements
456
473
e.g. "assert(....)" -> "/*assert(....)*/"
457
474
"""
458
475
output_string = input_string
459
476
asserts = list (re .finditer (r"\bassert[ ]*\(" , input_string ))
460
477
for assert_item in asserts :
461
- p_start , p_end = find_parenthesis_end (input_string , assert_item .end () - 1 )
478
+ p_start , p_end = find_parentheses_group (input_string , assert_item .end () - 1 )
462
479
start = assert_item .start ()
463
480
output_string = output_string .replace (input_string [start :p_end + 1 ], "" )
464
481
return output_string
@@ -712,6 +729,9 @@ def preprocessor(filepath, stats, hipify_caffe2):
712
729
# Replace std:: with non-std:: versions
713
730
output_source = replace_math_functions (output_source )
714
731
732
+ # Replace std:: with non-std:: versions
733
+ output_source = transpile_device_math (output_source )
734
+
715
735
# Replace __forceinline__ with inline
716
736
output_source = replace_forceinline (output_source )
717
737
@@ -888,6 +908,31 @@ def disable_module(input_file):
888
908
f .truncate ()
889
909
890
910
911
+ def transpile_device_math (input_string ):
912
+ """ Temporarily replace std:: invocations of math functions with non-std:: versions."""
913
+ # Extract device code positions
914
+ get_kernel_definitions = [k for k in re .finditer ( r"(template[ ]*<(.*)>\n.*\n?)?(__global__|__device__) void[\n| ](\w+(\(.*\))?)\(" , input_string )]
915
+
916
+ # Prepare output
917
+ output_string = input_string
918
+
919
+ # Iterate through each kernel definition
920
+ for kernel in get_kernel_definitions :
921
+ # Find the final paranthesis that closes this kernel function definition.
922
+ _ , paranth_end = find_bracket_group (input_string , kernel .end () - 1 )
923
+
924
+ # Replace all std:: math functions within range [start...ending]
925
+ selection = input_string [kernel .start ():paranth_end + 1 ]
926
+ selection_transpiled = selection
927
+ for func in MATH_TRANSPILATIONS :
928
+ selection_transpiled = selection_transpiled .replace (func , MATH_TRANSPILATIONS [func ])
929
+
930
+ # Perform replacements inside the output_string
931
+ output_string = output_string .replace (selection , selection_transpiled )
932
+
933
+ return output_string
934
+
935
+
891
936
def extract_arguments (start , string ):
892
937
""" Return the list of arguments in the upcoming function parameter closure.
893
938
Example:
@@ -1212,7 +1257,8 @@ def main():
1212
1257
f .truncate ()
1213
1258
1214
1259
all_files = list (matched_files_iter (args .output_directory , includes = args .includes ,
1215
- ignores = args .ignores , extensions = args .extensions , hipify_caffe2 = args .hipify_caffe2 ))
1260
+ ignores = args .ignores , extensions = args .extensions ,
1261
+ hipify_caffe2 = args .hipify_caffe2 ))
1216
1262
1217
1263
# Start Preprocessor
1218
1264
preprocess (
0 commit comments