Skip to content

Commit 5258116

Browse files
kasiaMarektgodzik
authored andcommitted
refactor: extract apply agrs utils (pc)
1 parent d5ca220 commit 5258116

File tree

2 files changed

+285
-266
lines changed

2 files changed

+285
-266
lines changed
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
package dotty.tools.pc
2+
3+
import scala.util.Try
4+
5+
import dotty.tools.dotc.ast.Trees.ValDef
6+
import dotty.tools.dotc.ast.tpd.*
7+
import dotty.tools.dotc.core.Contexts.Context
8+
import dotty.tools.dotc.core.Flags
9+
import dotty.tools.dotc.core.Flags.Method
10+
import dotty.tools.dotc.core.Names.Name
11+
import dotty.tools.dotc.core.StdNames.*
12+
import dotty.tools.dotc.core.SymDenotations.NoDenotation
13+
import dotty.tools.dotc.core.Symbols.defn
14+
import dotty.tools.dotc.core.Symbols.NoSymbol
15+
import dotty.tools.dotc.core.Symbols.Symbol
16+
import dotty.tools.dotc.core.Types.*
17+
import dotty.tools.pc.IndexedContext
18+
import dotty.tools.pc.utils.InteractiveEnrichments.*
19+
import scala.annotation.tailrec
20+
import dotty.tools.dotc.core.Denotations.SingleDenotation
21+
import dotty.tools.dotc.core.Denotations.MultiDenotation
22+
import dotty.tools.dotc.util.Spans.Span
23+
24+
object ApplyExtractor:
25+
def unapply(path: List[Tree])(using Context): Option[Apply] =
26+
path match
27+
case ValDef(_, _, _) :: Block(_, app: Apply) :: _
28+
if !app.fun.isInfix => Some(app)
29+
case rest =>
30+
def getApplyForContextFunctionParam(path: List[Tree]): Option[Apply] =
31+
path match
32+
// fun(arg@@)
33+
case (app: Apply) :: _ => Some(app)
34+
// fun(arg@@), where fun(argn: Context ?=> SomeType)
35+
// recursively matched for multiple context arguments, e.g. Context1 ?=> Context2 ?=> SomeType
36+
case (_: DefDef) :: Block(List(_), _: Closure) :: rest =>
37+
getApplyForContextFunctionParam(rest)
38+
case _ => None
39+
for
40+
app <- getApplyForContextFunctionParam(rest)
41+
if !app.fun.isInfix
42+
yield app
43+
end match
44+
45+
46+
object ApplyArgsExtractor:
47+
def getArgsAndParams(
48+
indexedContext: IndexedContext,
49+
apply: Apply,
50+
span: Span
51+
)(using Context): List[(List[Tree], List[ParamSymbol])] =
52+
def collectArgss(a: Apply): List[List[Tree]] =
53+
def stripContextFuntionArgument(argument: Tree): List[Tree] =
54+
argument match
55+
case Block(List(d: DefDef), _: Closure) =>
56+
d.rhs match
57+
case app: Apply =>
58+
app.args
59+
case b @ Block(List(_: DefDef), _: Closure) =>
60+
stripContextFuntionArgument(b)
61+
case _ => Nil
62+
case v => List(v)
63+
64+
val args = a.args.flatMap(stripContextFuntionArgument)
65+
a.fun match
66+
case app: Apply => collectArgss(app) :+ args
67+
case _ => List(args)
68+
end collectArgss
69+
70+
val method = apply.fun
71+
72+
val argss = collectArgss(apply)
73+
74+
def fallbackFindApply(sym: Symbol) =
75+
sym.info.member(nme.apply) match
76+
case NoDenotation => Nil
77+
case den => List(den.symbol)
78+
79+
// fallback for when multiple overloaded methods match the supplied args
80+
def fallbackFindMatchingMethods() =
81+
def matchingMethodsSymbols(
82+
method: Tree
83+
): List[Symbol] =
84+
method match
85+
case Ident(name) => indexedContext.findSymbol(name).getOrElse(Nil)
86+
case Select(This(_), name) => indexedContext.findSymbol(name).getOrElse(Nil)
87+
case sel @ Select(from, name) =>
88+
val symbol = from.symbol
89+
val ownerSymbol =
90+
if symbol.is(Method) && symbol.owner.isClass then
91+
Some(symbol.owner)
92+
else Try(symbol.info.classSymbol).toOption
93+
ownerSymbol.map(sym => sym.info.member(name)).collect{
94+
case single: SingleDenotation => List(single.symbol)
95+
case multi: MultiDenotation => multi.allSymbols
96+
}.getOrElse(Nil)
97+
case Apply(fun, _) => matchingMethodsSymbols(fun)
98+
case _ => Nil
99+
val matchingMethods =
100+
for
101+
potentialMatch <- matchingMethodsSymbols(method)
102+
if potentialMatch.is(Flags.Method) &&
103+
potentialMatch.vparamss.length >= argss.length &&
104+
Try(potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption
105+
.getOrElse(false) &&
106+
potentialMatch.vparamss
107+
.zip(argss)
108+
.reverse
109+
.zipWithIndex
110+
.forall { case (pair, index) =>
111+
FuzzyArgMatcher(potentialMatch.tparams)
112+
.doMatch(allArgsProvided = index != 0, span)
113+
.tupled(pair)
114+
}
115+
yield potentialMatch
116+
matchingMethods
117+
end fallbackFindMatchingMethods
118+
119+
val matchingMethods: List[Symbol] =
120+
if method.symbol.paramSymss.nonEmpty then
121+
val allArgsAreSupplied =
122+
val vparamss = method.symbol.vparamss
123+
vparamss.length == argss.length && vparamss
124+
.zip(argss)
125+
.lastOption
126+
.exists { case (baseParams, baseArgs) =>
127+
baseArgs.length == baseParams.length
128+
}
129+
// ```
130+
// m(arg : Int)
131+
// m(arg : Int, anotherArg : Int)
132+
// m(a@@)
133+
// ```
134+
// complier will choose the first `m`, so we need to manually look for the other one
135+
if allArgsAreSupplied then
136+
val foundPotential = fallbackFindMatchingMethods()
137+
if foundPotential.contains(method.symbol) then foundPotential
138+
else method.symbol :: foundPotential
139+
else List(method.symbol)
140+
else if method.symbol.is(Method) || method.symbol == NoSymbol then
141+
fallbackFindMatchingMethods()
142+
else fallbackFindApply(method.symbol)
143+
end if
144+
end matchingMethods
145+
146+
matchingMethods.map { methodSym =>
147+
val vparamss = methodSym.vparamss
148+
149+
// get params and args we are interested in
150+
// e.g.
151+
// in the following case, the interesting args and params are
152+
// - params: [apple, banana]
153+
// - args: [apple, b]
154+
// ```
155+
// def curry(x: Int)(apple: String, banana: String) = ???
156+
// curry(1)(apple = "test", b@@)
157+
// ```
158+
val (baseParams0, baseArgs) =
159+
vparamss.zip(argss).lastOption.getOrElse((Nil, Nil))
160+
161+
val baseParams: List[ParamSymbol] =
162+
def defaultBaseParams = baseParams0.map(JustSymbol(_))
163+
@tailrec
164+
def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] =
165+
if level > 0 then
166+
val resultTypeOpt =
167+
refinedType match
168+
case RefinedType(AppliedType(_, args), _, _) => args.lastOption
169+
case AppliedType(_, args) => args.lastOption
170+
case _ => None
171+
resultTypeOpt match
172+
case Some(resultType) => getRefinedParams(resultType, level - 1)
173+
case _ => defaultBaseParams
174+
else
175+
refinedType match
176+
case RefinedType(AppliedType(_, args), _, MethodType(ri)) =>
177+
baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) =>
178+
RefinedSymbol(sym, name, arg)
179+
}
180+
case _ => defaultBaseParams
181+
// finds param refinements for lambda expressions
182+
// val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
183+
@tailrec
184+
def refineParams(method: Tree, level: Int): List[ParamSymbol] =
185+
method match
186+
case Select(Apply(f, _), _) => refineParams(f, level + 1)
187+
case Select(h, name) =>
188+
// for Select(foo, name = apply) we want `foo.symbol`
189+
if name == nme.apply then getRefinedParams(h.symbol.info, level)
190+
else getRefinedParams(method.symbol.info, level)
191+
case Apply(f, _) =>
192+
refineParams(f, level + 1)
193+
case _ => getRefinedParams(method.symbol.info, level)
194+
refineParams(method, 0)
195+
end baseParams
196+
(baseArgs, baseParams)
197+
}
198+
199+
extension (method: Symbol)
200+
def vparamss(using Context) = method.filteredParamss(_.isTerm)
201+
def tparams(using Context) = method.filteredParamss(_.isType).flatten
202+
def filteredParamss(f: Symbol => Boolean)(using Context) =
203+
method.paramSymss.filter(params => params.forall(f))
204+
sealed trait ParamSymbol:
205+
def name: Name
206+
def info: Type
207+
def symbol: Symbol
208+
def nameBackticked(using Context) = name.decoded.backticked
209+
210+
case class JustSymbol(symbol: Symbol)(using Context) extends ParamSymbol:
211+
def name: Name = symbol.name
212+
def info: Type = symbol.info
213+
214+
case class RefinedSymbol(symbol: Symbol, name: Name, info: Type)
215+
extends ParamSymbol
216+
217+
218+
class FuzzyArgMatcher(tparams: List[Symbol])(using Context):
219+
220+
/**
221+
* A heuristic for checking if the passed arguments match the method's arguments' types.
222+
* For non-polymorphic methods we use the subtype relation (`<:<`)
223+
* and for polymorphic methods we use a heuristic.
224+
* We check the args types not the result type.
225+
*/
226+
def doMatch(
227+
allArgsProvided: Boolean,
228+
span: Span
229+
)(expectedArgs: List[Symbol], actualArgs: List[Tree]) =
230+
(expectedArgs.length == actualArgs.length ||
231+
(!allArgsProvided && expectedArgs.length >= actualArgs.length)) &&
232+
actualArgs.zipWithIndex.forall {
233+
case (arg: Ident, _) if arg.span.contains(span) => true
234+
case (NamedArg(name, arg), _) =>
235+
expectedArgs.exists { expected =>
236+
expected.name == name && (!arg.hasType || arg.typeOpt.unfold
237+
.fuzzyArg_<:<(expected.info))
238+
}
239+
case (arg, i) =>
240+
!arg.hasType || arg.typeOpt.unfold.fuzzyArg_<:<(expectedArgs(i).info)
241+
}
242+
243+
extension (arg: Type)
244+
def fuzzyArg_<:<(expected: Type) =
245+
if tparams.isEmpty then arg <:< expected
246+
else arg <:< substituteTypeParams(expected)
247+
def unfold =
248+
arg match
249+
case arg: TermRef => arg.underlying
250+
case e => e
251+
252+
private def substituteTypeParams(t: Type): Type =
253+
t match
254+
case e if tparams.exists(_ == e.typeSymbol) =>
255+
val matchingParam = tparams.find(_ == e.typeSymbol).get
256+
matchingParam.info match
257+
case b @ TypeBounds(_, _) => WildcardType(b)
258+
case _ => WildcardType
259+
case o @ OrType(e1, e2) =>
260+
OrType(substituteTypeParams(e1), substituteTypeParams(e2), o.isSoft)
261+
case AndType(e1, e2) =>
262+
AndType(substituteTypeParams(e1), substituteTypeParams(e2))
263+
case AppliedType(et, eparams) =>
264+
AppliedType(et, eparams.map(substituteTypeParams))
265+
case _ => t
266+
267+
end FuzzyArgMatcher

0 commit comments

Comments
 (0)