@@ -4,7 +4,7 @@ import com.cjcrafter.openai.chat.*
44import com.cjcrafter.openai.completions.CompletionRequest
55import com.cjcrafter.openai.completions.CompletionResponse
66import com.cjcrafter.openai.completions.CompletionResponseChunk
7- import com.cjcrafter.openai.completions.CompletionUsage
7+ import com.fasterxml.jackson.databind.JavaType
88import com.fasterxml.jackson.databind.node.ObjectNode
99import okhttp3.*
1010import okhttp3.MediaType.Companion.toMediaType
@@ -32,62 +32,40 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
3232 .post(body).build()
3333 }
3434
35- override fun createCompletion (request : CompletionRequest ): CompletionResponse {
36- @Suppress(" DEPRECATION" )
37- request.stream = false // use streamCompletion for stream=true
38- val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT )
39-
40- val httpResponse = client.newCall(httpRequest).execute()
41- println (httpResponse)
42-
43- return CompletionResponse (" 1" , 1 , " 1" , listOf (), CompletionUsage (1 , 1 , 1 ))
44- }
45-
46- override fun streamCompletion (request : CompletionRequest ): Iterable <CompletionResponseChunk > {
47- @Suppress(" DEPRECATION" )
48- request.stream = true // use createCompletion for stream=false
49- val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT )
50-
51- return listOf ()
52- }
53-
54- override fun createChatCompletion (request : ChatRequest ): ChatResponse {
55- @Suppress(" DEPRECATION" )
56- request.stream = false // use streamChatCompletion for stream=true
57- val httpRequest = buildRequest(request, CHAT_ENDPOINT )
58-
35+ protected open fun <T > executeRequest (httpRequest : Request , responseType : Class <T >): T {
5936 val httpResponse = client.newCall(httpRequest).execute()
6037 if (! httpResponse.isSuccessful) {
6138 val json = httpResponse.body?.byteStream()?.bufferedReader()?.readText()
6239 httpResponse.close()
63- throw IOException (" Unexpected code $httpResponse , recieved : $json " )
40+ throw IOException (" Unexpected code $httpResponse , received : $json " )
6441 }
6542
66- val json = httpResponse.body?.byteStream()?.bufferedReader() ? : throw IOException (" Response body is null" )
67- val str = json.readText()
68- return objectMapper.readValue(str, ChatResponse ::class .java)
43+ val jsonReader = httpResponse.body?.byteStream()?.bufferedReader()
44+ ? : throw IOException (" Response body is null" )
45+ val responseStr = jsonReader.readText()
46+ return objectMapper.readValue(responseStr, responseType)
6947 }
7048
71- override fun streamChatCompletion (request : ChatRequest ): Iterable <ChatResponseChunk > {
72- request.stream = true // Set streaming to true
73- val httpRequest = buildRequest(request, CHAT_ENDPOINT )
74-
75- return object : Iterable <ChatResponseChunk > {
76- override fun iterator (): Iterator <ChatResponseChunk > {
77- val httpResponse = client.newCall(httpRequest).execute()
49+ private fun <T > streamResponses (
50+ request : Request ,
51+ responseType : JavaType ,
52+ updateResponse : (T , String ) -> T
53+ ): Iterable <T > {
54+ return object : Iterable <T > {
55+ override fun iterator (): Iterator <T > {
56+ val httpResponse = client.newCall(request).execute()
7857
7958 if (! httpResponse.isSuccessful) {
8059 httpResponse.close()
8160 throw IOException (" Unexpected code $httpResponse " )
8261 }
8362
84- val reader = httpResponse.body?.byteStream()?.bufferedReader() ? : throw IOException (" Response body is null" )
63+ val reader = httpResponse.body?.byteStream()?.bufferedReader()
64+ ? : throw IOException (" Response body is null" )
8565
86- // Only instantiate 1 ChatResponseChunk, otherwise simply update
87- // the existing one. This lets us accumulate the message.
88- var chunk: ChatResponseChunk ? = null
66+ var currentResponse: T ? = null
8967
90- return object : Iterator <ChatResponseChunk > {
68+ return object : Iterator <T > {
9169 private var nextLine: String? = readNextLine(reader)
9270
9371 private fun readNextLine (reader : BufferedReader ): String? {
@@ -98,8 +76,6 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
9876 reader.close()
9977 return null
10078 }
101-
102- // Check if the line starts with 'data:' and skip empty lines
10379 } while (line != null && (line.isEmpty() || ! line.startsWith(" data: " )))
10480 return line?.removePrefix(" data: " )
10581 }
@@ -108,24 +84,57 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
10884 return nextLine != null
10985 }
11086
111- override fun next (): ChatResponseChunk {
112- val currentLine = nextLine ? : throw NoSuchElementException (" No more lines" )
113- // println(" $currentLine")
114- chunk = chunk?.apply { update(objectMapper.readTree(currentLine) as ObjectNode ) } ? : objectMapper.readValue(currentLine, ChatResponseChunk ::class .java)
115- nextLine = readNextLine(reader) // Prepare the next line
116- return chunk!!
117- // return ChatResponseChunk("1", 1, listOf())
87+ override fun next (): T {
88+ val line = nextLine ? : throw NoSuchElementException (" No more lines" )
89+ currentResponse = if (currentResponse == null ) {
90+ objectMapper.readValue(line, responseType)
91+ } else {
92+ updateResponse(currentResponse!! , line)
93+ }
94+ nextLine = readNextLine(reader)
95+ return currentResponse!!
11896 }
11997 }
12098 }
12199 }
122100 }
123101
102+ override fun createCompletion (request : CompletionRequest ): CompletionResponse {
103+ @Suppress(" DEPRECATION" )
104+ request.stream = false // use streamCompletion for stream=true
105+ val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT )
106+ return executeRequest(httpRequest, CompletionResponse ::class .java)
107+ }
108+
109+ override fun streamCompletion (request : CompletionRequest ): Iterable <CompletionResponseChunk > {
110+ @Suppress(" DEPRECATION" )
111+ request.stream = true
112+ val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT )
113+ return streamResponses(httpRequest, objectMapper.typeFactory.constructType(CompletionResponseChunk ::class .java)) { response, newLine ->
114+ // We don't have any update logic, so we should ignore the old response and just return a new one
115+ objectMapper.readValue(newLine, CompletionResponseChunk ::class .java)
116+ }
117+ }
118+
119+ override fun createChatCompletion (request : ChatRequest ): ChatResponse {
120+ @Suppress(" DEPRECATION" )
121+ request.stream = false // use streamChatCompletion for stream=true
122+ val httpRequest = buildRequest(request, CHAT_ENDPOINT )
123+ return executeRequest(httpRequest, ChatResponse ::class .java)
124+ }
125+
126+ override fun streamChatCompletion (request : ChatRequest ): Iterable <ChatResponseChunk > {
127+ @Suppress(" DEPRECATION" )
128+ request.stream = true
129+ val httpRequest = buildRequest(request, CHAT_ENDPOINT )
130+ return streamResponses(httpRequest, objectMapper.typeFactory.constructType(ChatResponseChunk ::class .java)) { response, newLine ->
131+ response.update(objectMapper.readTree(newLine) as ObjectNode )
132+ response
133+ }
134+ }
135+
124136 companion object {
125137 const val COMPLETIONS_ENDPOINT = " v1/completions"
126138 const val CHAT_ENDPOINT = " v1/chat/completions"
127- const val IMAGE_CREATE_ENDPOINT = " v1/images/generations"
128- const val IMAGE_EDIT_ENDPOINT = " v1/images/edits"
129- const val IMAGE_VARIATION_ENDPOINT = " v1/images/variations"
130139 }
131140}
0 commit comments