Skip to content

Commit 6670c18

Browse files
committed
ref
1 parent 880fc0f commit 6670c18

File tree

5 files changed

+103
-5
lines changed

5 files changed

+103
-5
lines changed

src/ai-bundle/config/options.php

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,15 @@
775775
->end()
776776
->end()
777777
->end()
778+
->arrayNode('chat')
779+
->useAttributeAsKey('name')
780+
->arrayPrototype()
781+
->children()
782+
->stringNode('agent')->cannotBeEmpty()->end()
783+
->stringNode('message_store')->cannotBeEmpty()->end()
784+
->end()
785+
->end()
786+
->end()
778787
->arrayNode('vectorizer')
779788
->info('Vectorizers for converting strings to Vector objects and transforming TextDocument arrays to VectorDocument arrays')
780789
->useAttributeAsKey('name')

src/ai-bundle/src/AiBundle.php

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@
3131
use Symfony\AI\Agent\Toolbox\ToolFactory\MemoryToolFactory;
3232
use Symfony\AI\AiBundle\DependencyInjection\ProcessorCompilerPass;
3333
use Symfony\AI\AiBundle\Exception\InvalidArgumentException;
34+
use Symfony\AI\AiBundle\Profiler\TraceableChat;
3435
use Symfony\AI\AiBundle\Profiler\TraceableMessageStore;
3536
use Symfony\AI\AiBundle\Profiler\TraceablePlatform;
3637
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
3738
use Symfony\AI\AiBundle\Security\Attribute\IsGrantedTool;
3839
use Symfony\AI\Chat\Bridge\HttpFoundation\SessionStore;
3940
use Symfony\AI\Chat\Bridge\Meilisearch\MessageStore as MeilisearchMessageStore;
4041
use Symfony\AI\Chat\Bridge\Pogocache\MessageStore as PogocacheMessageStore;
42+
use Symfony\AI\Chat\Chat;
43+
use Symfony\AI\Chat\ChatInterface;
4144
use Symfony\AI\Chat\MessageStoreInterface;
4245
use Symfony\AI\Platform\Bridge\Anthropic\PlatformFactory as AnthropicPlatformFactory;
4346
use Symfony\AI\Platform\Bridge\Azure\OpenAi\PlatformFactory as AzureOpenAiPlatformFactory;
@@ -182,12 +185,12 @@ public function loadExtension(array $config, ContainerConfigurator $container, C
182185

183186
if ($builder->getParameter('kernel.debug')) {
184187
foreach ($messageStores as $messageStore) {
185-
$traceablePlatformDefinition = (new Definition(TraceableMessageStore::class))
188+
$traceableMessageStoreDefinition = (new Definition(TraceableMessageStore::class))
186189
->setDecoratedService($messageStore)
187190
->setArguments([new Reference('.inner')])
188191
->addTag('ai.traceable_message_store');
189192
$suffix = u($messageStore)->afterLast('.')->toString();
190-
$builder->setDefinition('ai.traceable_message_store.'.$suffix, $traceablePlatformDefinition);
193+
$builder->setDefinition('ai.traceable_message_store.'.$suffix, $traceableMessageStoreDefinition);
191194
}
192195
}
193196

@@ -196,6 +199,27 @@ public function loadExtension(array $config, ContainerConfigurator $container, C
196199
$builder->removeDefinition('ai.command.drop_message_store');
197200
}
198201

202+
foreach ($config['chat'] ?? [] as $name => $chat) {
203+
$this->processChatConfig($name, $chat, $builder);
204+
}
205+
206+
$chats = array_keys($builder->findTaggedServiceIds('ai.chat'));
207+
208+
if (1 === \count($chats)) {
209+
$builder->setAlias(ChatInterface::class, reset($chats));
210+
}
211+
212+
if ($builder->getParameter('kernel.debug')) {
213+
foreach ($chats as $chat) {
214+
$traceableChatDefinition = (new Definition(TraceableChat::class))
215+
->setDecoratedService($chat)
216+
->setArguments([new Reference('.inner')])
217+
->addTag('ai.traceable_chat');
218+
$suffix = u($chat)->afterLast('.')->toString();
219+
$builder->setDefinition('ai.traceable_chat.'.$suffix, $traceableChatDefinition);
220+
}
221+
}
222+
199223
foreach ($config['vectorizer'] ?? [] as $vectorizerName => $vectorizer) {
200224
$this->processVectorizerConfig($vectorizerName, $vectorizer, $builder);
201225
}
@@ -1399,6 +1423,26 @@ private function processMessageStoreConfig(string $type, array $messageStores, C
13991423
}
14001424
}
14011425

