|
1 | 1 | import itertools |
2 | 2 | import time |
3 | | -from typing import Dict, List, Optional, Union |
| 3 | +from typing import Any, Dict, List, Optional, Union |
4 | 4 |
|
| 5 | +from redis._parsers.helpers import pairs_to_dict |
5 | 6 | from redis.client import NEVER_DECODE, Pipeline |
| 7 | +from redis.commands.search.dialect import DEFAULT_DIALECT |
| 8 | +from redis.commands.search.hybrid_query import ( |
| 9 | + HybridCursorQuery, |
| 10 | + HybridPostProcessingConfig, |
| 11 | + HybridQuery, |
| 12 | +) |
| 13 | +from redis.commands.search.hybrid_result import HybridCursorResult, HybridResult |
6 | 14 | from redis.utils import deprecated_function |
7 | 15 |
|
8 | 16 | from ..helpers import get_protocol_version |
9 | 17 | from ._util import to_string |
10 | | -from .aggregation import AggregateRequest, AggregateResult, Cursor |
| 18 | +from .aggregation import ( |
| 19 | + AggregateRequest, |
| 20 | + AggregateResult, |
| 21 | + Cursor, |
| 22 | +) |
11 | 23 | from .document import Document |
12 | 24 | from .field import Field |
13 | 25 | from .index_definition import IndexDefinition |
|
47 | 59 | SUGGET_COMMAND = "FT.SUGGET" |
48 | 60 | SYNUPDATE_CMD = "FT.SYNUPDATE" |
49 | 61 | SYNDUMP_CMD = "FT.SYNDUMP" |
| 62 | +HYBRID_CMD = "FT.HYBRID" |
50 | 63 |
|
51 | 64 | NOOFFSETS = "NOOFFSETS" |
52 | 65 | NOFIELDS = "NOFIELDS" |
@@ -84,6 +97,28 @@ def _parse_search(self, res, **kwargs): |
84 | 97 | field_encodings=kwargs["query"]._return_fields_decode_as, |
85 | 98 | ) |
86 | 99 |
|
| 100 | + def _parse_hybrid_search(self, res, **kwargs): |
| 101 | + res_dict = pairs_to_dict(res, decode_keys=True) |
| 102 | + if "cursor" in kwargs: |
| 103 | + return HybridCursorResult( |
| 104 | + search_cursor_id=int(res_dict["SEARCH"]), |
| 105 | + vsim_cursor_id=int(res_dict["VSIM"]), |
| 106 | + ) |
| 107 | + |
| 108 | + results: List[Dict[str, Any]] = [] |
| 109 | + # the original results are a list of lists |
| 110 | + # we convert them to a list of dicts |
| 111 | + for res_item in res_dict["results"]: |
| 112 | + item_dict = pairs_to_dict(res_item, decode_keys=True) |
| 113 | + results.append(item_dict) |
| 114 | + |
| 115 | + return HybridResult( |
| 116 | + total_results=int(res_dict["total_results"]), |
| 117 | + results=results, |
| 118 | + warnings=res_dict["warnings"], |
| 119 | + execution_time=float(res_dict["execution_time"]), |
| 120 | + ) |
| 121 | + |
87 | 122 | def _parse_aggregate(self, res, **kwargs): |
88 | 123 | return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"]) |
89 | 124 |
|
@@ -470,7 +505,7 @@ def get_params_args( |
470 | 505 | return [] |
471 | 506 | args = [] |
472 | 507 | if len(query_params) > 0: |
473 | | - args.append("params") |
| 508 | + args.append("PARAMS") |
474 | 509 | args.append(len(query_params) * 2) |
475 | 510 | for key, value in query_params.items(): |
476 | 511 | args.append(key) |
@@ -525,6 +560,52 @@ def search( |
525 | 560 | SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 |
526 | 561 | ) |
527 | 562 |
|
| 563 | + def hybrid_search( |
| 564 | + self, |
| 565 | + query: HybridQuery, |
| 566 | + post_processing: Optional[HybridPostProcessingConfig] = None, |
| 567 | + params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None, |
| 568 | + explain_score: bool = False, |
| 569 | + timeout: Optional[int] = None, |
| 570 | + cursor: Optional[HybridCursorQuery] = None, |
| 571 | + dialect: Optional[int] = DEFAULT_DIALECT, |
| 572 | + ) -> Union[HybridResult, HybridCursorResult]: |
| 573 | + """ |
| 574 | + Execute a hybrid search using both text and vector queries |
| 575 | + Args: |
| 576 | + query: HybridQuery object |
| 577 | +
|
| 578 | + """ |
| 579 | + index = self.index_name |
| 580 | + options = {} |
| 581 | + pieces = [HYBRID_CMD, index] |
| 582 | + pieces.extend(query.get_args()) |
| 583 | + if post_processing: |
| 584 | + pieces.extend(post_processing.build_args()) |
| 585 | + if params_substitution: |
| 586 | + pieces.extend(self.get_params_args(params_substitution)) |
| 587 | + if explain_score: |
| 588 | + pieces.append("EXPLAINSCORE") |
| 589 | + if timeout: |
| 590 | + pieces.extend(("TIMEOUT", timeout)) |
| 591 | + if cursor: |
| 592 | + options["cursor"] = True |
| 593 | + pieces.extend(cursor.build_args()) |
| 594 | + if dialect: |
| 595 | + pieces.extend(("DIALECT", dialect)) |
| 596 | + |
| 597 | + if get_protocol_version(self.client) not in ["3", 3]: |
| 598 | + options[NEVER_DECODE] = True |
| 599 | + |
| 600 | + print("") |
| 601 | + cli_command = "" |
| 602 | + for arg in pieces: |
| 603 | + cli_command = f"{cli_command} {arg}" |
| 604 | + print(cli_command) |
| 605 | + |
| 606 | + res = self.execute_command(*pieces, **options) |
| 607 | + return self._parse_results(HYBRID_CMD, res, **options) |
| 608 | + |
528 | 609 | def explain( |
529 | 610 | self, |
530 | 611 | query: Union[str, Query], |
|
0 commit comments