Skip to content

Commit 63177d3

Browse files
committed
Add scaling_exp argument to series
1 parent f88854e commit 63177d3

File tree

4 files changed

+100
-84
lines changed

4 files changed

+100
-84
lines changed

functional_algorithms/context.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def constant(self, value, like_expr=UNSPECIFIED):
254254
like_expr = self.default_like
255255
if like_expr is int:
256256
like_expr = self.symbol("_integer_value", int)
257-
print(f"LIKE: {like_expr}")
258257
return make_constant(self, value, like_expr)
259258

260259
def call(self, func, args):
@@ -524,5 +523,10 @@ def downcast(self, x):
524523
def is_finite(self, x):
525524
return Expr(self, "is_finite", (x,))
526525

527-
def series(self, unit_index, *terms):
528-
return make_series(self, unit_index, terms)
526+
def series(self, *terms, **params):
527+
unit_index = params.get("unit_index", 0)
528+
scaling_exp = params.get("scaling_exp", 0)
529+
return make_series(self, unit_index, scaling_exp, terms)
530+
531+
def _series(self, terms, params):
532+
return self.series(*terms, **params)

functional_algorithms/expr.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def make_apply(context, name, args, result):
121121
return Expr(context, "apply", (name, *args, result))
122122

123123

124-
def make_series(context, order, terms):
125-
return Expr(context, "series", (order, *terms))
124+
def make_series(context, unit_index, scaling_exp, terms):
125+
return Expr(context, "series", ((unit_index, scaling_exp), *terms))
126126

127127

128128
def normalize(context, operands):
@@ -219,8 +219,8 @@ def make_ref(expr):
219219
]
220220
if all_operands_have_ref_name:
221221
# for readability
222-
i = expr.operands[0]
223-
lst = [expr.kind] + [f"minus{i}" if i < 0 else str(i)] + list(map(make_ref, expr.operands[1:]))
222+
params = [f"minus{p}" if p < 0 else str(p) for p in expr.operands[0]]
223+
lst = [expr.kind] + params + list(map(make_ref, expr.operands[1:]))
224224
ref = "_".join(lst)
225225
else:
226226
ref = f"{expr.kind}_{expr.intkey}"
@@ -360,7 +360,9 @@ def __new__(cls, context, kind, operands):
360360
kind = "constant"
361361

362362
if kind == "series":
363-
assert isinstance(operands[0], int), type(operands[0])
363+
assert isinstance(operands[0], tuple) and len(operands[0]) == 2, type(operands[0])
364+
assert isinstance(operands[0][0], int)
365+
assert isinstance(operands[0][1], int)
364366
assert False not in [isinstance(operand, Expr) for operand in operands[1:]], operands
365367
else:
366368
assert False not in [isinstance(operand, Expr) for operand in operands], operands

functional_algorithms/rewrite.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def _binary_op(self, expr, op):
480480
return expr.context.constant(r, xlike)
481481

482482
@staticmethod
483-
def _add_terms(index_terms1, index_terms2, op=None):
483+
def _add_terms(params_terms1, params_terms2, op=None):
484484
"""\
485485
486486
Let `^` indicate the location of unity, then
@@ -517,11 +517,12 @@ def op(x, y):
517517
return x + y
518518

519519
swapped = False
520-
if index_terms1[0] < index_terms2[0]:
521-
index_terms1, index_terms2 = index_terms2, index_terms1
520+
if params_terms1[0][0] < params_terms2[0][0]:
521+
params_terms1, params_terms2 = params_terms2, params_terms1
522522
swapped = True
523-
index1, terms1 = index_terms1[0], index_terms1[1:]
524-
index2, terms2 = index_terms2[0], index_terms2[1:]
523+
(index1, sexp1), terms1 = params_terms1[0], params_terms1[1:]
524+
(index2, sexp2), terms2 = params_terms2[0], params_terms2[1:]
525+
assert sexp1 == sexp2, (sexp1, sexp2)
525526

526527
terms = []
527528

@@ -546,10 +547,10 @@ def op(x, y):
546547
else:
547548
terms.append(0)
548549

549-
return (index1, *terms)
550+
return (terms, dict(unit_index=index1, scaling_exp=sexp1))
550551

551552
@staticmethod
552-
def _subtract_terms(index_terms1, index_terms2):
553+
def _subtract_terms(params_terms1, params_terms2):
553554

554555
def op(x, y):
555556
if x is None:
@@ -558,10 +559,10 @@ def op(x, y):
558559
return x
559560
return x - y
560561

561-
return Rewriter._add_terms(index_terms1, index_terms2, op=op)
562+
return Rewriter._add_terms(params_terms1, params_terms2, op=op)
562563

563564
@staticmethod
564-
def _multiply_terms(index_terms1, index_terms2, op=None):
565+
def _multiply_terms(params_terms1, params_terms2, op=None):
565566
"""\
566567
567568
Let `^` indicate the location of unity, then
@@ -585,8 +586,9 @@ def _multiply_terms(index_terms1, index_terms2, op=None):
585586
def op(x, y):
586587
return x * y
587588

