diff --git a/mypyc/analysis/dataflow.py b/mypyc/analysis/dataflow.py index 57ad2b17fcc5..9babf860fb31 100644 --- a/mypyc/analysis/dataflow.py +++ b/mypyc/analysis/dataflow.py @@ -38,6 +38,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, RegisterOp, Return, @@ -234,6 +235,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]: def visit_call_c(self, op: CallC) -> GenAndKill[T]: return self.visit_register_op(op) + def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill[T]: + return self.visit_register_op(op) + def visit_truncate(self, op: Truncate) -> GenAndKill[T]: return self.visit_register_op(op) diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index 127047e02ff5..88737ac208de 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -37,6 +37,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, Register, Return, @@ -381,6 +382,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None: def visit_call_c(self, op: CallC) -> None: pass + def visit_primitive_op(self, op: PrimitiveOp) -> None: + pass + def visit_truncate(self, op: Truncate) -> None: pass diff --git a/mypyc/analysis/selfleaks.py b/mypyc/analysis/selfleaks.py index 80c2bc348bc2..5d89a9bfc7c6 100644 --- a/mypyc/analysis/selfleaks.py +++ b/mypyc/analysis/selfleaks.py @@ -31,6 +31,7 @@ LoadStatic, MethodCall, OpVisitor, + PrimitiveOp, RaiseStandardError, Register, RegisterOp, @@ -149,6 +150,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill: def visit_call_c(self, op: CallC) -> GenAndKill: return self.check_register_op(op) + def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill: + return self.check_register_op(op) + def visit_truncate(self, op: Truncate) -> GenAndKill: return CLEAN diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index c08f1f840fa4..12f57b9cee6f 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -47,6 +47,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, Register, Return, @@ -629,6 +630,11 @@ def visit_call_c(self, op: CallC) -> None: args = ", ".join(self.reg(arg) for arg in op.args) self.emitter.emit_line(f"{dest}{op.function_name}({args});") + def visit_primitive_op(self, op: PrimitiveOp) -> None: + raise RuntimeError( + f"unexpected PrimitiveOp {op.desc.name}: they must be lowered before codegen" + ) + def visit_truncate(self, op: Truncate) -> None: dest = self.reg(op) value = self.reg(op.src) diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 9466bc2cea79..6c8f5ac91335 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -59,6 +59,7 @@ from mypyc.transform.copy_propagation import do_copy_propagation from mypyc.transform.exceptions import insert_exception_handling from mypyc.transform.flag_elimination import do_flag_elimination +from mypyc.transform.lower import lower_ir from mypyc.transform.refcount import insert_ref_count_opcodes from mypyc.transform.uninit import insert_uninit_checks @@ -235,6 +236,8 @@ def compile_scc_to_ir( insert_exception_handling(fn) # Insert refcount handling. insert_ref_count_opcodes(fn) + # Switch to lower abstraction level IR. + lower_ir(fn, compiler_options) # Perform optimizations. do_copy_propagation(fn, compiler_options) do_flag_elimination(fn, compiler_options) @@ -423,10 +426,11 @@ def compile_modules_to_c( ) modules = compile_modules_to_ir(result, mapper, compiler_options, errors) - ctext = compile_ir_to_c(groups, modules, result, mapper, compiler_options) + if errors.num_errors > 0: + return {}, [] - if errors.num_errors == 0: - write_cache(modules, result, group_map, ctext) + ctext = compile_ir_to_c(groups, modules, result, mapper, compiler_options) + write_cache(modules, result, group_map, ctext) return modules, [ctext[name] for _, name in groups] diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 04c50d1e2841..3acfb0933e5a 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -576,6 +576,78 @@ def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_method_call(self) +class PrimitiveDescription: + """Description of a primitive op. + + Primitives get lowered into lower-level ops before code generation. + + If c_function_name is provided, a primitive will be lowered into a CallC op. + Otherwise custom logic will need to be implemented to transform the + primitive into lower-level ops. + """ + + def __init__( + self, + name: str, + arg_types: list[RType], + return_type: RType, # TODO: What about generic? + var_arg_type: RType | None, + truncated_type: RType | None, + c_function_name: str | None, + error_kind: int, + steals: StealsDescription, + is_borrowed: bool, + ordering: list[int] | None, + extra_int_constants: list[tuple[int, RType]], + priority: int, + ) -> None: + # Each primitive much have a distinct name, but otherwise they are arbitrary. + self.name: Final = name + self.arg_types: Final = arg_types + self.return_type: Final = return_type + self.var_arg_type: Final = var_arg_type + self.truncated_type: Final = truncated_type + # If non-None, this will map to a call of a C helper function; if None, + # there must be a custom handler function that gets invoked during the lowering + # pass to generate low-level IR for the primitive (in the mypyc.lower package) + self.c_function_name: Final = c_function_name + self.error_kind: Final = error_kind + self.steals: Final = steals + self.is_borrowed: Final = is_borrowed + self.ordering: Final = ordering + self.extra_int_constants: Final = extra_int_constants + self.priority: Final = priority + + def __repr__(self) -> str: + return f"" + + +class PrimitiveOp(RegisterOp): + """A higher-level primitive operation. + + Some of these have special compiler support. These will be lowered + (transformed) into lower-level IR ops before code generation, and after + reference counting op insertion. Others will be transformed into CallC + ops. + + Tagged integer equality is a typical primitive op with non-trivial + lowering. It gets transformed into a tag check, followed by different + code paths for short and long representations. + """ + + def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1) -> None: + self.args = args + self.type = desc.return_type + self.error_kind = desc.error_kind + self.desc = desc + + def sources(self) -> list[Value]: + return self.args + + def accept(self, visitor: OpVisitor[T]) -> T: + return visitor.visit_primitive_op(self) + + class LoadErrorValue(RegisterOp): """Load an error value. @@ -1446,7 +1518,8 @@ class Unborrow(RegisterOp): error_kind = ERR_NEVER - def __init__(self, src: Value) -> None: + def __init__(self, src: Value, line: int = -1) -> None: + super().__init__(line) assert src.is_borrowed self.src = src self.type = src.type @@ -1555,6 +1628,10 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> T: def visit_call_c(self, op: CallC) -> T: raise NotImplementedError + @abstractmethod + def visit_primitive_op(self, op: PrimitiveOp) -> T: + raise NotImplementedError + @abstractmethod def visit_truncate(self, op: Truncate) -> T: raise NotImplementedError diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index 5578049256f1..8d6723917ea0 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -43,6 +43,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, Register, Return, @@ -217,6 +218,22 @@ def visit_call_c(self, op: CallC) -> str: else: return self.format("%r = %s(%s)", op, op.function_name, args_str) + def visit_primitive_op(self, op: PrimitiveOp) -> str: + args = [] + arg_index = 0 + type_arg_index = 0 + for arg_type in zip(op.desc.arg_types): + if arg_type: + args.append(self.format("%r", op.args[arg_index])) + arg_index += 1 + else: + assert op.type_args + args.append(self.format("%r", op.type_args[type_arg_index])) + type_arg_index += 1 + + args_str = ", ".join(args) + return self.format("%r = %s %s ", op, op.desc.name, args_str) + def visit_truncate(self, op: Truncate) -> str: return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type) diff --git a/mypyc/irbuild/ast_helpers.py b/mypyc/irbuild/ast_helpers.py index 1af1ad611a89..8490eaa03477 100644 --- a/mypyc/irbuild/ast_helpers.py +++ b/mypyc/irbuild/ast_helpers.py @@ -93,7 +93,12 @@ def maybe_process_conditional_comparison( self.add_bool_branch(reg, true, false) else: # "left op right" for two tagged integers - self.builder.compare_tagged_condition(left, right, op, true, false, e.line) + if op in ("==", "!="): + reg = self.builder.binary_op(left, right, op, e.line) + self.flush_keep_alives() + self.add_bool_branch(reg, true, false) + else: + self.builder.compare_tagged_condition(left, right, op, true, false, e.line) return True diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 81e37953809f..021b7a1dbe90 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -756,7 +756,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: set_literal = precompute_set_literal(builder, e.operands[1]) if set_literal is not None: lhs = e.operands[0] - result = builder.builder.call_c( + result = builder.builder.primitive_op( set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive ) if first_op == "not in": @@ -778,7 +778,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: borrow_left = is_borrow_friendly_expr(builder, right_expr) left = builder.accept(left_expr, can_borrow=borrow_left) right = builder.accept(right_expr, can_borrow=True) - return builder.compare_tagged(left, right, first_op, e.line) + return builder.binary_op(left, right, first_op, e.line) # TODO: Don't produce an expression when used in conditional context # All of the trickiness here is due to support for chained conditionals diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 45c06e11befd..f9bacb43bc3e 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -63,6 +63,8 @@ LoadStatic, MethodCall, Op, + PrimitiveDescription, + PrimitiveOp, RaiseStandardError, Register, SetMem, @@ -1313,7 +1315,12 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: return self.compare_strings(lreg, rreg, op, line) if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="): return self.compare_bytes(lreg, rreg, op, line) - if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping: + if ( + is_tagged(ltype) + and is_tagged(rtype) + and op in int_comparison_op_mapping + and op not in ("==", "!=") + ): return self.compare_tagged(lreg, rreg, op, line) if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS: if op in ComparisonOp.signed_ops: @@ -1379,13 +1386,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: # Mixed int comparisons if op in ("==", "!="): - op_id = ComparisonOp.signed_ops[op] - if is_tagged(ltype) and is_subtype(rtype, ltype): - rreg = self.coerce(rreg, int_rprimitive, line) - return self.comparison_op(lreg, rreg, op_id, line) - if is_tagged(rtype) and is_subtype(ltype, rtype): - lreg = self.coerce(lreg, int_rprimitive, line) - return self.comparison_op(lreg, rreg, op_id, line) + pass # TODO: Do we need anything here? elif op in op in int_comparison_op_mapping: if is_tagged(ltype) and is_subtype(rtype, ltype): rreg = self.coerce(rreg, short_int_rprimitive, line) @@ -1412,8 +1413,8 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: if base_op in float_op_to_id: return self.float_op(lreg, rreg, base_op, line) - call_c_ops_candidates = binary_ops.get(op, []) - target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line) + primitive_ops_candidates = binary_ops.get(op, []) + target = self.matching_primitive_op(primitive_ops_candidates, [lreg, rreg], line) assert target, "Unsupported binary operation: %s" % op return target @@ -1432,7 +1433,14 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) - def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: """Compare two tagged integers using given operator (value context).""" # generate fast binary logic ops on short ints - if is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type): + if (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type)) and op in ( + "==", + "!=", + ): + quick = True + else: + quick = is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type) + if quick: return self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line) op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op] result = Register(bool_rprimitive) @@ -1986,6 +1994,102 @@ def matching_call_c( return target return None + def primitive_op( + self, + desc: PrimitiveDescription, + args: list[Value], + line: int, + result_type: RType | None = None, + ) -> Value: + """Add a primitive op.""" + # Does this primitive map into calling a Python C API + # or an internal mypyc C API function? + if desc.c_function_name: + # TODO: Generate PrimitiOps here and transform them into CallC + # ops only later in the lowering pass + c_desc = CFunctionDescription( + desc.name, + desc.arg_types, + desc.return_type, + desc.var_arg_type, + desc.truncated_type, + desc.c_function_name, + desc.error_kind, + desc.steals, + desc.is_borrowed, + desc.ordering, + desc.extra_int_constants, + desc.priority, + ) + return self.call_c(c_desc, args, line, result_type) + + # This primitve gets transformed in a lowering pass to + # lower-level IR ops using a custom transform function. + + coerced = [] + # Coerce fixed number arguments + for i in range(min(len(args), len(desc.arg_types))): + formal_type = desc.arg_types[i] + arg = args[i] + assert formal_type is not None # TODO + arg = self.coerce(arg, formal_type, line) + coerced.append(arg) + assert desc.ordering is None + assert desc.var_arg_type is None + assert not desc.extra_int_constants + target = self.add(PrimitiveOp(coerced, desc, line=line)) + if desc.is_borrowed: + # If the result is borrowed, force the arguments to be + # kept alive afterwards, as otherwise the result might be + # immediately freed, at the risk of a dangling pointer. + for arg in coerced: + if not isinstance(arg, (Integer, LoadLiteral)): + self.keep_alives.append(arg) + if desc.error_kind == ERR_NEG_INT: + comp = ComparisonOp(target, Integer(0, desc.return_type, line), ComparisonOp.SGE, line) + comp.error_kind = ERR_FALSE + self.add(comp) + + assert desc.truncated_type is None + result = target + if result_type and not is_runtime_subtype(result.type, result_type): + if is_none_rprimitive(result_type): + # Special case None return. The actual result may actually be a bool + # and so we can't just coerce it. + result = self.none() + else: + result = self.coerce(result, result_type, line, can_borrow=desc.is_borrowed) + return result + + def matching_primitive_op( + self, + candidates: list[PrimitiveDescription], + args: list[Value], + line: int, + result_type: RType | None = None, + can_borrow: bool = False, + ) -> Value | None: + matching: PrimitiveDescription | None = None + for desc in candidates: + if len(desc.arg_types) != len(args): + continue + if all( + # formal is not None and # TODO + is_subtype(actual.type, formal) + for actual, formal in zip(args, desc.arg_types) + ) and (not desc.is_borrowed or can_borrow): + if matching: + assert matching.priority != desc.priority, "Ambiguous:\n1) {}\n2) {}".format( + matching, desc + ) + if desc.priority > matching.priority: + matching = desc + else: + matching = desc + if matching: + return self.primitive_op(matching, args, line=line) + return None + def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int = -1) -> Value: """Generate a native integer binary op. diff --git a/mypyc/lower/__init__.py b/mypyc/lower/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/mypyc/lower/int_ops.py b/mypyc/lower/int_ops.py new file mode 100644 index 000000000000..40fba7af4f4d --- /dev/null +++ b/mypyc/lower/int_ops.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from mypyc.ir.ops import Value +from mypyc.irbuild.ll_builder import LowLevelIRBuilder +from mypyc.lower.registry import lower_binary_op + + +@lower_binary_op("int_eq") +def lower_int_eq(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value: + return builder.compare_tagged(args[0], args[1], "==", line) + + +@lower_binary_op("int_ne") +def lower_int_ne(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value: + return builder.compare_tagged(args[0], args[1], "!=", line) diff --git a/mypyc/lower/registry.py b/mypyc/lower/registry.py new file mode 100644 index 000000000000..cc53eb93f4dd --- /dev/null +++ b/mypyc/lower/registry.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Callable, Final, List + +from mypyc.ir.ops import Value +from mypyc.irbuild.ll_builder import LowLevelIRBuilder + +LowerFunc = Callable[[LowLevelIRBuilder, List[Value], int], Value] + + +lowering_registry: Final[dict[str, LowerFunc]] = {} + + +def lower_binary_op(name: str) -> Callable[[LowerFunc], LowerFunc]: + """Register a handler that generates low-level IR for a primitive binary op.""" + + def wrapper(f: LowerFunc) -> LowerFunc: + assert name not in lowering_registry + lowering_registry[name] = f + return f + + return wrapper + + +# Import various modules that set up global state. +import mypyc.lower.int_ops # noqa: F401 diff --git a/mypyc/primitives/int_ops.py b/mypyc/primitives/int_ops.py index 95f9cc5ff43f..4103fe349a74 100644 --- a/mypyc/primitives/int_ops.py +++ b/mypyc/primitives/int_ops.py @@ -12,7 +12,14 @@ from typing import NamedTuple -from mypyc.ir.ops import ERR_ALWAYS, ERR_MAGIC, ERR_MAGIC_OVERLAPPING, ERR_NEVER, ComparisonOp +from mypyc.ir.ops import ( + ERR_ALWAYS, + ERR_MAGIC, + ERR_MAGIC_OVERLAPPING, + ERR_NEVER, + ComparisonOp, + PrimitiveDescription, +) from mypyc.ir.rtypes import ( RType, bit_rprimitive, @@ -101,6 +108,22 @@ ) +def int_binary_primitive( + op: str, primitive_name: str, return_type: RType = int_rprimitive, error_kind: int = ERR_NEVER +) -> PrimitiveDescription: + return binary_op( + name=op, + arg_types=[int_rprimitive, int_rprimitive], + return_type=return_type, + primitive_name=primitive_name, + error_kind=error_kind, + ) + + +int_eq = int_binary_primitive(op="==", primitive_name="int_eq", return_type=bit_rprimitive) +int_ne = int_binary_primitive(op="!=", primitive_name="int_ne", return_type=bit_rprimitive) + + def int_binary_op( name: str, c_function_name: str, diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index 11fca7dc2c70..d4768b4df532 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -39,7 +39,7 @@ from typing import Final, NamedTuple -from mypyc.ir.ops import StealsDescription +from mypyc.ir.ops import PrimitiveDescription, StealsDescription from mypyc.ir.rtypes import RType # Error kind for functions that return negative integer on exception. This @@ -76,7 +76,7 @@ class LoadAddressDescription(NamedTuple): function_ops: dict[str, list[CFunctionDescription]] = {} # CallC op for binary ops -binary_ops: dict[str, list[CFunctionDescription]] = {} +binary_ops: dict[str, list[PrimitiveDescription]] = {} # CallC op for unary ops unary_ops: dict[str, list[CFunctionDescription]] = {} @@ -192,8 +192,9 @@ def binary_op( name: str, arg_types: list[RType], return_type: RType, - c_function_name: str, error_kind: int, + c_function_name: str | None = None, + primitive_name: str | None = None, var_arg_type: RType | None = None, truncated_type: RType | None = None, ordering: list[int] | None = None, @@ -201,7 +202,7 @@ def binary_op( steals: StealsDescription = False, is_borrowed: bool = False, priority: int = 1, -) -> CFunctionDescription: +) -> PrimitiveDescription: """Define a c function call op for a binary operation. This will be automatically generated by matching against the AST. @@ -209,22 +210,24 @@ def binary_op( Most arguments are similar to method_op(), but exactly two argument types are expected. """ + assert c_function_name is not None or primitive_name is not None + assert not (c_function_name is not None and primitive_name is not None) if extra_int_constants is None: extra_int_constants = [] ops = binary_ops.setdefault(name, []) - desc = CFunctionDescription( - name, - arg_types, - return_type, - var_arg_type, - truncated_type, - c_function_name, - error_kind, - steals, - is_borrowed, - ordering, - extra_int_constants, - priority, + desc = PrimitiveDescription( + name=primitive_name or name, + arg_types=arg_types, + return_type=return_type, + var_arg_type=var_arg_type, + truncated_type=truncated_type, + c_function_name=c_function_name, + error_kind=error_kind, + steals=steals, + is_borrowed=is_borrowed, + ordering=ordering, + extra_int_constants=extra_int_constants, + priority=priority, ) ops.append(desc) return desc @@ -311,11 +314,10 @@ def load_address_op(name: str, type: RType, src: str) -> LoadAddressDescription: return LoadAddressDescription(name, type, src) +# Import various modules that set up global state. import mypyc.primitives.bytes_ops import mypyc.primitives.dict_ops import mypyc.primitives.float_ops - -# Import various modules that set up global state. import mypyc.primitives.int_ops import mypyc.primitives.list_ops import mypyc.primitives.misc_ops diff --git a/mypyc/test-data/analysis.test b/mypyc/test-data/analysis.test index efd219cc222a..8e067aed4d79 100644 --- a/mypyc/test-data/analysis.test +++ b/mypyc/test-data/analysis.test @@ -10,40 +10,27 @@ def f(a: int) -> None: [out] def f(a): a, x :: int - r0 :: native_int - r1, r2, r3 :: bit + r0 :: bit y, z :: int L0: x = 2 - r0 = x & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq x, a + if r0 goto L1 else goto L2 :: bool L1: - r2 = CPyTagged_IsEq_(x, a) - if r2 goto L3 else goto L4 :: bool -L2: - r3 = x == a - if r3 goto L3 else goto L4 :: bool -L3: y = 2 - goto L5 -L4: + goto L3 +L2: z = 2 -L5: +L3: return 1 (0, 0) {a} {a, x} (0, 1) {a, x} {a, x} (0, 2) {a, x} {a, x} -(0, 3) {a, x} {a, x} -(1, 0) {a, x} {a, x} -(1, 1) {a, x} {a, x} -(2, 0) {a, x} {a, x} -(2, 1) {a, x} {a, x} -(3, 0) {a, x} {a, x, y} -(3, 1) {a, x, y} {a, x, y} -(4, 0) {a, x} {a, x, z} -(4, 1) {a, x, z} {a, x, z} -(5, 0) {a, x, y, z} {a, x, y, z} +(1, 0) {a, x} {a, x, y} +(1, 1) {a, x, y} {a, x, y} +(2, 0) {a, x} {a, x, z} +(2, 1) {a, x, z} {a, x, z} +(3, 0) {a, x, y, z} {a, x, y, z} [case testSimple_Liveness] def f(a: int) -> int: @@ -58,7 +45,7 @@ def f(a): r0 :: bit L0: x = 2 - r0 = x == 2 + r0 = int_eq x, 2 if r0 goto L1 else goto L2 :: bool L1: return a @@ -124,7 +111,7 @@ def f(a): r0 :: bit y, x :: int L0: - r0 = a == 2 + r0 = int_eq a, 2 if r0 goto L1 else goto L2 :: bool L1: y = 2 @@ -421,40 +408,27 @@ def f(a: int) -> int: [out] def f(a): a :: int - r0 :: native_int - r1, r2, r3 :: bit + r0 :: bit x :: int L0: - r0 = a & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq a, a + if r0 goto L1 else goto L2 :: bool L1: - r2 = CPyTagged_IsEq_(a, a) - if r2 goto L3 else goto L4 :: bool -L2: - r3 = a == a - if r3 goto L3 else goto L4 :: bool -L3: x = 4 a = 2 - goto L5 -L4: + goto L3 +L2: x = 2 -L5: +L3: return x (0, 0) {a} {a} (0, 1) {a} {a} -(0, 2) {a} {a} (1, 0) {a} {a} -(1, 1) {a} {a} +(1, 1) {a} {} +(1, 2) {} {} (2, 0) {a} {a} (2, 1) {a} {a} -(3, 0) {a} {a} -(3, 1) {a} {} -(3, 2) {} {} -(4, 0) {a} {a} -(4, 1) {a} {a} -(5, 0) {} {} +(3, 0) {} {} [case testLoop_BorrowedArgument] def f(a: int) -> int: diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index cd952ef2ebfd..981460dae371 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -568,7 +568,7 @@ L3: x = 2 goto L8 L4: - r4 = n == 0 + r4 = int_eq n, 0 if r4 goto L5 else goto L6 :: bool L5: x = 2 @@ -598,7 +598,7 @@ def f(n): r0 :: bit r1 :: int L0: - r0 = n == 0 + r0 = int_eq n, 0 if r0 goto L1 else goto L2 :: bool L1: r1 = 0 @@ -1462,7 +1462,7 @@ L0: r1 = load_mem r0 :: native_int* keep_alive x r2 = r1 << 1 - r3 = r2 != 0 + r3 = int_ne r2, 0 if r3 goto L1 else goto L2 :: bool L1: return 2 @@ -2052,19 +2052,12 @@ def f(): r13 :: bit r14 :: object r15, x :: int - r16 :: native_int - r17, r18 :: bit - r19 :: bool - r20, r21 :: bit - r22 :: native_int - r23, r24 :: bit - r25 :: bool - r26, r27 :: bit - r28 :: int - r29 :: object - r30 :: i32 - r31 :: bit - r32 :: short_int + r16, r17 :: bit + r18 :: int + r19 :: object + r20 :: i32 + r21 :: bit + r22 :: short_int L0: r0 = PyList_New(0) r1 = PyList_New(3) @@ -2086,52 +2079,30 @@ L1: keep_alive r1 r12 = r11 << 1 r13 = r9 < r12 :: signed - if r13 goto L2 else goto L14 :: bool + if r13 goto L2 else goto L8 :: bool L2: r14 = CPyList_GetItemUnsafe(r1, r9) r15 = unbox(int, r14) x = r15 - r16 = x & 1 - r17 = r16 == 0 - if r17 goto L3 else goto L4 :: bool + r16 = int_ne x, 4 + if r16 goto L4 else goto L3 :: bool L3: - r18 = x != 4 - r19 = r18 - goto L5 + goto L7 L4: - r20 = CPyTagged_IsEq_(x, 4) - r21 = r20 ^ 1 - r19 = r21 + r17 = int_ne x, 6 + if r17 goto L6 else goto L5 :: bool L5: - if r19 goto L7 else goto L6 :: bool + goto L7 L6: - goto L13 + r18 = CPyTagged_Multiply(x, x) + r19 = box(int, r18) + r20 = PyList_Append(r0, r19) + r21 = r20 >= 0 :: signed L7: - r22 = x & 1 - r23 = r22 == 0 - if r23 goto L8 else goto L9 :: bool -L8: - r24 = x != 6 - r25 = r24 - goto L10 -L9: - r26 = CPyTagged_IsEq_(x, 6) - r27 = r26 ^ 1 - r25 = r27 -L10: - if r25 goto L12 else goto L11 :: bool -L11: - goto L13 -L12: - r28 = CPyTagged_Multiply(x, x) - r29 = box(int, r28) - r30 = PyList_Append(r0, r29) - r31 = r30 >= 0 :: signed -L13: - r32 = r9 + 2 - r9 = r32 + r22 = r9 + 2 + r9 = r22 goto L1 -L14: +L8: return r0 [case testDictComprehension] @@ -2151,19 +2122,12 @@ def f(): r13 :: bit r14 :: object r15, x :: int - r16 :: native_int - r17, r18 :: bit - r19 :: bool - r20, r21 :: bit - r22 :: native_int - r23, r24 :: bit - r25 :: bool - r26, r27 :: bit - r28 :: int - r29, r30 :: object - r31 :: i32 - r32 :: bit - r33 :: short_int + r16, r17 :: bit + r18 :: int + r19, r20 :: object + r21 :: i32 + r22 :: bit + r23 :: short_int L0: r0 = PyDict_New() r1 = PyList_New(3) @@ -2185,53 +2149,31 @@ L1: keep_alive r1 r12 = r11 << 1 r13 = r9 < r12 :: signed - if r13 goto L2 else goto L14 :: bool + if r13 goto L2 else goto L8 :: bool L2: r14 = CPyList_GetItemUnsafe(r1, r9) r15 = unbox(int, r14) x = r15 - r16 = x & 1 - r17 = r16 == 0 - if r17 goto L3 else goto L4 :: bool + r16 = int_ne x, 4 + if r16 goto L4 else goto L3 :: bool L3: - r18 = x != 4 - r19 = r18 - goto L5 + goto L7 L4: - r20 = CPyTagged_IsEq_(x, 4) - r21 = r20 ^ 1 - r19 = r21 + r17 = int_ne x, 6 + if r17 goto L6 else goto L5 :: bool L5: - if r19 goto L7 else goto L6 :: bool + goto L7 L6: - goto L13 + r18 = CPyTagged_Multiply(x, x) + r19 = box(int, x) + r20 = box(int, r18) + r21 = CPyDict_SetItem(r0, r19, r20) + r22 = r21 >= 0 :: signed L7: - r22 = x & 1 - r23 = r22 == 0 - if r23 goto L8 else goto L9 :: bool -L8: - r24 = x != 6 - r25 = r24 - goto L10 -L9: - r26 = CPyTagged_IsEq_(x, 6) - r27 = r26 ^ 1 - r25 = r27 -L10: - if r25 goto L12 else goto L11 :: bool -L11: - goto L13 -L12: - r28 = CPyTagged_Multiply(x, x) - r29 = box(int, x) - r30 = box(int, r28) - r31 = CPyDict_SetItem(r0, r29, r30) - r32 = r31 >= 0 :: signed -L13: - r33 = r9 + 2 - r9 = r33 + r23 = r9 + 2 + r9 = r23 goto L1 -L14: +L8: return r0 [case testLoopsMultipleAssign] @@ -3011,85 +2953,57 @@ def call_any(l): r0 :: bool r1, r2 :: object r3, i :: int - r4 :: native_int - r5, r6 :: bit - r7 :: bool - r8, r9 :: bit + r4, r5 :: bit L0: r0 = 0 r1 = PyObject_GetIter(l) L1: r2 = PyIter_Next(r1) - if is_error(r2) goto L9 else goto L2 + if is_error(r2) goto L6 else goto L2 L2: r3 = unbox(int, r2) i = r3 - r4 = i & 1 - r5 = r4 == 0 - if r5 goto L3 else goto L4 :: bool + r4 = int_eq i, 0 + if r4 goto L3 else goto L4 :: bool L3: - r6 = i == 0 - r7 = r6 - goto L5 + r0 = 1 + goto L8 L4: - r8 = CPyTagged_IsEq_(i, 0) - r7 = r8 L5: - if r7 goto L6 else goto L7 :: bool + goto L1 L6: - r0 = 1 - goto L11 + r5 = CPy_NoErrOccured() L7: L8: - goto L1 -L9: - r9 = CPy_NoErrOccured() -L10: -L11: return r0 def call_all(l): l :: object r0 :: bool r1, r2 :: object r3, i :: int - r4 :: native_int - r5, r6 :: bit - r7 :: bool - r8 :: bit - r9 :: bool - r10 :: bit + r4, r5, r6 :: bit L0: r0 = 1 r1 = PyObject_GetIter(l) L1: r2 = PyIter_Next(r1) - if is_error(r2) goto L9 else goto L2 + if is_error(r2) goto L6 else goto L2 L2: r3 = unbox(int, r2) i = r3 - r4 = i & 1 - r5 = r4 == 0 + r4 = int_eq i, 0 + r5 = r4 ^ 1 if r5 goto L3 else goto L4 :: bool L3: - r6 = i == 0 - r7 = r6 - goto L5 + r0 = 0 + goto L8 L4: - r8 = CPyTagged_IsEq_(i, 0) - r7 = r8 L5: - r9 = r7 ^ 1 - if r9 goto L6 else goto L7 :: bool + goto L1 L6: - r0 = 0 - goto L11 + r6 = CPy_NoErrOccured() L7: L8: - goto L1 -L9: - r10 = CPy_NoErrOccured() -L10: -L11: return r0 [case testSum] diff --git a/mypyc/test-data/irbuild-bool.test b/mypyc/test-data/irbuild-bool.test index 731d393d69ab..f0b0b480bc0d 100644 --- a/mypyc/test-data/irbuild-bool.test +++ b/mypyc/test-data/irbuild-bool.test @@ -96,7 +96,7 @@ L0: r1 = load_mem r0 :: native_int* keep_alive l r2 = r1 << 1 - r3 = r2 != 0 + r3 = int_ne r2, 0 return r3 def always_truthy_instance_to_bool(o): o :: __main__.C @@ -222,7 +222,7 @@ def eq1(x, y): L0: r0 = y << 1 r1 = extend r0: builtins.bool to builtins.int - r2 = x == r1 + r2 = int_eq x, r1 return r2 def eq2(x, y): x :: bool @@ -233,7 +233,7 @@ def eq2(x, y): L0: r0 = x << 1 r1 = extend r0: builtins.bool to builtins.int - r2 = r1 == y + r2 = int_eq r1, y return r2 def neq1(x, y): x :: i64 diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 55e55dbf3286..8c4743c6a47f 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1249,7 +1249,7 @@ L0: r0 = x.__getitem__(2) r1 = CPyList_GetItemShortBorrow(r0, 0) r2 = unbox(int, r1) - r3 = r2 == 4 + r3 = int_eq r2, 4 keep_alive r0 if r3 goto L1 else goto L2 :: bool L1: diff --git a/mypyc/test-data/irbuild-int.test b/mypyc/test-data/irbuild-int.test index fbe00aff4040..1489f2f470dd 100644 --- a/mypyc/test-data/irbuild-int.test +++ b/mypyc/test-data/irbuild-int.test @@ -4,24 +4,10 @@ def f(x: int, y: int) -> bool: [out] def f(x, y): x, y :: int - r0 :: native_int - r1, r2 :: bit - r3 :: bool - r4, r5 :: bit + r0 :: bit L0: - r0 = x & 1 - r1 = r0 == 0 - if r1 goto L1 else goto L2 :: bool -L1: - r2 = x != y - r3 = r2 - goto L3 -L2: - r4 = CPyTagged_IsEq_(x, y) - r5 = r4 ^ 1 - r3 = r5 -L3: - return r3 + r0 = int_ne x, y + return r0 [case testShortIntComparisons] def f(x: int) -> int: @@ -43,22 +29,22 @@ def f(x): r4 :: native_int r5, r6, r7 :: bit L0: - r0 = x == 6 + r0 = int_eq x, 6 if r0 goto L1 else goto L2 :: bool L1: return 2 L2: - r1 = x != 8 + r1 = int_ne x, 8 if r1 goto L3 else goto L4 :: bool L3: return 4 L4: - r2 = 10 == x + r2 = int_eq 10, x if r2 goto L5 else goto L6 :: bool L5: return 6 L6: - r3 = 12 != x + r3 = int_ne 12, x if r3 goto L7 else goto L8 :: bool L7: return 8 diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index a078ae0defdb..ab5a19624ba6 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -14,7 +14,7 @@ def f(): r6 :: object_ptr r7, r8 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 if r0 goto L1 else goto L2 :: bool L1: r1 = 'matched' @@ -30,6 +30,7 @@ L2: L3: r8 = box(None, 1) return r8 + [case testMatchOrPattern_python3_10] def f(): match 123: @@ -46,10 +47,10 @@ def f(): r7 :: object_ptr r8, r9 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 if r0 goto L3 else goto L1 :: bool L1: - r1 = 246 == 912 + r1 = int_eq 246, 912 if r1 goto L3 else goto L2 :: bool L2: goto L4 @@ -67,6 +68,7 @@ L4: L5: r9 = box(None, 1) return r9 + [case testMatchOrPatternManyPatterns_python3_10] def f(): match 1: @@ -83,16 +85,16 @@ def f(): r9 :: object_ptr r10, r11 :: object L0: - r0 = 2 == 2 + r0 = int_eq 2, 2 if r0 goto L5 else goto L1 :: bool L1: - r1 = 2 == 4 + r1 = int_eq 2, 4 if r1 goto L5 else goto L2 :: bool L2: - r2 = 2 == 6 + r2 = int_eq 2, 6 if r2 goto L5 else goto L3 :: bool L3: - r3 = 2 == 8 + r3 = int_eq 2, 8 if r3 goto L5 else goto L4 :: bool L4: goto L6 @@ -110,6 +112,7 @@ L6: L7: r11 = box(None, 1) return r11 + [case testMatchClassPattern_python3_10] def f(): match 123: @@ -200,7 +203,7 @@ def f(): r14 :: object_ptr r15, r16 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 if r0 goto L1 else goto L2 :: bool L1: r1 = 'matched' @@ -213,7 +216,7 @@ L1: keep_alive r1 goto L5 L2: - r8 = 246 == 912 + r8 = int_eq 246, 912 if r8 goto L3 else goto L4 :: bool L3: r9 = 'no match' @@ -229,6 +232,7 @@ L4: L5: r16 = box(None, 1) return r16 + [case testMatchMultiBodyAndComplexOr_python3_10] def f(): match 123: @@ -265,7 +269,7 @@ def f(): r23 :: object_ptr r24, r25 :: object L0: - r0 = 246 == 2 + r0 = int_eq 246, 2 if r0 goto L1 else goto L2 :: bool L1: r1 = 'here 1' @@ -278,10 +282,10 @@ L1: keep_alive r1 goto L9 L2: - r8 = 246 == 4 + r8 = int_eq 246, 4 if r8 goto L5 else goto L3 :: bool L3: - r9 = 246 == 6 + r9 = int_eq 246, 6 if r9 goto L5 else goto L4 :: bool L4: goto L6 @@ -296,7 +300,7 @@ L5: keep_alive r10 goto L9 L6: - r17 = 246 == 246 + r17 = int_eq 246, 246 if r17 goto L7 else goto L8 :: bool L7: r18 = 'here 123' @@ -312,6 +316,7 @@ L8: L9: r25 = box(None, 1) return r25 + [case testMatchWithGuard_python3_10] def f(): match 123: @@ -328,7 +333,7 @@ def f(): r6 :: object_ptr r7, r8 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 if r0 goto L1 else goto L3 :: bool L1: if 1 goto L2 else goto L3 :: bool @@ -346,6 +351,7 @@ L3: L4: r8 = box(None, 1) return r8 + [case testMatchSingleton_python3_10] def f(): match 123: @@ -449,7 +455,7 @@ def f(): r9 :: object_ptr r10, r11 :: object L0: - r0 = 2 == 2 + r0 = int_eq 2, 2 if r0 goto L3 else goto L1 :: bool L1: r1 = load_address PyLong_Type @@ -472,6 +478,7 @@ L4: L5: r11 = box(None, 1) return r11 + [case testMatchAsPattern_python3_10] def f(): match 123: @@ -487,7 +494,7 @@ def f(): r6 :: object_ptr r7, r8 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 r1 = object 123 x = r1 if r0 goto L1 else goto L2 :: bool @@ -504,6 +511,7 @@ L2: L3: r8 = box(None, 1) return r8 + [case testMatchAsPatternOnOrPattern_python3_10] def f(): match 1: @@ -521,12 +529,12 @@ def f(): r8 :: object_ptr r9, r10 :: object L0: - r0 = 2 == 2 + r0 = int_eq 2, 2 r1 = object 1 x = r1 if r0 goto L3 else goto L1 :: bool L1: - r2 = 2 == 4 + r2 = int_eq 2, 4 r3 = object 2 x = r3 if r2 goto L3 else goto L2 :: bool @@ -545,6 +553,7 @@ L4: L5: r10 = box(None, 1) return r10 + [case testMatchAsPatternOnClassPattern_python3_10] def f(): match 123: diff --git a/mypyc/test-data/irbuild-nested.test b/mypyc/test-data/irbuild-nested.test index b2b884705366..62ae6eb9ee35 100644 --- a/mypyc/test-data/irbuild-nested.test +++ b/mypyc/test-data/irbuild-nested.test @@ -658,7 +658,7 @@ def baz_f_obj.__call__(__mypyc_self__, n): r6, r7 :: int L0: r0 = __mypyc_self__.__mypyc_env__ - r1 = n == 0 + r1 = int_eq n, 0 if r1 goto L1 else goto L2 :: bool L1: return 0 @@ -796,7 +796,7 @@ def baz(n): r0 :: bit r1, r2, r3 :: int L0: - r0 = n == 0 + r0 = int_eq n, 0 if r0 goto L1 else goto L2 :: bool L1: return 0 diff --git a/mypyc/test-data/irbuild-optional.test b/mypyc/test-data/irbuild-optional.test index e89018a727da..75c008586999 100644 --- a/mypyc/test-data/irbuild-optional.test +++ b/mypyc/test-data/irbuild-optional.test @@ -222,7 +222,7 @@ def f(y): L0: r0 = box(None, 1) x = r0 - r1 = y == 2 + r1 = int_eq y, 2 if r1 goto L1 else goto L2 :: bool L1: r2 = box(int, y) diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index a47f3db6a725..ab0e2fa09a9d 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -195,70 +195,30 @@ def f(i: int) -> bool: [out] def f(i): i :: int - r0 :: native_int - r1, r2 :: bit + r0 :: bit + r1 :: bool + r2 :: bit r3 :: bool r4 :: bit - r5 :: bool - r6 :: native_int - r7, r8 :: bit - r9 :: bool - r10 :: bit - r11 :: bool - r12 :: native_int - r13, r14 :: bit - r15 :: bool - r16 :: bit L0: - r0 = i & 1 - r1 = r0 == 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq i, 2 + if r0 goto L1 else goto L2 :: bool L1: - r2 = i == 2 - r3 = r2 + r1 = r0 goto L3 L2: - r4 = CPyTagged_IsEq_(i, 2) - r3 = r4 + r2 = int_eq i, 4 + r1 = r2 L3: - if r3 goto L4 else goto L5 :: bool + if r1 goto L4 else goto L5 :: bool L4: - r5 = r3 - goto L9 + r3 = r1 + goto L6 L5: - r6 = i & 1 - r7 = r6 == 0 - if r7 goto L6 else goto L7 :: bool + r4 = int_eq i, 6 + r3 = r4 L6: - r8 = i == 4 - r9 = r8 - goto L8 -L7: - r10 = CPyTagged_IsEq_(i, 4) - r9 = r10 -L8: - r5 = r9 -L9: - if r5 goto L10 else goto L11 :: bool -L10: - r11 = r5 - goto L15 -L11: - r12 = i & 1 - r13 = r12 == 0 - if r13 goto L12 else goto L13 :: bool -L12: - r14 = i == 6 - r15 = r14 - goto L14 -L13: - r16 = CPyTagged_IsEq_(i, 6) - r15 = r16 -L14: - r11 = r15 -L15: - return r11 - + return r3 [case testTupleBuiltFromList] def f(val: int) -> bool: @@ -270,24 +230,11 @@ def test() -> None: [out] def f(val): val, r0 :: int - r1 :: native_int - r2, r3 :: bit - r4 :: bool - r5 :: bit + r1 :: bit L0: r0 = CPyTagged_Remainder(val, 4) - r1 = r0 & 1 - r2 = r1 == 0 - if r2 goto L1 else goto L2 :: bool -L1: - r3 = r0 == 0 - r4 = r3 - goto L3 -L2: - r5 = CPyTagged_IsEq_(r0, 0) - r4 = r5 -L3: - return r4 + r1 = int_eq r0, 0 + return r1 def test(): r0 :: list r1, r2, r3 :: object diff --git a/mypyc/test-data/lowering-int.test b/mypyc/test-data/lowering-int.test new file mode 100644 index 000000000000..8c813563d0e6 --- /dev/null +++ b/mypyc/test-data/lowering-int.test @@ -0,0 +1,126 @@ +-- Test cases for converting high-level IR to lower-level IR (lowering). + +[case testLowerIntEq] +def f(x: int, y: int) -> int: + if x == y: + return 1 + else: + return 2 +[out] +def f(x, y): + x, y :: int + r0 :: native_int + r1, r2, r3 :: bit +L0: + r0 = x & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = x == y + if r2 goto L3 else goto L4 :: bool +L2: + r3 = CPyTagged_IsEq_(x, y) + if r3 goto L3 else goto L4 :: bool +L3: + return 2 +L4: + return 4 + +[case testLowerIntNe] +def f(x: int, y: int) -> int: + if x != y: + return 1 + else: + return 2 +[out] +def f(x, y): + x, y :: int + r0 :: native_int + r1, r2, r3, r4 :: bit +L0: + r0 = x & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = x != y + if r2 goto L3 else goto L4 :: bool +L2: + r3 = CPyTagged_IsEq_(x, y) + r4 = r3 ^ 1 + if r4 goto L3 else goto L4 :: bool +L3: + return 2 +L4: + return 4 + +[case testLowerIntEqWithConstant] +def f(x: int, y: int) -> int: + if x == 2: + return 1 + elif -1 == x: + return 2 + return 3 +[out] +def f(x, y): + x, y :: int + r0, r1 :: bit +L0: + r0 = x == 4 + if r0 goto L1 else goto L2 :: bool +L1: + return 2 +L2: + r1 = -2 == x + if r1 goto L3 else goto L4 :: bool +L3: + return 4 +L4: + return 6 + +[case testLowerIntNeWithConstant] +def f(x: int, y: int) -> int: + if x != 2: + return 1 + elif -1 != x: + return 2 + return 3 +[out] +def f(x, y): + x, y :: int + r0, r1 :: bit +L0: + r0 = x != 4 + if r0 goto L1 else goto L2 :: bool +L1: + return 2 +L2: + r1 = -2 != x + if r1 goto L3 else goto L4 :: bool +L3: + return 4 +L4: + return 6 + +[case testLowerIntEqValueContext] +def f(x: int, y: int) -> bool: + return x == y +[out] +def f(x, y): + x, y :: int + r0 :: native_int + r1, r2 :: bit + r3 :: bool + r4 :: bit +L0: + r0 = x & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = x == y + r3 = r2 + goto L3 +L2: + r4 = CPyTagged_IsEq_(x, y) + r3 = r4 +L3: + return r3 diff --git a/mypyc/test-data/opt-flag-elimination.test b/mypyc/test-data/opt-flag-elimination.test index f047a87dc3fa..337ced70a355 100644 --- a/mypyc/test-data/opt-flag-elimination.test +++ b/mypyc/test-data/opt-flag-elimination.test @@ -29,15 +29,13 @@ L0: if x goto L1 else goto L2 :: bool L1: r0 = c() - if r0 goto L4 else goto L5 :: bool + if r0 goto L3 else goto L4 :: bool L2: r1 = d() - if r1 goto L4 else goto L5 :: bool + if r1 goto L3 else goto L4 :: bool L3: - unreachable -L4: return 2 -L5: +L4: return 4 [case testFlagEliminationOneAssignment] @@ -92,20 +90,18 @@ L0: if x goto L1 else goto L2 :: bool L1: r0 = c(2) - if r0 goto L6 else goto L7 :: bool + if r0 goto L5 else goto L6 :: bool L2: if y goto L3 else goto L4 :: bool L3: r1 = c(4) - if r1 goto L6 else goto L7 :: bool + if r1 goto L5 else goto L6 :: bool L4: r2 = c(6) - if r2 goto L6 else goto L7 :: bool + if r2 goto L5 else goto L6 :: bool L5: - unreachable -L6: return 2 -L7: +L6: return 4 [case testFlagEliminationAssignmentNotLastOp] diff --git a/mypyc/test-data/refcount.test b/mypyc/test-data/refcount.test index 0f2c134ae21e..df980af8a7c7 100644 --- a/mypyc/test-data/refcount.test +++ b/mypyc/test-data/refcount.test @@ -67,7 +67,7 @@ def f(): L0: x = 2 y = 4 - r0 = x == 2 + r0 = int_eq x, 2 if r0 goto L3 else goto L4 :: bool L1: return x @@ -185,34 +185,26 @@ def f(a: int) -> int: [out] def f(a): a :: int - r0 :: native_int - r1, r2, r3 :: bit - x, r4, y :: int + r0 :: bit + x, r1, y :: int L0: - r0 = a & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq a, a + if r0 goto L1 else goto L2 :: bool L1: - r2 = CPyTagged_IsEq_(a, a) - if r2 goto L3 else goto L4 :: bool -L2: - r3 = a == a - if r3 goto L3 else goto L4 :: bool -L3: a = 2 - goto L5 -L4: + goto L3 +L2: x = 4 dec_ref x :: int - goto L6 -L5: - r4 = CPyTagged_Add(a, 2) + goto L4 +L3: + r1 = CPyTagged_Add(a, 2) dec_ref a :: int - y = r4 + y = r1 return y -L6: +L4: inc_ref a :: int - goto L5 + goto L3 [case testConditionalAssignToArgument2] def f(a: int) -> int: @@ -225,33 +217,25 @@ def f(a: int) -> int: [out] def f(a): a :: int - r0 :: native_int - r1, r2, r3 :: bit - x, r4, y :: int + r0 :: bit + x, r1, y :: int L0: - r0 = a & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq a, a + if r0 goto L1 else goto L2 :: bool L1: - r2 = CPyTagged_IsEq_(a, a) - if r2 goto L3 else goto L4 :: bool -L2: - r3 = a == a - if r3 goto L3 else goto L4 :: bool -L3: x = 4 dec_ref x :: int - goto L6 -L4: + goto L4 +L2: a = 2 -L5: - r4 = CPyTagged_Add(a, 2) +L3: + r1 = CPyTagged_Add(a, 2) dec_ref a :: int - y = r4 + y = r1 return y -L6: +L4: inc_ref a :: int - goto L5 + goto L3 [case testConditionalAssignToArgument3] def f(a: int) -> int: @@ -261,25 +245,17 @@ def f(a: int) -> int: [out] def f(a): a :: int - r0 :: native_int - r1, r2, r3 :: bit + r0 :: bit L0: - r0 = a & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq a, a + if r0 goto L1 else goto L3 :: bool L1: - r2 = CPyTagged_IsEq_(a, a) - if r2 goto L3 else goto L5 :: bool -L2: - r3 = a == a - if r3 goto L3 else goto L5 :: bool -L3: a = 2 -L4: +L2: return a -L5: +L3: inc_ref a :: int - goto L4 + goto L2 [case testAssignRegisterToItself] def f(a: int) -> int: @@ -438,40 +414,32 @@ def f() -> int: [out] def f(): x, y, z :: int - r0 :: native_int - r1, r2, r3 :: bit - a, r4, r5 :: int + r0 :: bit + a, r1, r2 :: int L0: x = 2 y = 4 z = 6 - r0 = z & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq z, z + if r0 goto L3 else goto L4 :: bool L1: - r2 = CPyTagged_IsEq_(z, z) - if r2 goto L5 else goto L6 :: bool -L2: - r3 = z == z - if r3 goto L5 else goto L6 :: bool -L3: return z -L4: +L2: a = 2 - r4 = CPyTagged_Add(x, y) + r1 = CPyTagged_Add(x, y) dec_ref x :: int dec_ref y :: int - r5 = CPyTagged_Subtract(r4, a) - dec_ref r4 :: int + r2 = CPyTagged_Subtract(r1, a) + dec_ref r1 :: int dec_ref a :: int - return r5 -L5: + return r2 +L3: dec_ref x :: int dec_ref y :: int - goto L3 -L6: + goto L1 +L4: dec_ref z :: int - goto L4 + goto L2 [case testLoop] def f(a: int) -> int: @@ -1371,25 +1339,12 @@ class C: def add(c): c :: __main__.C r0, r1 :: int - r2 :: native_int - r3, r4 :: bit - r5 :: bool - r6 :: bit + r2 :: bit L0: r0 = borrow c.x r1 = borrow c.y - r2 = r0 & 1 - r3 = r2 == 0 - if r3 goto L1 else goto L2 :: bool -L1: - r4 = r0 == r1 - r5 = r4 - goto L3 -L2: - r6 = CPyTagged_IsEq_(r0, r1) - r5 = r6 -L3: - return r5 + r2 = int_eq r0, r1 + return r2 [case testBorrowIntLessThan] def add(c: C) -> bool: @@ -1441,24 +1396,11 @@ class C: def add(c): c :: __main__.C r0 :: int - r1 :: native_int - r2, r3 :: bit - r4 :: bool - r5 :: bit + r1 :: bit L0: r0 = borrow c.x - r1 = r0 & 1 - r2 = r1 == 0 - if r2 goto L1 else goto L2 :: bool -L1: - r3 = r0 == 20 - r4 = r3 - goto L3 -L2: - r5 = CPyTagged_IsEq_(r0, 20) - r4 = r5 -L3: - return r4 + r1 = int_eq r0, 20 + return r1 [case testBorrowIntArithmetic] def add(c: C) -> int: @@ -1501,23 +1443,15 @@ class C: def add(c, n): c :: __main__.C n, r0, r1 :: int - r2 :: native_int - r3, r4, r5 :: bit + r2 :: bit L0: r0 = borrow c.x r1 = borrow c.y - r2 = r0 & 1 - r3 = r2 != 0 - if r3 goto L1 else goto L2 :: bool + r2 = int_eq r0, r1 + if r2 goto L1 else goto L2 :: bool L1: - r4 = CPyTagged_IsEq_(r0, r1) - if r4 goto L3 else goto L4 :: bool -L2: - r5 = r0 == r1 - if r5 goto L3 else goto L4 :: bool -L3: return 1 -L4: +L2: return 0 [case testBorrowIntInPlaceOp] diff --git a/mypyc/test/test_cheader.py b/mypyc/test/test_cheader.py index cc0fd9df2b34..f2af41c22ea9 100644 --- a/mypyc/test/test_cheader.py +++ b/mypyc/test/test_cheader.py @@ -7,6 +7,7 @@ import re import unittest +from mypyc.ir.ops import PrimitiveDescription from mypyc.primitives import registry from mypyc.primitives.registry import CFunctionDescription @@ -25,17 +26,24 @@ def check_name(name: str) -> None: rf"\b{name}\b", header ), f'"{name}" is used in mypyc.primitives but not declared in CPy.h' - for values in [ + for old_values in [ registry.method_call_ops.values(), registry.function_ops.values(), - registry.binary_ops.values(), registry.unary_ops.values(), ]: + for old_ops in old_values: + if isinstance(old_ops, CFunctionDescription): + old_ops = [old_ops] + for old_op in old_ops: + check_name(old_op.c_function_name) + + for values in [registry.binary_ops.values()]: for ops in values: - if isinstance(ops, CFunctionDescription): + if isinstance(ops, PrimitiveDescription): ops = [ops] for op in ops: - check_name(op.c_function_name) + if op.c_function_name is not None: + check_name(op.c_function_name) primitives_path = os.path.join(os.path.dirname(__file__), "..", "primitives") for fnam in glob.glob(f"{primitives_path}/*.py"): diff --git a/mypyc/test/test_emitfunc.py b/mypyc/test/test_emitfunc.py index ab1586bb22a8..b16387aa40af 100644 --- a/mypyc/test/test_emitfunc.py +++ b/mypyc/test/test_emitfunc.py @@ -859,6 +859,8 @@ def assert_emit_binary_op( args = [left, right] if desc.ordering is not None: args = [args[i] for i in desc.ordering] + # This only supports primitives that map to C calls + assert desc.c_function_name is not None self.assert_emit( CallC( desc.c_function_name, diff --git a/mypyc/test/test_lowering.py b/mypyc/test/test_lowering.py new file mode 100644 index 000000000000..e32dba2e1021 --- /dev/null +++ b/mypyc/test/test_lowering.py @@ -0,0 +1,54 @@ +"""Runner for lowering transform tests.""" + +from __future__ import annotations + +import os.path + +from mypy.errors import CompileError +from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase +from mypyc.common import TOP_LEVEL_NAME +from mypyc.ir.pprint import format_func +from mypyc.options import CompilerOptions +from mypyc.test.testutil import ( + ICODE_GEN_BUILTINS, + MypycDataSuite, + assert_test_output, + build_ir_for_single_file, + remove_comment_lines, + use_custom_builtins, +) +from mypyc.transform.exceptions import insert_exception_handling +from mypyc.transform.flag_elimination import do_flag_elimination +from mypyc.transform.lower import lower_ir +from mypyc.transform.refcount import insert_ref_count_opcodes +from mypyc.transform.uninit import insert_uninit_checks + + +class TestLowering(MypycDataSuite): + files = ["lowering-int.test"] + base_path = test_temp_dir + + def run_case(self, testcase: DataDrivenTestCase) -> None: + with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase): + expected_output = remove_comment_lines(testcase.output) + try: + ir = build_ir_for_single_file(testcase.input) + except CompileError as e: + actual = e.messages + else: + actual = [] + for fn in ir: + if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"): + continue + options = CompilerOptions() + # Lowering happens after exception handling and ref count opcodes have + # been added. Any changes must maintain reference counting semantics. + insert_uninit_checks(fn) + insert_exception_handling(fn) + insert_ref_count_opcodes(fn) + lower_ir(fn, options) + do_flag_elimination(fn, options) + actual.extend(format_func(fn)) + + assert_test_output(testcase, actual, "Invalid source code output", expected_output) diff --git a/mypyc/transform/ir_transform.py b/mypyc/transform/ir_transform.py index 254fe3f7771d..a631bd7352b5 100644 --- a/mypyc/transform/ir_transform.py +++ b/mypyc/transform/ir_transform.py @@ -35,6 +35,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, Return, SetAttr, @@ -80,6 +81,7 @@ def transform_blocks(self, blocks: list[BasicBlock]) -> None: """ block_map: dict[BasicBlock, BasicBlock] = {} op_map = self.op_map + empties = set() for block in blocks: new_block = BasicBlock() block_map[block] = new_block @@ -89,7 +91,10 @@ def transform_blocks(self, blocks: list[BasicBlock]) -> None: new_op = op.accept(self) if new_op is not op: op_map[op] = new_op - + # A transform can produce empty blocks which can be removed. + if is_empty_block(new_block) and not is_empty_block(block): + empties.add(new_block) + self.builder.blocks = [block for block in self.builder.blocks if block not in empties] # Update all op/block references to point to the transformed ones. patcher = PatchVisitor(op_map, block_map) for block in self.builder.blocks: @@ -170,6 +175,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> Value | None: def visit_call_c(self, op: CallC) -> Value | None: return self.add(op) + def visit_primitive_op(self, op: PrimitiveOp) -> Value | None: + return self.add(op) + def visit_truncate(self, op: Truncate) -> Value | None: return self.add(op) @@ -302,6 +310,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None: def visit_call_c(self, op: CallC) -> None: op.args = [self.fix_op(arg) for arg in op.args] + def visit_primitive_op(self, op: PrimitiveOp) -> None: + op.args = [self.fix_op(arg) for arg in op.args] + def visit_truncate(self, op: Truncate) -> None: op.src = self.fix_op(op.src) @@ -351,3 +362,7 @@ def visit_keep_alive(self, op: KeepAlive) -> None: def visit_unborrow(self, op: Unborrow) -> None: op.src = self.fix_op(op.src) + + +def is_empty_block(block: BasicBlock) -> bool: + return len(block.ops) == 1 and isinstance(block.ops[0], Unreachable) diff --git a/mypyc/transform/lower.py b/mypyc/transform/lower.py new file mode 100644 index 000000000000..b717657095f9 --- /dev/null +++ b/mypyc/transform/lower.py @@ -0,0 +1,33 @@ +"""Transform IR to lower-level ops. + +Higher-level ops are used in earlier compiler passes, as they make +various analyses, optimizations and transforms easier to implement. +Later passes use lower-level ops, as they are easier to generate code +from, and they help with lower-level optimizations. + +Lowering of various primitive ops is implemented in the mypyc.lower +package. +""" + +from mypyc.ir.func_ir import FuncIR +from mypyc.ir.ops import PrimitiveOp, Value +from mypyc.irbuild.ll_builder import LowLevelIRBuilder +from mypyc.lower.registry import lowering_registry +from mypyc.options import CompilerOptions +from mypyc.transform.ir_transform import IRTransform + + +def lower_ir(ir: FuncIR, options: CompilerOptions) -> None: + builder = LowLevelIRBuilder(None, options) + visitor = LoweringVisitor(builder) + visitor.transform_blocks(ir.blocks) + ir.blocks = builder.blocks + + +class LoweringVisitor(IRTransform): + def visit_primitive_op(self, op: PrimitiveOp) -> Value: + # The lowering implementation functions of various primitive ops are stored + # in a registry, which is populated using function decorators. The name + # of op (such as "int_eq") is used as the key. + lower_fn = lowering_registry[op.desc.name] + return lower_fn(self.builder, op.args, op.line)