diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30b23feddc..38c7d1e3f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -72,6 +72,7 @@ repos: src/eve/datamodels.py | # until datamodels is upgraded to 3.10 src/functional/ffront/func_to_foast.py | # until mypy supports match src/functional/ffront/symbol_makers.py | # until mypy supports match + src/functional/ffront/common_types.py | # until mypy supports match src/functional/ffront/foast_passes/type_deduction.py | # until mypy supports match docs/eve/conf.py | tests/eve_tests/definitions.py | diff --git a/src/functional/ffront/common_types.py b/src/functional/ffront/common_types.py index c72e856af4..d05ea2ca37 100644 --- a/src/functional/ffront/common_types.py +++ b/src/functional/ffront/common_types.py @@ -1,9 +1,10 @@ import typing from dataclasses import dataclass -from typing import Literal, Optional, Union +from typing import Literal, Optional, Type, TypeGuard, Union from eve.type_definitions import IntEnum from functional import common as func_common +from functional.ffront import common_types class ScalarKind(IntEnum): @@ -93,3 +94,150 @@ def __str__(self): kwarg_strs = [f"{key}: {value}" for key, value in self.kwargs] args_str = ", ".join(*arg_strs, *kwarg_strs) return f"({args_str}) -> {self.returns}" + + +def is_complete_symbol_type(fo_type: common_types.SymbolType) -> TypeGuard[common_types.SymbolType]: + """Figure out if the foast type is completely deduced.""" + match fo_type: + case None: + return False + case common_types.DeferredSymbolType(): + return False + case common_types.SymbolType(): + return True + return False + + +@dataclass +class TypeInfo: + """ + Wrapper around foast types for type deduction and compatibility checks. + + Examples: + --------- + >>> type_a = common_types.ScalarType(kind=common_types.ScalarKind.FLOAT64) + >>> typeinfo_a = TypeInfo(type_a) + >>> typeinfo_a.is_complete + True + >>> typeinfo_a.is_arithmetic_compatible + True + >>> typeinfo_a.is_logics_compatible + False + >>> typeinfo_b = TypeInfo(None) + >>> typeinfo_b.is_any_type + True + >>> typeinfo_b.is_arithmetic_compatible + False + >>> typeinfo_b.can_be_refined_to(typeinfo_a) + True + + """ + + type: common_types.SymbolType # noqa: A003 + + @property + def is_complete(self) -> bool: + return is_complete_symbol_type(self.type) + + @property + def is_any_type(self) -> bool: + return (not self.is_complete) and ((self.type is None) or (self.constraint is None)) + + @property + def constraint(self) -> Optional[Type[common_types.SymbolType]]: + """Find the constraint of a deferred type or the class of a complete type.""" + if self.is_complete: + return self.type.__class__ + elif self.type: + return self.type.constraint + return None + + @property + def is_field_type(self) -> bool: + return issubclass(self.constraint, common_types.FieldType) if self.constraint else False + + @property + def is_scalar(self) -> bool: + return issubclass(self.constraint, common_types.ScalarType) if self.constraint else False + + @property + def is_arithmetic_compatible(self) -> bool: + match self.type: + case common_types.FieldType( + dtype=common_types.ScalarType(kind=dtype_kind) + ) | common_types.ScalarType(kind=dtype_kind): + if dtype_kind is not common_types.ScalarKind.BOOL: + return True + return False + + @property + def is_logics_compatible(self) -> bool: + match self.type: + case common_types.FieldType( + dtype=common_types.ScalarType(kind=dtype_kind) + ) | common_types.ScalarType(kind=dtype_kind): + if dtype_kind is common_types.ScalarKind.BOOL: + return True + return False + + def can_be_refined_to(self, other: "TypeInfo") -> bool: + if self.is_any_type: + return True + if self.is_complete: + return self.type == other.type + if self.constraint: + if other.is_complete: + return isinstance(other.type, self.constraint) + elif other.constraint: + return self.constraint is other.constraint + return False + + +def are_broadcast_compatible(left: TypeInfo, right: TypeInfo) -> bool: + """ + Check if ``left`` and ``right`` types are compatible after optional broadcasting. + + A binary field operation between two arguments can proceed and the result is a field. + on top of the dimensions, also the dtypes must match. + + Examples: + --------- + >>> int_scalar_t = TypeInfo(common_types.ScalarType(kind=common_types.ScalarKind.INT64)) + >>> are_broadcast_compatible(int_scalar_t, int_scalar_t) + True + >>> int_field_t = TypeInfo(common_types.FieldType(dtype=common_types.ScalarType(kind=common_types.ScalarKind.INT64), + ... dims=...)) + >>> are_broadcast_compatible(int_field_t, int_scalar_t) + True + + """ + if left.is_field_type and right.is_field_type: + return left.type.dims == right.type.dims + elif left.is_field_type and right.is_scalar: + return left.type.dtype == right.type + elif left.is_scalar and left.is_field_type: + return left.type == right.type.dtype + elif left.is_scalar and right.is_scalar: + return left.type == right.type + return False + + +def broadcast_typeinfos(left: TypeInfo, right: TypeInfo) -> TypeInfo: + """ + Decide the result type of a binary operation between arguments of ``left`` and ``right`` type. + + Return None if the two types are not compatible even after broadcasting. + + Examples: + --------- + >>> int_scalar_t = TypeInfo(common_types.ScalarType(kind=common_types.ScalarKind.INT64)) + >>> int_field_t = TypeInfo(common_types.FieldType(dtype=common_types.ScalarType(kind=common_types.ScalarKind.INT64), + ... dims=...)) + >>> assert broadcast_typeinfos(int_field_t, int_scalar_t).type == int_field_t.type + + """ + if not are_broadcast_compatible(left, right): + return None + if left.is_scalar and right.is_field_type: + return right + return left diff --git a/src/functional/ffront/foast_passes/type_deduction.py b/src/functional/ffront/foast_passes/type_deduction.py index 4ec6ac0cd2..65299c349b 100644 --- a/src/functional/ffront/foast_passes/type_deduction.py +++ b/src/functional/ffront/foast_passes/type_deduction.py @@ -11,8 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass -from typing import Optional, Type, TypeGuard +from typing import Optional import functional.ffront.field_operator_ast as foast from eve import NodeTranslator, SymbolTableTrait @@ -20,153 +19,6 @@ from functional.ffront import common_types -def is_complete_symbol_type(fo_type: common_types.SymbolType) -> TypeGuard[common_types.SymbolType]: - """Figure out if the foast type is completely deduced.""" - match fo_type: - case None: - return False - case common_types.DeferredSymbolType(): - return False - case common_types.SymbolType(): - return True - return False - - -@dataclass -class TypeInfo: - """ - Wrapper around foast types for type deduction and compatibility checks. - - Examples: - --------- - >>> type_a = common_types.ScalarType(kind=common_types.ScalarKind.FLOAT64) - >>> typeinfo_a = TypeInfo(type_a) - >>> typeinfo_a.is_complete - True - >>> typeinfo_a.is_arithmetic_compatible - True - >>> typeinfo_a.is_logics_compatible - False - >>> typeinfo_b = TypeInfo(None) - >>> typeinfo_b.is_any_type - True - >>> typeinfo_b.is_arithmetic_compatible - False - >>> typeinfo_b.can_be_refined_to(typeinfo_a) - True - - """ - - type: common_types.SymbolType # noqa: A003 - - @property - def is_complete(self) -> bool: - return is_complete_symbol_type(self.type) - - @property - def is_any_type(self) -> bool: - return (not self.is_complete) and ((self.type is None) or (self.constraint is None)) - - @property - def constraint(self) -> Optional[Type[common_types.SymbolType]]: - """Find the constraint of a deferred type or the class of a complete type.""" - if self.is_complete: - return self.type.__class__ - elif self.type: - return self.type.constraint - return None - - @property - def is_field_type(self) -> bool: - return issubclass(self.constraint, common_types.FieldType) if self.constraint else False - - @property - def is_scalar(self) -> bool: - return issubclass(self.constraint, common_types.ScalarType) if self.constraint else False - - @property - def is_arithmetic_compatible(self) -> bool: - match self.type: - case common_types.FieldType( - dtype=common_types.ScalarType(kind=dtype_kind) - ) | common_types.ScalarType(kind=dtype_kind): - if dtype_kind is not common_types.ScalarKind.BOOL: - return True - return False - - @property - def is_logics_compatible(self) -> bool: - match self.type: - case common_types.FieldType( - dtype=common_types.ScalarType(kind=dtype_kind) - ) | common_types.ScalarType(kind=dtype_kind): - if dtype_kind is common_types.ScalarKind.BOOL: - return True - return False - - def can_be_refined_to(self, other: "TypeInfo") -> bool: - if self.is_any_type: - return True - if self.is_complete: - return self.type == other.type - if self.constraint: - if other.is_complete: - return isinstance(other.type, self.constraint) - elif other.constraint: - return self.constraint is other.constraint - return False - - -def are_broadcast_compatible(left: TypeInfo, right: TypeInfo) -> bool: - """ - Check if ``left`` and ``right`` types are compatible after optional broadcasting. - - A binary field operation between two arguments can proceed and the result is a field. - on top of the dimensions, also the dtypes must match. - - Examples: - --------- - >>> int_scalar_t = TypeInfo(common_types.ScalarType(kind=common_types.ScalarKind.INT64)) - >>> are_broadcast_compatible(int_scalar_t, int_scalar_t) - True - >>> int_field_t = TypeInfo(common_types.FieldType(dtype=common_types.ScalarType(kind=common_types.ScalarKind.INT64), - ... dims=...)) - >>> are_broadcast_compatible(int_field_t, int_scalar_t) - True - - """ - if left.is_field_type and right.is_field_type: - return left.type.dims == right.type.dims - elif left.is_field_type and right.is_scalar: - return left.type.dtype == right.type - elif left.is_scalar and left.is_field_type: - return left.type == right.type.dtype - elif left.is_scalar and right.is_scalar: - return left.type == right.type - return False - - -def broadcast_typeinfos(left: TypeInfo, right: TypeInfo) -> TypeInfo: - """ - Decide the result type of a binary operation between arguments of ``left`` and ``right`` type. - - Return None if the two types are not compatible even after broadcasting. - - Examples: - --------- - >>> int_scalar_t = TypeInfo(common_types.ScalarType(kind=common_types.ScalarKind.INT64)) - >>> int_field_t = TypeInfo(common_types.FieldType(dtype=common_types.ScalarType(kind=common_types.ScalarKind.INT64), - ... dims=...)) - >>> assert broadcast_typeinfos(int_field_t, int_scalar_t).type == int_field_t.type - - """ - if not are_broadcast_compatible(left, right): - return None - if left.is_scalar and right.is_field_type: - return right - return left - - class FieldOperatorTypeDeduction(NodeTranslator): """ Deduce and check types of FOAST expressions and symbols. @@ -212,14 +64,13 @@ def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: raise FieldOperatorTypeDeductionError.from_foast_node( node, msg=f"Undeclared symbol {node.id}" ) - return node symbol = symtable[node.id] return foast.Name(id=node.id, type=symbol.type, location=node.location) def visit_Assign(self, node: foast.Assign, **kwargs) -> foast.Assign: new_value = node.value - if not is_complete_symbol_type(node.value.type): + if not common_types.is_complete_symbol_type(node.value.type): new_value = self.visit(node.value, **kwargs) new_target = self.visit(node.target, refine_type=new_value.type, **kwargs) return foast.Assign(target=new_target, value=new_value, location=node.location) @@ -232,7 +83,7 @@ def visit_Symbol( ) -> foast.Symbol: symtable = kwargs["symtable"] if refine_type: - if not TypeInfo(node.type).can_be_refined_to(TypeInfo(refine_type)): + if not common_types.TypeInfo(node.type).can_be_refined_to(common_types.TypeInfo(refine_type)): raise FieldOperatorTypeDeductionError.from_foast_node( node, msg=( @@ -311,13 +162,13 @@ def _deduce_arithmetic_binop_type( right_type: common_types.SymbolType, **kwargs, ) -> common_types.SymbolType: - left, right = TypeInfo(left_type), TypeInfo(right_type) + left, right = common_types.TypeInfo(left_type), common_types.TypeInfo(right_type) if ( left.is_arithmetic_compatible and right.is_arithmetic_compatible - and are_broadcast_compatible(left, right) + and common_types.are_broadcast_compatible(left, right) ): - return broadcast_typeinfos(left, right).type + return common_types.broadcast_typeinfos(left, right).type else: raise FieldOperatorTypeDeductionError.from_foast_node( parent, @@ -333,13 +184,13 @@ def _deduce_logical_binop_type( right_type: common_types.SymbolType, **kwargs, ) -> common_types.SymbolType: - left, right = TypeInfo(left_type), TypeInfo(right_type) + left, right = common_types.TypeInfo(left_type), common_types.TypeInfo(right_type) if ( left.is_logics_compatible and right.is_logics_compatible - and are_broadcast_compatible(left, right) + and common_types.are_broadcast_compatible(left, right) ): - return broadcast_typeinfos(left, right).type + return common_types.broadcast_typeinfos(left, right).type else: raise FieldOperatorTypeDeductionError.from_foast_node( parent, @@ -360,7 +211,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: def _is_unaryop_type_compatible( self, op: foast.UnaryOperator, operand_type: common_types.FieldType ) -> bool: - operand_ti = TypeInfo(operand_type) + operand_ti = common_types.TypeInfo(operand_type) if op in [foast.UnaryOperator.UADD, foast.UnaryOperator.USUB]: return operand_ti.is_arithmetic_compatible elif op is foast.UnaryOperator.NOT: diff --git a/tests/functional_tests/ffront_tests/test_type_deduction.py b/tests/functional_tests/ffront_tests/test_type_deduction.py index c2a85456f0..008acf7bab 100644 --- a/tests/functional_tests/ffront_tests/test_type_deduction.py +++ b/tests/functional_tests/ffront_tests/test_type_deduction.py @@ -17,8 +17,9 @@ from functional.common import Dimension from functional.ffront import common_types from functional.ffront import field_operator_ast as foast +from functional.ffront.common_types import TypeInfo from functional.ffront.fbuiltins import Field, float64, int64 -from functional.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError, TypeInfo +from functional.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError from functional.ffront.func_to_foast import FieldOperatorParser