From 22d77f160bc0d0f15f967f163b85fd08e2df1602 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sun, 26 Oct 2025 10:40:07 +0800 Subject: [PATCH 1/2] . --- .../src/dotty/tools/repl/ReplDriver.scala | 14 +++- compiler/src/dotty/tools/repl/ReplMain.scala | 60 +++++++++++++++ .../test/dotty/tools/repl/ReplMainTest.scala | 73 +++++++++++++++++++ sbt-bridge/src/xsbt/ConsoleInterface.java | 2 +- 4 files changed, 145 insertions(+), 4 deletions(-) create mode 100644 compiler/src/dotty/tools/repl/ReplMain.scala create mode 100644 compiler/test/dotty/tools/repl/ReplMainTest.scala diff --git a/compiler/src/dotty/tools/repl/ReplDriver.scala b/compiler/src/dotty/tools/repl/ReplDriver.scala index ffa1b648446d..441ece19262e 100644 --- a/compiler/src/dotty/tools/repl/ReplDriver.scala +++ b/compiler/src/dotty/tools/repl/ReplDriver.scala @@ -168,9 +168,12 @@ class ReplDriver(settings: Array[String], * observable outside of the CLI, for this reason, most helper methods are * `protected final` to facilitate testing. */ - def runUntilQuit(using initialState: State = initialState)(): State = { + def runUntilQuit(using initialState: State = initialState)(hardcodedInput: java.io.InputStream = null): State = { val terminal = new JLineTerminal + val hardcodedInputLines = + if (hardcodedInput == null) null + else new java.io.BufferedReader(new java.io.InputStreamReader(hardcodedInput)) out.println( s"""Welcome to Scala $simpleVersionString ($javaVersion, Java $javaVmName). |Type in expressions for evaluation. Or try :help.""".stripMargin) @@ -208,8 +211,13 @@ class ReplDriver(settings: Array[String], } try { - val line = terminal.readLine(completer) - ParseResult(line) + println("hardcodedInputLines " + hardcodedInputLines) + val line = + if (hardcodedInputLines != null) hardcodedInputLines.readLine() + else terminal.readLine(completer) + println("line " + line) + if (line == null) Quit + else ParseResult(line) } catch { case _: EndOfFileException => // Ctrl+D Quit diff --git a/compiler/src/dotty/tools/repl/ReplMain.scala b/compiler/src/dotty/tools/repl/ReplMain.scala new file mode 100644 index 000000000000..3cf514502392 --- /dev/null +++ b/compiler/src/dotty/tools/repl/ReplMain.scala @@ -0,0 +1,60 @@ +package dotty.tools.repl + +import java.io.PrintStream + +class ReplMain( + settings: Array[String] = Array.empty, + out: PrintStream = Console.out, + classLoader: Option[ClassLoader] = Some(getClass.getClassLoader), + predefCode: String = "", + testCode: String = "" +): + def run(bindings: ReplMain.Bind[_]*): Any = + try + ReplMain.currentBindings.set(bindings.map{bind => bind.name -> bind.value}.toMap) + + val bindingsPredef = bindings + .map { case bind => + s"def ${bind.name}: ${bind.typeName.value} = dotty.tools.repl.ReplMain.currentBinding[${bind.typeName.value}](\"${bind.name}\")" + } + .mkString("\n") + + val fullPredef = + ReplDriver.pprintImport + + (if bindingsPredef.nonEmpty then s"\n$bindingsPredef\n" else "") + + (if predefCode.nonEmpty then s"\n$predefCode\n" else "") + + val driver = new ReplDriver(settings, out, classLoader, fullPredef) + + if (testCode == "") driver.tryRunning + else { + driver.runUntilQuit(using driver.initialState)(new java.io.ByteArrayInputStream(testCode.getBytes())) + } + () + finally + ReplMain.currentBindings.set(null) + + +object ReplMain: + final case class TypeName[A](value: String) + object TypeName extends TypeNamePlatform + + import scala.quoted._ + + trait TypeNamePlatform: + inline given [A]: TypeName[A] = ${TypeNamePlatform.impl[A]} + + object TypeNamePlatform: + def impl[A](using t: Type[A], ctx: Quotes): Expr[TypeName[A]] = + '{TypeName[A](${Expr(Type.show[A])})} + + + case class Bind[T](name: String, value: T)(implicit val typeName: TypeName[T]) + object Bind: + implicit def ammoniteReplArrowBinder[T](t: (String, T))(implicit typeName: TypeName[T]): Bind[T] = { + Bind(t._1, t._2)(typeName) + } + + def currentBinding[T](s: String): T = currentBindings.get().apply(s).asInstanceOf[T] + + private val currentBindings = new ThreadLocal[Map[String, Any]]() diff --git a/compiler/test/dotty/tools/repl/ReplMainTest.scala b/compiler/test/dotty/tools/repl/ReplMainTest.scala new file mode 100644 index 000000000000..495d688f035d --- /dev/null +++ b/compiler/test/dotty/tools/repl/ReplMainTest.scala @@ -0,0 +1,73 @@ +package dotty.tools +package repl + +import scala.language.unsafeNulls + +import java.io.{ByteArrayOutputStream, PrintStream} +import java.nio.charset.StandardCharsets + +import vulpix.TestConfiguration +import org.junit.Test +import org.junit.Assert._ + +/** Tests for the programmatic REPL API (ReplMain) */ +class ReplMainTest: + + private val defaultOptions = Array("-classpath", TestConfiguration.withCompilerClasspath) + + private def captureOutput(body: PrintStream => Unit): String = + val out = new ByteArrayOutputStream() + val ps = new PrintStream(out, true, StandardCharsets.UTF_8.name) + body(ps) + dotty.shaded.fansi.Str(out.toString(StandardCharsets.UTF_8.name)).plainText + + @Test def basicBinding(): Unit = + val output = captureOutput { out => + val replMain = new ReplMain( + settings = defaultOptions, + out = out, + testCode = "test" + ) + + replMain.run("test" -> 42) + } + + assertTrue(output.contains("val res0: Int = 42")) + + @Test def multipleBindings(): Unit = + val output = captureOutput { out => + val replMain = new ReplMain( + settings = defaultOptions, + out = out, + testCode = "x\ny\nz" + ) + + replMain.run( + "x" -> 1, + "y" -> "hello", + "z" -> true + ) + } + + assertTrue(output.contains("val res0: Int = 1")) + assertTrue(output.contains("val res1: String = \"hello\"")) + assertTrue(output.contains("val res2: Boolean = true")) + + @Test def bindingTypes(): Unit = + val output = captureOutput { out => + val replMain = new ReplMain( + settings = defaultOptions ++ Array("-repl-quit-after-init"), + out = out, + testCode = "list\nmap" + ) + + replMain.run( + "list" -> List(1, 2, 3), + "map" -> Map(1 -> "hello") + ) + } + + assertTrue(output.contains("val res0: List[Int] = List(1, 2, 3)")) + assertTrue(output.contains("val res1: Map[Int, String] = Map(1 -> \"hello\")")) + +end ReplMainTest diff --git a/sbt-bridge/src/xsbt/ConsoleInterface.java b/sbt-bridge/src/xsbt/ConsoleInterface.java index 3ba4e011c8e3..2f9ac33098d5 100644 --- a/sbt-bridge/src/xsbt/ConsoleInterface.java +++ b/sbt-bridge/src/xsbt/ConsoleInterface.java @@ -49,7 +49,7 @@ public void run( state = driver.run(initialCommands, state); // TODO handle failure during initialisation - state = driver.runUntilQuit(state); + state = driver.runUntilQuit(state, null); driver.run(cleanupCommands, state); } } From 8ec198ffa324691b81a598749acd3288dccf4ff9 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sun, 26 Oct 2025 10:44:31 +0800 Subject: [PATCH 2/2] wip --- compiler/src/dotty/tools/repl/ReplDriver.scala | 3 +-- compiler/src/dotty/tools/repl/ReplMain.scala | 14 +++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/compiler/src/dotty/tools/repl/ReplDriver.scala b/compiler/src/dotty/tools/repl/ReplDriver.scala index 441ece19262e..dd67b43ee0a9 100644 --- a/compiler/src/dotty/tools/repl/ReplDriver.scala +++ b/compiler/src/dotty/tools/repl/ReplDriver.scala @@ -211,11 +211,10 @@ class ReplDriver(settings: Array[String], } try { - println("hardcodedInputLines " + hardcodedInputLines) val line = if (hardcodedInputLines != null) hardcodedInputLines.readLine() else terminal.readLine(completer) - println("line " + line) + if (line == null) Quit else ParseResult(line) } catch { diff --git a/compiler/src/dotty/tools/repl/ReplMain.scala b/compiler/src/dotty/tools/repl/ReplMain.scala index 3cf514502392..9d93ee9a6be4 100644 --- a/compiler/src/dotty/tools/repl/ReplMain.scala +++ b/compiler/src/dotty/tools/repl/ReplMain.scala @@ -27,9 +27,9 @@ class ReplMain( val driver = new ReplDriver(settings, out, classLoader, fullPredef) if (testCode == "") driver.tryRunning - else { - driver.runUntilQuit(using driver.initialState)(new java.io.ByteArrayInputStream(testCode.getBytes())) - } + else driver.runUntilQuit(using driver.initialState)( + new java.io.ByteArrayInputStream(testCode.getBytes()) + ) () finally ReplMain.currentBindings.set(null) @@ -49,12 +49,12 @@ object ReplMain: '{TypeName[A](${Expr(Type.show[A])})} - case class Bind[T](name: String, value: T)(implicit val typeName: TypeName[T]) + case class Bind[T](name: String, value: () => T)(implicit val typeName: TypeName[T]) object Bind: implicit def ammoniteReplArrowBinder[T](t: (String, T))(implicit typeName: TypeName[T]): Bind[T] = { - Bind(t._1, t._2)(typeName) + Bind(t._1, () => t._2)(typeName) } - def currentBinding[T](s: String): T = currentBindings.get().apply(s).asInstanceOf[T] + def currentBinding[T](s: String): T = currentBindings.get().apply(s).apply().asInstanceOf[T] - private val currentBindings = new ThreadLocal[Map[String, Any]]() + private val currentBindings = new ThreadLocal[Map[String, () => Any]]()