Skip to content

Commit afdd9d5

Browse files
authored
[mypyc] Implement lowering for remaining tagged integer comparisons (#17040)
Support lowering of tagged integer `<`, `<=`, `>` and `>=` operations. Previously we had separate code paths for integer comparisons in values vs conditions. Unify these and remove the duplicate code path. The different code paths produced subtly different code, but now they are identical. The generated code is now sometimes slightly more verbose in the slow path (big integer). I may look into simplifying it in a follow-up PR. This also makes the output of many irbuild test cases significantly more compact. Follow-up to #17027. Work on mypyc/mypyc#854.
1 parent 7d0a8e7 commit afdd9d5

21 files changed

+622
-968
lines changed

mypyc/ir/pprint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def visit_primitive_op(self, op: PrimitiveOp) -> str:
232232
type_arg_index += 1
233233

234234
args_str = ", ".join(args)
235-
return self.format("%r = %s %s ", op, op.desc.name, args_str)
235+
return self.format("%r = %s %s", op, op.desc.name, args_str)
236236

237237
def visit_truncate(self, op: Truncate) -> str:
238238
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)

mypyc/irbuild/ast_helpers.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,9 @@ def maybe_process_conditional_comparison(
9393
self.add_bool_branch(reg, true, false)
9494
else:
9595
# "left op right" for two tagged integers
96-
if op in ("==", "!="):
97-
reg = self.builder.binary_op(left, right, op, e.line)
98-
self.flush_keep_alives()
99-
self.add_bool_branch(reg, true, false)
100-
else:
101-
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
96+
reg = self.builder.binary_op(left, right, op, e.line)
97+
self.flush_keep_alives()
98+
self.add_bool_branch(reg, true, false)
10299
return True
103100

104101

mypyc/irbuild/builder.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,6 @@ def call_c(self, desc: CFunctionDescription, args: list[Value], line: int) -> Va
378378
def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int) -> Value:
379379
return self.builder.int_op(type, lhs, rhs, op, line)
380380

381-
def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
382-
return self.builder.compare_tagged(lhs, rhs, op, line)
383-
384381
def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
385382
return self.builder.compare_tuples(lhs, rhs, op, line)
386383

mypyc/irbuild/expression.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -814,12 +814,6 @@ def translate_is_none(builder: IRBuilder, expr: Expression, negated: bool) -> Va
814814
def transform_basic_comparison(
815815
builder: IRBuilder, op: str, left: Value, right: Value, line: int
816816
) -> Value:
817-
if (
818-
is_int_rprimitive(left.type)
819-
and is_int_rprimitive(right.type)
820-
and op in int_comparison_op_mapping
821-
):
822-
return builder.compare_tagged(left, right, op, line)
823817
if is_fixed_width_rtype(left.type) and op in int_comparison_op_mapping:
824818
if right.type == left.type:
825819
if left.type.is_signed:

mypyc/irbuild/function.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -889,9 +889,8 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None:
889889
call_impl, next_impl = BasicBlock(), BasicBlock()
890890

891891
current_id = builder.load_int(i)
892-
builder.builder.compare_tagged_condition(
893-
passed_id, current_id, "==", call_impl, next_impl, line
894-
)
892+
cond = builder.binary_op(passed_id, current_id, "==", line)
893+
builder.add_bool_branch(cond, call_impl, next_impl)
895894

896895
# Call the registered implementation
897896
builder.activate_block(call_impl)

mypyc/irbuild/ll_builder.py

Lines changed: 13 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,13 +1315,6 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13151315
return self.compare_strings(lreg, rreg, op, line)
13161316
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="):
13171317
return self.compare_bytes(lreg, rreg, op, line)
1318-
if (
1319-
is_tagged(ltype)
1320-
and is_tagged(rtype)
1321-
and op in int_comparison_op_mapping
1322-
and op not in ("==", "!=")
1323-
):
1324-
return self.compare_tagged(lreg, rreg, op, line)
13251318
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
13261319
if op in ComparisonOp.signed_ops:
13271320
return self.bool_comparison_op(lreg, rreg, op, line)
@@ -1384,16 +1377,6 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13841377
if is_fixed_width_rtype(lreg.type):
13851378
return self.comparison_op(lreg, rreg, op_id, line)
13861379

