Skip to content

Commit 551315b

Browse files
committed
Adding support for hybrid search.
1 parent 1c73570 commit 551315b

File tree

7 files changed

+7979
-5219
lines changed

7 files changed

+7979
-5219
lines changed

redis/commands/search/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .commands import (
55
AGGREGATE_CMD,
66
CONFIG_CMD,
7+
HYBRID_CMD,
78
INFO_CMD,
89
PROFILE_CMD,
910
SEARCH_CMD,
@@ -102,6 +103,7 @@ def __init__(self, client, index_name="idx"):
102103
self._RESP2_MODULE_CALLBACKS = {
103104
INFO_CMD: self._parse_info,
104105
SEARCH_CMD: self._parse_search,
106+
HYBRID_CMD: self._parse_hybrid_search,
105107
AGGREGATE_CMD: self._parse_aggregate,
106108
PROFILE_CMD: self._parse_profile,
107109
SPELLCHECK_CMD: self._parse_spellcheck,

redis/commands/search/commands.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
import itertools
22
import time
3-
from typing import Dict, List, Optional, Union
3+
from typing import Any, Dict, List, Optional, Union
44

5+
from redis._parsers.helpers import pairs_to_dict
56
from redis.client import NEVER_DECODE, Pipeline
7+
from redis.commands.search.hybrid_query import (
8+
HybridCursorQuery,
9+
HybridPostProcessingConfig,
10+
HybridQuery,
11+
)
12+
from redis.commands.search.hybrid_result import HybridCursorResult, HybridResult
613
from redis.utils import deprecated_function
714

815
from ..helpers import get_protocol_version
916
from ._util import to_string
10-
from .aggregation import AggregateRequest, AggregateResult, Cursor
17+
from .aggregation import (
18+
AggregateRequest,
19+
AggregateResult,
20+
Cursor,
21+
)
1122
from .document import Document
1223
from .field import Field
1324
from .index_definition import IndexDefinition
@@ -47,6 +58,7 @@
4758
SUGGET_COMMAND = "FT.SUGGET"
4859
SYNUPDATE_CMD = "FT.SYNUPDATE"
4960
SYNDUMP_CMD = "FT.SYNDUMP"
61+
HYBRID_CMD = "FT.HYBRID"
5062

5163
NOOFFSETS = "NOOFFSETS"
5264
NOFIELDS = "NOFIELDS"
@@ -84,6 +96,28 @@ def _parse_search(self, res, **kwargs):
8496
field_encodings=kwargs["query"]._return_fields_decode_as,
8597
)
8698

99+
def _parse_hybrid_search(self, res, **kwargs):
100+
res_dict = pairs_to_dict(res, decode_keys=True)
101+
if "cursor" in kwargs:
102+
return HybridCursorResult(
103+
search_cursor_id=int(res_dict["SEARCH"]),
104+
vsim_cursor_id=int(res_dict["VSIM"]),
105+
)
106+
107+
results: List[Dict[str, Any]] = []
108+
# the original results are a list of lists
109+
# we convert them to a list of dicts
110+
for res_item in res_dict["results"]:
111+
item_dict = pairs_to_dict(res_item, decode_keys=True)
112+
results.append(item_dict)
113+
114+
return HybridResult(
115+
total_results=int(res_dict["total_results"]),
116+
results=results,
117+
warnings=res_dict["warnings"],
118+
execution_time=float(res_dict["execution_time"]),
119+
)
120+
87121
def _parse_aggregate(self, res, **kwargs):
88122
return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"])
89123

@@ -470,7 +504,7 @@ def get_params_args(
470504
return []
471505
args = []
472506
if len(query_params) > 0:
473-
args.append("params")
507+
args.append("PARAMS")
474508
args.append(len(query_params) * 2)
475509
for key, value in query_params.items():
476510
args.append(key)
@@ -525,6 +559,44 @@ def search(
525559
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
526560
)
527561

562+
def hybrid_search(
563+
self,
564+
query: HybridQuery,
565+
post_processing: Optional[HybridPostProcessingConfig] = None,
566+
params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
567+
timeout: Optional[int] = None,
568+
cursor: Optional[HybridCursorQuery] = None,
569+
) -> Union[HybridResult, HybridCursorResult, Pipeline]:
570+
"""
571+
Execute a hybrid search using both text and vector queries
572+
Args:
573+
query: HybridQuery object
574+
575+
"""
576+
index = self.index_name
577+
options = {}
578+
pieces = [HYBRID_CMD, index]
579+
pieces.extend(query.get_args())
580+
if post_processing:
581+
pieces.extend(post_processing.build_args())
582+
if params_substitution:
583+
pieces.extend(self.get_params_args(params_substitution))
584+
if timeout:
585+
pieces.extend(("TIMEOUT", timeout))
586+
if cursor:
587+
options["cursor"] = True
588+
pieces.extend(cursor.build_args())
589+
590+
if get_protocol_version(self.client) not in ["3", 3]:
591+
options[NEVER_DECODE] = True
592+
593+
res = self.execute_command(*pieces, **options)
594+
595+
if isinstance(res, Pipeline):
596+
return res
597+
598+
return self._parse_results(HYBRID_CMD, res, **options)
599+
528600
def explain(
529601
self,
530602
query: Union[str, Query],
@@ -965,6 +1037,44 @@ async def search(
9651037
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
9661038
)
9671039

1040+
async def hybrid_search(
1041+
self,
1042+
query: HybridQuery,
1043+
post_processing: Optional[HybridPostProcessingConfig] = None,
1044+
params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
1045+
timeout: Optional[int] = None,
1046+
cursor: Optional[HybridCursorQuery] = None,
1047+
) -> Union[HybridResult, HybridCursorResult, Pipeline]:
1048+
"""
1049+
Execute a hybrid search using both text and vector queries
1050+
Args:
1051+
query: HybridQuery object
1052+
1053+
"""
1054+
index = self.index_name
1055+
options = {}
1056+
pieces = [HYBRID_CMD, index]
1057+
pieces.extend(query.get_args())
1058+
if post_processing:
1059+
pieces.extend(post_processing.build_args())
1060+
if params_substitution:
1061+
pieces.extend(self.get_params_args(params_substitution))
1062+
if timeout:
1063+
pieces.extend(("TIMEOUT", timeout))
1064+
if cursor:
1065+
options["cursor"] = True
1066+
pieces.extend(cursor.build_args())
1067+
1068+
if get_protocol_version(self.client) not in ["3", 3]:
1069+
options[NEVER_DECODE] = True
1070+
1071+
res = await self.execute_command(*pieces, **options)
1072+
1073+
if isinstance(res, Pipeline):
1074+
return res
1075+
1076+
return self._parse_results(HYBRID_CMD, res, **options)
1077+
9681078
async def aggregate(
9691079
self,
9701080
query: Union[AggregateResult, Cursor],

0 commit comments

Comments
 (0)