|
1 | 1 | import itertools |
2 | 2 | import time |
3 | | -from typing import Dict, List, Optional, Union |
| 3 | +from typing import Any, Dict, List, Optional, Union |
4 | 4 |
|
| 5 | +from redis._parsers.helpers import pairs_to_dict |
5 | 6 | from redis.client import NEVER_DECODE, Pipeline |
| 7 | +from redis.commands.search.hybrid_query import ( |
| 8 | + HybridCursorQuery, |
| 9 | + HybridPostProcessingConfig, |
| 10 | + HybridQuery, |
| 11 | +) |
| 12 | +from redis.commands.search.hybrid_result import HybridCursorResult, HybridResult |
6 | 13 | from redis.utils import deprecated_function |
7 | 14 |
|
8 | 15 | from ..helpers import get_protocol_version |
9 | 16 | from ._util import to_string |
10 | | -from .aggregation import AggregateRequest, AggregateResult, Cursor |
| 17 | +from .aggregation import ( |
| 18 | + AggregateRequest, |
| 19 | + AggregateResult, |
| 20 | + Cursor, |
| 21 | +) |
11 | 22 | from .document import Document |
12 | 23 | from .field import Field |
13 | 24 | from .index_definition import IndexDefinition |
|
47 | 58 | SUGGET_COMMAND = "FT.SUGGET" |
48 | 59 | SYNUPDATE_CMD = "FT.SYNUPDATE" |
49 | 60 | SYNDUMP_CMD = "FT.SYNDUMP" |
| 61 | +HYBRID_CMD = "FT.HYBRID" |
50 | 62 |
|
51 | 63 | NOOFFSETS = "NOOFFSETS" |
52 | 64 | NOFIELDS = "NOFIELDS" |
@@ -84,6 +96,28 @@ def _parse_search(self, res, **kwargs): |
84 | 96 | field_encodings=kwargs["query"]._return_fields_decode_as, |
85 | 97 | ) |
86 | 98 |
|
| 99 | + def _parse_hybrid_search(self, res, **kwargs): |
| 100 | + res_dict = pairs_to_dict(res, decode_keys=True) |
| 101 | + if "cursor" in kwargs: |
| 102 | + return HybridCursorResult( |
| 103 | + search_cursor_id=int(res_dict["SEARCH"]), |
| 104 | + vsim_cursor_id=int(res_dict["VSIM"]), |
| 105 | + ) |
| 106 | + |
| 107 | + results: List[Dict[str, Any]] = [] |
| 108 | + # the original results are a list of lists |
| 109 | + # we convert them to a list of dicts |
| 110 | + for res_item in res_dict["results"]: |
| 111 | + item_dict = pairs_to_dict(res_item, decode_keys=True) |
| 112 | + results.append(item_dict) |
| 113 | + |
| 114 | + return HybridResult( |
| 115 | + total_results=int(res_dict["total_results"]), |
| 116 | + results=results, |
| 117 | + warnings=res_dict["warnings"], |
| 118 | + execution_time=float(res_dict["execution_time"]), |
| 119 | + ) |
| 120 | + |
87 | 121 | def _parse_aggregate(self, res, **kwargs): |
88 | 122 | return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"]) |
89 | 123 |
|
@@ -470,7 +504,7 @@ def get_params_args( |
470 | 504 | return [] |
471 | 505 | args = [] |
472 | 506 | if len(query_params) > 0: |
473 | | - args.append("params") |
| 507 | + args.append("PARAMS") |
474 | 508 | args.append(len(query_params) * 2) |
475 | 509 | for key, value in query_params.items(): |
476 | 510 | args.append(key) |
@@ -525,6 +559,44 @@ def search( |
525 | 559 | SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 |
526 | 560 | ) |
527 | 561 |
|
| 562 | + def hybrid_search( |
| 563 | + self, |
| 564 | + query: HybridQuery, |
| 565 | + post_processing: Optional[HybridPostProcessingConfig] = None, |
| 566 | + params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None, |
| 567 | + timeout: Optional[int] = None, |
| 568 | + cursor: Optional[HybridCursorQuery] = None, |
| 569 | + ) -> Union[HybridResult, HybridCursorResult, Pipeline]: |
| 570 | + """ |
| 571 | + Execute a hybrid search using both text and vector queries |
| 572 | + Args: |
| 573 | + query: HybridQuery object |
| 574 | +
|
| 575 | + """ |
| 576 | + index = self.index_name |
| 577 | + options = {} |
| 578 | + pieces = [HYBRID_CMD, index] |
| 579 | + pieces.extend(query.get_args()) |
| 580 | + if post_processing: |
| 581 | + pieces.extend(post_processing.build_args()) |
| 582 | + if params_substitution: |
| 583 | + pieces.extend(self.get_params_args(params_substitution)) |
| 584 | + if timeout: |
| 585 | + pieces.extend(("TIMEOUT", timeout)) |
| 586 | + if cursor: |
| 587 | + options["cursor"] = True |
| 588 | + pieces.extend(cursor.build_args()) |
| 589 | + |
| 590 | + if get_protocol_version(self.client) not in ["3", 3]: |
| 591 | + options[NEVER_DECODE] = True |
| 592 | + |
| 593 | + res = self.execute_command(*pieces, **options) |
| 594 | + |
| 595 | + if isinstance(res, Pipeline): |
| 596 | + return res |
| 597 | + |
| 598 | + return self._parse_results(HYBRID_CMD, res, **options) |
| 599 | + |
528 | 600 | def explain( |
529 | 601 | self, |
530 | 602 | query: Union[str, Query], |
@@ -965,6 +1037,44 @@ async def search( |
965 | 1037 | SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 |
966 | 1038 | ) |
967 | 1039 |
|
| 1040 | + async def hybrid_search( |
| 1041 | + self, |
| 1042 | + query: HybridQuery, |
| 1043 | + post_processing: Optional[HybridPostProcessingConfig] = None, |
| 1044 | + params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None, |
| 1045 | + timeout: Optional[int] = None, |
| 1046 | + cursor: Optional[HybridCursorQuery] = None, |
| 1047 | + ) -> Union[HybridResult, HybridCursorResult, Pipeline]: |
| 1048 | + """ |
| 1049 | + Execute a hybrid search using both text and vector queries |
| 1050 | + Args: |
| 1051 | + query: HybridQuery object |
| 1052 | +
|
| 1053 | + """ |
| 1054 | + index = self.index_name |
| 1055 | + options = {} |
| 1056 | + pieces = [HYBRID_CMD, index] |
| 1057 | + pieces.extend(query.get_args()) |
| 1058 | + if post_processing: |
| 1059 | + pieces.extend(post_processing.build_args()) |
| 1060 | + if params_substitution: |
| 1061 | + pieces.extend(self.get_params_args(params_substitution)) |
| 1062 | + if timeout: |
| 1063 | + pieces.extend(("TIMEOUT", timeout)) |
| 1064 | + if cursor: |
| 1065 | + options["cursor"] = True |
| 1066 | + pieces.extend(cursor.build_args()) |
| 1067 | + |
| 1068 | + if get_protocol_version(self.client) not in ["3", 3]: |
| 1069 | + options[NEVER_DECODE] = True |
| 1070 | + |
| 1071 | + res = await self.execute_command(*pieces, **options) |
| 1072 | + |
| 1073 | + if isinstance(res, Pipeline): |
| 1074 | + return res |
| 1075 | + |
| 1076 | + return self._parse_results(HYBRID_CMD, res, **options) |
| 1077 | + |
968 | 1078 | async def aggregate( |
969 | 1079 | self, |
970 | 1080 | query: Union[AggregateResult, Cursor], |
|
0 commit comments