Skip to content

Commit c591c89

Browse files
authored
[mypyc] Implement lowering pass and add primitives for int (in)equality (#17027)
Add a new `PrimitiveOp` op which can be transformed into lower-level ops in a lowering pass after reference counting op insertion pass. Higher-level ops in IR make it easier to implement various optimizations, and the output of irbuild test cases will be more compact and readable. Implement the lowering pass. Currently it's pretty minimal, and I will add additional primitives and the direct transformation of various primitives to `CallC` ops in follow-up PRs. Currently primitives that map to C calls generate `CallC` ops in the main irbuild pass, but the long-term goal is to only/mostly generate `PrimitiveOp`s instead of `CallC` ops during the main irbuild pass. Also implement primitives for tagged integer equality and inequality as examples. Lowering of primitives is implemented using decorated handler functions in `mypyc.lower` that are found based on the name of the primitive. The name has no other significance, though it's also used in pretty-printed IR output. Work on mypyc/mypyc#854. The issue describes the motivation in more detail.
1 parent 31dc503 commit c591c89

32 files changed

+772
-483
lines changed

mypyc/analysis/dataflow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
MethodCall,
3939
Op,
4040
OpVisitor,
41+
PrimitiveOp,
4142
RaiseStandardError,
4243
RegisterOp,
4344
Return,
@@ -234,6 +235,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]:
234235
def visit_call_c(self, op: CallC) -> GenAndKill[T]:
235236
return self.visit_register_op(op)
236237

238+
def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill[T]:
239+
return self.visit_register_op(op)
240+
237241
def visit_truncate(self, op: Truncate) -> GenAndKill[T]:
238242
return self.visit_register_op(op)
239243

mypyc/analysis/ircheck.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
MethodCall,
3838
Op,
3939
OpVisitor,
40+
PrimitiveOp,
4041
RaiseStandardError,
4142
Register,
4243
Return,
@@ -381,6 +382,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
381382
def visit_call_c(self, op: CallC) -> None:
382383
pass
383384

385+
def visit_primitive_op(self, op: PrimitiveOp) -> None:
386+
pass
387+
384388
def visit_truncate(self, op: Truncate) -> None:
385389
pass
386390

mypyc/analysis/selfleaks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
LoadStatic,
3232
MethodCall,
3333
OpVisitor,
34+
PrimitiveOp,
3435
RaiseStandardError,
3536
Register,
3637
RegisterOp,
@@ -149,6 +150,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill:
149150
def visit_call_c(self, op: CallC) -> GenAndKill:
150151
return self.check_register_op(op)
151152

153+
def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill:
154+
return self.check_register_op(op)
155+
152156
def visit_truncate(self, op: Truncate) -> GenAndKill:
153157
return CLEAN
154158

mypyc/codegen/emitfunc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
MethodCall,
4848
Op,
4949
OpVisitor,
50+
PrimitiveOp,
5051
RaiseStandardError,
5152
Register,
5253
Return,
@@ -629,6 +630,11 @@ def visit_call_c(self, op: CallC) -> None:
629630
args = ", ".join(self.reg(arg) for arg in op.args)
630631
self.emitter.emit_line(f"{dest}{op.function_name}({args});")
631632

633+
def visit_primitive_op(self, op: PrimitiveOp) -> None:
634+
raise RuntimeError(
635+
f"unexpected PrimitiveOp {op.desc.name}: they must be lowered before codegen"
636+
)
637+
632638
def visit_truncate(self, op: Truncate) -> None:
633639
dest = self.reg(op)
634640
value = self.reg(op.src)

mypyc/codegen/emitmodule.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from mypyc.transform.copy_propagation import do_copy_propagation
6060
from mypyc.transform.exceptions import insert_exception_handling
6161
from mypyc.transform.flag_elimination import do_flag_elimination
62+
from mypyc.transform.lower import lower_ir
6263
from mypyc.transform.refcount import insert_ref_count_opcodes
6364
from mypyc.transform.uninit import insert_uninit_checks
6465