588-
index1, terms1 = index_terms1[0], index_terms1[1:]
589-
index2, terms2 = index_terms2[0], index_terms2[1:]
589+
(index1, sexp1), terms1 = params_terms1[0], params_terms1[1:]
590+
(index2, sexp2), terms2 = params_terms2[0], params_terms2[1:]
591+
assert sexp1 == sexp2, (sexp1, sexp2)
590592

591593
terms = []
592594
for n in range(len(terms1) + len(terms2) - 1):
@@ -600,7 +602,7 @@ def op(x, y):
600602
xy += op(x, y)
601603
assert xy is not None
602604
terms.append(xy)
603-
return (index1 + index2, *terms)
605+
return terms, dict(unit_index=index1 + index2, scaling_exp=sexp1)
604606

605607
def add(self, expr):
606608
result = self._binary_op(expr, lambda x, y: x + y)
@@ -616,10 +618,10 @@ def add(self, expr):
616618

617619
if x.kind == "series":
618620
if y.kind == "series":
619-
return expr.context.series(*self._add_terms(x.operands, y.operands))
620-
return expr.context.series(*self._add_terms(x.operands, (0, y)))
621+
return expr.context._series(*self._add_terms(x.operands, y.operands))
622+
return expr.context._series(*self._add_terms(x.operands, ((0, x.operands[0][1]), y)))
621623
elif y.kind == "series":
622-
return expr.context.series(*self._add_terms((0, x), y.operands))
624+
return expr.context._series(*self._add_terms(((0, y.operands[0][1]), x), y.operands))
623625

624626
def subtract(self, expr):
625627
result = self._binary_op(expr, lambda x, y: x - y)
@@ -636,10 +638,10 @@ def subtract(self, expr):
636638

637639
if x.kind == "series":
638640
if y.kind == "series":
639-
return expr.context.series(*self._subtract_terms(x.operands, y.operands))
640-
return expr.context.series(*self._subtract_terms(x.operands, (0, y)))
641+
return expr.context._series(*self._subtract_terms(x.operands, y.operands))
642+
return expr.context._series(*self._subtract_terms(x.operands, ((0, x.operands[0][1]), y)))
641643
elif y.kind == "series":
642-
return expr.context.series(*self._subtract_terms((0, x), y.operands))
644+
return expr.context._series(*self._subtract_terms(((0, y.operands[0][1]), x), y.operands))
643645

644646
def multiply(self, expr):
645647
result = self._binary_op(expr, lambda x, y: x * y)
@@ -656,10 +658,10 @@ def multiply(self, expr):
656658

657659
if x.kind == "series":
658660
if y.kind == "series":
659-
return expr.context.series(*self._multiply_terms(x.operands, y.operands))
660-
return expr.context.series(*self._multiply_terms(x.operands, (0, y)))
661+
return expr.context._series(*self._multiply_terms(x.operands, y.operands))
662+
return expr.context._series(*self._multiply_terms(x.operands, ((0, x.operands[0][1]), y)))
661663
elif y.kind == "series":
662-
return expr.context.series(*self._multiply_terms((0, x), y.operands))
664+
return expr.context._series(*self._multiply_terms(((0, y.operands[0][1]), x), y.operands))
663665

664666
def minimum(self, expr):
665667
return self._binary_op(expr, lambda x, y: min(x, y))
@@ -679,7 +681,7 @@ def divide(self, expr):
679681
def op(x, y):
680682
return x / y
681683