1387-
# Mixed int comparisons
1388-
if op in ("==", "!="):
1389-
pass # TODO: Do we need anything here?
1390-
elif op in op in int_comparison_op_mapping:
1391-
if is_tagged(ltype) and is_subtype(rtype, ltype):
1392-
rreg = self.coerce(rreg, short_int_rprimitive, line)
1393-
return self.compare_tagged(lreg, rreg, op, line)
1394-
if is_tagged(rtype) and is_subtype(ltype, rtype):
1395-
lreg = self.coerce(lreg, short_int_rprimitive, line)
1396-
return self.compare_tagged(lreg, rreg, op, line)
13971380
if is_float_rprimitive(ltype) or is_float_rprimitive(rtype):
13981381
if isinstance(lreg, Integer):
13991382
lreg = Float(float(lreg.numeric_value()))
@@ -1445,18 +1428,16 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
14451428
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
14461429
result = Register(bool_rprimitive)
14471430
short_int_block, int_block, out = BasicBlock(), BasicBlock(), BasicBlock()
1448-
check_lhs = self.check_tagged_short_int(lhs, line)
1431+
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
14491432
if op in ("==", "!="):
1450-
check = check_lhs
1433+
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
14511434
else:
14521435
# for non-equality logical ops (less/greater than, etc.), need to check both sides
1453-
check_rhs = self.check_tagged_short_int(rhs, line)
1454-
check = self.int_op(bit_rprimitive, check_lhs, check_rhs, IntOp.AND, line)
1455-
self.add(Branch(check, short_int_block, int_block, Branch.BOOL))
1456-
self.activate_block(short_int_block)
1457-
eq = self.comparison_op(lhs, rhs, op_type, line)
1458-
self.add(Assign(result, eq, line))
1459-
self.goto(out)
1436+
short_lhs = BasicBlock()
1437+
self.add(Branch(check_lhs, int_block, short_lhs, Branch.BOOL))
1438+
self.activate_block(short_lhs)
1439+
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
1440+
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
14601441
self.activate_block(int_block)
14611442
if swap_op:
14621443
args = [rhs, lhs]
@@ -1469,62 +1450,12 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
14691450
else:
14701451
call_result = call
14711452
self.add(Assign(result, call_result, line))
1472-
self.goto_and_activate(out)
1473-
return result
1474-
1475-
def compare_tagged_condition(
1476-
self, lhs: Value, rhs: Value, op: str, true: BasicBlock, false: BasicBlock, line: int
1477-
) -> None:
1478-
"""Compare two tagged integers using given operator (conditional context).
1479-
1480-
Assume lhs and rhs are tagged integers.
1481-
1482-
Args:
1483-
lhs: Left operand
1484-
rhs: Right operand
1485-
op: Operation, one of '==', '!=', '<', '<=', '>', '<='
1486-
true: Branch target if comparison is true
1487-
false: Branch target if comparison is false
1488-
"""
1489-
is_eq = op in ("==", "!=")
1490-
if (is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type)) or (
1491-
is_eq and (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type))
1492-
):
1493-
# We can skip the tag check
1494-
check = self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
1495-
self.flush_keep_alives()
1496-
self.add(Branch(check, true, false, Branch.BOOL))
1497-
return
1498-
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
1499-
int_block, short_int_block = BasicBlock(), BasicBlock()
1500-
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
1501-
if is_eq or is_short_int_rprimitive(rhs.type):
1502-
self.flush_keep_alives()
1503-
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
1504-
else:
1505-
# For non-equality logical ops (less/greater than, etc.), need to check both sides
1506-
rhs_block = BasicBlock()
1507-
self.add(Branch(check_lhs, int_block, rhs_block, Branch.BOOL))
1508-
self.activate_block(rhs_block)
1509-
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
1510-
self.flush_keep_alives()
1511-
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
1512-
# Arbitrary integers (slow path)
1513-
self.activate_block(int_block)
1514-
if swap_op:
1515-
args = [rhs, lhs]
1516-
else:
1517-
args = [lhs, rhs]
1518-
call = self.call_c(c_func_desc, args, line)
1519-
if negate_result:
1520-
self.add(Branch(call, false, true, Branch.BOOL))
1521-
else:
1522-
self.flush_keep_alives()
1523-
self.add(Branch(call, true, false, Branch.BOOL))
1524-
# Short integers (fast path)
1453+
self.goto(out)
15251454
self.activate_block(short_int_block)
15261455
eq = self.comparison_op(lhs, rhs, op_type, line)
1527-
self.add(Branch(eq, true, false, Branch.BOOL))
1456+
self.add(Assign(result, eq, line))
1457+
self.goto_and_activate(out)
1458+
return result
15281459

15291460
def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
15301461
"""Compare two strings"""
@@ -2309,7 +2240,8 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val
23092240
length = self.gen_method_call(val, "__len__", [], int_rprimitive, line)
23102241
length = self.coerce(length, int_rprimitive, line)
23112242
ok, fail = BasicBlock(), BasicBlock()
2312-
self.compare_tagged_condition(length, Integer(0), ">=", ok, fail, line)
2243+
cond = self.binary_op(length, Integer(0), ">=", line)
2244+
self.add_bool_branch(cond, ok, fail)
23132245
self.activate_block(fail)
23142246
self.add(
23152247
RaiseStandardError(

mypyc/lower/int_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,23 @@ def lower_int_eq(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Va
1313
@lower_binary_op("int_ne")
1414
def lower_int_ne(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
1515
return builder.compare_tagged(args[0], args[1], "!=", line)
16+
17+
18+
@lower_binary_op("int_lt")
19+
def lower_int_lt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
20+
return builder.compare_tagged(args[0], args[1], "<", line)
21+
22+
23+
@lower_binary_op("int_le")
24+
def lower_int_le(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
25+
return builder.compare_tagged(args[0], args[1], "<=", line)
26+
27+
28+
@lower_binary_op("int_gt")
29+
def lower_int_gt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
30+
return builder.compare_tagged(args[0], args[1], ">", line)
31+
32+
33+
@lower_binary_op("int_ge")
34+
def lower_int_ge(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
35+
return builder.compare_tagged(args[0], args[1], ">=", line)

mypyc/primitives/int_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def int_binary_primitive(
122122

123123
int_eq = int_binary_primitive(op="==", primitive_name="int_eq", return_type=bit_rprimitive)
124124
int_ne = int_binary_primitive(op="!=", primitive_name="int_ne", return_type=bit_rprimitive)
125+
int_lt = int_binary_primitive(op="<", primitive_name="int_lt", return_type=bit_rprimitive)
126+
int_le = int_binary_primitive(op="<=", primitive_name="int_le", return_type=bit_rprimitive)
127+
int_gt = int_binary_primitive(op=">", primitive_name="int_gt", return_type=bit_rprimitive)
128+
int_ge = int_binary_primitive(op=">=", primitive_name="int_ge", return_type=bit_rprimitive)
125129

126130

127131
def int_binary_op(

0 commit comments

Comments
 (0)