@@ -235,6 +236,8 @@ def compile_scc_to_ir(
235236
insert_exception_handling(fn)
236237
# Insert refcount handling.
237238
insert_ref_count_opcodes(fn)
239+
# Switch to lower abstraction level IR.
240+
lower_ir(fn, compiler_options)
238241
# Perform optimizations.
239242
do_copy_propagation(fn, compiler_options)
240243
do_flag_elimination(fn, compiler_options)
@@ -423,10 +426,11 @@ def compile_modules_to_c(
423426
)
424427

425428
modules = compile_modules_to_ir(result, mapper, compiler_options, errors)
426-
ctext = compile_ir_to_c(groups, modules, result, mapper, compiler_options)
429+
if errors.num_errors > 0:
430+
return {}, []
427431

428-
if errors.num_errors == 0:
429-
write_cache(modules, result, group_map, ctext)
432+
ctext = compile_ir_to_c(groups, modules, result, mapper, compiler_options)
433+
write_cache(modules, result, group_map, ctext)
430434

431435
return modules, [ctext[name] for _, name in groups]
432436

mypyc/ir/ops.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,78 @@ def accept(self, visitor: OpVisitor[T]) -> T:
576576
return visitor.visit_method_call(self)
577577

578578

579+
class PrimitiveDescription:
580+
"""Description of a primitive op.
581+
582+
Primitives get lowered into lower-level ops before code generation.
583+
584+
If c_function_name is provided, a primitive will be lowered into a CallC op.
585+
Otherwise custom logic will need to be implemented to transform the
586+
primitive into lower-level ops.
587+
"""
588+
589+
def __init__(
590+
self,
591+
name: str,
592+
arg_types: list[RType],
593+
return_type: RType, # TODO: What about generic?
594+
var_arg_type: RType | None,
595+
truncated_type: RType | None,
596+
c_function_name: str | None,
597+
error_kind: int,
598+
steals: StealsDescription,
599+
is_borrowed: bool,
600+
ordering: list[int] | None,
601+
extra_int_constants: list[tuple[int, RType]],
602+
priority: int,
603+
) -> None:
604+
# Each primitive much have a distinct name, but otherwise they are arbitrary.
605+
self.name: Final = name
606+
self.arg_types: Final = arg_types
607+
self.return_type: Final = return_type
608+
self.var_arg_type: Final = var_arg_type
609+
self.truncated_type: Final = truncated_type
610+
# If non-None, this will map to a call of a C helper function; if None,
611+
# there must be a custom handler function that gets invoked during the lowering
612+
# pass to generate low-level IR for the primitive (in the mypyc.lower package)
613+
self.c_function_name: Final = c_function_name
614+
self.error_kind: Final = error_kind
615+
self.steals: Final = steals
616+
self.is_borrowed: Final = is_borrowed
617+
self.ordering: Final = ordering
618+
self.extra_int_constants: Final = extra_int_constants
619+
self.priority: Final = priority
620+
621+
def __repr__(self) -> str:
622+
return f"<PrimitiveDescription {self.name}>"
623+
624+
625+
class PrimitiveOp(RegisterOp):
626+
"""A higher-level primitive operation.
627+
628+
Some of these have special compiler support. These will be lowered
629+
(transformed) into lower-level IR ops before code generation, and after
630+
reference counting op insertion. Others will be transformed into CallC
631+
ops.
632+
633+
Tagged integer equality is a typical primitive op with non-trivial
634+
lowering. It gets transformed into a tag check, followed by different
635+
code paths for short and long representations.
636+
"""
637+
638+
def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1) -> None:
639+
self.args = args
640+
self.type = desc.return_type
641+
self.error_kind = desc.error_kind
642+
self.desc = desc
643+
644+
def sources(self) -> list[Value]:
645+
return self.args
646+
647+
def accept(self, visitor: OpVisitor[T]) -> T:
648+
return visitor.visit_primitive_op(self)
649+
650+
579651
class LoadErrorValue(RegisterOp):
580652
"""Load an error value.
581653
@@ -1446,7 +1518,8 @@ class Unborrow(RegisterOp):
14461518

14471519
error_kind = ERR_NEVER
14481520

1449-
def __init__(self, src: Value) -> None:
1521+
def __init__(self, src: Value, line: int = -1) -> None:
1522+
super().__init__(line)
14501523
assert src.is_borrowed
14511524
self.src = src
14521525
self.type = src.type
@@ -1555,6 +1628,10 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> T:
15551628
def visit_call_c(self, op: CallC) -> T:
15561629
raise NotImplementedError
15571630

1631+
@abstractmethod
1632+
def visit_primitive_op(self, op: PrimitiveOp) -> T:
1633+
raise NotImplementedError
1634+
15581635
@abstractmethod
15591636
def visit_truncate(self, op: Truncate) -> T:
15601637
raise NotImplementedError

mypyc/ir/pprint.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
MethodCall,
4444
Op,
4545
OpVisitor,
46+
PrimitiveOp,
4647
RaiseStandardError,
4748
Register,
4849
Return,
@@ -217,6 +218,22 @@ def visit_call_c(self, op: CallC) -> str:
217218
else:
218219
return self.format("%r = %s(%s)", op, op.function_name, args_str)
219220

221+
def visit_primitive_op(self, op: PrimitiveOp) -> str:
222+
args = []
223+
arg_index = 0
224+
type_arg_index = 0
225+
for arg_type in zip(op.desc.arg_types):
226+
if arg_type:
227+
args.append(self.format("%r", op.args[arg_index]))
228+
arg_index += 1
229+
else:
230+
assert op.type_args
231+
args.append(self.format("%r", op.type_args[type_arg_index]))
232+
type_arg_index += 1
233+
234+
args_str = ", ".join(args)
235+
return self.format("%r = %s %s ", op, op.desc.name, args_str)
236+
220237
def visit_truncate(self, op: Truncate) -> str:
221238
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
222239

mypyc/irbuild/ast_helpers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,12 @@ def maybe_process_conditional_comparison(
9393
self.add_bool_branch(reg, true, false)
9494
else:
9595
# "left op right" for two tagged integers
96-
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
96+
if op in ("==", "!="):
97+
reg = self.builder.binary_op(left, right, op, e.line)
98+
self.flush_keep_alives()
99+
self.add_bool_branch(reg, true, false)
100+
else:
101+
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
97102
return True
98103

99104

mypyc/irbuild/expression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
756756
set_literal = precompute_set_literal(builder, e.operands[1])
757757
if set_literal is not None:
758758
lhs = e.operands[0]
759-
result = builder.builder.call_c(
759+
result = builder.builder.primitive_op(
760760
set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive
761761
)
762762
if first_op == "not in":
@@ -778,7 +778,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
778778
borrow_left = is_borrow_friendly_expr(builder, right_expr)
779779
left = builder.accept(left_expr, can_borrow=borrow_left)
780780
right = builder.accept(right_expr, can_borrow=True)
781-
return builder.compare_tagged(left, right, first_op, e.line)
781+
return builder.binary_op(left, right, first_op, e.line)
782782

783783
# TODO: Don't produce an expression when used in conditional context
784784
# All of the trickiness here is due to support for chained conditionals

0 commit comments

Comments
 (0)