682-
return expr.context.series(*self._multiply_terms(x.operands, (0, y), op=op))
684+
return expr.context._series(*self._multiply_terms(x.operands, ((0, 0), y), op=op))
683685

684686
def complex(self, expr):
685687
pass
@@ -805,7 +807,7 @@ def negative(self, expr):
805807
return x.operands[0]
806808

807809
if x.kind == "series":
808-
return expr.context.series(x.operands[0], *(-x_ for x_ in x.operands[1:]))
810+
return _expr.make_series(expr.context, *x.operands[0], tuple(-x_ for x_ in x.operands[1:]))
809811

810812
def conjugate(self, expr):
811813

@@ -826,7 +828,7 @@ def conjugate(self, expr):
826828
return x
827829

828830
if x.kind == "series":
829-
return expr.context.series(x.operands[0], *(expr.context.conjugate(x_) for x_ in x.operands[1:]))
831+
return _expr.make_series(expr.context, *x.operands[0], tuple(expr.context.conjugate(x_) for x_ in x.operands[1:]))
830832

831833
def real(self, expr):
832834

@@ -839,7 +841,7 @@ def real(self, expr):
839841
return x.operands[0]
840842

841843
if x.kind == "series":
842-
return expr.context.series(x.operands[0], *(expr.context.real(x_) for x_ in x.operands[1:]))
844+
return _expr.make_series(expr.context, *x.operands[0], tuple(expr.context.real(x_) for x_ in x.operands[1:]))
843845

844846
def imag(self, expr):
845847

@@ -852,7 +854,7 @@ def imag(self, expr):
852854
return x.operands[1]
853855

854856
if x.kind == "series":
855-
return expr.context.series(x.operands[0], *(expr.context.imag(x_) for x_ in x.operands[1:]))
857+
return _expr.make_series(expr.context, *x.operands[0], tuple(expr.context.imag(x_) for x_ in x.operands[1:]))
856858

857859
def _compare(self, expr, relop, relop_index, swap_relop_index):
858860
x, y = expr.operands
@@ -974,7 +976,7 @@ def square(self, expr):
974976
return self._eval(like, "square", value)
975977

976978
if x.kind == "series":
977-
return expr.context.series(*self._multiply_terms(x.operands, x.operands))
979+
return expr.context._series(*self._multiply_terms(x.operands, x.operands))
978980

979981
def pow(self, expr):
980982
base, exp = expr.operands
@@ -1091,10 +1093,18 @@ class ReplaceSeries:
10911093
def __rewrite_modifier__(self, expr):
10921094
if expr.kind == "series":
10931095
s = None
1094-
for t in reversed(expr.operands[1:]):
1095-
if s is None:
1096-
s = t
1096+
index, sexp = expr.operands[0]
1097+
for i, t in enumerate(reversed(expr.operands[1:])):
1098+
i = len(expr.operands[1:]) - 1 - i + index
1099+
if sexp == 0 or i == 0:
1100+
if s is None:
1101+
s = t
1102+
else:
1103+
s += t
10971104
else:
1098-
s += t
1105+
if s is None:
1106+
s = t * expr.context.constant(2 ** (-i * sexp), t)
1107+
else:
1108+
s += t * expr.context.constant(2 ** (-i * sexp), t)
10991109
return s
11001110
return expr

functional_algorithms/tests/test_expr.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -760,61 +760,61 @@ def test_series():
760760
z = ctx.symbol("z")
761761
w = ctx.symbol("w")
762762

763-
s1 = ctx.series(0, x, y)
764-
s2 = ctx.series(0, z, w)
765-
s3 = ctx.series(1, z, w)
766-
s4 = ctx.series(-1, z, w)
767-
s5 = ctx.series(-3, z, w)
763+
s1 = ctx.series(x, y, unit_index=0)
764+
s2 = ctx.series(z, w, unit_index=0)
765+
s3 = ctx.series(z, w, unit_index=1)
766+
s4 = ctx.series(z, w, unit_index=-1)
767+
s5 = ctx.series(z, w, unit_index=-3)
768768

769769
def rewrite(expr):
770770
return expr.rewrite(fa.rewrite)
771771

