@@ -52,6 +52,10 @@ object desugar {
5252 */
5353 val ContextBoundParam : Property .Key [Unit ] = Property .StickyKey ()
5454
55+ /** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way
56+ */
57+ val PolyFunctionApply : Property .Key [Unit ] = Property .StickyKey ()
58+
5559 /** What static check should be applied to a Match? */
5660 enum MatchCheck {
5761 case None , Exhaustive , IrrefutablePatDef , IrrefutableGenFrom
@@ -242,7 +246,7 @@ object desugar {
242246 * def f$default$2[T](x: Int) = x + "m"
243247 */
244248 private def defDef (meth : DefDef , isPrimaryConstructor : Boolean = false )(using Context ): Tree =
245- addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
249+ addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor). asInstanceOf [ DefDef ] )
246250
247251 /** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that
248252 * get added to a buffer.
@@ -304,10 +308,8 @@ object desugar {
304308 tdef1
305309 end desugarContextBounds
306310
307- private def elimContextBounds (meth : DefDef , isPrimaryConstructor : Boolean )(using Context ): DefDef =
308- val DefDef (_, paramss, tpt, rhs) = meth
311+ def elimContextBounds (meth : Tree , isPrimaryConstructor : Boolean = false )(using Context ): Tree =
309312 val evidenceParamBuf = mutable.ListBuffer [ValDef ]()
310-
311313 var seenContextBounds : Int = 0
312314 def freshName (unused : Tree ) =
313315 seenContextBounds += 1 // Start at 1 like FreshNameCreator.
@@ -317,7 +319,7 @@ object desugar {
317319 // parameters of the method since shadowing does not affect
318320 // implicit resolution in Scala 3.
319321
320- val paramssNoContextBounds =
322+ def paramssNoContextBounds ( paramss : List [ ParamClause ]) : List [ ParamClause ] =
321323 val iflag = paramss.lastOption.flatMap(_.headOption) match
322324 case Some (param) if param.mods.isOneOf(GivenOrImplicit ) =>
323325 param.mods.flags & GivenOrImplicit
@@ -329,15 +331,32 @@ object desugar {
329331 tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss)
330332 }(identity)
331333
332- rhs match
333- case MacroTree (call) =>
334- cpy.DefDef (meth)(rhs = call).withMods(meth.mods | Macro | Erased )
335- case _ =>
336- addEvidenceParams(
337- cpy.DefDef (meth)(
338- name = normalizeName(meth, tpt).asTermName,
339- paramss = paramssNoContextBounds),
340- evidenceParamBuf.toList)
334+ meth match
335+ case meth @ DefDef (_, paramss, tpt, rhs) =>
336+ val newParamss = paramssNoContextBounds(paramss)
337+ rhs match
338+ case MacroTree (call) =>
339+ cpy.DefDef (meth)(rhs = call).withMods(meth.mods | Macro | Erased )
340+ case _ =>
341+ addEvidenceParams(
342+ cpy.DefDef (meth)(
343+ name = normalizeName(meth, tpt).asTermName,
344+ paramss = newParamss
345+ ),
346+ evidenceParamBuf.toList
347+ )
348+ case meth @ PolyFunction (tparams, fun) =>
349+ val PolyFunction (tparams : List [untpd.TypeDef ] @ unchecked, fun) = meth : @ unchecked
350+ val Function (vparams : List [untpd.ValDef ] @ unchecked, rhs) = fun : @ unchecked
351+ val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil )
352+ val params = evidenceParamBuf.toList
353+ if params.isEmpty then
354+ meth
355+ else
356+ val boundNames = getBoundNames(params, newParamss)
357+ val recur = fitEvidenceParams(params, nme.apply, boundNames)
358+ val (paramsFst, paramsSnd) = recur(newParamss)
359+ functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs)
341360 end elimContextBounds
342361
343362 def addDefaultGetters (meth : DefDef )(using Context ): Tree =
@@ -465,6 +484,74 @@ object desugar {
465484 case _ =>
466485 (Nil , tree)
467486
487+ private def referencesName (vdef : ValDef , names : Set [TermName ])(using Context ): Boolean =
488+ vdef.tpt.existsSubTree:
489+ case Ident (name : TermName ) => names.contains(name)
490+ case _ => false
491+
492+ /** Fit evidence `params` into the `mparamss` parameter lists, making sure
493+ * that all parameters referencing `params` are after them.
494+ * - for methods the final parameter lists are := result._1 ++ result._2
495+ * - for poly functions, each element of the pair contains at most one term
496+ * parameter list
497+ *
498+ * @param params the evidence parameters list that should fit into `mparamss`
499+ * @param methName the name of the method that `mparamss` belongs to
500+ * @param boundNames the names of the evidence parameters
501+ * @param mparamss the original parameter lists of the method
502+ * @return a pair of parameter lists containing all parameter lists in a
503+ * reference-correct order; make sure that `params` is always at the
504+ * intersection of the pair elements; this is relevant, for poly functions
505+ * where `mparamss` is guaranteed to have exectly one term parameter list,
506+ * then each pair element will have at most one term parameter list
507+ */
508+ private def fitEvidenceParams (
509+ params : List [ValDef ],
510+ methName : Name ,
511+ boundNames : Set [TermName ]
512+ )(mparamss : List [ParamClause ])(using Context ): (List [ParamClause ], List [ParamClause ]) = mparamss match
513+ case ValDefs (mparams) :: _ if mparams.exists(referencesName(_, boundNames)) =>
514+ (params :: Nil ) -> mparamss
515+ case ValDefs (mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit ) =>
516+ val normParams =
517+ if params.head.mods.flags.is(Given ) != mparam.mods.flags.is(Given ) then
518+ params.map: param =>
519+ val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit ))
520+ param.withMods(param.mods.withFlags(normFlags))
521+ .showing(i " adapted param $result ${result.mods.flags} for ${methName}" , Printers .desugar)
522+ else params
523+ ((normParams ++ mparams) :: Nil ) -> Nil
524+ case mparams :: mparamss1 =>
525+ val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1)
526+ (mparams :: fst) -> snd
527+ case Nil =>
528+ Nil -> (params :: Nil )
529+
530+ /** Create a chain of possibly contextual functions from the parameter lists */
531+ private def functionsOf (paramss : List [ParamClause ], rhs : Tree )(using Context ): Tree = paramss match
532+ case Nil => rhs
533+ case ValDefs (head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit ) =>
534+ val paramTpts = head.map(_.tpt)
535+ val paramNames = head.map(_.name)
536+ val paramsErased = head.map(_.mods.flags.is(Erased ))
537+ makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
538+ case ValDefs (head) :: rest =>
539+ Function (head, functionsOf(rest, rhs))
540+ case TypeDefs (head) :: rest =>
541+ PolyFunction (head, functionsOf(rest, rhs))
542+ case _ =>
543+ assert(false , i " unexpected paramss $paramss" )
544+ EmptyTree
545+
546+ private def getBoundNames (params : List [ValDef ], paramss : List [ParamClause ])(using Context ): Set [TermName ] =
547+ var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
548+ for mparams <- paramss; mparam <- mparams do
549+ mparam match
550+ case tparam : TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot .unapply(_).isDefined) =>
551+ boundNames += tparam.name.toTermName
552+ case _ =>
553+ boundNames
554+
468555 /** Add all evidence parameters in `params` as implicit parameters to `meth`.
469556 * The position of the added parameters is determined as follows:
470557 *
@@ -479,36 +566,23 @@ object desugar {
479566 private def addEvidenceParams (meth : DefDef , params : List [ValDef ])(using Context ): DefDef =
480567 if params.isEmpty then return meth
481568
482- var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
483- for mparams <- meth.paramss; mparam <- mparams do
484- mparam match
485- case tparam : TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot .unapply(_).isDefined) =>
486- boundNames += tparam.name.toTermName
487- case _ =>
569+ val boundNames = getBoundNames(params, meth.paramss)
488570
489- def referencesBoundName (vdef : ValDef ): Boolean =
490- vdef.tpt.existsSubTree:
491- case Ident (name : TermName ) => boundNames.contains(name)
492- case _ => false
571+ val fitParams = fitEvidenceParams(params, meth.name, boundNames)
493572
494- def recur (mparamss : List [ParamClause ]): List [ParamClause ] = mparamss match
495- case ValDefs (mparams) :: _ if mparams.exists(referencesBoundName) =>
496- params :: mparamss
497- case ValDefs (mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit ) =>
498- val normParams =
499- if params.head.mods.flags.is(Given ) != mparam.mods.flags.is(Given ) then
500- params.map: param =>
501- val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit ))
502- param.withMods(param.mods.withFlags(normFlags))
503- .showing(i " adapted param $result ${result.mods.flags} for ${meth.name}" , Printers .desugar)
504- else params
505- (normParams ++ mparams) :: Nil
506- case mparams :: mparamss1 =>
507- mparams :: recur(mparamss1)
508- case Nil =>
509- params :: Nil
510-
511- cpy.DefDef (meth)(paramss = recur(meth.paramss))
573+ if meth.removeAttachment(PolyFunctionApply ).isDefined then
574+ // for PolyFunctions we are limited to a single term param list, so we
575+ // reuse the fitEvidenceParams logic to compute the new parameter lists
576+ // and then we add the other parameter lists as function types to the
577+ // return type
578+ val (paramsFst, paramsSnd) = fitParams(meth.paramss)
579+ if ctx.mode.is(Mode .Type ) then
580+ cpy.DefDef (meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
581+ else
582+ cpy.DefDef (meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))
583+ else
584+ val (paramsFst, paramsSnd) = fitParams(meth.paramss)
585+ cpy.DefDef (meth)(paramss = paramsFst ++ paramsSnd)
512586 end addEvidenceParams
513587
514588 /** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
@@ -1224,27 +1298,29 @@ object desugar {
12241298 /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12251299 * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
12261300 */
1227- def makePolyFunctionType (tree : PolyFunction )(using Context ): RefinedTypeTree =
1228- val PolyFunction (tparams : List [untpd.TypeDef ] @ unchecked, fun @ untpd.Function (vparamTypes, res)) = tree : @ unchecked
1229- val paramFlags = fun match
1230- case fun : FunctionWithMods =>
1231- // TODO: make use of this in the desugaring when pureFuns is enabled.
1232- // val isImpure = funFlags.is(Impure)
1233-
1234- // Function flags to be propagated to each parameter in the desugared method type.
1235- val givenFlag = fun.mods.flags.toTermFlags & Given
1236- fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1237- case _ =>
1238- vparamTypes.map(_ => EmptyFlags )
1239-
1240- val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1241- case ((p : ValDef , paramFlags), n) => p.withAddedFlags(paramFlags)
1242- case ((p, paramFlags), n) => makeSyntheticParameter(n + 1 , p).withAddedFlags(paramFlags)
1243- }.toList
1244-
1245- RefinedTypeTree (ref(defn.PolyFunctionType ), List (
1246- DefDef (nme.apply, tparams :: vparams :: Nil , res, EmptyTree ).withFlags(Synthetic )
1247- )).withSpan(tree.span)
1301+ def makePolyFunctionType (tree : PolyFunction )(using Context ): RefinedTypeTree = (tree : @ unchecked) match
1302+ case PolyFunction (tparams : List [untpd.TypeDef ] @ unchecked, fun @ untpd.Function (vparamTypes, res)) =>
1303+ val paramFlags = fun match
1304+ case fun : FunctionWithMods =>
1305+ // TODO: make use of this in the desugaring when pureFuns is enabled.
1306+ // val isImpure = funFlags.is(Impure)
1307+
1308+ // Function flags to be propagated to each parameter in the desugared method type.
1309+ val givenFlag = fun.mods.flags.toTermFlags & Given
1310+ fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1311+ case _ =>
1312+ vparamTypes.map(_ => EmptyFlags )
1313+
1314+ val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1315+ case ((p : ValDef , paramFlags), n) => p.withAddedFlags(paramFlags)
1316+ case ((p, paramFlags), n) => makeSyntheticParameter(n + 1 , p).withAddedFlags(paramFlags)
1317+ }.toList
1318+
1319+ RefinedTypeTree (ref(defn.PolyFunctionType ), List (
1320+ DefDef (nme.apply, tparams :: vparams :: Nil , res, EmptyTree )
1321+ .withFlags(Synthetic )
1322+ .withAttachment(PolyFunctionApply , ())
1323+ )).withSpan(tree.span)
12481324 end makePolyFunctionType
12491325
12501326 /** Invent a name for an anonympus given of type or template `impl`. */
0 commit comments