1426+
/**
1427+
* @param array{
1428+
* agent: string,
1429+
* message_store: string,
1430+
* } $configuration
1431+
*/
1432+
private function processChatConfig(string $name, array $configuration, ContainerBuilder $container): void
1433+
{
1434+
$definition = new Definition(Chat::class);
1435+
$definition
1436+
->setArguments([
1437+
new Reference('ai.agent.'.$configuration['agent']),
1438+
new Reference('ai.message_store.'.$configuration['message_store']),
1439+
])
1440+
->addTag('ai.chat');
1441+
1442+
$container->setDefinition('ai.chat.'.$name, $definition);
1443+
$container->registerAliasForArgument('ai.chat.'.$name, ChatInterface::class, $name);
1444+
}
1445+
14021446
/**
14031447
* @param array<string, mixed> $config
14041448
*/

src/ai-bundle/tests/DependencyInjection/AiBundleTest.php

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
use Symfony\AI\Agent\MultiAgent\Handoff;
2222
use Symfony\AI\Agent\MultiAgent\MultiAgent;
2323
use Symfony\AI\AiBundle\AiBundle;
24+
use Symfony\AI\Chat\ChatInterface;
2425
use Symfony\AI\Chat\MessageStoreInterface;
2526
use Symfony\AI\Store\Document\Filter\TextContainsFilter;
2627
use Symfony\AI\Store\Document\Loader\InMemoryLoader;
@@ -192,6 +193,41 @@ public function testInjectionMessageStoreAliasIsRegistered()
192193
$this->assertTrue($container->hasAlias('.'.MessageStoreInterface::class.' $session_session'));
193194
}
194195

196+
public function testInjectionChatAliasIsRegistered()
197+
{
198+
$container = $this->buildContainer([
199+
'ai' => [
200+
'agent' => [
201+
'my_agent' => [
202+
'model' => 'gpt-4',
203+
],
204+
],
205+
'message_store' => [
206+
'memory' => [
207+
'main' => [
208+
'identifier' => '_memory',
209+
],
210+
],
211+
],
212+
'chat' => [
213+
'main' => [
214+
'agent' => 'my_agent',
215+
'message_store' => 'memory.main',
216+
],
217+
],
218+
],
219+
]);
220+
221+
$this->assertCount(1, $container->findTaggedServiceIds('ai.chat'));
222+
223+
$this->assertTrue($container->hasAlias(ChatInterface::class.' $main'));
224+
225+
$chatDefinition = $container->getDefinition('ai.chat.main');
226+
$this->assertCount(2, $chatDefinition->getArguments());
227+
$this->assertInstanceOf(Reference::class, $chatDefinition->getArgument(0));
228+
$this->assertInstanceOf(Reference::class, $chatDefinition->getArgument(1));
229+
}
230+
195231
public function testAgentHasTag()
196232
{
197233
$container = $this->buildContainer([
@@ -3055,6 +3091,12 @@ private function getFullConfig(): array
30553091
],
30563092
],
30573093
],
3094+
'chat' => [
3095+
'main' => [
3096+
'agent' => 'my_chat_agent',
3097+
'message_store' => 'cache',
3098+
],
3099+
],
30583100
'vectorizer' => [
30593101
'test_vectorizer' => [
30603102
'platform' => 'mistral_platform_service_id',

src/ai-bundle/tests/Profiler/DataCollectorTest.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public function testCollectsDataForNonStreamingResponse()
3939
$this->assertSame('Assistant response', $result->asText());
4040

4141
$dataCollector = new DataCollector([$traceablePlatform], []);
42+
$dataCollector = new DataCollector([$traceablePlatform], $this->createStub(ToolboxInterface::class), [], [], []);
4243
$dataCollector->lateCollect();
4344

4445
$this->assertCount(1, $dataCollector->getPlatformCalls());
@@ -63,6 +64,7 @@ public function testCollectsDataForStreamingResponse()
6364
$this->assertSame('Assistant response', implode('', iterator_to_array($result->asStream())));
6465

6566
$dataCollector = new DataCollector([$traceablePlatform], []);
67+
$dataCollector = new DataCollector([$traceablePlatform], $this->createStub(ToolboxInterface::class), [], [], []);
6668
$dataCollector->lateCollect();
6769

6870
$this->assertCount(1, $dataCollector->getPlatformCalls());

src/platform/src/ModelCatalog/FallbackModelCatalog.php

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@
2525
*/
2626
class FallbackModelCatalog extends AbstractModelCatalog
2727
{
28-
public function __construct()
29-
{
28+
public function __construct(
29+
private readonly ?string $expectedModel = Model::class,
30+
) {
3031
$this->models = [];
3132
}
3233

3334
public function getModel(string $modelName): Model
3435
{
3536
$parsed = self::parseModelName($modelName);
3637

37-
return new Model($parsed['name'], Capability::cases(), $parsed['options']);
38+
return new $this->expectedModel($parsed['name'], Capability::cases(), $parsed['options']);
3839
}
3940
}

0 commit comments

Comments
 (0)