772772
assert_equal(s1 + s2, ctx.add(s1, s2))
773-
assert_equal(rewrite(s1 + s2), ctx.series(0, x + z, y + w))
774-
assert_equal(rewrite(s1 + s3), ctx.series(1, z, x + w, y))
775-
assert_equal(rewrite(s3 + s1), ctx.series(1, z, w + x, y))
776-
assert_equal(rewrite(s1 + s4), ctx.series(0, x, y + z, w))
777-
assert_equal(rewrite(s4 + s1), ctx.series(0, x, z + y, w))
778-
assert_equal(rewrite(s1 + s5), ctx.series(0, x, y, 0, z, w))
779-
assert_equal(rewrite(s5 + s1), ctx.series(0, x, y, 0, z, w))
780-
assert_equal(rewrite(s4 + s5), ctx.series(-1, z, w, z, w))
773+
assert_equal(rewrite(s1 + s2), ctx.series(x + z, y + w, unit_index=0))
774+
assert_equal(rewrite(s1 + s3), ctx.series(z, x + w, y, unit_index=1))
775+
assert_equal(rewrite(s3 + s1), ctx.series(z, w + x, y, unit_index=1))
776+
assert_equal(rewrite(s1 + s4), ctx.series(x, y + z, w, unit_index=0))
777+
assert_equal(rewrite(s4 + s1), ctx.series(x, z + y, w, unit_index=0))
778+
assert_equal(rewrite(s1 + s5), ctx.series(x, y, 0, z, w, unit_index=0))
779+
assert_equal(rewrite(s5 + s1), ctx.series(x, y, 0, z, w, unit_index=0))
780+
assert_equal(rewrite(s4 + s5), ctx.series(z, w, z, w, unit_index=-1))
781781

782782
assert_equal(s1 - s2, ctx.subtract(s1, s2))
783-
assert_equal(rewrite(s1 - s2), ctx.series(0, x - z, y - w))
784-
assert_equal(rewrite(s1 - s3), ctx.series(1, -z, x - w, y))
785-
assert_equal(rewrite(s3 - s1), ctx.series(1, z, w - x, -y))
786-
assert_equal(rewrite(s1 - s4), ctx.series(0, x, y - z, -w))
787-
assert_equal(rewrite(s4 - s1), ctx.series(0, -x, z - y, w))
788-
assert_equal(rewrite(s1 - s5), ctx.series(0, x, y, 0, -z, -w))
789-
assert_equal(rewrite(s5 - s1), ctx.series(0, -x, -y, 0, z, w))
790-
assert_equal(rewrite(s4 - s5), ctx.series(-1, z, w, -z, -w))
791-
assert_equal(rewrite(s5 - s4), ctx.series(-1, -z, -w, z, w))
792-
793-
assert_equal(rewrite(s1 + z), ctx.series(0, x + z, y))
794-
assert_equal(rewrite(s1 - z), ctx.series(0, x - z, y))
795-
assert_equal(rewrite(z + s1), ctx.series(0, z + x, y))
796-
assert_equal(rewrite(z - s1), ctx.series(0, z - x, -y))
797-
assert_equal(rewrite(s2 + x), ctx.series(0, z + x, w))
798-
assert_equal(rewrite(s3 + x), ctx.series(1, z, w + x))
799-
assert_equal(rewrite(s4 + x), ctx.series(0, x, z, w))
800-
assert_equal(rewrite(s5 + x), ctx.series(0, x, 0, 0, z, w))
783+
assert_equal(rewrite(s1 - s2), ctx.series(x - z, y - w, unit_index=0))
784+
assert_equal(rewrite(s1 - s3), ctx.series(-z, x - w, y, unit_index=1))
785+
assert_equal(rewrite(s3 - s1), ctx.series(z, w - x, -y, unit_index=1))
786+
assert_equal(rewrite(s1 - s4), ctx.series(x, y - z, -w, unit_index=0))
787+
assert_equal(rewrite(s4 - s1), ctx.series(-x, z - y, w, unit_index=0))
788+
assert_equal(rewrite(s1 - s5), ctx.series(x, y, 0, -z, -w, unit_index=0))
789+
assert_equal(rewrite(s5 - s1), ctx.series(-x, -y, 0, z, w, unit_index=0))
790+
assert_equal(rewrite(s4 - s5), ctx.series(z, w, -z, -w, unit_index=-1))
791+
assert_equal(rewrite(s5 - s4), ctx.series(-z, -w, z, w, unit_index=-1))
792+
793+
assert_equal(rewrite(s1 + z), ctx.series(x + z, y, unit_index=0))
794+
assert_equal(rewrite(s1 - z), ctx.series(x - z, y, unit_index=0))
795+
assert_equal(rewrite(z + s1), ctx.series(z + x, y, unit_index=0))
796+
assert_equal(rewrite(z - s1), ctx.series(z - x, -y, unit_index=0))
797+
assert_equal(rewrite(s2 + x), ctx.series(z + x, w, unit_index=0))
798+
assert_equal(rewrite(s3 + x), ctx.series(z, w + x, unit_index=1))
799+
assert_equal(rewrite(s4 + x), ctx.series(x, z, w, unit_index=0))
800+
assert_equal(rewrite(s5 + x), ctx.series(x, 0, 0, z, w, unit_index=0))
801801

