@@ -183,21 +183,8 @@ class ChatSession
183183
184184 return nil , err_msg, response
185185
186- -- if we are streaming we need to pase the entire fragmented response
187186 if stream_callback
188- assert type ( response) == " string" ,
189- " Expected string response from streaming output"
190-
191- parts = {}
192- f = @client \ create_stream_filter ( c) ->
193- if c = parse_completion_chunk c
194- table.insert parts, c. content
195-
196- f response
197- message = {
198- role : " assistant"
199- content : table.concat parts
200- }
187+ message = response. choices[ 1 ] . message
201188
202189 if append_response
203190 @append_message message
@@ -274,10 +261,7 @@ class OpenAI
274261 for k, v in pairs opts
275262 payload[ k] = v
276263
277- stream_filter = if payload. stream
278- @create_stream_filter chunk_callback
279-
280- @_request " POST" , " /chat/completions" , payload, nil , stream_filter
264+ @_request " POST" , " /chat/completions" , payload, nil , if payload. stream then chunk_callback else nil
281265
282266 -- call /completions
283267 -- opts: additional parameters as described in https://platform.openai.com/docs/api-reference/completions
@@ -363,7 +347,7 @@ class OpenAI
363347 image_generation : ( params) =>
364348 @_request " POST" , " /images/generations" , params
365349
366- _request : ( method, path, payload, more_headers, stream_fn ) =>
350+ _request : ( method, path, payload, more_headers, chunk_callback ) =>
367351 assert path, " missing path"
368352 assert method, " missing method"
369353
@@ -393,7 +377,13 @@ class OpenAI
393377
394378 sink = ltn12. sink. table out
395379
396- if stream_fn
380+ parts = {}
381+ if chunk_callback
382+ stream_fn = @create_stream_filter ( c) ->
383+ if parsed = parse_completion_chunk c
384+ parts[ parsed. index] = parts[ parsed. index] or {}
385+ table.insert parts[ parsed. index] , parsed. content
386+ chunk_callback( c)
397387 sink = ltn12. sink. chain stream_fn, sink
398388
399389 _, status, out_headers = @get_http !. request {
@@ -404,6 +394,22 @@ class OpenAI
404394 : headers
405395 }
406396
397+ if status == 200 and chunk_callback
398+ choices = {}
399+ data = {
400+ object : " chat.completion"
401+ : choices
402+ }
403+ index = 0
404+ while parts[ index]
405+ message = {
406+ role : " assistant"
407+ content : table.concat parts[ index]
408+ }
409+ choices[ index+ 1 ] = { : index, : message }
410+ index += 1
411+ return status, data, out_headers
412+
407413 response = table.concat out
408414 pcall -> response = cjson. decode response
409415 status, response, out_headers
0 commit comments