Skip to content

Commit 1e8345a

Browse files
committed
Adding support for hybrid search.
1 parent c620503 commit 1e8345a

File tree

6 files changed

+1468
-6
lines changed

6 files changed

+1468
-6
lines changed

redis/commands/search/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .commands import (
55
AGGREGATE_CMD,
66
CONFIG_CMD,
7+
HYBRID_CMD,
78
INFO_CMD,
89
PROFILE_CMD,
910
SEARCH_CMD,
@@ -102,6 +103,7 @@ def __init__(self, client, index_name="idx"):
102103
self._RESP2_MODULE_CALLBACKS = {
103104
INFO_CMD: self._parse_info,
104105
SEARCH_CMD: self._parse_search,
106+
HYBRID_CMD: self._parse_hybrid_search,
105107
AGGREGATE_CMD: self._parse_aggregate,
106108
PROFILE_CMD: self._parse_profile,
107109
SPELLCHECK_CMD: self._parse_spellcheck,

redis/commands/search/commands.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
import itertools
22
import time
3-
from typing import Dict, List, Optional, Union
3+
from typing import Any, Dict, List, Optional, Union
44

5+
from redis._parsers.helpers import pairs_to_dict
56
from redis.client import NEVER_DECODE, Pipeline
7+
from redis.commands.search.dialect import DEFAULT_DIALECT
8+
from redis.commands.search.hybrid_query import (
9+
HybridCursorQuery,
10+
HybridPostProcessingConfig,
11+
HybridQuery,
12+
)
13+
from redis.commands.search.hybrid_result import HybridCursorResult, HybridResult
614
from redis.utils import deprecated_function
715

816
from ..helpers import get_protocol_version
917
from ._util import to_string
10-
from .aggregation import AggregateRequest, AggregateResult, Cursor
18+
from .aggregation import (
19+
AggregateRequest,
20+
AggregateResult,
21+
Cursor,
22+
)
1123
from .document import Document
1224
from .field import Field
1325
from .index_definition import IndexDefinition
@@ -47,6 +59,7 @@
4759
SUGGET_COMMAND = "FT.SUGGET"
4860
SYNUPDATE_CMD = "FT.SYNUPDATE"
4961
SYNDUMP_CMD = "FT.SYNDUMP"
62+
HYBRID_CMD = "FT.HYBRID"
5063

5164
NOOFFSETS = "NOOFFSETS"
5265
NOFIELDS = "NOFIELDS"
@@ -84,6 +97,28 @@ def _parse_search(self, res, **kwargs):
8497
field_encodings=kwargs["query"]._return_fields_decode_as,
8598
)
8699

100+
def _parse_hybrid_search(self, res, **kwargs):
101+
res_dict = pairs_to_dict(res, decode_keys=True)
102+
if "cursor" in kwargs:
103+
return HybridCursorResult(
104+
search_cursor_id=int(res_dict["SEARCH"]),
105+
vsim_cursor_id=int(res_dict["VSIM"]),
106+
)
107+
108+
results: List[Dict[str, Any]] = []
109+
# the original results are a list of lists
110+
# we convert them to a list of dicts
111+
for res_item in res_dict["results"]:
112+
item_dict = pairs_to_dict(res_item, decode_keys=True)
113+
results.append(item_dict)
114+
115+
return HybridResult(
116+
total_results=int(res_dict["total_results"]),
117+
results=results,
118+
warnings=res_dict["warnings"],
119+
execution_time=float(res_dict["execution_time"]),
120+
)
121+
87122
def _parse_aggregate(self, res, **kwargs):
88123
return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"])
89124

@@ -470,7 +505,7 @@ def get_params_args(
470505
return []
471506
args = []
472507
if len(query_params) > 0:
473-
args.append("params")
508+
args.append("PARAMS")
474509
args.append(len(query_params) * 2)
475510
for key, value in query_params.items():
476511
args.append(key)
@@ -525,6 +560,52 @@ def search(
525560
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
526561
)
527562

563+
def hybrid_search(
564+
self,
565+
query: HybridQuery,
566+
post_processing: Optional[HybridPostProcessingConfig] = None,
567+
params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
568+
explain_score: bool = False,
569+
timeout: Optional[int] = None,
570+
cursor: Optional[HybridCursorQuery] = None,
571+
dialect: Optional[int] = DEFAULT_DIALECT,
572+
) -> Union[HybridResult, HybridCursorResult]:
573+
"""
574+
Execute a hybrid search using both text and vector queries
575+
Args:
576+
query: HybridQuery object
577+
578+
"""
579+
index = self.index_name
580+
options = {}
581+
pieces = [HYBRID_CMD, index]
582+
pieces.extend(query.get_args())
583+
if post_processing:
584+
pieces.extend(post_processing.build_args())
585+
if params_substitution:
586+
pieces.extend(self.get_params_args(params_substitution))
587+
if explain_score:
588+
pieces.append("EXPLAINSCORE")
589+
if timeout:
590+
pieces.extend(("TIMEOUT", timeout))
591+
if cursor:
592+
options["cursor"] = True
593+
pieces.extend(cursor.build_args())
594+
if dialect:
595+
pieces.extend(("DIALECT", dialect))
596+
597+
if get_protocol_version(self.client) not in ["3", 3]:
598+
options[NEVER_DECODE] = True
599+
600+
print("")
601+
cli_command = ""
602+
for arg in pieces:
603+
cli_command = f"{cli_command} {arg}"
604+
print(cli_command)
605+
606+
res = self.execute_command(*pieces, **options)
607+
return self._parse_results(HYBRID_CMD, res, **options)
608+
528609
def explain(
529610
self,
530611
query: Union[str, Query],

0 commit comments

Comments
 (0)