802802
assert_equal(s1 * s2, ctx.multiply(s1, s2))
803-
assert_equal(rewrite(s1 * s2), ctx.series(0, x * z, x * w + y * z, y * w))
804-
assert_equal(rewrite(s2 * s2), ctx.series(0, z * z, z * w + w * z, w * w))
805-
assert_equal(rewrite(s3 * s3), ctx.series(2, z * z, z * w + w * z, w * w))
806-
assert_equal(rewrite(s4 * s4), ctx.series(-2, z * z, z * w + w * z, w * w))
807-
assert_equal(rewrite(s4 * s5), ctx.series(-4, z * z, z * w + w * z, w * w))
803+
assert_equal(rewrite(s1 * s2), ctx.series(x * z, x * w + y * z, y * w, unit_index=0))
804+
assert_equal(rewrite(s2 * s2), ctx.series(z * z, z * w + w * z, w * w, unit_index=0))
805+
assert_equal(rewrite(s3 * s3), ctx.series(z * z, z * w + w * z, w * w, unit_index=2))
806+
assert_equal(rewrite(s4 * s4), ctx.series(z * z, z * w + w * z, w * w, unit_index=-2))
807+
assert_equal(rewrite(s4 * s5), ctx.series(z * z, z * w + w * z, w * w, unit_index=-4))
808808

809-
assert_equal(rewrite(s3**2), ctx.series(2, z * z, z * w + w * z, w * w))
809+
assert_equal(rewrite(s3**2), ctx.series(z * z, z * w + w * z, w * w, unit_index=2))
810810

811-
assert_equal(rewrite(s1 * z), ctx.series(0, x * z, y * z))
812-
assert_equal(rewrite(z * s1), ctx.series(0, z * x, z * y))
813-
assert_equal(rewrite(s2 * x), ctx.series(0, z * x, w * x))
814-
assert_equal(rewrite(s3 * x), ctx.series(1, z * x, w * x))
815-
assert_equal(rewrite(s5 * x), ctx.series(-3, z * x, w * x))
811+
assert_equal(rewrite(s1 * z), ctx.series(x * z, y * z, unit_index=0))
812+
assert_equal(rewrite(z * s1), ctx.series(z * x, z * y, unit_index=0))
813+
assert_equal(rewrite(s2 * x), ctx.series(z * x, w * x, unit_index=0))
814+
assert_equal(rewrite(s3 * x), ctx.series(z * x, w * x, unit_index=1))
815+
assert_equal(rewrite(s5 * x), ctx.series(z * x, w * x, unit_index=-3))
816816

817-
assert_equal(rewrite(s1 / z), ctx.series(0, x / z, y / z))
817+
assert_equal(rewrite(s1 / z), ctx.series(x / z, y / z, unit_index=0))
818818

819-
assert_equal(rewrite(-s1), ctx.series(0, -x, -y))
820-
assert_equal(rewrite(-s3), ctx.series(1, -z, -w))
819+
assert_equal(rewrite(-s1), ctx.series(-x, -y, unit_index=0))
820+
assert_equal(rewrite(-s3), ctx.series(-z, -w, unit_index=1))

0 commit comments

Comments
 (0)