From be6dfd55d992375cd1fa029a5961e0d77a92b433 Mon Sep 17 00:00:00 2001 From: ofekshenawa <104765379+ofekshenawa@users.noreply.github.com> Date: Wed, 16 Jul 2025 23:06:47 +0300 Subject: [PATCH 01/62] Add search module builders and tests (#1) * Add search module builders and tests * Add tests --- search_builders.go | 759 ++++++++++++++++++++++++++++++++++++++++ search_builders_test.go | 680 +++++++++++++++++++++++++++++++++++ 2 files changed, 1439 insertions(+) create mode 100644 search_builders.go create mode 100644 search_builders_test.go diff --git a/search_builders.go b/search_builders.go new file mode 100644 index 0000000000..964b26878d --- /dev/null +++ b/search_builders.go @@ -0,0 +1,759 @@ +package redis + +import ( + "context" +) + +// ---------------------- +// Search Module Builders +// ---------------------- + +// SearchBuilder provides a fluent API for FT.SEARCH +// (see original FTSearchOptions for all options). +type SearchBuilder struct { + c *Client + ctx context.Context + index string + query string + options *FTSearchOptions +} + +// Search starts building an FT.SEARCH command. +func (c *Client) Search(ctx context.Context, index, query string) *SearchBuilder { + b := &SearchBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTSearchOptions{LimitOffset: -1}} + return b +} + +// WithScores includes WITHSCORES. +func (b *SearchBuilder) WithScores() *SearchBuilder { + b.options.WithScores = true + return b +} + +// NoContent includes NOCONTENT. +func (b *SearchBuilder) NoContent() *SearchBuilder { b.options.NoContent = true; return b } + +// Verbatim includes VERBATIM. +func (b *SearchBuilder) Verbatim() *SearchBuilder { b.options.Verbatim = true; return b } + +// NoStopWords includes NOSTOPWORDS. +func (b *SearchBuilder) NoStopWords() *SearchBuilder { b.options.NoStopWords = true; return b } + +// WithPayloads includes WITHPAYLOADS. +func (b *SearchBuilder) WithPayloads() *SearchBuilder { + b.options.WithPayloads = true + return b +} + +// WithSortKeys includes WITHSORTKEYS. +func (b *SearchBuilder) WithSortKeys() *SearchBuilder { + b.options.WithSortKeys = true + return b +} + +// Filter adds a FILTER clause: FILTER . +func (b *SearchBuilder) Filter(field string, min, max interface{}) *SearchBuilder { + b.options.Filters = append(b.options.Filters, FTSearchFilter{ + FieldName: field, + Min: min, + Max: max, + }) + return b +} + +// GeoFilter adds a GEOFILTER clause: GEOFILTER . +func (b *SearchBuilder) GeoFilter(field string, lon, lat, radius float64, unit string) *SearchBuilder { + b.options.GeoFilter = append(b.options.GeoFilter, FTSearchGeoFilter{ + FieldName: field, + Longitude: lon, + Latitude: lat, + Radius: radius, + Unit: unit, + }) + return b +} + +// InKeys restricts the search to the given keys. +func (b *SearchBuilder) InKeys(keys ...interface{}) *SearchBuilder { + b.options.InKeys = append(b.options.InKeys, keys...) + return b +} + +// InFields restricts the search to the given fields. +func (b *SearchBuilder) InFields(fields ...interface{}) *SearchBuilder { + b.options.InFields = append(b.options.InFields, fields...) + return b +} + +// ReturnFields adds simple RETURN ... +func (b *SearchBuilder) ReturnFields(fields ...string) *SearchBuilder { + for _, f := range fields { + b.options.Return = append(b.options.Return, FTSearchReturn{FieldName: f}) + } + return b +} + +// ReturnAs adds RETURN AS . +func (b *SearchBuilder) ReturnAs(field, alias string) *SearchBuilder { + b.options.Return = append(b.options.Return, FTSearchReturn{FieldName: field, As: alias}) + return b +} + +// Slop adds SLOP . +func (b *SearchBuilder) Slop(slop int) *SearchBuilder { + b.options.Slop = slop + return b +} + +// Timeout adds TIMEOUT . +func (b *SearchBuilder) Timeout(timeout int) *SearchBuilder { + b.options.Timeout = timeout + return b +} + +// InOrder includes INORDER. +func (b *SearchBuilder) InOrder() *SearchBuilder { + b.options.InOrder = true + return b +} + +// Language sets LANGUAGE . +func (b *SearchBuilder) Language(lang string) *SearchBuilder { + b.options.Language = lang + return b +} + +// Expander sets EXPANDER . +func (b *SearchBuilder) Expander(expander string) *SearchBuilder { + b.options.Expander = expander + return b +} + +// Scorer sets SCORER . +func (b *SearchBuilder) Scorer(scorer string) *SearchBuilder { + b.options.Scorer = scorer + return b +} + +// ExplainScore includes EXPLAINSCORE. +func (b *SearchBuilder) ExplainScore() *SearchBuilder { + b.options.ExplainScore = true + return b +} + +// Payload sets PAYLOAD . +func (b *SearchBuilder) Payload(payload string) *SearchBuilder { + b.options.Payload = payload + return b +} + +// SortBy adds SORTBY ASC|DESC. +func (b *SearchBuilder) SortBy(field string, asc bool) *SearchBuilder { + b.options.SortBy = append(b.options.SortBy, FTSearchSortBy{ + FieldName: field, + Asc: asc, + Desc: !asc, + }) + return b +} + +// WithSortByCount includes WITHCOUNT (when used with SortBy). +func (b *SearchBuilder) WithSortByCount() *SearchBuilder { + b.options.SortByWithCount = true + return b +} + +// Param adds a single PARAMS . +func (b *SearchBuilder) Param(key string, value interface{}) *SearchBuilder { + if b.options.Params == nil { + b.options.Params = make(map[string]interface{}, 1) + } + b.options.Params[key] = value + return b +} + +// ParamsMap adds multiple PARAMS at once. +func (b *SearchBuilder) ParamsMap(p map[string]interface{}) *SearchBuilder { + if b.options.Params == nil { + b.options.Params = make(map[string]interface{}, len(p)) + } + for k, v := range p { + b.options.Params[k] = v + } + return b +} + +// Dialect sets DIALECT . +func (b *SearchBuilder) Dialect(version int) *SearchBuilder { + b.options.DialectVersion = version + return b +} + +// Limit sets OFFSET and COUNT. CountOnly uses LIMIT 0 0. +func (b *SearchBuilder) Limit(offset, count int) *SearchBuilder { + b.options.LimitOffset = offset + b.options.Limit = count + return b +} +func (b *SearchBuilder) CountOnly() *SearchBuilder { b.options.CountOnly = true; return b } + +// Run executes FT.SEARCH and returns a typed result. +func (b *SearchBuilder) Run() (FTSearchResult, error) { + cmd := b.c.FTSearchWithArgs(b.ctx, b.index, b.query, b.options) + return cmd.Result() +} + +// ---------------------- +// AggregateBuilder for FT.AGGREGATE +// ---------------------- + +type AggregateBuilder struct { + c *Client + ctx context.Context + index string + query string + options *FTAggregateOptions +} + +// Aggregate starts building an FT.AGGREGATE command. +func (c *Client) Aggregate(ctx context.Context, index, query string) *AggregateBuilder { + return &AggregateBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTAggregateOptions{LimitOffset: -1}} +} + +// Verbatim includes VERBATIM. +func (b *AggregateBuilder) Verbatim() *AggregateBuilder { b.options.Verbatim = true; return b } + +// AddScores includes ADDSCORES. +func (b *AggregateBuilder) AddScores() *AggregateBuilder { b.options.AddScores = true; return b } + +// Scorer sets SCORER . +func (b *AggregateBuilder) Scorer(s string) *AggregateBuilder { + b.options.Scorer = s + return b +} + +// LoadAll includes LOAD * (mutually exclusive with Load). +func (b *AggregateBuilder) LoadAll() *AggregateBuilder { + b.options.LoadAll = true + return b +} + +// Load adds LOAD [AS alias]... +// You can call it multiple times for multiple fields. +func (b *AggregateBuilder) Load(field string, alias ...string) *AggregateBuilder { + // each Load entry becomes one element in options.Load + l := FTAggregateLoad{Field: field} + if len(alias) > 0 { + l.As = alias[0] + } + b.options.Load = append(b.options.Load, l) + return b +} + +// Timeout sets TIMEOUT . +func (b *AggregateBuilder) Timeout(ms int) *AggregateBuilder { + b.options.Timeout = ms + return b +} + +// Apply adds APPLY [AS alias]. +func (b *AggregateBuilder) Apply(field string, alias ...string) *AggregateBuilder { + a := FTAggregateApply{Field: field} + if len(alias) > 0 { + a.As = alias[0] + } + b.options.Apply = append(b.options.Apply, a) + return b +} + +// GroupBy starts a new GROUPBY clause. +func (b *AggregateBuilder) GroupBy(fields ...interface{}) *AggregateBuilder { + b.options.GroupBy = append(b.options.GroupBy, FTAggregateGroupBy{ + Fields: fields, + }) + return b +} + +// Reduce adds a REDUCE [<#args> ] clause to the *last* GROUPBY. +func (b *AggregateBuilder) Reduce(fn SearchAggregator, args ...interface{}) *AggregateBuilder { + if len(b.options.GroupBy) == 0 { + // no GROUPBY yet — nothing to attach to + return b + } + idx := len(b.options.GroupBy) - 1 + b.options.GroupBy[idx].Reduce = append(b.options.GroupBy[idx].Reduce, FTAggregateReducer{ + Reducer: fn, + Args: args, + }) + return b +} + +// ReduceAs does the same but also sets an alias: REDUCE … AS +func (b *AggregateBuilder) ReduceAs(fn SearchAggregator, alias string, args ...interface{}) *AggregateBuilder { + if len(b.options.GroupBy) == 0 { + return b + } + idx := len(b.options.GroupBy) - 1 + b.options.GroupBy[idx].Reduce = append(b.options.GroupBy[idx].Reduce, FTAggregateReducer{ + Reducer: fn, + Args: args, + As: alias, + }) + return b +} + +// SortBy adds SORTBY ASC|DESC. +func (b *AggregateBuilder) SortBy(field string, asc bool) *AggregateBuilder { + sb := FTAggregateSortBy{FieldName: field, Asc: asc, Desc: !asc} + b.options.SortBy = append(b.options.SortBy, sb) + return b +} + +// SortByMax sets MAX (only if SortBy was called). +func (b *AggregateBuilder) SortByMax(max int) *AggregateBuilder { + b.options.SortByMax = max + return b +} + +// Filter sets FILTER . +func (b *AggregateBuilder) Filter(expr string) *AggregateBuilder { + b.options.Filter = expr + return b +} + +// WithCursor enables WITHCURSOR [COUNT ] [MAXIDLE ]. +func (b *AggregateBuilder) WithCursor(count, maxIdle int) *AggregateBuilder { + b.options.WithCursor = true + if b.options.WithCursorOptions == nil { + b.options.WithCursorOptions = &FTAggregateWithCursor{} + } + b.options.WithCursorOptions.Count = count + b.options.WithCursorOptions.MaxIdle = maxIdle + return b +} + +// Params adds PARAMS pairs. +func (b *AggregateBuilder) Params(p map[string]interface{}) *AggregateBuilder { + if b.options.Params == nil { + b.options.Params = make(map[string]interface{}, len(p)) + } + for k, v := range p { + b.options.Params[k] = v + } + return b +} + +// Dialect sets DIALECT . +func (b *AggregateBuilder) Dialect(version int) *AggregateBuilder { + b.options.DialectVersion = version + return b +} + +// Run executes FT.AGGREGATE and returns a typed result. +func (b *AggregateBuilder) Run() (*FTAggregateResult, error) { + cmd := b.c.FTAggregateWithArgs(b.ctx, b.index, b.query, b.options) + return cmd.Result() +} + +// ---------------------- +// CreateIndexBuilder for FT.CREATE +// ---------------------- + +type CreateIndexBuilder struct { + c *Client + ctx context.Context + index string + options *FTCreateOptions + schema []*FieldSchema +} + +// CreateIndex starts building an FT.CREATE command. +func (c *Client) CreateIndex(ctx context.Context, index string) *CreateIndexBuilder { + return &CreateIndexBuilder{c: c, ctx: ctx, index: index, options: &FTCreateOptions{}} +} + +// OnHash sets ON HASH. +func (b *CreateIndexBuilder) OnHash() *CreateIndexBuilder { b.options.OnHash = true; return b } + +// OnJSON sets ON JSON. +func (b *CreateIndexBuilder) OnJSON() *CreateIndexBuilder { b.options.OnJSON = true; return b } + +// Prefix sets PREFIX. +func (b *CreateIndexBuilder) Prefix(prefixes ...interface{}) *CreateIndexBuilder { + b.options.Prefix = prefixes + return b +} + +// Filter sets FILTER. +func (b *CreateIndexBuilder) Filter(filter string) *CreateIndexBuilder { + b.options.Filter = filter + return b +} + +// DefaultLanguage sets LANGUAGE. +func (b *CreateIndexBuilder) DefaultLanguage(lang string) *CreateIndexBuilder { + b.options.DefaultLanguage = lang + return b +} + +// LanguageField sets LANGUAGE_FIELD. +func (b *CreateIndexBuilder) LanguageField(field string) *CreateIndexBuilder { + b.options.LanguageField = field + return b +} + +// Score sets SCORE. +func (b *CreateIndexBuilder) Score(score float64) *CreateIndexBuilder { + b.options.Score = score + return b +} + +// ScoreField sets SCORE_FIELD. +func (b *CreateIndexBuilder) ScoreField(field string) *CreateIndexBuilder { + b.options.ScoreField = field + return b +} + +// PayloadField sets PAYLOAD_FIELD. +func (b *CreateIndexBuilder) PayloadField(field string) *CreateIndexBuilder { + b.options.PayloadField = field + return b +} + +// NoOffsets includes NOOFFSETS. +func (b *CreateIndexBuilder) NoOffsets() *CreateIndexBuilder { b.options.NoOffsets = true; return b } + +// Temporary sets TEMPORARY seconds. +func (b *CreateIndexBuilder) Temporary(sec int) *CreateIndexBuilder { + b.options.Temporary = sec + return b +} + +// NoHL includes NOHL. +func (b *CreateIndexBuilder) NoHL() *CreateIndexBuilder { b.options.NoHL = true; return b } + +// NoFields includes NOFIELDS. +func (b *CreateIndexBuilder) NoFields() *CreateIndexBuilder { b.options.NoFields = true; return b } + +// NoFreqs includes NOFREQS. +func (b *CreateIndexBuilder) NoFreqs() *CreateIndexBuilder { b.options.NoFreqs = true; return b } + +// StopWords sets STOPWORDS. +func (b *CreateIndexBuilder) StopWords(words ...interface{}) *CreateIndexBuilder { + b.options.StopWords = words + return b +} + +// SkipInitialScan includes SKIPINITIALSCAN. +func (b *CreateIndexBuilder) SkipInitialScan() *CreateIndexBuilder { + b.options.SkipInitialScan = true + return b +} + +// Schema adds a FieldSchema. +func (b *CreateIndexBuilder) Schema(field *FieldSchema) *CreateIndexBuilder { + b.schema = append(b.schema, field) + return b +} + +// Run executes FT.CREATE and returns the status. +func (b *CreateIndexBuilder) Run() (string, error) { + cmd := b.c.FTCreate(b.ctx, b.index, b.options, b.schema...) + return cmd.Result() +} + +// ---------------------- +// DropIndexBuilder for FT.DROPINDEX +// ---------------------- + +type DropIndexBuilder struct { + c *Client + ctx context.Context + index string + options *FTDropIndexOptions +} + +// DropIndex starts FT.DROPINDEX builder. +func (c *Client) DropIndex(ctx context.Context, index string) *DropIndexBuilder { + return &DropIndexBuilder{c: c, ctx: ctx, index: index} +} + +// DeleteRuncs includes DD. +func (b *DropIndexBuilder) DeleteDocs() *DropIndexBuilder { b.options.DeleteDocs = true; return b } + +// Run executes FT.DROPINDEX. +func (b *DropIndexBuilder) Run() (string, error) { + cmd := b.c.FTDropIndexWithArgs(b.ctx, b.index, b.options) + return cmd.Result() +} + +// ---------------------- +// AliasBuilder for FT.ALIAS* commands +// ---------------------- + +type AliasBuilder struct { + c *Client + ctx context.Context + alias string + index string + action string // add|del|update +} + +// AliasAdd starts FT.ALIASADD builder. +func (c *Client) AliasAdd(ctx context.Context, alias, index string) *AliasBuilder { + return &AliasBuilder{c: c, ctx: ctx, alias: alias, index: index, action: "add"} +} + +// AliasDel starts FT.ALIASDEL builder. +func (c *Client) AliasDel(ctx context.Context, alias string) *AliasBuilder { + return &AliasBuilder{c: c, ctx: ctx, alias: alias, action: "del"} +} + +// AliasUpdate starts FT.ALIASUPDATE builder. +func (c *Client) AliasUpdate(ctx context.Context, alias, index string) *AliasBuilder { + return &AliasBuilder{c: c, ctx: ctx, alias: alias, index: index, action: "update"} +} + +// Run executes the configured alias command. +func (b *AliasBuilder) Run() (string, error) { + switch b.action { + case "add": + cmd := b.c.FTAliasAdd(b.ctx, b.index, b.alias) + return cmd.Result() + case "del": + cmd := b.c.FTAliasDel(b.ctx, b.alias) + return cmd.Result() + case "update": + cmd := b.c.FTAliasUpdate(b.ctx, b.index, b.alias) + return cmd.Result() + } + return "", nil +} + +// ---------------------- +// ExplainBuilder for FT.EXPLAIN +// ---------------------- + +type ExplainBuilder struct { + c *Client + ctx context.Context + index string + query string + options *FTExplainOptions +} + +// Explain starts FT.EXPLAIN builder. +func (c *Client) Explain(ctx context.Context, index, query string) *ExplainBuilder { + return &ExplainBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTExplainOptions{}} +} + +// Dialect sets dialect for EXPLAINCLI. +func (b *ExplainBuilder) Dialect(d string) *ExplainBuilder { b.options.Dialect = d; return b } + +// Run executes FT.EXPLAIN and returns the plan. +func (b *ExplainBuilder) Run() (string, error) { + cmd := b.c.FTExplainWithArgs(b.ctx, b.index, b.query, b.options) + return cmd.Result() +} + +// ---------------------- +// InfoBuilder for FT.INFO +// ---------------------- + +type FTInfoBuilder struct { + c *Client + ctx context.Context + index string +} + +// SearchInfo starts building an FT.INFO command for RediSearch. +func (c *Client) SearchInfo(ctx context.Context, index string) *FTInfoBuilder { + return &FTInfoBuilder{c: c, ctx: ctx, index: index} +} + +// Run executes FT.INFO and returns detailed info. +func (b *FTInfoBuilder) Run() (FTInfoResult, error) { + cmd := b.c.FTInfo(b.ctx, b.index) + return cmd.Result() +} + +// ---------------------- +// SpellCheckBuilder for FT.SPELLCHECK +// ---------------------- + +type SpellCheckBuilder struct { + c *Client + ctx context.Context + index string + query string + options *FTSpellCheckOptions +} + +// SpellCheck starts FT.SPELLCHECK builder. +func (c *Client) SpellCheck(ctx context.Context, index, query string) *SpellCheckBuilder { + return &SpellCheckBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTSpellCheckOptions{}} +} + +// Distance sets MAXDISTANCE. +func (b *SpellCheckBuilder) Distance(d int) *SpellCheckBuilder { b.options.Distance = d; return b } + +// Terms sets INCLUDE or EXCLUDE terms. +func (b *SpellCheckBuilder) Terms(include bool, dictionary string, terms ...interface{}) *SpellCheckBuilder { + if b.options.Terms == nil { + b.options.Terms = &FTSpellCheckTerms{} + } + if include { + b.options.Terms.Inclusion = "INCLUDE" + } else { + b.options.Terms.Inclusion = "EXCLUDE" + } + b.options.Terms.Dictionary = dictionary + b.options.Terms.Terms = terms + return b +} + +// Dialect sets dialect version. +func (b *SpellCheckBuilder) Dialect(d int) *SpellCheckBuilder { b.options.Dialect = d; return b } + +// Run executes FT.SPELLCHECK and returns suggestions. +func (b *SpellCheckBuilder) Run() ([]SpellCheckResult, error) { + cmd := b.c.FTSpellCheckWithArgs(b.ctx, b.index, b.query, b.options) + return cmd.Result() +} + +// ---------------------- +// DictBuilder for FT.DICT* commands +// ---------------------- + +type DictBuilder struct { + c *Client + ctx context.Context + dict string + terms []interface{} + action string // add|del|dump +} + +// DictAdd starts FT.DICTADD builder. +func (c *Client) DictAdd(ctx context.Context, dict string, terms ...interface{}) *DictBuilder { + return &DictBuilder{c: c, ctx: ctx, dict: dict, terms: terms, action: "add"} +} + +// DictDel starts FT.DICTDEL builder. +func (c *Client) DictDel(ctx context.Context, dict string, terms ...interface{}) *DictBuilder { + return &DictBuilder{c: c, ctx: ctx, dict: dict, terms: terms, action: "del"} +} + +// DictDump starts FT.DICTDUMP builder. +func (c *Client) DictDump(ctx context.Context, dict string) *DictBuilder { + return &DictBuilder{c: c, ctx: ctx, dict: dict, action: "dump"} +} + +// Run executes the configured dictionary command. +func (b *DictBuilder) Run() (interface{}, error) { + switch b.action { + case "add": + cmd := b.c.FTDictAdd(b.ctx, b.dict, b.terms...) + return cmd.Result() + case "del": + cmd := b.c.FTDictDel(b.ctx, b.dict, b.terms...) + return cmd.Result() + case "dump": + cmd := b.c.FTDictDump(b.ctx, b.dict) + return cmd.Result() + } + return nil, nil +} + +// ---------------------- +// TagValsBuilder for FT.TAGVALS +// ---------------------- + +type TagValsBuilder struct { + c *Client + ctx context.Context + index string + field string +} + +// TagVals starts FT.TAGVALS builder. +func (c *Client) TagVals(ctx context.Context, index, field string) *TagValsBuilder { + return &TagValsBuilder{c: c, ctx: ctx, index: index, field: field} +} + +// Run executes FT.TAGVALS and returns tag values. +func (b *TagValsBuilder) Run() ([]string, error) { + cmd := b.c.FTTagVals(b.ctx, b.index, b.field) + return cmd.Result() +} + +// ---------------------- +// CursorBuilder for FT.CURSOR* +// ---------------------- + +type CursorBuilder struct { + c *Client + ctx context.Context + index string + cursorId int64 + count int + action string // read|del +} + +// CursorRead starts FT.CURSOR READ builder. +func (c *Client) CursorRead(ctx context.Context, index string, cursorId int64) *CursorBuilder { + return &CursorBuilder{c: c, ctx: ctx, index: index, cursorId: cursorId, action: "read"} +} + +// CursorDel starts FT.CURSOR DEL builder. +func (c *Client) CursorDel(ctx context.Context, index string, cursorId int64) *CursorBuilder { + return &CursorBuilder{c: c, ctx: ctx, index: index, cursorId: cursorId, action: "del"} +} + +// Count for READ. +func (b *CursorBuilder) Count(count int) *CursorBuilder { b.count = count; return b } + +// Run executes the cursor command. +func (b *CursorBuilder) Run() (interface{}, error) { + switch b.action { + case "read": + cmd := b.c.FTCursorRead(b.ctx, b.index, int(b.cursorId), b.count) + return cmd.Result() + case "del": + cmd := b.c.FTCursorDel(b.ctx, b.index, int(b.cursorId)) + return cmd.Result() + } + return nil, nil +} + +// ---------------------- +// SynUpdateBuilder for FT.SYNUPDATE +// ---------------------- + +type SynUpdateBuilder struct { + c *Client + ctx context.Context + index string + groupId interface{} + options *FTSynUpdateOptions + terms []interface{} +} + +// SynUpdate starts FT.SYNUPDATE builder. +func (c *Client) SynUpdate(ctx context.Context, index string, groupId interface{}) *SynUpdateBuilder { + return &SynUpdateBuilder{c: c, ctx: ctx, index: index, groupId: groupId, options: &FTSynUpdateOptions{}} +} + +// SkipInitialScan includes SKIPINITIALSCAN. +func (b *SynUpdateBuilder) SkipInitialScan() *SynUpdateBuilder { + b.options.SkipInitialScan = true + return b +} + +// Terms adds synonyms to the group. +func (b *SynUpdateBuilder) Terms(terms ...interface{}) *SynUpdateBuilder { b.terms = terms; return b } + +// Run executes FT.SYNUPDATE. +func (b *SynUpdateBuilder) Run() (string, error) { + cmd := b.c.FTSynUpdateWithArgs(b.ctx, b.index, b.groupId, b.options, b.terms) + return cmd.Result() +} diff --git a/search_builders_test.go b/search_builders_test.go new file mode 100644 index 0000000000..0fedf83a96 --- /dev/null +++ b/search_builders_test.go @@ -0,0 +1,680 @@ +package redis_test + +import ( + "context" + "fmt" + + . "github.com/bsm/ginkgo/v2" + . "github.com/bsm/gomega" + "github.com/redis/go-redis/v9" +) + +var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { + ctx := context.Background() + var client *redis.Client + + BeforeEach(func() { + client = redis.NewClient(&redis.Options{Addr: ":6379", Protocol: 2}) + Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + expectCloseErr := client.Close() + Expect(expectCloseErr).NotTo(HaveOccurred()) + }) + + It("should create index and search with scores using builders", Label("search", "ftcreate", "ftsearch"), func() { + createVal, err := client.CreateIndex(ctx, "idx1"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + + WaitForIndexing(client, "idx1") + + client.HSet(ctx, "doc1", "foo", "hello world") + client.HSet(ctx, "doc2", "foo", "hello redis") + + res, err := client.Search(ctx, "idx1", "hello").WithScores().Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res.Total).To(Equal(2)) + for _, doc := range res.Docs { + Expect(*doc.Score).To(BeNumerically(">", 0)) + } + }) + + It("should aggregate using builders", Label("search", "ftaggregate"), func() { + _, err := client.CreateIndex(ctx, "idx2"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "n", FieldType: redis.SearchFieldTypeNumeric}). + Run() + Expect(err).NotTo(HaveOccurred()) + WaitForIndexing(client, "idx2") + + client.HSet(ctx, "d1", "n", 1) + client.HSet(ctx, "d2", "n", 2) + + agg, err := client.Aggregate(ctx, "idx2", "*"). + GroupBy("@n"). + ReduceAs(redis.SearchCount, "count"). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(len(agg.Rows)).To(Equal(2)) + }) + + It("should drop index using builder", Label("search", "ftdropindex"), func() { + Expect(client.CreateIndex(ctx, "idx3"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "x", FieldType: redis.SearchFieldTypeText}). + Run()).To(Equal("OK")) + WaitForIndexing(client, "idx3") + + dropVal, err := client.DropIndex(ctx, "idx3").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(dropVal).To(Equal("OK")) + }) + + It("should manage aliases using builder", Label("search", "ftalias"), func() { + Expect(client.CreateIndex(ctx, "idx4"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "t", FieldType: redis.SearchFieldTypeText}). + Run()).To(Equal("OK")) + WaitForIndexing(client, "idx4") + + addVal, err := client.AliasAdd(ctx, "alias1", "idx4").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(addVal).To(Equal("OK")) + + _, err = client.Search(ctx, "alias1", "*").Run() + Expect(err).NotTo(HaveOccurred()) + + delVal, err := client.AliasDel(ctx, "alias1").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(delVal).To(Equal("OK")) + }) + + It("should explain query using ExplainBuilder", Label("search", "builders", "ftexplain"), func() { + createVal, err := client.CreateIndex(ctx, "idx_explain"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_explain") + + expl, err := client.Explain(ctx, "idx_explain", "foo").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(expl).To(ContainSubstring("UNION")) + }) + + It("should retrieve info using SearchInfo builder", Label("search", "builders", "ftinfo"), func() { + createVal, err := client.CreateIndex(ctx, "idx_info"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_info") + + i, err := client.SearchInfo(ctx, "idx_info").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(i.IndexName).To(Equal("idx_info")) + }) + + It("should spellcheck using builder", Label("search", "builders", "ftspellcheck"), func() { + createVal, err := client.CreateIndex(ctx, "idx_spell"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_spell") + + client.HSet(ctx, "doc1", "foo", "bar") + + _, err = client.SpellCheck(ctx, "idx_spell", "ba").Distance(1).Run() + Expect(err).NotTo(HaveOccurred()) + }) + + It("should manage dictionary using DictBuilder", Label("search", "ftdict"), func() { + addCount, err := client.DictAdd(ctx, "dict1", "a", "b").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(addCount).To(Equal(int64(2))) + + dump, err := client.DictDump(ctx, "dict1").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(dump).To(ContainElements("a", "b")) + + delCount, err := client.DictDel(ctx, "dict1", "a").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(delCount).To(Equal(int64(1))) + }) + + It("should tag values using TagValsBuilder", Label("search", "builders", "fttagvals"), func() { + createVal, err := client.CreateIndex(ctx, "idx_tag"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "tags", FieldType: redis.SearchFieldTypeTag}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_tag") + + client.HSet(ctx, "doc1", "tags", "red,blue") + client.HSet(ctx, "doc2", "tags", "green,blue") + + vals, err := client.TagVals(ctx, "idx_tag", "tags").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(vals).To(BeAssignableToTypeOf([]string{})) + }) + + It("should cursor read and delete using CursorBuilder", Label("search", "builders", "ftcursor"), func() { + Expect(client.CreateIndex(ctx, "idx5"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "f", FieldType: redis.SearchFieldTypeText}). + Run()).To(Equal("OK")) + WaitForIndexing(client, "idx5") + client.HSet(ctx, "doc1", "f", "hello") + client.HSet(ctx, "doc2", "f", "world") + + cursorBuilder := client.CursorRead(ctx, "idx5", 1) + Expect(cursorBuilder).NotTo(BeNil()) + + cursorBuilder = cursorBuilder.Count(10) + Expect(cursorBuilder).NotTo(BeNil()) + + delBuilder := client.CursorDel(ctx, "idx5", 1) + Expect(delBuilder).NotTo(BeNil()) + }) + + It("should update synonyms using SynUpdateBuilder", Label("search", "builders", "ftsynupdate"), func() { + createVal, err := client.CreateIndex(ctx, "idx_syn"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_syn") + + syn, err := client.SynUpdate(ctx, "idx_syn", "grp1").Terms("a", "b").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(syn).To(Equal("OK")) + }) + + It("should test SearchBuilder with NoContent and Verbatim", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_nocontent"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText, Weight: 5}). + Schema(&redis.FieldSchema{FieldName: "body", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_nocontent") + + client.HSet(ctx, "doc1", "title", "RediSearch", "body", "Redisearch implements a search engine on top of redis") + + res, err := client.Search(ctx, "idx_nocontent", "search engine"). + NoContent(). + Verbatim(). + Limit(0, 5). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res.Total).To(Equal(1)) + Expect(res.Docs[0].ID).To(Equal("doc1")) + // NoContent means no fields should be returned + Expect(res.Docs[0].Fields).To(BeEmpty()) + }) + + It("should test SearchBuilder with NoStopWords", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_nostop"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_nostop") + + client.HSet(ctx, "doc1", "txt", "hello world") + client.HSet(ctx, "doc2", "txt", "test document") + + // Test that NoStopWords method can be called and search works + res, err := client.Search(ctx, "idx_nostop", "hello").NoContent().NoStopWords().Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res.Total).To(Equal(1)) + }) + + It("should test SearchBuilder with filters", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_filters"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). + Schema(&redis.FieldSchema{FieldName: "num", FieldType: redis.SearchFieldTypeNumeric}). + Schema(&redis.FieldSchema{FieldName: "loc", FieldType: redis.SearchFieldTypeGeo}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_filters") + + client.HSet(ctx, "doc1", "txt", "foo bar", "num", 3.141, "loc", "-0.441,51.458") + client.HSet(ctx, "doc2", "txt", "foo baz", "num", 2, "loc", "-0.1,51.2") + + // Test numeric filter + res1, err := client.Search(ctx, "idx_filters", "foo"). + Filter("num", 2, 4). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res1.Total).To(Equal(2)) + + // Test geo filter + res2, err := client.Search(ctx, "idx_filters", "foo"). + GeoFilter("loc", -0.44, 51.45, 10, "km"). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res2.Total).To(Equal(1)) + }) + + It("should test SearchBuilder with sorting", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_sort"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). + Schema(&redis.FieldSchema{FieldName: "num", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_sort") + + client.HSet(ctx, "doc1", "txt", "foo bar", "num", 1) + client.HSet(ctx, "doc2", "txt", "foo baz", "num", 2) + client.HSet(ctx, "doc3", "txt", "foo qux", "num", 3) + + // Test ascending sort + res1, err := client.Search(ctx, "idx_sort", "foo"). + SortBy("num", true). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res1.Total).To(Equal(3)) + Expect(res1.Docs[0].ID).To(Equal("doc1")) + Expect(res1.Docs[1].ID).To(Equal("doc2")) + Expect(res1.Docs[2].ID).To(Equal("doc3")) + + // Test descending sort + res2, err := client.Search(ctx, "idx_sort", "foo"). + SortBy("num", false). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res2.Total).To(Equal(3)) + Expect(res2.Docs[0].ID).To(Equal("doc3")) + Expect(res2.Docs[1].ID).To(Equal("doc2")) + Expect(res2.Docs[2].ID).To(Equal("doc1")) + }) + + It("should test SearchBuilder with InKeys and InFields", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_in"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText}). + Schema(&redis.FieldSchema{FieldName: "body", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_in") + + client.HSet(ctx, "doc1", "title", "hello world", "body", "lorem ipsum") + client.HSet(ctx, "doc2", "title", "foo bar", "body", "hello world") + client.HSet(ctx, "doc3", "title", "baz qux", "body", "dolor sit") + + // Test InKeys + res1, err := client.Search(ctx, "idx_in", "hello"). + InKeys("doc1", "doc2"). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res1.Total).To(Equal(2)) + + // Test InFields + res2, err := client.Search(ctx, "idx_in", "hello"). + InFields("title"). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res2.Total).To(Equal(1)) + Expect(res2.Docs[0].ID).To(Equal("doc1")) + }) + + It("should test SearchBuilder with Return fields", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_return"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText}). + Schema(&redis.FieldSchema{FieldName: "body", FieldType: redis.SearchFieldTypeText}). + Schema(&redis.FieldSchema{FieldName: "num", FieldType: redis.SearchFieldTypeNumeric}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_return") + + client.HSet(ctx, "doc1", "title", "hello", "body", "world", "num", 42) + + // Test ReturnFields + res1, err := client.Search(ctx, "idx_return", "hello"). + ReturnFields("title", "num"). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res1.Total).To(Equal(1)) + Expect(res1.Docs[0].Fields).To(HaveKey("title")) + Expect(res1.Docs[0].Fields).To(HaveKey("num")) + Expect(res1.Docs[0].Fields).NotTo(HaveKey("body")) + + // Test ReturnAs + res2, err := client.Search(ctx, "idx_return", "hello"). + ReturnAs("title", "doc_title"). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res2.Total).To(Equal(1)) + Expect(res2.Docs[0].Fields).To(HaveKey("doc_title")) + Expect(res2.Docs[0].Fields).NotTo(HaveKey("title")) + }) + + It("should test SearchBuilder with advanced options", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_advanced"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "description", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_advanced") + + client.HSet(ctx, "doc1", "description", "The quick brown fox jumps over the lazy dog") + client.HSet(ctx, "doc2", "description", "Quick alice was beginning to get very tired of sitting by her quick sister on the bank") + + // Test with scores and different scorers + res1, err := client.Search(ctx, "idx_advanced", "quick"). + WithScores(). + Scorer("TFIDF"). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res1.Total).To(Equal(2)) + for _, doc := range res1.Docs { + Expect(*doc.Score).To(BeNumerically(">", 0)) + } + + res2, err := client.Search(ctx, "idx_advanced", "quick"). + WithScores(). + Payload("test_payload"). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res2.Total).To(Equal(2)) + + // Test with Slop and InOrder + res3, err := client.Search(ctx, "idx_advanced", "quick brown"). + Slop(1). + InOrder(). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res3.Total).To(Equal(1)) + + // Test with Language and Expander + res4, err := client.Search(ctx, "idx_advanced", "quick"). + Language("english"). + Expander("SYNONYM"). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res4.Total).To(BeNumerically(">=", 0)) + + // Test with Timeout + res5, err := client.Search(ctx, "idx_advanced", "quick"). + Timeout(1000). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res5.Total).To(Equal(2)) + }) + + It("should test SearchBuilder with Params and Dialect", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_params"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "name", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_params") + + client.HSet(ctx, "doc1", "name", "Alice") + client.HSet(ctx, "doc2", "name", "Bob") + client.HSet(ctx, "doc3", "name", "Carol") + + // Test with single param + res1, err := client.Search(ctx, "idx_params", "@name:$name"). + Param("name", "Alice"). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res1.Total).To(Equal(1)) + Expect(res1.Docs[0].ID).To(Equal("doc1")) + + // Test with multiple params using ParamsMap + params := map[string]interface{}{ + "name1": "Bob", + "name2": "Carol", + } + res2, err := client.Search(ctx, "idx_params", "@name:($name1|$name2)"). + ParamsMap(params). + Dialect(2). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res2.Total).To(Equal(2)) + }) + + It("should test SearchBuilder with Limit and CountOnly", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_limit"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_limit") + + for i := 1; i <= 10; i++ { + client.HSet(ctx, fmt.Sprintf("doc%d", i), "txt", "test document") + } + + // Test with Limit + res1, err := client.Search(ctx, "idx_limit", "test"). + Limit(2, 3). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res1.Total).To(Equal(10)) + Expect(len(res1.Docs)).To(Equal(3)) + + // Test with CountOnly + res2, err := client.Search(ctx, "idx_limit", "test"). + CountOnly(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res2.Total).To(Equal(10)) + Expect(len(res2.Docs)).To(Equal(0)) + }) + + It("should test SearchBuilder with WithSortByCount and SortBy", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_payloads"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). + Schema(&redis.FieldSchema{FieldName: "num", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_payloads") + + client.HSet(ctx, "doc1", "txt", "hello", "num", 1) + client.HSet(ctx, "doc2", "txt", "world", "num", 2) + + // Test WithSortByCount and SortBy + res, err := client.Search(ctx, "idx_payloads", "*"). + SortBy("num", true). + WithSortByCount(). + NoContent(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res.Total).To(Equal(2)) + }) + + It("should test SearchBuilder with JSON", Label("search", "ftsearch", "builders", "json"), func() { + createVal, err := client.CreateIndex(ctx, "idx_json"). + OnJSON(). + Prefix("king:"). + Schema(&redis.FieldSchema{FieldName: "$.name", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_json") + + client.JSONSet(ctx, "king:1", "$", `{"name": "henry"}`) + client.JSONSet(ctx, "king:2", "$", `{"name": "james"}`) + + res, err := client.Search(ctx, "idx_json", "henry").Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res.Total).To(Equal(1)) + Expect(res.Docs[0].ID).To(Equal("king:1")) + Expect(res.Docs[0].Fields["$"]).To(Equal(`{"name":"henry"}`)) + }) + + It("should test SearchBuilder with vector search", Label("search", "ftsearch", "builders", "vector"), func() { + hnswOptions := &redis.FTHNSWOptions{Type: "FLOAT32", Dim: 2, DistanceMetric: "L2"} + createVal, err := client.CreateIndex(ctx, "idx_vector"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "v", FieldType: redis.SearchFieldTypeVector, VectorArgs: &redis.FTVectorArgs{HNSWOptions: hnswOptions}}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_vector") + + client.HSet(ctx, "a", "v", "aaaaaaaa") + client.HSet(ctx, "b", "v", "aaaabaaa") + client.HSet(ctx, "c", "v", "aaaaabaa") + + res, err := client.Search(ctx, "idx_vector", "*=>[KNN 2 @v $vec]"). + ReturnFields("__v_score"). + SortBy("__v_score", true). + Dialect(2). + Param("vec", "aaaaaaaa"). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res.Docs[0].ID).To(Equal("a")) + Expect(res.Docs[0].Fields["__v_score"]).To(Equal("0")) + }) + + It("should test SearchBuilder with complex filtering and aggregation", Label("search", "ftsearch", "builders"), func() { + createVal, err := client.CreateIndex(ctx, "idx_complex"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "category", FieldType: redis.SearchFieldTypeTag}). + Schema(&redis.FieldSchema{FieldName: "price", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}). + Schema(&redis.FieldSchema{FieldName: "location", FieldType: redis.SearchFieldTypeGeo}). + Schema(&redis.FieldSchema{FieldName: "description", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_complex") + + client.HSet(ctx, "product1", "category", "electronics", "price", 100, "location", "-0.1,51.5", "description", "smartphone device") + client.HSet(ctx, "product2", "category", "electronics", "price", 200, "location", "-0.2,51.6", "description", "laptop computer") + client.HSet(ctx, "product3", "category", "books", "price", 20, "location", "-0.3,51.7", "description", "programming guide") + + res, err := client.Search(ctx, "idx_complex", "@category:{electronics} @description:(device|computer)"). + Filter("price", 50, 250). + GeoFilter("location", -0.15, 51.55, 50, "km"). + SortBy("price", true). + ReturnFields("category", "price", "description"). + Limit(0, 10). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res.Total).To(BeNumerically(">=", 1)) + + res2, err := client.Search(ctx, "idx_complex", "@category:{$cat} @price:[$min $max]"). + ParamsMap(map[string]interface{}{ + "cat": "electronics", + "min": 150, + "max": 300, + }). + Dialect(2). + WithScores(). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res2.Total).To(Equal(1)) + Expect(res2.Docs[0].ID).To(Equal("product2")) + }) + + It("should test SearchBuilder error handling and edge cases", Label("search", "ftsearch", "builders", "edge-cases"), func() { + createVal, err := client.CreateIndex(ctx, "idx_edge"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_edge") + + client.HSet(ctx, "doc1", "txt", "hello world") + + // Test empty query + res1, err := client.Search(ctx, "idx_edge", "*").NoContent().Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res1.Total).To(Equal(1)) + + // Test query with no results + res2, err := client.Search(ctx, "idx_edge", "nonexistent").NoContent().Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res2.Total).To(Equal(0)) + + // Test with multiple chained methods + res3, err := client.Search(ctx, "idx_edge", "hello"). + WithScores(). + NoContent(). + Verbatim(). + InOrder(). + Slop(0). + Timeout(5000). + Language("english"). + Dialect(2). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res3.Total).To(Equal(1)) + }) + + It("should test SearchBuilder method chaining", Label("search", "ftsearch", "builders", "fluent"), func() { + createVal, err := client.CreateIndex(ctx, "idx_fluent"). + OnHash(). + Schema(&redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText}). + Schema(&redis.FieldSchema{FieldName: "tags", FieldType: redis.SearchFieldTypeTag}). + Schema(&redis.FieldSchema{FieldName: "score", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}). + Run() + Expect(err).NotTo(HaveOccurred()) + Expect(createVal).To(Equal("OK")) + WaitForIndexing(client, "idx_fluent") + + client.HSet(ctx, "doc1", "title", "Redis Search Tutorial", "tags", "redis,search,tutorial", "score", 95) + client.HSet(ctx, "doc2", "title", "Advanced Redis", "tags", "redis,advanced", "score", 88) + + builder := client.Search(ctx, "idx_fluent", "@title:(redis) @tags:{search}") + result := builder. + WithScores(). + Filter("score", 90, 100). + SortBy("score", false). + ReturnFields("title", "score"). + Limit(0, 5). + Dialect(2). + Timeout(1000). + Language("english") + + res, err := result.Run() + Expect(err).NotTo(HaveOccurred()) + Expect(res.Total).To(Equal(1)) + Expect(res.Docs[0].ID).To(Equal("doc1")) + Expect(res.Docs[0].Fields["title"]).To(Equal("Redis Search Tutorial")) + Expect(*res.Docs[0].Score).To(BeNumerically(">", 0)) + }) +}) From 3dc30512d274062ee11bc928475bd6a52c09f5e1 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Sun, 3 Aug 2025 17:55:57 +0300 Subject: [PATCH 02/62] Use builders and Actions in more clean way --- search_builders.go | 129 +++++++++++++++++++++++++------------ search_builders_test.go | 138 ++++++++++++++++++++-------------------- 2 files changed, 156 insertions(+), 111 deletions(-) diff --git a/search_builders.go b/search_builders.go index 964b26878d..a35d6e9281 100644 --- a/search_builders.go +++ b/search_builders.go @@ -18,8 +18,8 @@ type SearchBuilder struct { options *FTSearchOptions } -// Search starts building an FT.SEARCH command. -func (c *Client) Search(ctx context.Context, index, query string) *SearchBuilder { +// NewSearchBuilder creates a new SearchBuilder for FT.SEARCH commands. +func (c *Client) NewSearchBuilder(ctx context.Context, index, query string) *SearchBuilder { b := &SearchBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTSearchOptions{LimitOffset: -1}} return b } @@ -215,8 +215,8 @@ type AggregateBuilder struct { options *FTAggregateOptions } -// Aggregate starts building an FT.AGGREGATE command. -func (c *Client) Aggregate(ctx context.Context, index, query string) *AggregateBuilder { +// NewAggregateBuilder creates a new AggregateBuilder for FT.AGGREGATE commands. +func (c *Client) NewAggregateBuilder(ctx context.Context, index, query string) *AggregateBuilder { return &AggregateBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTAggregateOptions{LimitOffset: -1}} } @@ -367,8 +367,8 @@ type CreateIndexBuilder struct { schema []*FieldSchema } -// CreateIndex starts building an FT.CREATE command. -func (c *Client) CreateIndex(ctx context.Context, index string) *CreateIndexBuilder { +// NewCreateIndexBuilder creates a new CreateIndexBuilder for FT.CREATE commands. +func (c *Client) NewCreateIndexBuilder(ctx context.Context, index string) *CreateIndexBuilder { return &CreateIndexBuilder{c: c, ctx: ctx, index: index, options: &FTCreateOptions{}} } @@ -473,8 +473,8 @@ type DropIndexBuilder struct { options *FTDropIndexOptions } -// DropIndex starts FT.DROPINDEX builder. -func (c *Client) DropIndex(ctx context.Context, index string) *DropIndexBuilder { +// NewDropIndexBuilder creates a new DropIndexBuilder for FT.DROPINDEX commands. +func (c *Client) NewDropIndexBuilder(ctx context.Context, index string) *DropIndexBuilder { return &DropIndexBuilder{c: c, ctx: ctx, index: index} } @@ -499,19 +499,35 @@ type AliasBuilder struct { action string // add|del|update } -// AliasAdd starts FT.ALIASADD builder. -func (c *Client) AliasAdd(ctx context.Context, alias, index string) *AliasBuilder { - return &AliasBuilder{c: c, ctx: ctx, alias: alias, index: index, action: "add"} +// NewAliasBuilder creates a new AliasBuilder for FT.ALIAS* commands. +func (c *Client) NewAliasBuilder(ctx context.Context, alias string) *AliasBuilder { + return &AliasBuilder{c: c, ctx: ctx, alias: alias} } -// AliasDel starts FT.ALIASDEL builder. -func (c *Client) AliasDel(ctx context.Context, alias string) *AliasBuilder { - return &AliasBuilder{c: c, ctx: ctx, alias: alias, action: "del"} +// Action sets the action for the alias builder. +func (b *AliasBuilder) Action(action string) *AliasBuilder { + b.action = action + return b +} + +// Add sets the action to "add" and requires an index. +func (b *AliasBuilder) Add(index string) *AliasBuilder { + b.action = "add" + b.index = index + return b +} + +// Del sets the action to "del". +func (b *AliasBuilder) Del() *AliasBuilder { + b.action = "del" + return b } -// AliasUpdate starts FT.ALIASUPDATE builder. -func (c *Client) AliasUpdate(ctx context.Context, alias, index string) *AliasBuilder { - return &AliasBuilder{c: c, ctx: ctx, alias: alias, index: index, action: "update"} +// Update sets the action to "update" and requires an index. +func (b *AliasBuilder) Update(index string) *AliasBuilder { + b.action = "update" + b.index = index + return b } // Run executes the configured alias command. @@ -542,8 +558,8 @@ type ExplainBuilder struct { options *FTExplainOptions } -// Explain starts FT.EXPLAIN builder. -func (c *Client) Explain(ctx context.Context, index, query string) *ExplainBuilder { +// NewExplainBuilder creates a new ExplainBuilder for FT.EXPLAIN commands. +func (c *Client) NewExplainBuilder(ctx context.Context, index, query string) *ExplainBuilder { return &ExplainBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTExplainOptions{}} } @@ -566,8 +582,8 @@ type FTInfoBuilder struct { index string } -// SearchInfo starts building an FT.INFO command for RediSearch. -func (c *Client) SearchInfo(ctx context.Context, index string) *FTInfoBuilder { +// NewSearchInfoBuilder creates a new FTInfoBuilder for FT.INFO commands. +func (c *Client) NewSearchInfoBuilder(ctx context.Context, index string) *FTInfoBuilder { return &FTInfoBuilder{c: c, ctx: ctx, index: index} } @@ -589,8 +605,8 @@ type SpellCheckBuilder struct { options *FTSpellCheckOptions } -// SpellCheck starts FT.SPELLCHECK builder. -func (c *Client) SpellCheck(ctx context.Context, index, query string) *SpellCheckBuilder { +// NewSpellCheckBuilder creates a new SpellCheckBuilder for FT.SPELLCHECK commands. +func (c *Client) NewSpellCheckBuilder(ctx context.Context, index, query string) *SpellCheckBuilder { return &SpellCheckBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTSpellCheckOptions{}} } @@ -633,19 +649,35 @@ type DictBuilder struct { action string // add|del|dump } -// DictAdd starts FT.DICTADD builder. -func (c *Client) DictAdd(ctx context.Context, dict string, terms ...interface{}) *DictBuilder { - return &DictBuilder{c: c, ctx: ctx, dict: dict, terms: terms, action: "add"} +// NewDictBuilder creates a new DictBuilder for FT.DICT* commands. +func (c *Client) NewDictBuilder(ctx context.Context, dict string) *DictBuilder { + return &DictBuilder{c: c, ctx: ctx, dict: dict} } -// DictDel starts FT.DICTDEL builder. -func (c *Client) DictDel(ctx context.Context, dict string, terms ...interface{}) *DictBuilder { - return &DictBuilder{c: c, ctx: ctx, dict: dict, terms: terms, action: "del"} +// Action sets the action for the dictionary builder. +func (b *DictBuilder) Action(action string) *DictBuilder { + b.action = action + return b } -// DictDump starts FT.DICTDUMP builder. -func (c *Client) DictDump(ctx context.Context, dict string) *DictBuilder { - return &DictBuilder{c: c, ctx: ctx, dict: dict, action: "dump"} +// Add sets the action to "add" and requires terms. +func (b *DictBuilder) Add(terms ...interface{}) *DictBuilder { + b.action = "add" + b.terms = terms + return b +} + +// Del sets the action to "del" and requires terms. +func (b *DictBuilder) Del(terms ...interface{}) *DictBuilder { + b.action = "del" + b.terms = terms + return b +} + +// Dump sets the action to "dump". +func (b *DictBuilder) Dump() *DictBuilder { + b.action = "dump" + return b } // Run executes the configured dictionary command. @@ -675,8 +707,8 @@ type TagValsBuilder struct { field string } -// TagVals starts FT.TAGVALS builder. -func (c *Client) TagVals(ctx context.Context, index, field string) *TagValsBuilder { +// NewTagValsBuilder creates a new TagValsBuilder for FT.TAGVALS commands. +func (c *Client) NewTagValsBuilder(ctx context.Context, index, field string) *TagValsBuilder { return &TagValsBuilder{c: c, ctx: ctx, index: index, field: field} } @@ -699,14 +731,27 @@ type CursorBuilder struct { action string // read|del } -// CursorRead starts FT.CURSOR READ builder. -func (c *Client) CursorRead(ctx context.Context, index string, cursorId int64) *CursorBuilder { - return &CursorBuilder{c: c, ctx: ctx, index: index, cursorId: cursorId, action: "read"} +// NewCursorBuilder creates a new CursorBuilder for FT.CURSOR* commands. +func (c *Client) NewCursorBuilder(ctx context.Context, index string, cursorId int64) *CursorBuilder { + return &CursorBuilder{c: c, ctx: ctx, index: index, cursorId: cursorId} } -// CursorDel starts FT.CURSOR DEL builder. -func (c *Client) CursorDel(ctx context.Context, index string, cursorId int64) *CursorBuilder { - return &CursorBuilder{c: c, ctx: ctx, index: index, cursorId: cursorId, action: "del"} +// Action sets the action for the cursor builder. +func (b *CursorBuilder) Action(action string) *CursorBuilder { + b.action = action + return b +} + +// Read sets the action to "read". +func (b *CursorBuilder) Read() *CursorBuilder { + b.action = "read" + return b +} + +// Del sets the action to "del". +func (b *CursorBuilder) Del() *CursorBuilder { + b.action = "del" + return b } // Count for READ. @@ -738,8 +783,8 @@ type SynUpdateBuilder struct { terms []interface{} } -// SynUpdate starts FT.SYNUPDATE builder. -func (c *Client) SynUpdate(ctx context.Context, index string, groupId interface{}) *SynUpdateBuilder { +// NewSynUpdateBuilder creates a new SynUpdateBuilder for FT.SYNUPDATE commands. +func (c *Client) NewSynUpdateBuilder(ctx context.Context, index string, groupId interface{}) *SynUpdateBuilder { return &SynUpdateBuilder{c: c, ctx: ctx, index: index, groupId: groupId, options: &FTSynUpdateOptions{}} } diff --git a/search_builders_test.go b/search_builders_test.go index 0fedf83a96..bd8b6ff7c4 100644 --- a/search_builders_test.go +++ b/search_builders_test.go @@ -24,7 +24,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should create index and search with scores using builders", Label("search", "ftcreate", "ftsearch"), func() { - createVal, err := client.CreateIndex(ctx, "idx1"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx1"). OnHash(). Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). Run() @@ -36,7 +36,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc1", "foo", "hello world") client.HSet(ctx, "doc2", "foo", "hello redis") - res, err := client.Search(ctx, "idx1", "hello").WithScores().Run() + res, err := client.NewSearchBuilder(ctx, "idx1", "hello").WithScores().Run() Expect(err).NotTo(HaveOccurred()) Expect(res.Total).To(Equal(2)) for _, doc := range res.Docs { @@ -45,7 +45,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should aggregate using builders", Label("search", "ftaggregate"), func() { - _, err := client.CreateIndex(ctx, "idx2"). + _, err := client.NewCreateIndexBuilder(ctx, "idx2"). OnHash(). Schema(&redis.FieldSchema{FieldName: "n", FieldType: redis.SearchFieldTypeNumeric}). Run() @@ -55,7 +55,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "d1", "n", 1) client.HSet(ctx, "d2", "n", 2) - agg, err := client.Aggregate(ctx, "idx2", "*"). + agg, err := client.NewAggregateBuilder(ctx, "idx2", "*"). GroupBy("@n"). ReduceAs(redis.SearchCount, "count"). Run() @@ -64,38 +64,38 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should drop index using builder", Label("search", "ftdropindex"), func() { - Expect(client.CreateIndex(ctx, "idx3"). + Expect(client.NewCreateIndexBuilder(ctx, "idx3"). OnHash(). Schema(&redis.FieldSchema{FieldName: "x", FieldType: redis.SearchFieldTypeText}). Run()).To(Equal("OK")) WaitForIndexing(client, "idx3") - dropVal, err := client.DropIndex(ctx, "idx3").Run() + dropVal, err := client.NewDropIndexBuilder(ctx, "idx3").Run() Expect(err).NotTo(HaveOccurred()) Expect(dropVal).To(Equal("OK")) }) It("should manage aliases using builder", Label("search", "ftalias"), func() { - Expect(client.CreateIndex(ctx, "idx4"). + Expect(client.NewCreateIndexBuilder(ctx, "idx4"). OnHash(). Schema(&redis.FieldSchema{FieldName: "t", FieldType: redis.SearchFieldTypeText}). Run()).To(Equal("OK")) WaitForIndexing(client, "idx4") - addVal, err := client.AliasAdd(ctx, "alias1", "idx4").Run() + addVal, err := client.NewAliasBuilder(ctx, "alias1").Add("idx4").Run() Expect(err).NotTo(HaveOccurred()) Expect(addVal).To(Equal("OK")) - _, err = client.Search(ctx, "alias1", "*").Run() + _, err = client.NewSearchBuilder(ctx, "alias1", "*").Run() Expect(err).NotTo(HaveOccurred()) - delVal, err := client.AliasDel(ctx, "alias1").Run() + delVal, err := client.NewAliasBuilder(ctx, "alias1").Del().Run() Expect(err).NotTo(HaveOccurred()) Expect(delVal).To(Equal("OK")) }) It("should explain query using ExplainBuilder", Label("search", "builders", "ftexplain"), func() { - createVal, err := client.CreateIndex(ctx, "idx_explain"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_explain"). OnHash(). Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). Run() @@ -103,13 +103,13 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(createVal).To(Equal("OK")) WaitForIndexing(client, "idx_explain") - expl, err := client.Explain(ctx, "idx_explain", "foo").Run() + expl, err := client.NewExplainBuilder(ctx, "idx_explain", "foo").Run() Expect(err).NotTo(HaveOccurred()) Expect(expl).To(ContainSubstring("UNION")) }) It("should retrieve info using SearchInfo builder", Label("search", "builders", "ftinfo"), func() { - createVal, err := client.CreateIndex(ctx, "idx_info"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_info"). OnHash(). Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). Run() @@ -117,13 +117,13 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(createVal).To(Equal("OK")) WaitForIndexing(client, "idx_info") - i, err := client.SearchInfo(ctx, "idx_info").Run() + i, err := client.NewSearchInfoBuilder(ctx, "idx_info").Run() Expect(err).NotTo(HaveOccurred()) Expect(i.IndexName).To(Equal("idx_info")) }) It("should spellcheck using builder", Label("search", "builders", "ftspellcheck"), func() { - createVal, err := client.CreateIndex(ctx, "idx_spell"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_spell"). OnHash(). Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). Run() @@ -133,26 +133,26 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc1", "foo", "bar") - _, err = client.SpellCheck(ctx, "idx_spell", "ba").Distance(1).Run() + _, err = client.NewSpellCheckBuilder(ctx, "idx_spell", "ba").Distance(1).Run() Expect(err).NotTo(HaveOccurred()) }) It("should manage dictionary using DictBuilder", Label("search", "ftdict"), func() { - addCount, err := client.DictAdd(ctx, "dict1", "a", "b").Run() + addCount, err := client.NewDictBuilder(ctx, "dict1").Add("a", "b").Run() Expect(err).NotTo(HaveOccurred()) Expect(addCount).To(Equal(int64(2))) - dump, err := client.DictDump(ctx, "dict1").Run() + dump, err := client.NewDictBuilder(ctx, "dict1").Dump().Run() Expect(err).NotTo(HaveOccurred()) Expect(dump).To(ContainElements("a", "b")) - delCount, err := client.DictDel(ctx, "dict1", "a").Run() + delCount, err := client.NewDictBuilder(ctx, "dict1").Del("a").Run() Expect(err).NotTo(HaveOccurred()) Expect(delCount).To(Equal(int64(1))) }) It("should tag values using TagValsBuilder", Label("search", "builders", "fttagvals"), func() { - createVal, err := client.CreateIndex(ctx, "idx_tag"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_tag"). OnHash(). Schema(&redis.FieldSchema{FieldName: "tags", FieldType: redis.SearchFieldTypeTag}). Run() @@ -163,13 +163,13 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc1", "tags", "red,blue") client.HSet(ctx, "doc2", "tags", "green,blue") - vals, err := client.TagVals(ctx, "idx_tag", "tags").Run() + vals, err := client.NewTagValsBuilder(ctx, "idx_tag", "tags").Run() Expect(err).NotTo(HaveOccurred()) Expect(vals).To(BeAssignableToTypeOf([]string{})) }) It("should cursor read and delete using CursorBuilder", Label("search", "builders", "ftcursor"), func() { - Expect(client.CreateIndex(ctx, "idx5"). + Expect(client.NewCreateIndexBuilder(ctx, "idx5"). OnHash(). Schema(&redis.FieldSchema{FieldName: "f", FieldType: redis.SearchFieldTypeText}). Run()).To(Equal("OK")) @@ -177,18 +177,18 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc1", "f", "hello") client.HSet(ctx, "doc2", "f", "world") - cursorBuilder := client.CursorRead(ctx, "idx5", 1) + cursorBuilder := client.NewCursorBuilder(ctx, "idx5", 1) Expect(cursorBuilder).NotTo(BeNil()) cursorBuilder = cursorBuilder.Count(10) Expect(cursorBuilder).NotTo(BeNil()) - delBuilder := client.CursorDel(ctx, "idx5", 1) + delBuilder := client.NewCursorBuilder(ctx, "idx5", 1) Expect(delBuilder).NotTo(BeNil()) }) It("should update synonyms using SynUpdateBuilder", Label("search", "builders", "ftsynupdate"), func() { - createVal, err := client.CreateIndex(ctx, "idx_syn"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_syn"). OnHash(). Schema(&redis.FieldSchema{FieldName: "foo", FieldType: redis.SearchFieldTypeText}). Run() @@ -196,13 +196,13 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(createVal).To(Equal("OK")) WaitForIndexing(client, "idx_syn") - syn, err := client.SynUpdate(ctx, "idx_syn", "grp1").Terms("a", "b").Run() + syn, err := client.NewSynUpdateBuilder(ctx, "idx_syn", "grp1").Terms("a", "b").Run() Expect(err).NotTo(HaveOccurred()) Expect(syn).To(Equal("OK")) }) It("should test SearchBuilder with NoContent and Verbatim", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_nocontent"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_nocontent"). OnHash(). Schema(&redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText, Weight: 5}). Schema(&redis.FieldSchema{FieldName: "body", FieldType: redis.SearchFieldTypeText}). @@ -213,7 +213,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc1", "title", "RediSearch", "body", "Redisearch implements a search engine on top of redis") - res, err := client.Search(ctx, "idx_nocontent", "search engine"). + res, err := client.NewSearchBuilder(ctx, "idx_nocontent", "search engine"). NoContent(). Verbatim(). Limit(0, 5). @@ -226,7 +226,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with NoStopWords", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_nostop"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_nostop"). OnHash(). Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). Run() @@ -238,13 +238,13 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc2", "txt", "test document") // Test that NoStopWords method can be called and search works - res, err := client.Search(ctx, "idx_nostop", "hello").NoContent().NoStopWords().Run() + res, err := client.NewSearchBuilder(ctx, "idx_nostop", "hello").NoContent().NoStopWords().Run() Expect(err).NotTo(HaveOccurred()) Expect(res.Total).To(Equal(1)) }) It("should test SearchBuilder with filters", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_filters"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_filters"). OnHash(). Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). Schema(&redis.FieldSchema{FieldName: "num", FieldType: redis.SearchFieldTypeNumeric}). @@ -258,7 +258,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc2", "txt", "foo baz", "num", 2, "loc", "-0.1,51.2") // Test numeric filter - res1, err := client.Search(ctx, "idx_filters", "foo"). + res1, err := client.NewSearchBuilder(ctx, "idx_filters", "foo"). Filter("num", 2, 4). NoContent(). Run() @@ -266,7 +266,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(res1.Total).To(Equal(2)) // Test geo filter - res2, err := client.Search(ctx, "idx_filters", "foo"). + res2, err := client.NewSearchBuilder(ctx, "idx_filters", "foo"). GeoFilter("loc", -0.44, 51.45, 10, "km"). NoContent(). Run() @@ -275,7 +275,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with sorting", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_sort"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_sort"). OnHash(). Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). Schema(&redis.FieldSchema{FieldName: "num", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}). @@ -289,7 +289,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc3", "txt", "foo qux", "num", 3) // Test ascending sort - res1, err := client.Search(ctx, "idx_sort", "foo"). + res1, err := client.NewSearchBuilder(ctx, "idx_sort", "foo"). SortBy("num", true). NoContent(). Run() @@ -300,7 +300,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(res1.Docs[2].ID).To(Equal("doc3")) // Test descending sort - res2, err := client.Search(ctx, "idx_sort", "foo"). + res2, err := client.NewSearchBuilder(ctx, "idx_sort", "foo"). SortBy("num", false). NoContent(). Run() @@ -312,7 +312,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with InKeys and InFields", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_in"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_in"). OnHash(). Schema(&redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText}). Schema(&redis.FieldSchema{FieldName: "body", FieldType: redis.SearchFieldTypeText}). @@ -326,7 +326,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc3", "title", "baz qux", "body", "dolor sit") // Test InKeys - res1, err := client.Search(ctx, "idx_in", "hello"). + res1, err := client.NewSearchBuilder(ctx, "idx_in", "hello"). InKeys("doc1", "doc2"). NoContent(). Run() @@ -334,7 +334,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(res1.Total).To(Equal(2)) // Test InFields - res2, err := client.Search(ctx, "idx_in", "hello"). + res2, err := client.NewSearchBuilder(ctx, "idx_in", "hello"). InFields("title"). NoContent(). Run() @@ -344,7 +344,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with Return fields", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_return"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_return"). OnHash(). Schema(&redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText}). Schema(&redis.FieldSchema{FieldName: "body", FieldType: redis.SearchFieldTypeText}). @@ -357,7 +357,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc1", "title", "hello", "body", "world", "num", 42) // Test ReturnFields - res1, err := client.Search(ctx, "idx_return", "hello"). + res1, err := client.NewSearchBuilder(ctx, "idx_return", "hello"). ReturnFields("title", "num"). Run() Expect(err).NotTo(HaveOccurred()) @@ -367,7 +367,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(res1.Docs[0].Fields).NotTo(HaveKey("body")) // Test ReturnAs - res2, err := client.Search(ctx, "idx_return", "hello"). + res2, err := client.NewSearchBuilder(ctx, "idx_return", "hello"). ReturnAs("title", "doc_title"). Run() Expect(err).NotTo(HaveOccurred()) @@ -377,7 +377,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with advanced options", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_advanced"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_advanced"). OnHash(). Schema(&redis.FieldSchema{FieldName: "description", FieldType: redis.SearchFieldTypeText}). Run() @@ -389,7 +389,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc2", "description", "Quick alice was beginning to get very tired of sitting by her quick sister on the bank") // Test with scores and different scorers - res1, err := client.Search(ctx, "idx_advanced", "quick"). + res1, err := client.NewSearchBuilder(ctx, "idx_advanced", "quick"). WithScores(). Scorer("TFIDF"). Run() @@ -399,7 +399,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(*doc.Score).To(BeNumerically(">", 0)) } - res2, err := client.Search(ctx, "idx_advanced", "quick"). + res2, err := client.NewSearchBuilder(ctx, "idx_advanced", "quick"). WithScores(). Payload("test_payload"). NoContent(). @@ -408,7 +408,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(res2.Total).To(Equal(2)) // Test with Slop and InOrder - res3, err := client.Search(ctx, "idx_advanced", "quick brown"). + res3, err := client.NewSearchBuilder(ctx, "idx_advanced", "quick brown"). Slop(1). InOrder(). NoContent(). @@ -417,7 +417,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(res3.Total).To(Equal(1)) // Test with Language and Expander - res4, err := client.Search(ctx, "idx_advanced", "quick"). + res4, err := client.NewSearchBuilder(ctx, "idx_advanced", "quick"). Language("english"). Expander("SYNONYM"). NoContent(). @@ -426,7 +426,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(res4.Total).To(BeNumerically(">=", 0)) // Test with Timeout - res5, err := client.Search(ctx, "idx_advanced", "quick"). + res5, err := client.NewSearchBuilder(ctx, "idx_advanced", "quick"). Timeout(1000). NoContent(). Run() @@ -435,7 +435,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with Params and Dialect", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_params"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_params"). OnHash(). Schema(&redis.FieldSchema{FieldName: "name", FieldType: redis.SearchFieldTypeText}). Run() @@ -448,7 +448,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc3", "name", "Carol") // Test with single param - res1, err := client.Search(ctx, "idx_params", "@name:$name"). + res1, err := client.NewSearchBuilder(ctx, "idx_params", "@name:$name"). Param("name", "Alice"). NoContent(). Run() @@ -461,7 +461,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { "name1": "Bob", "name2": "Carol", } - res2, err := client.Search(ctx, "idx_params", "@name:($name1|$name2)"). + res2, err := client.NewSearchBuilder(ctx, "idx_params", "@name:($name1|$name2)"). ParamsMap(params). Dialect(2). NoContent(). @@ -471,7 +471,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with Limit and CountOnly", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_limit"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_limit"). OnHash(). Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). Run() @@ -484,7 +484,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { } // Test with Limit - res1, err := client.Search(ctx, "idx_limit", "test"). + res1, err := client.NewSearchBuilder(ctx, "idx_limit", "test"). Limit(2, 3). NoContent(). Run() @@ -493,7 +493,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(len(res1.Docs)).To(Equal(3)) // Test with CountOnly - res2, err := client.Search(ctx, "idx_limit", "test"). + res2, err := client.NewSearchBuilder(ctx, "idx_limit", "test"). CountOnly(). Run() Expect(err).NotTo(HaveOccurred()) @@ -502,7 +502,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with WithSortByCount and SortBy", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_payloads"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_payloads"). OnHash(). Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). Schema(&redis.FieldSchema{FieldName: "num", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}). @@ -515,7 +515,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc2", "txt", "world", "num", 2) // Test WithSortByCount and SortBy - res, err := client.Search(ctx, "idx_payloads", "*"). + res, err := client.NewSearchBuilder(ctx, "idx_payloads", "*"). SortBy("num", true). WithSortByCount(). NoContent(). @@ -525,7 +525,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with JSON", Label("search", "ftsearch", "builders", "json"), func() { - createVal, err := client.CreateIndex(ctx, "idx_json"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_json"). OnJSON(). Prefix("king:"). Schema(&redis.FieldSchema{FieldName: "$.name", FieldType: redis.SearchFieldTypeText}). @@ -537,7 +537,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.JSONSet(ctx, "king:1", "$", `{"name": "henry"}`) client.JSONSet(ctx, "king:2", "$", `{"name": "james"}`) - res, err := client.Search(ctx, "idx_json", "henry").Run() + res, err := client.NewSearchBuilder(ctx, "idx_json", "henry").Run() Expect(err).NotTo(HaveOccurred()) Expect(res.Total).To(Equal(1)) Expect(res.Docs[0].ID).To(Equal("king:1")) @@ -546,7 +546,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { It("should test SearchBuilder with vector search", Label("search", "ftsearch", "builders", "vector"), func() { hnswOptions := &redis.FTHNSWOptions{Type: "FLOAT32", Dim: 2, DistanceMetric: "L2"} - createVal, err := client.CreateIndex(ctx, "idx_vector"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_vector"). OnHash(). Schema(&redis.FieldSchema{FieldName: "v", FieldType: redis.SearchFieldTypeVector, VectorArgs: &redis.FTVectorArgs{HNSWOptions: hnswOptions}}). Run() @@ -558,7 +558,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "b", "v", "aaaabaaa") client.HSet(ctx, "c", "v", "aaaaabaa") - res, err := client.Search(ctx, "idx_vector", "*=>[KNN 2 @v $vec]"). + res, err := client.NewSearchBuilder(ctx, "idx_vector", "*=>[KNN 2 @v $vec]"). ReturnFields("__v_score"). SortBy("__v_score", true). Dialect(2). @@ -570,7 +570,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder with complex filtering and aggregation", Label("search", "ftsearch", "builders"), func() { - createVal, err := client.CreateIndex(ctx, "idx_complex"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_complex"). OnHash(). Schema(&redis.FieldSchema{FieldName: "category", FieldType: redis.SearchFieldTypeTag}). Schema(&redis.FieldSchema{FieldName: "price", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}). @@ -585,7 +585,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "product2", "category", "electronics", "price", 200, "location", "-0.2,51.6", "description", "laptop computer") client.HSet(ctx, "product3", "category", "books", "price", 20, "location", "-0.3,51.7", "description", "programming guide") - res, err := client.Search(ctx, "idx_complex", "@category:{electronics} @description:(device|computer)"). + res, err := client.NewSearchBuilder(ctx, "idx_complex", "@category:{electronics} @description:(device|computer)"). Filter("price", 50, 250). GeoFilter("location", -0.15, 51.55, 50, "km"). SortBy("price", true). @@ -595,7 +595,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Total).To(BeNumerically(">=", 1)) - res2, err := client.Search(ctx, "idx_complex", "@category:{$cat} @price:[$min $max]"). + res2, err := client.NewSearchBuilder(ctx, "idx_complex", "@category:{$cat} @price:[$min $max]"). ParamsMap(map[string]interface{}{ "cat": "electronics", "min": 150, @@ -610,7 +610,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder error handling and edge cases", Label("search", "ftsearch", "builders", "edge-cases"), func() { - createVal, err := client.CreateIndex(ctx, "idx_edge"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_edge"). OnHash(). Schema(&redis.FieldSchema{FieldName: "txt", FieldType: redis.SearchFieldTypeText}). Run() @@ -621,17 +621,17 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc1", "txt", "hello world") // Test empty query - res1, err := client.Search(ctx, "idx_edge", "*").NoContent().Run() + res1, err := client.NewSearchBuilder(ctx, "idx_edge", "*").NoContent().Run() Expect(err).NotTo(HaveOccurred()) Expect(res1.Total).To(Equal(1)) // Test query with no results - res2, err := client.Search(ctx, "idx_edge", "nonexistent").NoContent().Run() + res2, err := client.NewSearchBuilder(ctx, "idx_edge", "nonexistent").NoContent().Run() Expect(err).NotTo(HaveOccurred()) Expect(res2.Total).To(Equal(0)) // Test with multiple chained methods - res3, err := client.Search(ctx, "idx_edge", "hello"). + res3, err := client.NewSearchBuilder(ctx, "idx_edge", "hello"). WithScores(). NoContent(). Verbatim(). @@ -646,7 +646,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { }) It("should test SearchBuilder method chaining", Label("search", "ftsearch", "builders", "fluent"), func() { - createVal, err := client.CreateIndex(ctx, "idx_fluent"). + createVal, err := client.NewCreateIndexBuilder(ctx, "idx_fluent"). OnHash(). Schema(&redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText}). Schema(&redis.FieldSchema{FieldName: "tags", FieldType: redis.SearchFieldTypeTag}). @@ -659,7 +659,7 @@ var _ = Describe("RediSearch Builders", Label("search", "builders"), func() { client.HSet(ctx, "doc1", "title", "Redis Search Tutorial", "tags", "redis,search,tutorial", "score", 95) client.HSet(ctx, "doc2", "title", "Advanced Redis", "tags", "redis,advanced", "score", 88) - builder := client.Search(ctx, "idx_fluent", "@title:(redis) @tags:{search}") + builder := client.NewSearchBuilder(ctx, "idx_fluent", "@title:(redis) @tags:{search}") result := builder. WithScores(). Filter("score", 90, 100). From e4c4833c4cffb459a980ec0e7fe3dfbbb6378c44 Mon Sep 17 00:00:00 2001 From: ofekshenawa <104765379+ofekshenawa@users.noreply.github.com> Date: Mon, 4 Aug 2025 11:51:08 +0300 Subject: [PATCH 03/62] Update search_builders.go Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> --- search_builders.go | 1 + 1 file changed, 1 insertion(+) diff --git a/search_builders.go b/search_builders.go index a35d6e9281..2e86aa7caa 100644 --- a/search_builders.go +++ b/search_builders.go @@ -10,6 +10,7 @@ import ( // SearchBuilder provides a fluent API for FT.SEARCH // (see original FTSearchOptions for all options). +// EXPERIMENTAL: this API is subject to change, use with caution. type SearchBuilder struct { c *Client ctx context.Context From 1928265b9fdd8f6a57c925da6ffa01a9660eb166 Mon Sep 17 00:00:00 2001 From: ofekshenawa <104765379+ofekshenawa@users.noreply.github.com> Date: Mon, 4 Aug 2025 11:51:21 +0300 Subject: [PATCH 04/62] Update search_builders.go Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> --- search_builders.go | 1 + 1 file changed, 1 insertion(+) diff --git a/search_builders.go b/search_builders.go index 2e86aa7caa..5e6e760f7d 100644 --- a/search_builders.go +++ b/search_builders.go @@ -20,6 +20,7 @@ type SearchBuilder struct { } // NewSearchBuilder creates a new SearchBuilder for FT.SEARCH commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewSearchBuilder(ctx context.Context, index, query string) *SearchBuilder { b := &SearchBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTSearchOptions{LimitOffset: -1}} return b From fcf645ecb6d0f575d74c65af7de19e32ad2a1d59 Mon Sep 17 00:00:00 2001 From: ofekshenawa <104765379+ofekshenawa@users.noreply.github.com> Date: Mon, 4 Aug 2025 11:53:16 +0300 Subject: [PATCH 05/62] Apply suggestions from code review Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> --- search_builders.go | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/search_builders.go b/search_builders.go index 5e6e760f7d..91f0634041 100644 --- a/search_builders.go +++ b/search_builders.go @@ -218,6 +218,7 @@ type AggregateBuilder struct { } // NewAggregateBuilder creates a new AggregateBuilder for FT.AGGREGATE commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewAggregateBuilder(ctx context.Context, index, query string) *AggregateBuilder { return &AggregateBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTAggregateOptions{LimitOffset: -1}} } @@ -360,7 +361,8 @@ func (b *AggregateBuilder) Run() (*FTAggregateResult, error) { // ---------------------- // CreateIndexBuilder for FT.CREATE // ---------------------- - +// CreateIndexBuilder is builder for FT.CREATE +// EXPERIMENTAL: this API is subject to change, use with caution. type CreateIndexBuilder struct { c *Client ctx context.Context @@ -370,6 +372,7 @@ type CreateIndexBuilder struct { } // NewCreateIndexBuilder creates a new CreateIndexBuilder for FT.CREATE commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewCreateIndexBuilder(ctx context.Context, index string) *CreateIndexBuilder { return &CreateIndexBuilder{c: c, ctx: ctx, index: index, options: &FTCreateOptions{}} } @@ -467,7 +470,8 @@ func (b *CreateIndexBuilder) Run() (string, error) { // ---------------------- // DropIndexBuilder for FT.DROPINDEX // ---------------------- - +// DropIndexBuilder is a builder for FT.DROPINDEX +// EXPERIMENTAL: this API is subject to change, use with caution. type DropIndexBuilder struct { c *Client ctx context.Context @@ -476,6 +480,7 @@ type DropIndexBuilder struct { } // NewDropIndexBuilder creates a new DropIndexBuilder for FT.DROPINDEX commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewDropIndexBuilder(ctx context.Context, index string) *DropIndexBuilder { return &DropIndexBuilder{c: c, ctx: ctx, index: index} } @@ -492,7 +497,8 @@ func (b *DropIndexBuilder) Run() (string, error) { // ---------------------- // AliasBuilder for FT.ALIAS* commands // ---------------------- - +// AliasBuilder is builder for FT.ALIAS* commands +// EXPERIMENTAL: this API is subject to change, use with caution. type AliasBuilder struct { c *Client ctx context.Context @@ -502,6 +508,7 @@ type AliasBuilder struct { } // NewAliasBuilder creates a new AliasBuilder for FT.ALIAS* commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewAliasBuilder(ctx context.Context, alias string) *AliasBuilder { return &AliasBuilder{c: c, ctx: ctx, alias: alias} } @@ -551,7 +558,8 @@ func (b *AliasBuilder) Run() (string, error) { // ---------------------- // ExplainBuilder for FT.EXPLAIN // ---------------------- - +// ExplainBuilder is builder for FT.EXPLAIN +// EXPERIMENTAL: this API is subject to change, use with caution. type ExplainBuilder struct { c *Client ctx context.Context @@ -561,6 +569,7 @@ type ExplainBuilder struct { } // NewExplainBuilder creates a new ExplainBuilder for FT.EXPLAIN commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewExplainBuilder(ctx context.Context, index, query string) *ExplainBuilder { return &ExplainBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTExplainOptions{}} } @@ -598,7 +607,8 @@ func (b *FTInfoBuilder) Run() (FTInfoResult, error) { // ---------------------- // SpellCheckBuilder for FT.SPELLCHECK // ---------------------- - +// SpellCheckBuilder is builder for FT.SPELLCHECK +// EXPERIMENTAL: this API is subject to change, use with caution. type SpellCheckBuilder struct { c *Client ctx context.Context @@ -608,6 +618,7 @@ type SpellCheckBuilder struct { } // NewSpellCheckBuilder creates a new SpellCheckBuilder for FT.SPELLCHECK commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewSpellCheckBuilder(ctx context.Context, index, query string) *SpellCheckBuilder { return &SpellCheckBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTSpellCheckOptions{}} } @@ -642,7 +653,8 @@ func (b *SpellCheckBuilder) Run() ([]SpellCheckResult, error) { // ---------------------- // DictBuilder for FT.DICT* commands // ---------------------- - +// DictBuilder is builder for FT.DICT* commands +// EXPERIMENTAL: this API is subject to change, use with caution. type DictBuilder struct { c *Client ctx context.Context @@ -652,6 +664,7 @@ type DictBuilder struct { } // NewDictBuilder creates a new DictBuilder for FT.DICT* commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewDictBuilder(ctx context.Context, dict string) *DictBuilder { return &DictBuilder{c: c, ctx: ctx, dict: dict} } @@ -701,7 +714,8 @@ func (b *DictBuilder) Run() (interface{}, error) { // ---------------------- // TagValsBuilder for FT.TAGVALS // ---------------------- - +// TagValsBuilder is builder for FT.TAGVALS +// EXPERIMENTAL: this API is subject to change, use with caution. type TagValsBuilder struct { c *Client ctx context.Context @@ -710,6 +724,7 @@ type TagValsBuilder struct { } // NewTagValsBuilder creates a new TagValsBuilder for FT.TAGVALS commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewTagValsBuilder(ctx context.Context, index, field string) *TagValsBuilder { return &TagValsBuilder{c: c, ctx: ctx, index: index, field: field} } @@ -723,7 +738,8 @@ func (b *TagValsBuilder) Run() ([]string, error) { // ---------------------- // CursorBuilder for FT.CURSOR* // ---------------------- - +// CursorBuilder is builder for FT.CURSOR* commands +// EXPERIMENTAL: this API is subject to change, use with caution. type CursorBuilder struct { c *Client ctx context.Context @@ -734,6 +750,7 @@ type CursorBuilder struct { } // NewCursorBuilder creates a new CursorBuilder for FT.CURSOR* commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewCursorBuilder(ctx context.Context, index string, cursorId int64) *CursorBuilder { return &CursorBuilder{c: c, ctx: ctx, index: index, cursorId: cursorId} } @@ -775,7 +792,8 @@ func (b *CursorBuilder) Run() (interface{}, error) { // ---------------------- // SynUpdateBuilder for FT.SYNUPDATE // ---------------------- - +// SyncUpdateBuilder is builder for FT.SYNCUPDATE +// EXPERIMENTAL: this API is subject to change, use with caution. type SynUpdateBuilder struct { c *Client ctx context.Context @@ -786,6 +804,7 @@ type SynUpdateBuilder struct { } // NewSynUpdateBuilder creates a new SynUpdateBuilder for FT.SYNUPDATE commands. +// EXPERIMENTAL: this API is subject to change, use with caution. func (c *Client) NewSynUpdateBuilder(ctx context.Context, index string, groupId interface{}) *SynUpdateBuilder { return &SynUpdateBuilder{c: c, ctx: ctx, index: index, groupId: groupId, options: &FTSynUpdateOptions{}} } From b2d2d9143179eb5711c0de383a4b9acd38b33749 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Wed, 14 May 2025 21:35:04 +0300 Subject: [PATCH 06/62] feat(routing): add internal request/response policy enums --- internal/routing/policy.go | 118 +++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 internal/routing/policy.go diff --git a/internal/routing/policy.go b/internal/routing/policy.go new file mode 100644 index 0000000000..e403067c0e --- /dev/null +++ b/internal/routing/policy.go @@ -0,0 +1,118 @@ +package routing + +import ( + "fmt" + "strings" +) + +type RequestPolicy uint8 + +const ( + ReqDefault RequestPolicy = iota + + ReqAllNodes + + ReqAllShards + + ReqMultiShard + + ReqSpecial +) + +func (p RequestPolicy) String() string { + switch p { + case ReqDefault: + return "default" + case ReqAllNodes: + return "all_nodes" + case ReqAllShards: + return "all_shards" + case ReqMultiShard: + return "multi_shard" + case ReqSpecial: + return "special" + default: + return fmt.Sprintf("unknown_request_policy(%d)", p) + } +} + +func ParseRequestPolicy(raw string) (RequestPolicy, error) { + switch strings.ToLower(raw) { + case "", "default", "none": + return ReqDefault, nil + case "all_nodes": + return ReqAllNodes, nil + case "all_shards": + return ReqAllShards, nil + case "multi_shard": + return ReqMultiShard, nil + case "special": + return ReqSpecial, nil + default: + return ReqDefault, fmt.Errorf("routing: unknown request_policy %q", raw) + } +} + +type ResponsePolicy uint8 + +const ( + RespAllSucceeded ResponsePolicy = iota + RespOneSucceeded + RespAggSum + RespAggMin + RespAggMax + RespAggLogicalAnd + RespAggLogicalOr + RespSpecial +) + +func (p ResponsePolicy) String() string { + switch p { + case RespAllSucceeded: + return "all_succeeded" + case RespOneSucceeded: + return "one_succeeded" + case RespAggSum: + return "agg_sum" + case RespAggMin: + return "agg_min" + case RespAggMax: + return "agg_max" + case RespAggLogicalAnd: + return "agg_logical_and" + case RespAggLogicalOr: + return "agg_logical_or" + case RespSpecial: + return "special" + default: + return fmt.Sprintf("unknown_response_policy(%d)", p) + } +} + +func ParseResponsePolicy(raw string) (ResponsePolicy, error) { + switch strings.ToLower(raw) { + case "all_succeeded": + return RespAllSucceeded, nil + case "one_succeeded": + return RespOneSucceeded, nil + case "agg_sum": + return RespAggSum, nil + case "agg_min": + return RespAggMin, nil + case "agg_max": + return RespAggMax, nil + case "agg_logical_and": + return RespAggLogicalAnd, nil + case "agg_logical_or": + return RespAggLogicalOr, nil + case "special": + return RespSpecial, nil + default: + return RespAllSucceeded, fmt.Errorf("routing: unknown response_policy %q", raw) + } +} + +type CommandPolicy struct { + Request RequestPolicy + Response ResponsePolicy +} From b943692d640941c0094fa78ac8d0d4fc96ceb844 Mon Sep 17 00:00:00 2001 From: ofekshenawa <104765379+ofekshenawa@users.noreply.github.com> Date: Tue, 20 May 2025 11:37:03 +0300 Subject: [PATCH 07/62] feat: load the policy table in cluster client (#4) * feat: load the policy table in cluster client * Remove comments --- command.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++-- commands_test.go | 16 +++++++++++ osscluster.go | 12 ++++---- 3 files changed, 96 insertions(+), 7 deletions(-) diff --git a/command.go b/command.go index d3fb231b5e..0fcc7a5559 100644 --- a/command.go +++ b/command.go @@ -14,6 +14,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/routing" "github.com/redis/go-redis/v9/internal/util" ) @@ -3478,6 +3479,7 @@ type CommandInfo struct { LastKeyPos int8 StepCount int8 ReadOnly bool + Tips map[string]string } type CommandsInfoCmd struct { @@ -3516,7 +3518,7 @@ func (cmd *CommandsInfoCmd) String() string { func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { const numArgRedis5 = 6 const numArgRedis6 = 7 - const numArgRedis7 = 10 + const numArgRedis7 = 10 // Also matches redis 8 n, err := rd.ReadArrayLen() if err != nil { @@ -3604,9 +3606,34 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { } if nn >= numArgRedis7 { - if err := rd.DiscardNext(); err != nil { + // The 8th argument is an array of tips. + tipsLen, err := rd.ReadArrayLen() + if err != nil { return err } + + cmdInfo.Tips = make(map[string]string, tipsLen) + + for f := 0; f < tipsLen; f++ { + tip, err := rd.ReadString() + if err != nil { + return err + } + + // Handle tips that don't have a colon (like "nondeterministic_output") + if !strings.Contains(tip, ":") { + cmdInfo.Tips[tip] = "" + continue + } + + // Handle normal key:value tips + k, v, ok := strings.Cut(tip, ":") + if !ok { + return fmt.Errorf("redis: unexpected tip %q in COMMAND reply", tip) + } + cmdInfo.Tips[k] = v + } + if err := rd.DiscardNext(); err != nil { return err } @@ -3656,6 +3683,50 @@ func (c *cmdsInfoCache) Get(ctx context.Context) (map[string]*CommandInfo, error return c.cmds, err } +// ------------------------------------------------------------------------------ +var BuiltinPolicies = map[string]routing.CommandPolicy{ + "ft.create": {Request: routing.ReqSpecial, Response: routing.RespAllSucceeded}, + "ft.alter": {Request: routing.ReqSpecial, Response: routing.RespAllSucceeded}, + "ft.drop": {Request: routing.ReqSpecial, Response: routing.RespAllSucceeded}, + + "mset": {Request: routing.ReqMultiShard, Response: routing.RespAllSucceeded}, + "mget": {Request: routing.ReqMultiShard, Response: routing.RespSpecial}, + "del": {Request: routing.ReqMultiShard, Response: routing.RespAggSum}, +} + +func newCommandPolicies(commandInfo map[string]*CommandInfo) map[string]routing.CommandPolicy { + + table := make(map[string]routing.CommandPolicy, len(commandInfo)) + + for name, info := range commandInfo { + req := routing.ReqDefault + resp := routing.RespAllSucceeded + + if tips := info.Tips; tips != nil { + if v, ok := tips["request_policy"]; ok { + if p, err := routing.ParseRequestPolicy(v); err == nil { + req = p + } + } + if v, ok := tips["response_policy"]; ok { + if p, err := routing.ParseResponsePolicy(v); err == nil { + resp = p + } + } + } else { + return BuiltinPolicies + } + table[name] = routing.CommandPolicy{Request: req, Response: resp} + } + + if len(table) == 0 { + for k, v := range BuiltinPolicies { + table[k] = v + } + } + return table +} + //------------------------------------------------------------------------------ type SlowLog struct { diff --git a/commands_test.go b/commands_test.go index 17b4dd0306..ef597079ec 100644 --- a/commands_test.go +++ b/commands_test.go @@ -657,6 +657,22 @@ var _ = Describe("Commands", func() { Expect(cmd.StepCount).To(Equal(int8(0))) }) + It("should Command Tips", Label("NonRedisEnterprise"), func() { + SkipAfterRedisVersion(7.9, "Redis 8 changed the COMMAND reply format") + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + cmd := cmds["touch"] + Expect(cmd.Name).To(Equal("touch")) + Expect(cmd.Tips["request_policy"]).To(Equal("multi_shard")) + Expect(cmd.Tips["response_policy"]).To(Equal("agg_sum")) + + cmd = cmds["flushall"] + Expect(cmd.Name).To(Equal("flushall")) + Expect(cmd.Tips["request_policy"]).To(Equal("all_shards")) + Expect(cmd.Tips["response_policy"]).To(Equal("all_succeeded")) + }) + It("should return all command names", func() { cmdList := client.CommandList(ctx, nil) Expect(cmdList.Err()).NotTo(HaveOccurred()) diff --git a/osscluster.go b/osscluster.go index 7925d2c603..94f2d1f677 100644 --- a/osscluster.go +++ b/osscluster.go @@ -20,6 +20,7 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/internal/routing" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -1006,10 +1007,11 @@ func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, er // or more underlying connections. It's safe for concurrent use by // multiple goroutines. type ClusterClient struct { - opt *ClusterOptions - nodes *clusterNodes - state *clusterStateHolder - cmdsInfoCache *cmdsInfoCache + opt *ClusterOptions + nodes *clusterNodes + state *clusterStateHolder + cmdsInfoCache *cmdsInfoCache + commandPolicies map[string]routing.CommandPolicy cmdable hooksMixin } @@ -1029,8 +1031,8 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) + c.commandPolicies = newCommandPolicies(c.cmdsInfoCache.cmds) c.cmdable = c.Process - c.initHooks(hooks{ dial: nil, process: c.process, From 5375c51e6db39303189ba0d409f365de6572de65 Mon Sep 17 00:00:00 2001 From: ofekshenawa <104765379+ofekshenawa@users.noreply.github.com> Date: Tue, 20 May 2025 13:53:49 +0300 Subject: [PATCH 08/62] modify Tips and command pplicy in commandInfo (#5) --- command.go | 62 +++++++++++++++----------------------- commands_test.go | 9 +++--- internal/routing/policy.go | 3 ++ osscluster.go | 11 +++---- 4 files changed, 36 insertions(+), 49 deletions(-) diff --git a/command.go b/command.go index 0fcc7a5559..f2a070c90f 100644 --- a/command.go +++ b/command.go @@ -3479,7 +3479,7 @@ type CommandInfo struct { LastKeyPos int8 StepCount int8 ReadOnly bool - Tips map[string]string + Tips *routing.CommandPolicy } type CommandsInfoCmd struct { @@ -3612,8 +3612,7 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { return err } - cmdInfo.Tips = make(map[string]string, tipsLen) - + rawTips := make(map[string]string, tipsLen) for f := 0; f < tipsLen; f++ { tip, err := rd.ReadString() if err != nil { @@ -3622,7 +3621,7 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { // Handle tips that don't have a colon (like "nondeterministic_output") if !strings.Contains(tip, ":") { - cmdInfo.Tips[tip] = "" + rawTips[tip] = "" continue } @@ -3631,8 +3630,9 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { if !ok { return fmt.Errorf("redis: unexpected tip %q in COMMAND reply", tip) } - cmdInfo.Tips[k] = v + rawTips[k] = v } + cmdInfo.Tips = parseCommandPolicies(rawTips) if err := rd.DiscardNext(); err != nil { return err @@ -3684,47 +3684,33 @@ func (c *cmdsInfoCache) Get(ctx context.Context) (map[string]*CommandInfo, error } // ------------------------------------------------------------------------------ -var BuiltinPolicies = map[string]routing.CommandPolicy{ - "ft.create": {Request: routing.ReqSpecial, Response: routing.RespAllSucceeded}, - "ft.alter": {Request: routing.ReqSpecial, Response: routing.RespAllSucceeded}, - "ft.drop": {Request: routing.ReqSpecial, Response: routing.RespAllSucceeded}, - - "mset": {Request: routing.ReqMultiShard, Response: routing.RespAllSucceeded}, - "mget": {Request: routing.ReqMultiShard, Response: routing.RespSpecial}, - "del": {Request: routing.ReqMultiShard, Response: routing.RespAggSum}, -} - -func newCommandPolicies(commandInfo map[string]*CommandInfo) map[string]routing.CommandPolicy { - - table := make(map[string]routing.CommandPolicy, len(commandInfo)) +const requestPolicy = "request_policy" +const responsePolicy = "response_policy" - for name, info := range commandInfo { - req := routing.ReqDefault - resp := routing.RespAllSucceeded +func parseCommandPolicies(commandInfoTips map[string]string) *routing.CommandPolicy { + req := routing.ReqDefault + resp := routing.RespAllSucceeded - if tips := info.Tips; tips != nil { - if v, ok := tips["request_policy"]; ok { - if p, err := routing.ParseRequestPolicy(v); err == nil { - req = p - } + if commandInfoTips != nil { + if v, ok := commandInfoTips[requestPolicy]; ok { + if p, err := routing.ParseRequestPolicy(v); err == nil { + req = p } - if v, ok := tips["response_policy"]; ok { - if p, err := routing.ParseResponsePolicy(v); err == nil { - resp = p - } + } + if v, ok := commandInfoTips[responsePolicy]; ok { + if p, err := routing.ParseResponsePolicy(v); err == nil { + resp = p } - } else { - return BuiltinPolicies } - table[name] = routing.CommandPolicy{Request: req, Response: resp} } - - if len(table) == 0 { - for k, v := range BuiltinPolicies { - table[k] = v + tips := make(map[string]string, len(commandInfoTips)) + for k, v := range commandInfoTips { + if k == requestPolicy || k == responsePolicy { + continue } + tips[k] = v } - return table + return &routing.CommandPolicy{Request: req, Response: resp, Tips: tips} } //------------------------------------------------------------------------------ diff --git a/commands_test.go b/commands_test.go index ef597079ec..476a08a302 100644 --- a/commands_test.go +++ b/commands_test.go @@ -13,6 +13,7 @@ import ( "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/routing" ) type TimeValue struct { @@ -664,13 +665,13 @@ var _ = Describe("Commands", func() { cmd := cmds["touch"] Expect(cmd.Name).To(Equal("touch")) - Expect(cmd.Tips["request_policy"]).To(Equal("multi_shard")) - Expect(cmd.Tips["response_policy"]).To(Equal("agg_sum")) + Expect(cmd.Tips.Request).To(Equal(routing.ReqMultiShard)) + Expect(cmd.Tips.Response).To(Equal(routing.RespAggSum)) cmd = cmds["flushall"] Expect(cmd.Name).To(Equal("flushall")) - Expect(cmd.Tips["request_policy"]).To(Equal("all_shards")) - Expect(cmd.Tips["response_policy"]).To(Equal("all_succeeded")) + Expect(cmd.Tips.Request).To(Equal(routing.ReqAllShards)) + Expect(cmd.Tips.Response).To(Equal(routing.RespAllSucceeded)) }) It("should return all command names", func() { diff --git a/internal/routing/policy.go b/internal/routing/policy.go index e403067c0e..18c03cd2dd 100644 --- a/internal/routing/policy.go +++ b/internal/routing/policy.go @@ -115,4 +115,7 @@ func ParseResponsePolicy(raw string) (ResponsePolicy, error) { type CommandPolicy struct { Request RequestPolicy Response ResponsePolicy + // Tips that are not request_policy or response_policy + // e.g nondeterministic_output, nondeterministic_output_order. + Tips map[string]string } diff --git a/osscluster.go b/osscluster.go index 94f2d1f677..0db0699f8a 100644 --- a/osscluster.go +++ b/osscluster.go @@ -20,7 +20,6 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/rand" - "github.com/redis/go-redis/v9/internal/routing" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -1007,11 +1006,10 @@ func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, er // or more underlying connections. It's safe for concurrent use by // multiple goroutines. type ClusterClient struct { - opt *ClusterOptions - nodes *clusterNodes - state *clusterStateHolder - cmdsInfoCache *cmdsInfoCache - commandPolicies map[string]routing.CommandPolicy + opt *ClusterOptions + nodes *clusterNodes + state *clusterStateHolder + cmdsInfoCache *cmdsInfoCache cmdable hooksMixin } @@ -1031,7 +1029,6 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) - c.commandPolicies = newCommandPolicies(c.cmdsInfoCache.cmds) c.cmdable = c.Process c.initHooks(hooks{ dial: nil, From 294be252b0d28e160a62669e9f0a6421c227fba9 Mon Sep 17 00:00:00 2001 From: ofekshenawa <104765379+ofekshenawa@users.noreply.github.com> Date: Mon, 30 Jun 2025 11:34:11 +0300 Subject: [PATCH 09/62] centralize cluster command routing in osscluster_router.go and refactor osscluster.go (#6) * centralize cluster command routing in osscluster_router.go and refactor osscluster.go * enalbe ci on all branches * Add debug prints * Add debug prints * FIX: deal with nil policy * FIX: fixing clusterClient process * chore(osscluster): simplify switch case * wip(command): ai generated clone method for commands * feat: implement response aggregator for Redis cluster commands * feat: implement response aggregator for Redis cluster commands * fix: solve concurrency errors * fix: solve concurrency errors * return MaxRedirects settings * remove locks from getCommandPolicy * Handle MOVED errors more robustly, remove cluster reloading at exectutions, ennsure better routing * Fix: supports Process hook test * Fix: remove response aggregation for single shard commands * Add more preformant type conversion for Cmd type * Add router logic into processPipeline --------- Co-authored-by: Nedyalko Dyakov --- command.go | 1369 ++++++++++++++++++++++++--- go.mod | 1 + go.sum | 2 + internal/routing/aggregator.go | 933 ++++++++++++++++++ internal/routing/aggregator_test.go | 427 +++++++++ internal/routing/policy.go | 16 +- internal/routing/shard_picker.go | 41 + json.go | 52 +- main_test.go | 13 + osscluster.go | 262 +++-- osscluster_router.go | 847 +++++++++++++++++ osscluster_router_test.go | 379 ++++++++ osscluster_test.go | 83 +- probabilistic.go | 72 +- search_commands.go | 235 ++++- timeseries_commands.go | 29 +- 16 files changed, 4519 insertions(+), 242 deletions(-) create mode 100644 internal/routing/aggregator.go create mode 100644 internal/routing/aggregator_test.go create mode 100644 internal/routing/shard_picker.go create mode 100644 osscluster_router.go create mode 100644 osscluster_router_test.go diff --git a/command.go b/command.go index f2a070c90f..c7acf2237c 100644 --- a/command.go +++ b/command.go @@ -18,6 +18,7 @@ import ( "github.com/redis/go-redis/v9/internal/util" ) +<<<<<<< HEAD // keylessCommands contains Redis commands that have empty key specifications (9th slot empty) // Only includes core Redis commands, excludes FT.*, ts.*, timeseries.*, search.* and subcommands var keylessCommands = map[string]struct{}{ @@ -66,6 +67,81 @@ var keylessCommands = map[string]struct{}{ "unsubscribe": {}, "unwatch": {}, } +======= +type CmdType = routing.CmdType + +const ( + CmdTypeGeneric = routing.CmdTypeGeneric + CmdTypeString = routing.CmdTypeString + CmdTypeInt = routing.CmdTypeInt + CmdTypeBool = routing.CmdTypeBool + CmdTypeFloat = routing.CmdTypeFloat + CmdTypeStringSlice = routing.CmdTypeStringSlice + CmdTypeIntSlice = routing.CmdTypeIntSlice + CmdTypeFloatSlice = routing.CmdTypeFloatSlice + CmdTypeBoolSlice = routing.CmdTypeBoolSlice + CmdTypeMapStringString = routing.CmdTypeMapStringString + CmdTypeMapStringInt = routing.CmdTypeMapStringInt + CmdTypeMapStringInterface = routing.CmdTypeMapStringInterface + CmdTypeMapStringInterfaceSlice = routing.CmdTypeMapStringInterfaceSlice + CmdTypeSlice = routing.CmdTypeSlice + CmdTypeStatus = routing.CmdTypeStatus + CmdTypeDuration = routing.CmdTypeDuration + CmdTypeTime = routing.CmdTypeTime + CmdTypeKeyValueSlice = routing.CmdTypeKeyValueSlice + CmdTypeStringStructMap = routing.CmdTypeStringStructMap + CmdTypeXMessageSlice = routing.CmdTypeXMessageSlice + CmdTypeXStreamSlice = routing.CmdTypeXStreamSlice + CmdTypeXPending = routing.CmdTypeXPending + CmdTypeXPendingExt = routing.CmdTypeXPendingExt + CmdTypeXAutoClaim = routing.CmdTypeXAutoClaim + CmdTypeXAutoClaimJustID = routing.CmdTypeXAutoClaimJustID + CmdTypeXInfoConsumers = routing.CmdTypeXInfoConsumers + CmdTypeXInfoGroups = routing.CmdTypeXInfoGroups + CmdTypeXInfoStream = routing.CmdTypeXInfoStream + CmdTypeXInfoStreamFull = routing.CmdTypeXInfoStreamFull + CmdTypeZSlice = routing.CmdTypeZSlice + CmdTypeZWithKey = routing.CmdTypeZWithKey + CmdTypeScan = routing.CmdTypeScan + CmdTypeClusterSlots = routing.CmdTypeClusterSlots + CmdTypeGeoLocation = routing.CmdTypeGeoLocation + CmdTypeGeoSearchLocation = routing.CmdTypeGeoSearchLocation + CmdTypeGeoPos = routing.CmdTypeGeoPos + CmdTypeCommandsInfo = routing.CmdTypeCommandsInfo + CmdTypeSlowLog = routing.CmdTypeSlowLog + CmdTypeMapStringStringSlice = routing.CmdTypeMapStringStringSlice + CmdTypeMapMapStringInterface = routing.CmdTypeMapMapStringInterface + CmdTypeKeyValues = routing.CmdTypeKeyValues + CmdTypeZSliceWithKey = routing.CmdTypeZSliceWithKey + CmdTypeFunctionList = routing.CmdTypeFunctionList + CmdTypeFunctionStats = routing.CmdTypeFunctionStats + CmdTypeLCS = routing.CmdTypeLCS + CmdTypeKeyFlags = routing.CmdTypeKeyFlags + CmdTypeClusterLinks = routing.CmdTypeClusterLinks + CmdTypeClusterShards = routing.CmdTypeClusterShards + CmdTypeRankWithScore = routing.CmdTypeRankWithScore + CmdTypeClientInfo = routing.CmdTypeClientInfo + CmdTypeACLLog = routing.CmdTypeACLLog + CmdTypeInfo = routing.CmdTypeInfo + CmdTypeMonitor = routing.CmdTypeMonitor + CmdTypeJSON = routing.CmdTypeJSON + CmdTypeJSONSlice = routing.CmdTypeJSONSlice + CmdTypeIntPointerSlice = routing.CmdTypeIntPointerSlice + CmdTypeScanDump = routing.CmdTypeScanDump + CmdTypeBFInfo = routing.CmdTypeBFInfo + CmdTypeCFInfo = routing.CmdTypeCFInfo + CmdTypeCMSInfo = routing.CmdTypeCMSInfo + CmdTypeTopKInfo = routing.CmdTypeTopKInfo + CmdTypeTDigestInfo = routing.CmdTypeTDigestInfo + CmdTypeFTSynDump = routing.CmdTypeFTSynDump + CmdTypeAggregate = routing.CmdTypeAggregate + CmdTypeFTInfo = routing.CmdTypeFTInfo + CmdTypeFTSpellCheck = routing.CmdTypeFTSpellCheck + CmdTypeFTSearch = routing.CmdTypeFTSearch + CmdTypeTSTimestampValue = routing.CmdTypeTSTimestampValue + CmdTypeTSTimestampValueSlice = routing.CmdTypeTSTimestampValueSlice +) +>>>>>>> b6633bf9 (centralize cluster command routing in osscluster_router.go and refactor osscluster.go (#6)) type Cmder interface { // command name. @@ -84,6 +160,9 @@ type Cmder interface { // e.g. "set k v ex 10" -> "set k v ex 10: OK", "get k" -> "get k: v". String() string + // Clone creates a copy of the command. + Clone() Cmder + stringArg(int) string firstKeyPos() int8 SetFirstKeyPos(int8) @@ -93,6 +172,9 @@ type Cmder interface { readRawReply(rd *proto.Reader) error SetErr(error) Err() error + + // GetCmdType returns the command type for fast value extraction + GetCmdType() CmdType } func setCmdsErr(cmds []Cmder, e error) { @@ -188,6 +270,7 @@ type baseCmd struct { keyPos int8 rawVal interface{} _readTimeout *time.Duration + cmdType CmdType } var _ Cmder = (*Cmd)(nil) @@ -264,6 +347,32 @@ func (cmd *baseCmd) readRawReply(rd *proto.Reader) (err error) { return err } +func (cmd *baseCmd) GetCmdType() CmdType { + return cmd.cmdType +} + +func (cmd *baseCmd) cloneBaseCmd() baseCmd { + var readTimeout *time.Duration + if cmd._readTimeout != nil { + timeout := *cmd._readTimeout + readTimeout = &timeout + } + + // Create a copy of args slice + args := make([]interface{}, len(cmd.args)) + copy(args, cmd.args) + + return baseCmd{ + ctx: cmd.ctx, + args: args, + err: cmd.err, + keyPos: cmd.keyPos, + rawVal: cmd.rawVal, + _readTimeout: readTimeout, + cmdType: cmd.cmdType, + } +} + //------------------------------------------------------------------------------ type Cmd struct { @@ -275,8 +384,9 @@ type Cmd struct { func NewCmd(ctx context.Context, args ...interface{}) *Cmd { return &Cmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeGeneric, }, } } @@ -549,6 +659,13 @@ func (cmd *Cmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *Cmd) Clone() Cmder { + return &Cmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type SliceCmd struct { @@ -562,8 +679,9 @@ var _ Cmder = (*SliceCmd)(nil) func NewSliceCmd(ctx context.Context, args ...interface{}) *SliceCmd { return &SliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeSlice, }, } } @@ -609,6 +727,18 @@ func (cmd *SliceCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *SliceCmd) Clone() Cmder { + var val []interface{} + if cmd.val != nil { + val = make([]interface{}, len(cmd.val)) + copy(val, cmd.val) + } + return &SliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type StatusCmd struct { @@ -622,8 +752,9 @@ var _ Cmder = (*StatusCmd)(nil) func NewStatusCmd(ctx context.Context, args ...interface{}) *StatusCmd { return &StatusCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeStatus, }, } } @@ -653,6 +784,13 @@ func (cmd *StatusCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *StatusCmd) Clone() Cmder { + return &StatusCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type IntCmd struct { @@ -666,8 +804,9 @@ var _ Cmder = (*IntCmd)(nil) func NewIntCmd(ctx context.Context, args ...interface{}) *IntCmd { return &IntCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeInt, }, } } @@ -697,6 +836,13 @@ func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *IntCmd) Clone() Cmder { + return &IntCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type IntSliceCmd struct { @@ -710,8 +856,9 @@ var _ Cmder = (*IntSliceCmd)(nil) func NewIntSliceCmd(ctx context.Context, args ...interface{}) *IntSliceCmd { return &IntSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeIntSlice, }, } } @@ -746,6 +893,18 @@ func (cmd *IntSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *IntSliceCmd) Clone() Cmder { + var val []int64 + if cmd.val != nil { + val = make([]int64, len(cmd.val)) + copy(val, cmd.val) + } + return &IntSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type DurationCmd struct { @@ -760,8 +919,9 @@ var _ Cmder = (*DurationCmd)(nil) func NewDurationCmd(ctx context.Context, precision time.Duration, args ...interface{}) *DurationCmd { return &DurationCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeDuration, }, precision: precision, } @@ -799,6 +959,14 @@ func (cmd *DurationCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *DurationCmd) Clone() Cmder { + return &DurationCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + precision: cmd.precision, + } +} + //------------------------------------------------------------------------------ type TimeCmd struct { @@ -812,8 +980,9 @@ var _ Cmder = (*TimeCmd)(nil) func NewTimeCmd(ctx context.Context, args ...interface{}) *TimeCmd { return &TimeCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTime, }, } } @@ -850,6 +1019,13 @@ func (cmd *TimeCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *TimeCmd) Clone() Cmder { + return &TimeCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type BoolCmd struct { @@ -863,8 +1039,9 @@ var _ Cmder = (*BoolCmd)(nil) func NewBoolCmd(ctx context.Context, args ...interface{}) *BoolCmd { return &BoolCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeBool, }, } } @@ -897,6 +1074,13 @@ func (cmd *BoolCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *BoolCmd) Clone() Cmder { + return &BoolCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type StringCmd struct { @@ -910,8 +1094,9 @@ var _ Cmder = (*StringCmd)(nil) func NewStringCmd(ctx context.Context, args ...interface{}) *StringCmd { return &StringCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeString, }, } } @@ -936,7 +1121,7 @@ func (cmd *StringCmd) Bool() (bool, error) { if cmd.err != nil { return false, cmd.err } - return strconv.ParseBool(cmd.val) + return strconv.ParseBool(cmd.Val()) } func (cmd *StringCmd) Int() (int, error) { @@ -1001,6 +1186,13 @@ func (cmd *StringCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *StringCmd) Clone() Cmder { + return &StringCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type FloatCmd struct { @@ -1014,8 +1206,9 @@ var _ Cmder = (*FloatCmd)(nil) func NewFloatCmd(ctx context.Context, args ...interface{}) *FloatCmd { return &FloatCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFloat, }, } } @@ -1041,6 +1234,13 @@ func (cmd *FloatCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *FloatCmd) Clone() Cmder { + return &FloatCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type FloatSliceCmd struct { @@ -1054,8 +1254,9 @@ var _ Cmder = (*FloatSliceCmd)(nil) func NewFloatSliceCmd(ctx context.Context, args ...interface{}) *FloatSliceCmd { return &FloatSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFloatSlice, }, } } @@ -1096,6 +1297,18 @@ func (cmd *FloatSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *FloatSliceCmd) Clone() Cmder { + var val []float64 + if cmd.val != nil { + val = make([]float64, len(cmd.val)) + copy(val, cmd.val) + } + return &FloatSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type StringSliceCmd struct { @@ -1109,8 +1322,9 @@ var _ Cmder = (*StringSliceCmd)(nil) func NewStringSliceCmd(ctx context.Context, args ...interface{}) *StringSliceCmd { return &StringSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeStringSlice, }, } } @@ -1154,6 +1368,18 @@ func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *StringSliceCmd) Clone() Cmder { + var val []string + if cmd.val != nil { + val = make([]string, len(cmd.val)) + copy(val, cmd.val) + } + return &StringSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type KeyValue struct { @@ -1172,8 +1398,9 @@ var _ Cmder = (*KeyValueSliceCmd)(nil) func NewKeyValueSliceCmd(ctx context.Context, args ...interface{}) *KeyValueSliceCmd { return &KeyValueSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeKeyValueSlice, }, } } @@ -1248,6 +1475,18 @@ func (cmd *KeyValueSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl return nil } +func (cmd *KeyValueSliceCmd) Clone() Cmder { + var val []KeyValue + if cmd.val != nil { + val = make([]KeyValue, len(cmd.val)) + copy(val, cmd.val) + } + return &KeyValueSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type BoolSliceCmd struct { @@ -1261,8 +1500,9 @@ var _ Cmder = (*BoolSliceCmd)(nil) func NewBoolSliceCmd(ctx context.Context, args ...interface{}) *BoolSliceCmd { return &BoolSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeBoolSlice, }, } } @@ -1297,6 +1537,18 @@ func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *BoolSliceCmd) Clone() Cmder { + var val []bool + if cmd.val != nil { + val = make([]bool, len(cmd.val)) + copy(val, cmd.val) + } + return &BoolSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type MapStringStringCmd struct { @@ -1310,8 +1562,9 @@ var _ Cmder = (*MapStringStringCmd)(nil) func NewMapStringStringCmd(ctx context.Context, args ...interface{}) *MapStringStringCmd { return &MapStringStringCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringString, }, } } @@ -1376,6 +1629,20 @@ func (cmd *MapStringStringCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringStringCmd) Clone() Cmder { + var val map[string]string + if cmd.val != nil { + val = make(map[string]string, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapStringStringCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type MapStringIntCmd struct { @@ -1389,8 +1656,9 @@ var _ Cmder = (*MapStringIntCmd)(nil) func NewMapStringIntCmd(ctx context.Context, args ...interface{}) *MapStringIntCmd { return &MapStringIntCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInt, }, } } @@ -1433,6 +1701,20 @@ func (cmd *MapStringIntCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringIntCmd) Clone() Cmder { + var val map[string]int64 + if cmd.val != nil { + val = make(map[string]int64, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapStringIntCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ------------------------------------------------------------------------------ type MapStringSliceInterfaceCmd struct { baseCmd @@ -1442,8 +1724,9 @@ type MapStringSliceInterfaceCmd struct { func NewMapStringSliceInterfaceCmd(ctx context.Context, args ...interface{}) *MapStringSliceInterfaceCmd { return &MapStringSliceInterfaceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInterfaceSlice, }, } } @@ -1529,6 +1812,24 @@ func (cmd *MapStringSliceInterfaceCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *MapStringSliceInterfaceCmd) Clone() Cmder { + var val map[string][]interface{} + if cmd.val != nil { + val = make(map[string][]interface{}, len(cmd.val)) + for k, v := range cmd.val { + if v != nil { + newSlice := make([]interface{}, len(v)) + copy(newSlice, v) + val[k] = newSlice + } + } + } + return &MapStringSliceInterfaceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type StringStructMapCmd struct { @@ -1542,8 +1843,9 @@ var _ Cmder = (*StringStructMapCmd)(nil) func NewStringStructMapCmd(ctx context.Context, args ...interface{}) *StringStructMapCmd { return &StringStructMapCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeStringStructMap, }, } } @@ -1581,6 +1883,20 @@ func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *StringStructMapCmd) Clone() Cmder { + var val map[string]struct{} + if cmd.val != nil { + val = make(map[string]struct{}, len(cmd.val)) + for k := range cmd.val { + val[k] = struct{}{} + } + } + return &StringStructMapCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XMessage struct { @@ -1599,8 +1915,9 @@ var _ Cmder = (*XMessageSliceCmd)(nil) func NewXMessageSliceCmd(ctx context.Context, args ...interface{}) *XMessageSliceCmd { return &XMessageSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXMessageSlice, }, } } @@ -1626,6 +1943,28 @@ func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *XMessageSliceCmd) Clone() Cmder { + var val []XMessage + if cmd.val != nil { + val = make([]XMessage, len(cmd.val)) + for i, msg := range cmd.val { + val[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val[i].Values[k] = v + } + } + } + } + return &XMessageSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) { n, err := rd.ReadArrayLen() if err != nil { @@ -1705,8 +2044,9 @@ var _ Cmder = (*XStreamSliceCmd)(nil) func NewXStreamSliceCmd(ctx context.Context, args ...interface{}) *XStreamSliceCmd { return &XStreamSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXStreamSlice, }, } } @@ -1759,6 +2099,36 @@ func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XStreamSliceCmd) Clone() Cmder { + var val []XStream + if cmd.val != nil { + val = make([]XStream, len(cmd.val)) + for i, stream := range cmd.val { + val[i] = XStream{ + Stream: stream.Stream, + } + if stream.Messages != nil { + val[i].Messages = make([]XMessage, len(stream.Messages)) + for j, msg := range stream.Messages { + val[i].Messages[j] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Messages[j].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val[i].Messages[j].Values[k] = v + } + } + } + } + } + } + return &XStreamSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XPending struct { @@ -1778,8 +2148,9 @@ var _ Cmder = (*XPendingCmd)(nil) func NewXPendingCmd(ctx context.Context, args ...interface{}) *XPendingCmd { return &XPendingCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXPending, }, } } @@ -1842,6 +2213,27 @@ func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XPendingCmd) Clone() Cmder { + var val *XPending + if cmd.val != nil { + val = &XPending{ + Count: cmd.val.Count, + Lower: cmd.val.Lower, + Higher: cmd.val.Higher, + } + if cmd.val.Consumers != nil { + val.Consumers = make(map[string]int64, len(cmd.val.Consumers)) + for k, v := range cmd.val.Consumers { + val.Consumers[k] = v + } + } + } + return &XPendingCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XPendingExt struct { @@ -1861,8 +2253,9 @@ var _ Cmder = (*XPendingExtCmd)(nil) func NewXPendingExtCmd(ctx context.Context, args ...interface{}) *XPendingExtCmd { return &XPendingExtCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXPendingExt, }, } } @@ -1917,6 +2310,18 @@ func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XPendingExtCmd) Clone() Cmder { + var val []XPendingExt + if cmd.val != nil { + val = make([]XPendingExt, len(cmd.val)) + copy(val, cmd.val) + } + return &XPendingExtCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XAutoClaimCmd struct { @@ -1931,8 +2336,9 @@ var _ Cmder = (*XAutoClaimCmd)(nil) func NewXAutoClaimCmd(ctx context.Context, args ...interface{}) *XAutoClaimCmd { return &XAutoClaimCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXAutoClaim, }, } } @@ -1987,6 +2393,29 @@ func (cmd *XAutoClaimCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XAutoClaimCmd) Clone() Cmder { + var val []XMessage + if cmd.val != nil { + val = make([]XMessage, len(cmd.val)) + for i, msg := range cmd.val { + val[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val[i].Values[k] = v + } + } + } + } + return &XAutoClaimCmd{ + baseCmd: cmd.cloneBaseCmd(), + start: cmd.start, + val: val, + } +} + //------------------------------------------------------------------------------ type XAutoClaimJustIDCmd struct { @@ -2001,8 +2430,9 @@ var _ Cmder = (*XAutoClaimJustIDCmd)(nil) func NewXAutoClaimJustIDCmd(ctx context.Context, args ...interface{}) *XAutoClaimJustIDCmd { return &XAutoClaimJustIDCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXAutoClaimJustID, }, } } @@ -2065,6 +2495,19 @@ func (cmd *XAutoClaimJustIDCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XAutoClaimJustIDCmd) Clone() Cmder { + var val []string + if cmd.val != nil { + val = make([]string, len(cmd.val)) + copy(val, cmd.val) + } + return &XAutoClaimJustIDCmd{ + baseCmd: cmd.cloneBaseCmd(), + start: cmd.start, + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoConsumersCmd struct { @@ -2084,8 +2527,9 @@ var _ Cmder = (*XInfoConsumersCmd)(nil) func NewXInfoConsumersCmd(ctx context.Context, stream string, group string) *XInfoConsumersCmd { return &XInfoConsumersCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"xinfo", "consumers", stream, group}, + ctx: ctx, + args: []interface{}{"xinfo", "consumers", stream, group}, + cmdType: CmdTypeXInfoConsumers, }, } } @@ -2151,6 +2595,18 @@ func (cmd *XInfoConsumersCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XInfoConsumersCmd) Clone() Cmder { + var val []XInfoConsumer + if cmd.val != nil { + val = make([]XInfoConsumer, len(cmd.val)) + copy(val, cmd.val) + } + return &XInfoConsumersCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoGroupsCmd struct { @@ -2174,8 +2630,9 @@ var _ Cmder = (*XInfoGroupsCmd)(nil) func NewXInfoGroupsCmd(ctx context.Context, stream string) *XInfoGroupsCmd { return &XInfoGroupsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"xinfo", "groups", stream}, + ctx: ctx, + args: []interface{}{"xinfo", "groups", stream}, + cmdType: CmdTypeXInfoGroups, }, } } @@ -2264,6 +2721,18 @@ func (cmd *XInfoGroupsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XInfoGroupsCmd) Clone() Cmder { + var val []XInfoGroup + if cmd.val != nil { + val = make([]XInfoGroup, len(cmd.val)) + copy(val, cmd.val) + } + return &XInfoGroupsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoStreamCmd struct { @@ -2289,8 +2758,9 @@ var _ Cmder = (*XInfoStreamCmd)(nil) func NewXInfoStreamCmd(ctx context.Context, stream string) *XInfoStreamCmd { return &XInfoStreamCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"xinfo", "stream", stream}, + ctx: ctx, + args: []interface{}{"xinfo", "stream", stream}, + cmdType: CmdTypeXInfoStream, }, } } @@ -2381,6 +2851,45 @@ func (cmd *XInfoStreamCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XInfoStreamCmd) Clone() Cmder { + var val *XInfoStream + if cmd.val != nil { + val = &XInfoStream{ + Length: cmd.val.Length, + RadixTreeKeys: cmd.val.RadixTreeKeys, + RadixTreeNodes: cmd.val.RadixTreeNodes, + Groups: cmd.val.Groups, + LastGeneratedID: cmd.val.LastGeneratedID, + MaxDeletedEntryID: cmd.val.MaxDeletedEntryID, + EntriesAdded: cmd.val.EntriesAdded, + RecordedFirstEntryID: cmd.val.RecordedFirstEntryID, + } + // Clone XMessage fields + val.FirstEntry = XMessage{ + ID: cmd.val.FirstEntry.ID, + } + if cmd.val.FirstEntry.Values != nil { + val.FirstEntry.Values = make(map[string]interface{}, len(cmd.val.FirstEntry.Values)) + for k, v := range cmd.val.FirstEntry.Values { + val.FirstEntry.Values[k] = v + } + } + val.LastEntry = XMessage{ + ID: cmd.val.LastEntry.ID, + } + if cmd.val.LastEntry.Values != nil { + val.LastEntry.Values = make(map[string]interface{}, len(cmd.val.LastEntry.Values)) + for k, v := range cmd.val.LastEntry.Values { + val.LastEntry.Values[k] = v + } + } + } + return &XInfoStreamCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoStreamFullCmd struct { @@ -2436,8 +2945,9 @@ var _ Cmder = (*XInfoStreamFullCmd)(nil) func NewXInfoStreamFullCmd(ctx context.Context, args ...interface{}) *XInfoStreamFullCmd { return &XInfoStreamFullCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXInfoStreamFull, }, } } @@ -2722,6 +3232,45 @@ func readXInfoStreamConsumers(rd *proto.Reader) ([]XInfoStreamConsumer, error) { return consumers, nil } +func (cmd *XInfoStreamFullCmd) Clone() Cmder { + var val *XInfoStreamFull + if cmd.val != nil { + val = &XInfoStreamFull{ + Length: cmd.val.Length, + RadixTreeKeys: cmd.val.RadixTreeKeys, + RadixTreeNodes: cmd.val.RadixTreeNodes, + LastGeneratedID: cmd.val.LastGeneratedID, + MaxDeletedEntryID: cmd.val.MaxDeletedEntryID, + EntriesAdded: cmd.val.EntriesAdded, + RecordedFirstEntryID: cmd.val.RecordedFirstEntryID, + } + // Clone Entries + if cmd.val.Entries != nil { + val.Entries = make([]XMessage, len(cmd.val.Entries)) + for i, msg := range cmd.val.Entries { + val.Entries[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val.Entries[i].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val.Entries[i].Values[k] = v + } + } + } + } + // Clone Groups - simplified copy for now due to complexity + if cmd.val.Groups != nil { + val.Groups = make([]XInfoStreamGroup, len(cmd.val.Groups)) + copy(val.Groups, cmd.val.Groups) + } + } + return &XInfoStreamFullCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type ZSliceCmd struct { @@ -2735,8 +3284,9 @@ var _ Cmder = (*ZSliceCmd)(nil) func NewZSliceCmd(ctx context.Context, args ...interface{}) *ZSliceCmd { return &ZSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeZSlice, }, } } @@ -2800,6 +3350,18 @@ func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl return nil } +func (cmd *ZSliceCmd) Clone() Cmder { + var val []Z + if cmd.val != nil { + val = make([]Z, len(cmd.val)) + copy(val, cmd.val) + } + return &ZSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type ZWithKeyCmd struct { @@ -2813,8 +3375,9 @@ var _ Cmder = (*ZWithKeyCmd)(nil) func NewZWithKeyCmd(ctx context.Context, args ...interface{}) *ZWithKeyCmd { return &ZWithKeyCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeZWithKey, }, } } @@ -2854,6 +3417,23 @@ func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *ZWithKeyCmd) Clone() Cmder { + var val *ZWithKey + if cmd.val != nil { + val = &ZWithKey{ + Z: Z{ + Score: cmd.val.Score, + Member: cmd.val.Member, + }, + Key: cmd.val.Key, + } + } + return &ZWithKeyCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type ScanCmd struct { @@ -2870,8 +3450,9 @@ var _ Cmder = (*ScanCmd)(nil) func NewScanCmd(ctx context.Context, process cmdable, args ...interface{}) *ScanCmd { return &ScanCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeScan, }, process: process, } @@ -2919,6 +3500,20 @@ func (cmd *ScanCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ScanCmd) Clone() Cmder { + var page []string + if cmd.page != nil { + page = make([]string, len(cmd.page)) + copy(page, cmd.page) + } + return &ScanCmd{ + baseCmd: cmd.cloneBaseCmd(), + page: page, + cursor: cmd.cursor, + process: cmd.process, + } +} + // Iterator creates a new ScanIterator. func (cmd *ScanCmd) Iterator() *ScanIterator { return &ScanIterator{ @@ -2951,8 +3546,9 @@ var _ Cmder = (*ClusterSlotsCmd)(nil) func NewClusterSlotsCmd(ctx context.Context, args ...interface{}) *ClusterSlotsCmd { return &ClusterSlotsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClusterSlots, }, } } @@ -3065,6 +3661,38 @@ func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ClusterSlotsCmd) Clone() Cmder { + var val []ClusterSlot + if cmd.val != nil { + val = make([]ClusterSlot, len(cmd.val)) + for i, slot := range cmd.val { + val[i] = ClusterSlot{ + Start: slot.Start, + End: slot.End, + } + if slot.Nodes != nil { + val[i].Nodes = make([]ClusterNode, len(slot.Nodes)) + for j, node := range slot.Nodes { + val[i].Nodes[j] = ClusterNode{ + ID: node.ID, + Addr: node.Addr, + } + if node.NetworkingMetadata != nil { + val[i].Nodes[j].NetworkingMetadata = make(map[string]string, len(node.NetworkingMetadata)) + for k, v := range node.NetworkingMetadata { + val[i].Nodes[j].NetworkingMetadata[k] = v + } + } + } + } + } + } + return &ClusterSlotsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ // GeoLocation is used with GeoAdd to add geospatial location. @@ -3104,8 +3732,9 @@ var _ Cmder = (*GeoLocationCmd)(nil) func NewGeoLocationCmd(ctx context.Context, q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd { return &GeoLocationCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: geoLocationArgs(q, args...), + ctx: ctx, + args: geoLocationArgs(q, args...), + cmdType: CmdTypeGeoLocation, }, q: q, } @@ -3213,6 +3842,34 @@ func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *GeoLocationCmd) Clone() Cmder { + var q *GeoRadiusQuery + if cmd.q != nil { + q = &GeoRadiusQuery{ + Radius: cmd.q.Radius, + Unit: cmd.q.Unit, + WithCoord: cmd.q.WithCoord, + WithDist: cmd.q.WithDist, + WithGeoHash: cmd.q.WithGeoHash, + Count: cmd.q.Count, + Sort: cmd.q.Sort, + Store: cmd.q.Store, + StoreDist: cmd.q.StoreDist, + withLen: cmd.q.withLen, + } + } + var locations []GeoLocation + if cmd.locations != nil { + locations = make([]GeoLocation, len(cmd.locations)) + copy(locations, cmd.locations) + } + return &GeoLocationCmd{ + baseCmd: cmd.cloneBaseCmd(), + q: q, + locations: locations, + } +} + //------------------------------------------------------------------------------ // GeoSearchQuery is used for GEOSearch/GEOSearchStore command query. @@ -3320,8 +3977,9 @@ func NewGeoSearchLocationCmd( ) *GeoSearchLocationCmd { return &GeoSearchLocationCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: geoSearchLocationArgs(opt, args), + cmdType: CmdTypeGeoSearchLocation, }, opt: opt, } @@ -3394,6 +4052,40 @@ func (cmd *GeoSearchLocationCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *GeoSearchLocationCmd) Clone() Cmder { + var opt *GeoSearchLocationQuery + if cmd.opt != nil { + opt = &GeoSearchLocationQuery{ + GeoSearchQuery: GeoSearchQuery{ + Member: cmd.opt.Member, + Longitude: cmd.opt.Longitude, + Latitude: cmd.opt.Latitude, + Radius: cmd.opt.Radius, + RadiusUnit: cmd.opt.RadiusUnit, + BoxWidth: cmd.opt.BoxWidth, + BoxHeight: cmd.opt.BoxHeight, + BoxUnit: cmd.opt.BoxUnit, + Sort: cmd.opt.Sort, + Count: cmd.opt.Count, + CountAny: cmd.opt.CountAny, + }, + WithCoord: cmd.opt.WithCoord, + WithDist: cmd.opt.WithDist, + WithHash: cmd.opt.WithHash, + } + } + var val []GeoLocation + if cmd.val != nil { + val = make([]GeoLocation, len(cmd.val)) + copy(val, cmd.val) + } + return &GeoSearchLocationCmd{ + baseCmd: cmd.cloneBaseCmd(), + opt: opt, + val: val, + } +} + //------------------------------------------------------------------------------ type GeoPos struct { @@ -3411,8 +4103,9 @@ var _ Cmder = (*GeoPosCmd)(nil) func NewGeoPosCmd(ctx context.Context, args ...interface{}) *GeoPosCmd { return &GeoPosCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeGeoPos, }, } } @@ -3468,6 +4161,25 @@ func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *GeoPosCmd) Clone() Cmder { + var val []*GeoPos + if cmd.val != nil { + val = make([]*GeoPos, len(cmd.val)) + for i, pos := range cmd.val { + if pos != nil { + val[i] = &GeoPos{ + Longitude: pos.Longitude, + Latitude: pos.Latitude, + } + } + } + } + return &GeoPosCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type CommandInfo struct { @@ -3493,8 +4205,9 @@ var _ Cmder = (*CommandsInfoCmd)(nil) func NewCommandsInfoCmd(ctx context.Context, args ...interface{}) *CommandsInfoCmd { return &CommandsInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeCommandsInfo, }, } } @@ -3648,6 +4361,39 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *CommandsInfoCmd) Clone() Cmder { + var val map[string]*CommandInfo + if cmd.val != nil { + val = make(map[string]*CommandInfo, len(cmd.val)) + for k, v := range cmd.val { + if v != nil { + newInfo := &CommandInfo{ + Name: v.Name, + Arity: v.Arity, + FirstKeyPos: v.FirstKeyPos, + LastKeyPos: v.LastKeyPos, + StepCount: v.StepCount, + ReadOnly: v.ReadOnly, + Tips: v.Tips, // CommandPolicy can be shared as it's immutable + } + if v.Flags != nil { + newInfo.Flags = make([]string, len(v.Flags)) + copy(newInfo.Flags, v.Flags) + } + if v.ACLFlags != nil { + newInfo.ACLFlags = make([]string, len(v.ACLFlags)) + copy(newInfo.ACLFlags, v.ACLFlags) + } + val[k] = newInfo + } + } + } + return &CommandsInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type cmdsInfoCache struct { @@ -3737,8 +4483,9 @@ var _ Cmder = (*SlowLogCmd)(nil) func NewSlowLogCmd(ctx context.Context, args ...interface{}) *SlowLogCmd { return &SlowLogCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeSlowLog, }, } } @@ -3823,6 +4570,30 @@ func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *SlowLogCmd) Clone() Cmder { + var val []SlowLog + if cmd.val != nil { + val = make([]SlowLog, len(cmd.val)) + for i, log := range cmd.val { + val[i] = SlowLog{ + ID: log.ID, + Time: log.Time, + Duration: log.Duration, + ClientAddr: log.ClientAddr, + ClientName: log.ClientName, + } + if log.Args != nil { + val[i].Args = make([]string, len(log.Args)) + copy(val[i].Args, log.Args) + } + } + } + return &SlowLogCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //----------------------------------------------------------------------- type MapStringInterfaceCmd struct { @@ -3836,8 +4607,9 @@ var _ Cmder = (*MapStringInterfaceCmd)(nil) func NewMapStringInterfaceCmd(ctx context.Context, args ...interface{}) *MapStringInterfaceCmd { return &MapStringInterfaceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInterface, }, } } @@ -3887,6 +4659,20 @@ func (cmd *MapStringInterfaceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringInterfaceCmd) Clone() Cmder { + var val map[string]interface{} + if cmd.val != nil { + val = make(map[string]interface{}, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapStringInterfaceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //----------------------------------------------------------------------- type MapStringStringSliceCmd struct { @@ -3900,8 +4686,9 @@ var _ Cmder = (*MapStringStringSliceCmd)(nil) func NewMapStringStringSliceCmd(ctx context.Context, args ...interface{}) *MapStringStringSliceCmd { return &MapStringStringSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringStringSlice, }, } } @@ -3951,6 +4738,25 @@ func (cmd *MapStringStringSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringStringSliceCmd) Clone() Cmder { + var val []map[string]string + if cmd.val != nil { + val = make([]map[string]string, len(cmd.val)) + for i, m := range cmd.val { + if m != nil { + val[i] = make(map[string]string, len(m)) + for k, v := range m { + val[i][k] = v + } + } + } + } + return &MapStringStringSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ----------------------------------------------------------------------- // MapMapStringInterfaceCmd represents a command that returns a map of strings to interface{}. @@ -3962,8 +4768,9 @@ type MapMapStringInterfaceCmd struct { func NewMapMapStringInterfaceCmd(ctx context.Context, args ...interface{}) *MapMapStringInterfaceCmd { return &MapMapStringInterfaceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapMapStringInterface, }, } } @@ -4029,6 +4836,20 @@ func (cmd *MapMapStringInterfaceCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *MapMapStringInterfaceCmd) Clone() Cmder { + var val map[string]interface{} + if cmd.val != nil { + val = make(map[string]interface{}, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapMapStringInterfaceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //----------------------------------------------------------------------- type MapStringInterfaceSliceCmd struct { @@ -4042,8 +4863,9 @@ var _ Cmder = (*MapStringInterfaceSliceCmd)(nil) func NewMapStringInterfaceSliceCmd(ctx context.Context, args ...interface{}) *MapStringInterfaceSliceCmd { return &MapStringInterfaceSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInterfaceSlice, }, } } @@ -4094,6 +4916,25 @@ func (cmd *MapStringInterfaceSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringInterfaceSliceCmd) Clone() Cmder { + var val []map[string]interface{} + if cmd.val != nil { + val = make([]map[string]interface{}, len(cmd.val)) + for i, m := range cmd.val { + if m != nil { + val[i] = make(map[string]interface{}, len(m)) + for k, v := range m { + val[i][k] = v + } + } + } + } + return &MapStringInterfaceSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type KeyValuesCmd struct { @@ -4108,8 +4949,9 @@ var _ Cmder = (*KeyValuesCmd)(nil) func NewKeyValuesCmd(ctx context.Context, args ...interface{}) *KeyValuesCmd { return &KeyValuesCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeKeyValues, }, } } @@ -4156,6 +4998,19 @@ func (cmd *KeyValuesCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *KeyValuesCmd) Clone() Cmder { + var val []string + if cmd.val != nil { + val = make([]string, len(cmd.val)) + copy(val, cmd.val) + } + return &KeyValuesCmd{ + baseCmd: cmd.cloneBaseCmd(), + key: cmd.key, + val: val, + } +} + //------------------------------------------------------------------------------ type ZSliceWithKeyCmd struct { @@ -4170,8 +5025,9 @@ var _ Cmder = (*ZSliceWithKeyCmd)(nil) func NewZSliceWithKeyCmd(ctx context.Context, args ...interface{}) *ZSliceWithKeyCmd { return &ZSliceWithKeyCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeZSliceWithKey, }, } } @@ -4239,6 +5095,19 @@ func (cmd *ZSliceWithKeyCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *ZSliceWithKeyCmd) Clone() Cmder { + var val []Z + if cmd.val != nil { + val = make([]Z, len(cmd.val)) + copy(val, cmd.val) + } + return &ZSliceWithKeyCmd{ + baseCmd: cmd.cloneBaseCmd(), + key: cmd.key, + val: val, + } +} + type Function struct { Name string Description string @@ -4263,8 +5132,9 @@ var _ Cmder = (*FunctionListCmd)(nil) func NewFunctionListCmd(ctx context.Context, args ...interface{}) *FunctionListCmd { return &FunctionListCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFunctionList, }, } } @@ -4391,6 +5261,37 @@ func (cmd *FunctionListCmd) readFunctions(rd *proto.Reader) ([]Function, error) return functions, nil } +func (cmd *FunctionListCmd) Clone() Cmder { + var val []Library + if cmd.val != nil { + val = make([]Library, len(cmd.val)) + for i, lib := range cmd.val { + val[i] = Library{ + Name: lib.Name, + Engine: lib.Engine, + Code: lib.Code, + } + if lib.Functions != nil { + val[i].Functions = make([]Function, len(lib.Functions)) + for j, fn := range lib.Functions { + val[i].Functions[j] = Function{ + Name: fn.Name, + Description: fn.Description, + } + if fn.Flags != nil { + val[i].Functions[j].Flags = make([]string, len(fn.Flags)) + copy(val[i].Functions[j].Flags, fn.Flags) + } + } + } + } + } + return &FunctionListCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // FunctionStats contains information about the scripts currently executing on the server, and the available engines // - Engines: // Statistics about the engine like number of functions and number of libraries @@ -4444,8 +5345,9 @@ var _ Cmder = (*FunctionStatsCmd)(nil) func NewFunctionStatsCmd(ctx context.Context, args ...interface{}) *FunctionStatsCmd { return &FunctionStatsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFunctionStats, }, } } @@ -4616,6 +5518,34 @@ func (cmd *FunctionStatsCmd) readRunningScripts(rd *proto.Reader) ([]RunningScri return runningScripts, len(runningScripts) > 0, nil } +func (cmd *FunctionStatsCmd) Clone() Cmder { + val := FunctionStats{ + isRunning: cmd.val.isRunning, + rs: cmd.val.rs, // RunningScript is a simple struct, can be copied directly + } + if cmd.val.Engines != nil { + val.Engines = make([]Engine, len(cmd.val.Engines)) + copy(val.Engines, cmd.val.Engines) + } + if cmd.val.allrs != nil { + val.allrs = make([]RunningScript, len(cmd.val.allrs)) + for i, rs := range cmd.val.allrs { + val.allrs[i] = RunningScript{ + Name: rs.Name, + Duration: rs.Duration, + } + if rs.Command != nil { + val.allrs[i].Command = make([]string, len(rs.Command)) + copy(val.allrs[i].Command, rs.Command) + } + } + } + return &FunctionStatsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ // LCSQuery is a parameter used for the LCS command @@ -4679,8 +5609,9 @@ func NewLCSCmd(ctx context.Context, q *LCSQuery) *LCSCmd { } } cmd.baseCmd = baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeLCS, } return cmd @@ -4792,6 +5723,25 @@ func (cmd *LCSCmd) readPosition(rd *proto.Reader) (pos LCSPosition, err error) { return pos, nil } +func (cmd *LCSCmd) Clone() Cmder { + var val *LCSMatch + if cmd.val != nil { + val = &LCSMatch{ + MatchString: cmd.val.MatchString, + Len: cmd.val.Len, + } + if cmd.val.Matches != nil { + val.Matches = make([]LCSMatchedPosition, len(cmd.val.Matches)) + copy(val.Matches, cmd.val.Matches) + } + } + return &LCSCmd{ + baseCmd: cmd.cloneBaseCmd(), + readType: cmd.readType, + val: val, + } +} + // ------------------------------------------------------------------------ type KeyFlags struct { @@ -4810,8 +5760,9 @@ var _ Cmder = (*KeyFlagsCmd)(nil) func NewKeyFlagsCmd(ctx context.Context, args ...interface{}) *KeyFlagsCmd { return &KeyFlagsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeKeyFlags, }, } } @@ -4870,6 +5821,26 @@ func (cmd *KeyFlagsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *KeyFlagsCmd) Clone() Cmder { + var val []KeyFlags + if cmd.val != nil { + val = make([]KeyFlags, len(cmd.val)) + for i, kf := range cmd.val { + val[i] = KeyFlags{ + Key: kf.Key, + } + if kf.Flags != nil { + val[i].Flags = make([]string, len(kf.Flags)) + copy(val[i].Flags, kf.Flags) + } + } + } + return &KeyFlagsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // --------------------------------------------------------------------------------------------------- type ClusterLink struct { @@ -4892,8 +5863,9 @@ var _ Cmder = (*ClusterLinksCmd)(nil) func NewClusterLinksCmd(ctx context.Context, args ...interface{}) *ClusterLinksCmd { return &ClusterLinksCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClusterLinks, }, } } @@ -4959,6 +5931,18 @@ func (cmd *ClusterLinksCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ClusterLinksCmd) Clone() Cmder { + var val []ClusterLink + if cmd.val != nil { + val = make([]ClusterLink, len(cmd.val)) + copy(val, cmd.val) + } + return &ClusterLinksCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ------------------------------------------------------------------------------------------------------------------ type SlotRange struct { @@ -4994,8 +5978,9 @@ var _ Cmder = (*ClusterShardsCmd)(nil) func NewClusterShardsCmd(ctx context.Context, args ...interface{}) *ClusterShardsCmd { return &ClusterShardsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClusterShards, }, } } @@ -5109,6 +6094,28 @@ func (cmd *ClusterShardsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ClusterShardsCmd) Clone() Cmder { + var val []ClusterShard + if cmd.val != nil { + val = make([]ClusterShard, len(cmd.val)) + for i, shard := range cmd.val { + val[i] = ClusterShard{} + if shard.Slots != nil { + val[i].Slots = make([]SlotRange, len(shard.Slots)) + copy(val[i].Slots, shard.Slots) + } + if shard.Nodes != nil { + val[i].Nodes = make([]Node, len(shard.Nodes)) + copy(val[i].Nodes, shard.Nodes) + } + } + } + return &ClusterShardsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ----------------------------------------- type RankScore struct { @@ -5127,8 +6134,9 @@ var _ Cmder = (*RankWithScoreCmd)(nil) func NewRankWithScoreCmd(ctx context.Context, args ...interface{}) *RankWithScoreCmd { return &RankWithScoreCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeRankWithScore, }, } } @@ -5169,6 +6177,13 @@ func (cmd *RankWithScoreCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *RankWithScoreCmd) Clone() Cmder { + return &RankWithScoreCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // RankScore is a simple struct, can be copied directly + } +} + // -------------------------------------------------------------------------------------------------- // ClientFlags is redis-server client flags, copy from redis/src/server.h (redis 7.0) @@ -5278,8 +6293,9 @@ var _ Cmder = (*ClientInfoCmd)(nil) func NewClientInfoCmd(ctx context.Context, args ...interface{}) *ClientInfoCmd { return &ClientInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClientInfo, }, } } @@ -5456,6 +6472,50 @@ func parseClientInfo(txt string) (info *ClientInfo, err error) { return info, nil } +func (cmd *ClientInfoCmd) Clone() Cmder { + var val *ClientInfo + if cmd.val != nil { + val = &ClientInfo{ + ID: cmd.val.ID, + Addr: cmd.val.Addr, + LAddr: cmd.val.LAddr, + FD: cmd.val.FD, + Name: cmd.val.Name, + Age: cmd.val.Age, + Idle: cmd.val.Idle, + Flags: cmd.val.Flags, + DB: cmd.val.DB, + Sub: cmd.val.Sub, + PSub: cmd.val.PSub, + SSub: cmd.val.SSub, + Multi: cmd.val.Multi, + Watch: cmd.val.Watch, + QueryBuf: cmd.val.QueryBuf, + QueryBufFree: cmd.val.QueryBufFree, + ArgvMem: cmd.val.ArgvMem, + MultiMem: cmd.val.MultiMem, + BufferSize: cmd.val.BufferSize, + BufferPeak: cmd.val.BufferPeak, + OutputBufferLength: cmd.val.OutputBufferLength, + OutputListLength: cmd.val.OutputListLength, + OutputMemory: cmd.val.OutputMemory, + TotalMemory: cmd.val.TotalMemory, + IoThread: cmd.val.IoThread, + Events: cmd.val.Events, + LastCmd: cmd.val.LastCmd, + User: cmd.val.User, + Redir: cmd.val.Redir, + Resp: cmd.val.Resp, + LibName: cmd.val.LibName, + LibVer: cmd.val.LibVer, + } + } + return &ClientInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ------------------------------------------- type ACLLogEntry struct { @@ -5482,8 +6542,9 @@ var _ Cmder = (*ACLLogCmd)(nil) func NewACLLogCmd(ctx context.Context, args ...interface{}) *ACLLogCmd { return &ACLLogCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeACLLog, }, } } @@ -5565,6 +6626,69 @@ func (cmd *ACLLogCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ACLLogCmd) Clone() Cmder { + var val []*ACLLogEntry + if cmd.val != nil { + val = make([]*ACLLogEntry, len(cmd.val)) + for i, entry := range cmd.val { + if entry != nil { + val[i] = &ACLLogEntry{ + Count: entry.Count, + Reason: entry.Reason, + Context: entry.Context, + Object: entry.Object, + Username: entry.Username, + AgeSeconds: entry.AgeSeconds, + EntryID: entry.EntryID, + TimestampCreated: entry.TimestampCreated, + TimestampLastUpdated: entry.TimestampLastUpdated, + } + // Clone ClientInfo if present + if entry.ClientInfo != nil { + val[i].ClientInfo = &ClientInfo{ + ID: entry.ClientInfo.ID, + Addr: entry.ClientInfo.Addr, + LAddr: entry.ClientInfo.LAddr, + FD: entry.ClientInfo.FD, + Name: entry.ClientInfo.Name, + Age: entry.ClientInfo.Age, + Idle: entry.ClientInfo.Idle, + Flags: entry.ClientInfo.Flags, + DB: entry.ClientInfo.DB, + Sub: entry.ClientInfo.Sub, + PSub: entry.ClientInfo.PSub, + SSub: entry.ClientInfo.SSub, + Multi: entry.ClientInfo.Multi, + Watch: entry.ClientInfo.Watch, + QueryBuf: entry.ClientInfo.QueryBuf, + QueryBufFree: entry.ClientInfo.QueryBufFree, + ArgvMem: entry.ClientInfo.ArgvMem, + MultiMem: entry.ClientInfo.MultiMem, + BufferSize: entry.ClientInfo.BufferSize, + BufferPeak: entry.ClientInfo.BufferPeak, + OutputBufferLength: entry.ClientInfo.OutputBufferLength, + OutputListLength: entry.ClientInfo.OutputListLength, + OutputMemory: entry.ClientInfo.OutputMemory, + TotalMemory: entry.ClientInfo.TotalMemory, + IoThread: entry.ClientInfo.IoThread, + Events: entry.ClientInfo.Events, + LastCmd: entry.ClientInfo.LastCmd, + User: entry.ClientInfo.User, + Redir: entry.ClientInfo.Redir, + Resp: entry.ClientInfo.Resp, + LibName: entry.ClientInfo.LibName, + LibVer: entry.ClientInfo.LibVer, + } + } + } + } + } + return &ACLLogCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // LibraryInfo holds the library info. type LibraryInfo struct { LibName *string @@ -5593,8 +6717,9 @@ var _ Cmder = (*InfoCmd)(nil) func NewInfoCmd(ctx context.Context, args ...interface{}) *InfoCmd { return &InfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeInfo, }, } } @@ -5660,6 +6785,25 @@ func (cmd *InfoCmd) Item(section, key string) string { } } +func (cmd *InfoCmd) Clone() Cmder { + var val map[string]map[string]string + if cmd.val != nil { + val = make(map[string]map[string]string, len(cmd.val)) + for section, sectionMap := range cmd.val { + if sectionMap != nil { + val[section] = make(map[string]string, len(sectionMap)) + for k, v := range sectionMap { + val[section][k] = v + } + } + } + } + return &InfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + type MonitorStatus int const ( @@ -5678,8 +6822,9 @@ type MonitorCmd struct { func newMonitorCmd(ctx context.Context, ch chan string) *MonitorCmd { return &MonitorCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"monitor"}, + ctx: ctx, + args: []interface{}{"monitor"}, + cmdType: CmdTypeMonitor, }, ch: ch, status: monitorStatusIdle, @@ -5800,3 +6945,9 @@ func (cmd *VectorScoreSliceCmd) readReply(rd *proto.Reader) error { } return nil } + +func (cmd *MonitorCmd) Clone() Cmder { + // MonitorCmd cannot be safely cloned due to channels and goroutines + // Return a new MonitorCmd with the same channel + return newMonitorCmd(cmd.ctx, cmd.ch) +} diff --git a/go.mod b/go.mod index 3bbb8ac4d8..9da8e58a64 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/bsm/gomega v1.27.10 github.com/cespare/xxhash/v2 v2.3.0 github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f + github.com/fortytw2/leaktest v1.3.0 ) retract ( diff --git a/go.sum b/go.sum index 4db68f6d4f..a60f6d5880 100644 --- a/go.sum +++ b/go.sum @@ -6,3 +6,5 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go new file mode 100644 index 0000000000..f065415f60 --- /dev/null +++ b/internal/routing/aggregator.go @@ -0,0 +1,933 @@ +package routing + +import ( + "fmt" + "math" + "sync" +) + +type CmdTyper interface { + GetCmdType() CmdType +} + +type CmdType uint8 + +const ( + CmdTypeGeneric CmdType = iota + CmdTypeString + CmdTypeInt + CmdTypeBool + CmdTypeFloat + CmdTypeStringSlice + CmdTypeIntSlice + CmdTypeFloatSlice + CmdTypeBoolSlice + CmdTypeMapStringString + CmdTypeMapStringInt + CmdTypeMapStringInterface + CmdTypeMapStringInterfaceSlice + CmdTypeSlice + CmdTypeStatus + CmdTypeDuration + CmdTypeTime + CmdTypeKeyValueSlice + CmdTypeStringStructMap + CmdTypeXMessageSlice + CmdTypeXStreamSlice + CmdTypeXPending + CmdTypeXPendingExt + CmdTypeXAutoClaim + CmdTypeXAutoClaimJustID + CmdTypeXInfoConsumers + CmdTypeXInfoGroups + CmdTypeXInfoStream + CmdTypeXInfoStreamFull + CmdTypeZSlice + CmdTypeZWithKey + CmdTypeScan + CmdTypeClusterSlots + CmdTypeGeoLocation + CmdTypeGeoSearchLocation + CmdTypeGeoPos + CmdTypeCommandsInfo + CmdTypeSlowLog + CmdTypeMapStringStringSlice + CmdTypeMapMapStringInterface + CmdTypeKeyValues + CmdTypeZSliceWithKey + CmdTypeFunctionList + CmdTypeFunctionStats + CmdTypeLCS + CmdTypeKeyFlags + CmdTypeClusterLinks + CmdTypeClusterShards + CmdTypeRankWithScore + CmdTypeClientInfo + CmdTypeACLLog + CmdTypeInfo + CmdTypeMonitor + CmdTypeJSON + CmdTypeJSONSlice + CmdTypeIntPointerSlice + CmdTypeScanDump + CmdTypeBFInfo + CmdTypeCFInfo + CmdTypeCMSInfo + CmdTypeTopKInfo + CmdTypeTDigestInfo + CmdTypeFTSynDump + CmdTypeAggregate + CmdTypeFTInfo + CmdTypeFTSpellCheck + CmdTypeFTSearch + CmdTypeTSTimestampValue + CmdTypeTSTimestampValueSlice +) + +// ResponseAggregator defines the interface for aggregating responses from multiple shards. +type ResponseAggregator interface { + // Add processes a single shard response. + Add(result interface{}, err error) error + + // AddWithKey processes a single shard response for a specific key (used by keyed aggregators). + AddWithKey(key string, result interface{}, err error) error + + // Finish returns the final aggregated result and any error. + Finish() (interface{}, error) +} + +// NewResponseAggregator creates an aggregator based on the response policy. +func NewResponseAggregator(policy ResponsePolicy, cmdName string) ResponseAggregator { + switch policy { + case RespDefaultKeyless: + return &DefaultKeylessAggregator{} + case RespDefaultHashSlot: + return &DefaultKeyedAggregator{} + case RespAllSucceeded: + return &AllSucceededAggregator{} + case RespOneSucceeded: + return &OneSucceededAggregator{} + case RespAggSum: + return &AggSumAggregator{} + case RespAggMin: + return &AggMinAggregator{} + case RespAggMax: + return &AggMaxAggregator{} + case RespAggLogicalAnd: + return &AggLogicalAndAggregator{} + case RespAggLogicalOr: + return &AggLogicalOrAggregator{} + case RespSpecial: + return NewSpecialAggregator(cmdName) + default: + return &AllSucceededAggregator{} + } +} + +func NewDefaultAggregator(isKeyed bool) ResponseAggregator { + if isKeyed { + return &DefaultKeyedAggregator{ + results: make(map[string]interface{}), + } + } + return &DefaultKeylessAggregator{} +} + +// AllSucceededAggregator returns one non-error reply if every shard succeeded, +// propagates the first error otherwise. +type AllSucceededAggregator struct { + mu sync.Mutex + result interface{} + firstErr error + hasResult bool +} + +func (a *AllSucceededAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil && !a.hasResult { + a.result = result + a.hasResult = true + } + return nil +} + +func (a *AllSucceededAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AllSucceededAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.result, nil +} + +// OneSucceededAggregator returns the first non-error reply, +// if all shards errored, returns any one of those errors. +type OneSucceededAggregator struct { + mu sync.Mutex + result interface{} + firstErr error + hasResult bool +} + +func (a *OneSucceededAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil && !a.hasResult { + a.result = result + a.hasResult = true + } + return nil +} + +func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *OneSucceededAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.hasResult { + return a.result, nil + } + return nil, a.firstErr +} + +// AggSumAggregator sums numeric replies from all shards. +type AggSumAggregator struct { + mu sync.Mutex + sum int64 + hasResult bool + firstErr error +} + +func (a *AggSumAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toInt64(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + a.sum += val + a.hasResult = true + } + } + return nil +} + +func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggSumAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.sum, nil +} + +// AggMinAggregator returns the minimum numeric value from all shards. +type AggMinAggregator struct { + mu sync.Mutex + min int64 + hasResult bool + firstErr error +} + +func (a *AggMinAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toInt64(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + if !a.hasResult || val < a.min { + a.min = val + a.hasResult = true + } + } + } + return nil +} + +func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggMinAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.min, nil +} + +// AggMaxAggregator returns the maximum numeric value from all shards. +type AggMaxAggregator struct { + mu sync.Mutex + max int64 + hasResult bool + firstErr error +} + +func (a *AggMaxAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toInt64(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + if !a.hasResult || val > a.max { + a.max = val + a.hasResult = true + } + } + } + return nil +} + +func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggMaxAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.max, nil +} + +// AggLogicalAndAggregator performs logical AND on boolean values. +type AggLogicalAndAggregator struct { + mu sync.Mutex + result bool + hasResult bool + firstErr error +} + +func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toBool(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + if !a.hasResult { + a.result = val + a.hasResult = true + } else { + a.result = a.result && val + } + } + } + return nil +} + +func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggLogicalAndAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.result, nil +} + +// AggLogicalOrAggregator performs logical OR on boolean values. +type AggLogicalOrAggregator struct { + mu sync.Mutex + result bool + hasResult bool + firstErr error +} + +func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toBool(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + if !a.hasResult { + a.result = val + a.hasResult = true + } else { + a.result = a.result || val + } + } + } + return nil +} + +func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggLogicalOrAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.result, nil +} + +func toInt64(val interface{}) (int64, error) { + switch v := val.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case int32: + return int64(v), nil + case float64: + if v != math.Trunc(v) { + return 0, fmt.Errorf("cannot convert float %f to int64", v) + } + return int64(v), nil + default: + return 0, fmt.Errorf("cannot convert %T to int64", val) + } +} + +func toBool(val interface{}) (bool, error) { + switch v := val.(type) { + case bool: + return v, nil + case int64: + return v != 0, nil + case int: + return v != 0, nil + default: + return false, fmt.Errorf("cannot convert %T to bool", val) + } +} + +// DefaultKeylessAggregator collects all results in an array, order doesn't matter. +type DefaultKeylessAggregator struct { + mu sync.Mutex + results []interface{} + firstErr error +} + +func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + a.results = append(a.results, result) + } + return nil +} + +func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *DefaultKeylessAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.results, nil +} + +// DefaultKeyedAggregator reassembles replies in the exact key order of the original request. +type DefaultKeyedAggregator struct { + mu sync.Mutex + results map[string]interface{} + keyOrder []string + firstErr error +} + +func NewDefaultKeyedAggregator(keyOrder []string) *DefaultKeyedAggregator { + return &DefaultKeyedAggregator{ + results: make(map[string]interface{}), + keyOrder: keyOrder, + } +} + +func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + // For non-keyed Add, just collect the result without ordering + if err == nil { + a.results["__default__"] = result + } + return nil +} + +func (a *DefaultKeyedAggregator) AddWithKey(key string, result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + a.results[key] = result + } + return nil +} + +func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) { + a.mu.Lock() + defer a.mu.Unlock() + a.keyOrder = keyOrder +} + +func (a *DefaultKeyedAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + + // If no explicit key order is set, return results in any order + if len(a.keyOrder) == 0 { + orderedResults := make([]interface{}, 0, len(a.results)) + for _, result := range a.results { + orderedResults = append(orderedResults, result) + } + return orderedResults, nil + } + + // Return results in the exact key order + orderedResults := make([]interface{}, len(a.keyOrder)) + for i, key := range a.keyOrder { + if result, exists := a.results[key]; exists { + orderedResults[i] = result + } + } + return orderedResults, nil +} + +// SpecialAggregator provides a registry for command-specific aggregation logic. +type SpecialAggregator struct { + mu sync.Mutex + aggregatorFunc func([]interface{}, []error) (interface{}, error) + results []interface{} + errors []error +} + +func (a *SpecialAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + a.results = append(a.results, result) + a.errors = append(a.errors, err) + return nil +} + +func (a *SpecialAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *SpecialAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.aggregatorFunc != nil { + return a.aggregatorFunc(a.results, a.errors) + } + // Default behavior: return first non-error result or first error + for i, err := range a.errors { + if err == nil { + return a.results[i], nil + } + } + if len(a.errors) > 0 { + return nil, a.errors[0] + } + return nil, nil +} + +// SetAggregatorFunc allows setting custom aggregation logic for special commands. +func (a *SpecialAggregator) SetAggregatorFunc(fn func([]interface{}, []error) (interface{}, error)) { + a.mu.Lock() + defer a.mu.Unlock() + a.aggregatorFunc = fn +} + +// SpecialAggregatorRegistry holds custom aggregation functions for specific commands. +var SpecialAggregatorRegistry = make(map[string]func([]interface{}, []error) (interface{}, error)) + +// RegisterSpecialAggregator registers a custom aggregation function for a command. +func RegisterSpecialAggregator(cmdName string, fn func([]interface{}, []error) (interface{}, error)) { + SpecialAggregatorRegistry[cmdName] = fn +} + +// NewSpecialAggregator creates a special aggregator with command-specific logic if available. +func NewSpecialAggregator(cmdName string) *SpecialAggregator { + agg := &SpecialAggregator{} + if fn, exists := SpecialAggregatorRegistry[cmdName]; exists { + agg.SetAggregatorFunc(fn) + } + return agg +} + +// CmdTypeGetter interface for getting command type without circular imports +type CmdTypeGetter interface { + GetCmdType() CmdType +} + +// ExtractCommandValue extracts the value from a command result using the fast enum-based approach +func ExtractCommandValue(cmd interface{}) interface{} { + // First try to get the command type using the interface + if cmdTypeGetter, ok := cmd.(CmdTypeGetter); ok { + cmdType := cmdTypeGetter.GetCmdType() + + // Use fast type-based extraction + switch cmdType { + case CmdTypeString: + if stringCmd, ok := cmd.(interface{ Val() string }); ok { + return stringCmd.Val() + } + case CmdTypeInt: + if intCmd, ok := cmd.(interface{ Val() int64 }); ok { + return intCmd.Val() + } + case CmdTypeBool: + if boolCmd, ok := cmd.(interface{ Val() bool }); ok { + return boolCmd.Val() + } + case CmdTypeFloat: + if floatCmd, ok := cmd.(interface{ Val() float64 }); ok { + return floatCmd.Val() + } + case CmdTypeDuration: + if durationCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return durationCmd.Val() + } + case CmdTypeTime: + if timeCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return timeCmd.Val() + } + case CmdTypeStringSlice: + if stringSliceCmd, ok := cmd.(interface{ Val() []string }); ok { + return stringSliceCmd.Val() + } + case CmdTypeIntSlice: + if intSliceCmd, ok := cmd.(interface{ Val() []int64 }); ok { + return intSliceCmd.Val() + } + case CmdTypeBoolSlice: + if boolSliceCmd, ok := cmd.(interface{ Val() []bool }); ok { + return boolSliceCmd.Val() + } + case CmdTypeFloatSlice: + if floatSliceCmd, ok := cmd.(interface{ Val() []float64 }); ok { + return floatSliceCmd.Val() + } + case CmdTypeMapStringString: + if mapCmd, ok := cmd.(interface{ Val() map[string]string }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringInt: + if mapCmd, ok := cmd.(interface{ Val() map[string]int64 }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringInterfaceSlice: + if mapCmd, ok := cmd.(interface { + Val() map[string][]interface{} + }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringInterface: + if mapCmd, ok := cmd.(interface{ Val() map[string]interface{} }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringStringSlice: + if mapCmd, ok := cmd.(interface{ Val() map[string][]string }); ok { + return mapCmd.Val() + } + case CmdTypeMapMapStringInterface: + if mapCmd, ok := cmd.(interface { + Val() map[string][]interface{} + }); ok { + return mapCmd.Val() + } + case CmdTypeStringStructMap: + if mapCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return mapCmd.Val() + } + case CmdTypeXMessageSlice: + if xMsgCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xMsgCmd.Val() + } + case CmdTypeXStreamSlice: + if xStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xStreamCmd.Val() + } + case CmdTypeXPending: + if xPendingCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xPendingCmd.Val() + } + case CmdTypeXPendingExt: + if xPendingExtCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xPendingExtCmd.Val() + } + case CmdTypeXAutoClaim: + if xAutoClaimCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xAutoClaimCmd.Val() + } + case CmdTypeXAutoClaimJustID: + if xAutoClaimJustIDCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xAutoClaimJustIDCmd.Val() + } + case CmdTypeXInfoConsumers: + if xInfoConsumersCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoConsumersCmd.Val() + } + case CmdTypeXInfoGroups: + if xInfoGroupsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoGroupsCmd.Val() + } + case CmdTypeXInfoStream: + if xInfoStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoStreamCmd.Val() + } + case CmdTypeXInfoStreamFull: + if xInfoStreamFullCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoStreamFullCmd.Val() + } + case CmdTypeZSlice: + if zSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return zSliceCmd.Val() + } + case CmdTypeZWithKey: + if zWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return zWithKeyCmd.Val() + } + case CmdTypeScan: + if scanCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return scanCmd.Val() + } + case CmdTypeClusterSlots: + if clusterSlotsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clusterSlotsCmd.Val() + } + case CmdTypeGeoSearchLocation: + if geoSearchLocationCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return geoSearchLocationCmd.Val() + } + case CmdTypeGeoPos: + if geoPosCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return geoPosCmd.Val() + } + case CmdTypeCommandsInfo: + if commandsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return commandsInfoCmd.Val() + } + case CmdTypeSlowLog: + if slowLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return slowLogCmd.Val() + } + + case CmdTypeKeyValues: + if keyValuesCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return keyValuesCmd.Val() + } + case CmdTypeZSliceWithKey: + if zSliceWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return zSliceWithKeyCmd.Val() + } + case CmdTypeFunctionList: + if functionListCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return functionListCmd.Val() + } + case CmdTypeFunctionStats: + if functionStatsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return functionStatsCmd.Val() + } + case CmdTypeLCS: + if lcsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return lcsCmd.Val() + } + case CmdTypeKeyFlags: + if keyFlagsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return keyFlagsCmd.Val() + } + case CmdTypeClusterLinks: + if clusterLinksCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clusterLinksCmd.Val() + } + case CmdTypeClusterShards: + if clusterShardsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clusterShardsCmd.Val() + } + case CmdTypeRankWithScore: + if rankWithScoreCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return rankWithScoreCmd.Val() + } + case CmdTypeClientInfo: + if clientInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clientInfoCmd.Val() + } + case CmdTypeACLLog: + if aclLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return aclLogCmd.Val() + } + case CmdTypeInfo: + if infoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return infoCmd.Val() + } + case CmdTypeMonitor: + if monitorCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return monitorCmd.Val() + } + case CmdTypeJSON: + if jsonCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return jsonCmd.Val() + } + case CmdTypeJSONSlice: + if jsonSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return jsonSliceCmd.Val() + } + case CmdTypeIntPointerSlice: + if intPointerSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return intPointerSliceCmd.Val() + } + case CmdTypeScanDump: + if scanDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return scanDumpCmd.Val() + } + case CmdTypeBFInfo: + if bfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return bfInfoCmd.Val() + } + case CmdTypeCFInfo: + if cfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return cfInfoCmd.Val() + } + case CmdTypeCMSInfo: + if cmsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return cmsInfoCmd.Val() + } + case CmdTypeTopKInfo: + if topKInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return topKInfoCmd.Val() + } + case CmdTypeTDigestInfo: + if tDigestInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return tDigestInfoCmd.Val() + } + case CmdTypeFTSearch: + if ftSearchCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftSearchCmd.Val() + } + case CmdTypeFTInfo: + if ftInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftInfoCmd.Val() + } + case CmdTypeFTSpellCheck: + if ftSpellCheckCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftSpellCheckCmd.Val() + } + case CmdTypeFTSynDump: + if ftSynDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftSynDumpCmd.Val() + } + case CmdTypeAggregate: + if aggregateCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return aggregateCmd.Val() + } + case CmdTypeTSTimestampValue: + if tsTimestampValueCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return tsTimestampValueCmd.Val() + } + case CmdTypeTSTimestampValueSlice: + if tsTimestampValueSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return tsTimestampValueSliceCmd.Val() + } + default: + // For unknown command types, return nil + return nil + } + } + + // If we can't get the command type, return nil + return nil +} diff --git a/internal/routing/aggregator_test.go b/internal/routing/aggregator_test.go new file mode 100644 index 0000000000..4de29396df --- /dev/null +++ b/internal/routing/aggregator_test.go @@ -0,0 +1,427 @@ +package routing + +import ( + "errors" + "testing" +) + +// Mock command types for testing +type MockStringCmd struct { + cmdType CmdType + val string +} + +func (cmd *MockStringCmd) GetCmdType() CmdType { + return cmd.cmdType +} + +func (cmd *MockStringCmd) Val() string { + return cmd.val +} + +type MockIntCmd struct { + cmdType CmdType + val int64 +} + +func (cmd *MockIntCmd) GetCmdType() CmdType { + return cmd.cmdType +} + +func (cmd *MockIntCmd) Val() int64 { + return cmd.val +} + +type MockBoolCmd struct { + cmdType CmdType + val bool +} + +func (cmd *MockBoolCmd) GetCmdType() CmdType { + return cmd.cmdType +} + +func (cmd *MockBoolCmd) Val() bool { + return cmd.val +} + +// Legacy command without GetCmdType for comparison +type LegacyStringCmd struct { + val string +} + +func (cmd *LegacyStringCmd) Val() string { + return cmd.val +} + +func BenchmarkExtractCommandValueOptimized(b *testing.B) { + commands := []interface{}{ + &MockStringCmd{cmdType: CmdTypeString, val: "test-value"}, + &MockIntCmd{cmdType: CmdTypeInt, val: 42}, + &MockBoolCmd{cmdType: CmdTypeBool, val: true}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, cmd := range commands { + ExtractCommandValue(cmd) + } + } +} + +func BenchmarkExtractCommandValueLegacy(b *testing.B) { + commands := []interface{}{ + &LegacyStringCmd{val: "test-value"}, + &MockIntCmd{cmdType: CmdTypeInt, val: 42}, + &MockBoolCmd{cmdType: CmdTypeBool, val: true}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, cmd := range commands { + ExtractCommandValue(cmd) + } + } +} + +func TestExtractCommandValue(t *testing.T) { + tests := []struct { + name string + cmd interface{} + expected interface{} + }{ + { + name: "string command", + cmd: &MockStringCmd{cmdType: CmdTypeString, val: "hello"}, + expected: "hello", + }, + { + name: "int command", + cmd: &MockIntCmd{cmdType: CmdTypeInt, val: 123}, + expected: int64(123), + }, + { + name: "bool command", + cmd: &MockBoolCmd{cmdType: CmdTypeBool, val: true}, + expected: true, + }, + { + name: "unsupported command", + cmd: &LegacyStringCmd{val: "test"}, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractCommandValue(tt.cmd) + if result != tt.expected { + t.Errorf("ExtractCommandValue() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestExtractCommandValueIntegration(t *testing.T) { + tests := []struct { + name string + cmd interface{} + expected interface{} + }{ + { + name: "optimized string command", + cmd: &MockStringCmd{cmdType: CmdTypeString, val: "hello"}, + expected: "hello", + }, + { + name: "legacy string command returns nil (no GetCmdType)", + cmd: &LegacyStringCmd{val: "legacy"}, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractCommandValue(tt.cmd) + if result != tt.expected { + t.Errorf("ExtractCommandValue() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestAllSucceededAggregator(t *testing.T) { + agg := &AllSucceededAggregator{} + + err := agg.Add("result1", nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add("result2", nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err := agg.Finish() + if err != nil { + t.Errorf("Finish failed: %v", err) + } + if result != "result1" { + t.Errorf("Expected 'result1', got %v", result) + } + + agg = &AllSucceededAggregator{} + testErr := errors.New("test error") + err = agg.Add("result1", nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add("result2", testErr) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err = agg.Finish() + if err != testErr { + t.Errorf("Expected test error, got %v", err) + } +} + +func TestOneSucceededAggregator(t *testing.T) { + agg := &OneSucceededAggregator{} + + testErr := errors.New("test error") + err := agg.Add("result1", testErr) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add("result2", nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err := agg.Finish() + if err != nil { + t.Errorf("Finish failed: %v", err) + } + if result != "result2" { + t.Errorf("Expected 'result2', got %v", result) + } + + agg = &OneSucceededAggregator{} + err = agg.Add("result1", testErr) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add("result2", testErr) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err = agg.Finish() + if err != testErr { + t.Errorf("Expected test error, got %v", err) + } +} + +func TestAggSumAggregator(t *testing.T) { + agg := &AggSumAggregator{} + + err := agg.Add(int64(10), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(int64(20), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(int64(30), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err := agg.Finish() + if err != nil { + t.Errorf("Finish failed: %v", err) + } + if result != int64(60) { + t.Errorf("Expected 60, got %v", result) + } + + agg = &AggSumAggregator{} + testErr := errors.New("test error") + err = agg.Add(int64(10), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(int64(20), testErr) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err = agg.Finish() + if err != testErr { + t.Errorf("Expected test error, got %v", err) + } +} + +func TestAggMinAggregator(t *testing.T) { + agg := &AggMinAggregator{} + + err := agg.Add(int64(30), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(int64(10), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(int64(20), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err := agg.Finish() + if err != nil { + t.Errorf("Finish failed: %v", err) + } + if result != int64(10) { + t.Errorf("Expected 10, got %v", result) + } +} + +func TestAggMaxAggregator(t *testing.T) { + agg := &AggMaxAggregator{} + + err := agg.Add(int64(10), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(int64(30), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(int64(20), nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err := agg.Finish() + if err != nil { + t.Errorf("Finish failed: %v", err) + } + if result != int64(30) { + t.Errorf("Expected 30, got %v", result) + } +} + +func TestAggLogicalAndAggregator(t *testing.T) { + agg := &AggLogicalAndAggregator{} + + err := agg.Add(true, nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(true, nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(false, nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err := agg.Finish() + if err != nil { + t.Errorf("Finish failed: %v", err) + } + if result != false { + t.Errorf("Expected false, got %v", result) + } +} + +func TestAggLogicalOrAggregator(t *testing.T) { + agg := &AggLogicalOrAggregator{} + + err := agg.Add(false, nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(true, nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add(false, nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err := agg.Finish() + if err != nil { + t.Errorf("Finish failed: %v", err) + } + if result != true { + t.Errorf("Expected true, got %v", result) + } +} + +func TestDefaultKeylessAggregator(t *testing.T) { + agg := &DefaultKeylessAggregator{} + + err := agg.Add("result1", nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add("result2", nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + err = agg.Add("result3", nil) + if err != nil { + t.Errorf("Add failed: %v", err) + } + + result, err := agg.Finish() + if err != nil { + t.Errorf("Finish failed: %v", err) + } + + results, ok := result.([]interface{}) + if !ok { + t.Errorf("Expected []interface{}, got %T", result) + } + if len(results) != 3 { + t.Errorf("Expected 3 results, got %d", len(results)) + } + if results[0] != "result1" || results[1] != "result2" || results[2] != "result3" { + t.Errorf("Unexpected results: %v", results) + } +} + +func TestNewResponseAggregator(t *testing.T) { + tests := []struct { + policy ResponsePolicy + cmdName string + expected string + }{ + {RespAllSucceeded, "test", "*routing.AllSucceededAggregator"}, + {RespOneSucceeded, "test", "*routing.OneSucceededAggregator"}, + {RespAggSum, "test", "*routing.AggSumAggregator"}, + {RespAggMin, "test", "*routing.AggMinAggregator"}, + {RespAggMax, "test", "*routing.AggMaxAggregator"}, + {RespAggLogicalAnd, "test", "*routing.AggLogicalAndAggregator"}, + {RespAggLogicalOr, "test", "*routing.AggLogicalOrAggregator"}, + {RespSpecial, "test", "*routing.SpecialAggregator"}, + } + + for _, test := range tests { + agg := NewResponseAggregator(test.policy, test.cmdName) + if agg == nil { + t.Errorf("NewResponseAggregator returned nil for policy %v", test.policy) + } + _, ok := agg.(ResponseAggregator) + if !ok { + t.Errorf("Aggregator does not implement ResponseAggregator interface") + } + } +} diff --git a/internal/routing/policy.go b/internal/routing/policy.go index 18c03cd2dd..d65efb8aef 100644 --- a/internal/routing/policy.go +++ b/internal/routing/policy.go @@ -56,7 +56,9 @@ func ParseRequestPolicy(raw string) (RequestPolicy, error) { type ResponsePolicy uint8 const ( - RespAllSucceeded ResponsePolicy = iota + RespDefaultKeyless ResponsePolicy = iota + RespDefaultHashSlot + RespAllSucceeded RespOneSucceeded RespAggSum RespAggMin @@ -68,6 +70,10 @@ const ( func (p ResponsePolicy) String() string { switch p { + case RespDefaultKeyless: + return "default(keyless)" + case RespDefaultHashSlot: + return "default(hashslot)" case RespAllSucceeded: return "all_succeeded" case RespOneSucceeded: @@ -85,12 +91,16 @@ func (p ResponsePolicy) String() string { case RespSpecial: return "special" default: - return fmt.Sprintf("unknown_response_policy(%d)", p) + return "all_succeeded" } } func ParseResponsePolicy(raw string) (ResponsePolicy, error) { switch strings.ToLower(raw) { + case "default(keyless)": + return RespDefaultKeyless, nil + case "default(hashslot)": + return RespDefaultHashSlot, nil case "all_succeeded": return RespAllSucceeded, nil case "one_succeeded": @@ -108,7 +118,7 @@ func ParseResponsePolicy(raw string) (ResponsePolicy, error) { case "special": return RespSpecial, nil default: - return RespAllSucceeded, fmt.Errorf("routing: unknown response_policy %q", raw) + return RespDefaultKeyless, fmt.Errorf("routing: unknown response_policy %q", raw) } } diff --git a/internal/routing/shard_picker.go b/internal/routing/shard_picker.go new file mode 100644 index 0000000000..e29d526b0b --- /dev/null +++ b/internal/routing/shard_picker.go @@ -0,0 +1,41 @@ +package routing + +import ( + "math/rand" + "sync/atomic" +) + +// ShardPicker chooses “one arbitrary shard” when the request_policy is +// ReqDefault and the command has no keys. +type ShardPicker interface { + Next(total int) int // returns an index in [0,total) +} + +/*─────────────────────────────── + Round-robin (default) +────────────────────────────────*/ + +type RoundRobinPicker struct { + cnt atomic.Uint32 +} + +func (p *RoundRobinPicker) Next(total int) int { + if total == 0 { + return 0 + } + i := p.cnt.Add(1) + return int(i-1) % total +} + +/*─────────────────────────────── + Random +────────────────────────────────*/ + +type RandomPicker struct{} + +func (RandomPicker) Next(total int) int { + if total == 0 { + return 0 + } + return rand.Intn(total) +} diff --git a/json.go b/json.go index 2b9fa527ee..44b09d4754 100644 --- a/json.go +++ b/json.go @@ -68,8 +68,9 @@ var _ Cmder = (*JSONCmd)(nil) func newJSONCmd(ctx context.Context, args ...interface{}) *JSONCmd { return &JSONCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeJSON, }, } } @@ -165,6 +166,14 @@ func (cmd *JSONCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *JSONCmd) Clone() Cmder { + return &JSONCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + expanded: cmd.expanded, // interface{} can be shared as it should be immutable after parsing + } +} + // ------------------------------------------- type JSONSliceCmd struct { @@ -175,8 +184,9 @@ type JSONSliceCmd struct { func NewJSONSliceCmd(ctx context.Context, args ...interface{}) *JSONSliceCmd { return &JSONSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeJSONSlice, }, } } @@ -233,6 +243,18 @@ func (cmd *JSONSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *JSONSliceCmd) Clone() Cmder { + var val []interface{} + if cmd.val != nil { + val = make([]interface{}, len(cmd.val)) + copy(val, cmd.val) + } + return &JSONSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + /******************************************************************************* * * IntPointerSliceCmd @@ -249,8 +271,9 @@ type IntPointerSliceCmd struct { func NewIntPointerSliceCmd(ctx context.Context, args ...interface{}) *IntPointerSliceCmd { return &IntPointerSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeIntPointerSlice, }, } } @@ -290,6 +313,23 @@ func (cmd *IntPointerSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *IntPointerSliceCmd) Clone() Cmder { + var val []*int64 + if cmd.val != nil { + val = make([]*int64, len(cmd.val)) + for i, ptr := range cmd.val { + if ptr != nil { + newVal := *ptr + val[i] = &newVal + } + } + } + return &IntPointerSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ // JSONArrAppend adds the provided JSON values to the end of the array at the given path. diff --git a/main_test.go b/main_test.go index 9d8efe3d98..81cc6d2aaf 100644 --- a/main_test.go +++ b/main_test.go @@ -4,6 +4,8 @@ import ( "fmt" "net" "os" + "runtime" + "runtime/pprof" "strconv" "strings" "sync" @@ -107,6 +109,7 @@ var _ = BeforeSuite(func() { if RedisVersion < 7.0 || RedisVersion > 9 { panic("incorrect or not supported redis version") + } redisPort = redisStackPort @@ -148,12 +151,22 @@ var _ = BeforeSuite(func() { // populate cluster node information Expect(configureClusterTopology(ctx, cluster)).NotTo(HaveOccurred()) } + runtime.SetBlockProfileRate(1) + runtime.SetMutexProfileFraction(1) }) var _ = AfterSuite(func() { if !RECluster { Expect(cluster.Close()).NotTo(HaveOccurred()) } + if f, err := os.Create("block.pprof"); err == nil { + pprof.Lookup("block").WriteTo(f, 0) + f.Close() + } + if f, err := os.Create("mutex.pprof"); err == nil { + pprof.Lookup("mutex").WriteTo(f, 0) + f.Close() + } }) func TestGinkgoSuite(t *testing.T) { diff --git a/osscluster.go b/osscluster.go index 0db0699f8a..598b9409af 100644 --- a/osscluster.go +++ b/osscluster.go @@ -20,6 +20,7 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/internal/routing" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -148,6 +149,9 @@ type ClusterOptions struct { // If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it. // The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications. MaintNotificationsConfig *maintnotifications.Config + // ShardPicker is used to pick a shard when the request_policy is + // ReqDefault and the command has no keys. + ShardPicker routing.ShardPicker } func (opt *ClusterOptions) init() { @@ -208,6 +212,10 @@ func (opt *ClusterOptions) init() { if opt.FailingTimeoutSeconds == 0 { opt.FailingTimeoutSeconds = 15 } + + if opt.ShardPicker == nil { + opt.ShardPicker = &routing.RoundRobinPicker{} + } } // ParseClusterURL parses a URL into ClusterOptions that can be used to connect to Redis. @@ -1090,13 +1098,13 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { if ask { ask = false - pipe := node.Client.Pipeline() _ = pipe.Process(ctx, NewCmd(ctx, "asking")) _ = pipe.Process(ctx, cmd) _, lastErr = pipe.Exec(ctx) } else { - lastErr = node.Client.Process(ctx, cmd) + // Execute the command on the selected node + lastErr = c.routeAndRun(ctx, cmd, node) } // If there is no error - we are done. @@ -1370,11 +1378,23 @@ func (c *ClusterClient) Pipelined(ctx context.Context, fn func(Pipeliner) error) } func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { - cmdsMap := newCmdsMap() + // Separate commands into those that can be batched vs those that need individual routing + batchableCmds := make([]Cmder, 0) + individualCmds := make([]Cmder, 0) - if err := c.mapCmdsByNode(ctx, cmdsMap, cmds); err != nil { - setCmdsErr(cmds, err) - return err + for _, cmd := range cmds { + policy := c.getCommandPolicy(ctx, cmd) + + // Commands that need special routing should be handled individually + if policy != nil && (policy.Request == routing.ReqAllNodes || + policy.Request == routing.ReqAllShards || + policy.Request == routing.ReqMultiShard || + policy.Request == routing.ReqSpecial) { + individualCmds = append(individualCmds, cmd) + } else { + // Single-node commands can be batched + batchableCmds = append(batchableCmds, cmd) + } } for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { @@ -1385,73 +1405,68 @@ func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error } } - failedCmds := newCmdsMap() - var wg sync.WaitGroup + var allSucceeded = true + var failedBatchableCmds []Cmder + var failedIndividualCmds []Cmder - for node, cmds := range cmdsMap.m { - wg.Add(1) - go func(node *clusterNode, cmds []Cmder) { - defer wg.Done() - c.processPipelineNode(ctx, node, cmds, failedCmds) - }(node, cmds) + // Handle individual commands using existing router + for _, cmd := range individualCmds { + if err := c.routeAndRun(ctx, cmd, nil); err != nil { + allSucceeded = false + failedIndividualCmds = append(failedIndividualCmds, cmd) + } } - wg.Wait() - if len(failedCmds.m) == 0 { - break - } - cmdsMap = failedCmds - } + // Handle batchable commands using original pipeline logic + if len(batchableCmds) > 0 { + cmdsMap := newCmdsMap() - return cmdsFirstErr(cmds) -} + if err := c.mapCmdsByNode(ctx, cmdsMap, batchableCmds); err != nil { + setCmdsErr(batchableCmds, err) + allSucceeded = false + failedBatchableCmds = append(failedBatchableCmds, batchableCmds...) + } else { + batchFailedCmds := newCmdsMap() + var wg sync.WaitGroup + + for node, nodeCmds := range cmdsMap.m { + wg.Add(1) + go func(node *clusterNode, nodeCmds []Cmder) { + defer wg.Done() + c.processPipelineNode(ctx, node, nodeCmds, batchFailedCmds) + }(node, nodeCmds) + } -func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmds []Cmder) error { - state, err := c.state.Get(ctx) - if err != nil { - return err - } + wg.Wait() - preferredRandomSlot := -1 - if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { - for _, cmd := range cmds { - slot := c.cmdSlot(cmd, preferredRandomSlot) - if preferredRandomSlot == -1 { - preferredRandomSlot = slot - } - node, err := c.slotReadOnlyNode(state, slot) - if err != nil { - return err + if len(batchFailedCmds.m) > 0 { + allSucceeded = false + for _, nodeCmds := range batchFailedCmds.m { + failedBatchableCmds = append(failedBatchableCmds, nodeCmds...) + } + } } - cmdsMap.Add(node, cmd) } - return nil - } - for _, cmd := range cmds { - slot := c.cmdSlot(cmd, preferredRandomSlot) - if preferredRandomSlot == -1 { - preferredRandomSlot = slot - } - node, err := state.slotMasterNode(slot) - if err != nil { - return err + // If all commands succeeded, we're done + if allSucceeded { + break } - cmdsMap.Add(node, cmd) - } - return nil -} -func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool { - for _, cmd := range cmds { - cmdInfo := c.cmdInfo(ctx, cmd.Name()) - if cmdInfo == nil || !cmdInfo.ReadOnly { - return false + // If this was the last attempt, return the error + if attempt == c.opt.MaxRedirects { + break } + + // Update command lists for retry - no reclassification needed + batchableCmds = failedBatchableCmds + individualCmds = failedIndividualCmds } - return true + + return cmdsFirstErr(cmds) } +// processPipelineNode handles batched pipeline commands for a single node func (c *ClusterClient) processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { @@ -1461,7 +1476,8 @@ func (c *ClusterClient) processPipelineNode( if !isContextError(err) { node.MarkAsFailing() } - _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + // Commands are already mapped to this node, just add them as failed + failedCmds.Add(node, cmds...) setCmdsErr(cmds, err) return err } @@ -1486,7 +1502,8 @@ func (c *ClusterClient) processPipelineNodeConn( node.MarkAsFailing() } if shouldRetry(err, true) { - _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + // Commands are already mapped to this node, just add them as failed + failedCmds.Add(node, cmds...) } setCmdsErr(cmds, err) return err @@ -1522,7 +1539,8 @@ func (c *ClusterClient) pipelineReadCmds( if !isRedisError(err) { if shouldRetry(err, true) { - _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + // Commands are already mapped to this node, just add them as failed + failedCmds.Add(node, cmds[i:]...) } setCmdsErr(cmds[i+1:], err) return err @@ -1530,13 +1548,61 @@ func (c *ClusterClient) pipelineReadCmds( } if err := cmds[0].Err(); err != nil && shouldRetry(err, true) { - _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + // Commands are already mapped to this node, just add them as failed + failedCmds.Add(node, cmds...) return err } return nil } +// Legacy functions needed for transaction pipeline processing +func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmds []Cmder) error { + state, err := c.state.Get(ctx) + if err != nil { + return err + } + + preferredRandomSlot := -1 + if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { + for _, cmd := range cmds { + slot := c.cmdSlot(cmd, preferredRandomSlot) + if preferredRandomSlot == -1 { + preferredRandomSlot = slot + } + node, err := c.slotReadOnlyNode(state, slot) + if err != nil { + return err + } + cmdsMap.Add(node, cmd) + } + return nil + } + + for _, cmd := range cmds { + slot := c.cmdSlot(cmd, preferredRandomSlot) + if preferredRandomSlot == -1 { + preferredRandomSlot = slot + } + node, err := state.slotMasterNode(slot) + if err != nil { + return err + } + cmdsMap.Add(node, cmd) + } + return nil +} + +func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool { + for _, cmd := range cmds { + cmdInfo := c.cmdInfo(ctx, cmd.Name()) + if cmdInfo == nil || !cmdInfo.ReadOnly { + return false + } + } + return true +} + func (c *ClusterClient) checkMovedErr( ctx context.Context, cmd Cmder, err error, failedCmds *cmdsMap, ) bool { @@ -1564,6 +1630,35 @@ func (c *ClusterClient) checkMovedErr( panic("not reached") } +func (c *ClusterClient) cmdsMoved( + ctx context.Context, cmds []Cmder, + moved, ask bool, + addr string, + failedCmds *cmdsMap, +) error { + node, err := c.nodes.GetOrCreate(addr) + if err != nil { + return err + } + + if moved { + c.state.LazyReload() + for _, cmd := range cmds { + failedCmds.Add(node, cmd) + } + return nil + } + + if ask { + for _, cmd := range cmds { + failedCmds.Add(node, NewCmd(ctx, "asking"), cmd) + } + return nil + } + + return nil +} + // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. func (c *ClusterClient) TxPipeline() Pipeliner { pipe := Pipeline{ @@ -1787,35 +1882,6 @@ func (c *ClusterClient) txPipelineReadQueued( return nil } -func (c *ClusterClient) cmdsMoved( - ctx context.Context, cmds []Cmder, - moved, ask bool, - addr string, - failedCmds *cmdsMap, -) error { - node, err := c.nodes.GetOrCreate(addr) - if err != nil { - return err - } - - if moved { - c.state.LazyReload() - for _, cmd := range cmds { - failedCmds.Add(node, cmd) - } - return nil - } - - if ask { - for _, cmd := range cmds { - failedCmds.Add(node, NewCmd(ctx, "asking"), cmd) - } - return nil - } - - return nil -} - func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { if len(keys) == 0 { return fmt.Errorf("redis: Watch requires at least one key") @@ -1994,7 +2060,6 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, for _, idx := range perm { addr := addrs[idx] - node, err := c.nodes.GetOrCreate(addr) if err != nil { if firstErr == nil { @@ -2007,6 +2072,7 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, if err == nil { return info, nil } + if firstErr == nil { firstErr = err } @@ -2019,7 +2085,17 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, } func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { - cmdsInfo, err := c.cmdsInfoCache.Get(ctx) + // Use a separate context that won't be canceled to ensure command info lookup + // doesn't fail due to original context cancellation + cmdInfoCtx := context.Background() + if c.opt.ContextTimeoutEnabled && ctx != nil { + // If context timeout is enabled, still use a reasonable timeout + var cancel context.CancelFunc + cmdInfoCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } + + cmdsInfo, err := c.cmdsInfoCache.Get(cmdInfoCtx) if err != nil { internal.Logger.Printf(context.TODO(), "getting command info: %s", err) return nil diff --git a/osscluster_router.go b/osscluster_router.go new file mode 100644 index 0000000000..a1fe669736 --- /dev/null +++ b/osscluster_router.go @@ -0,0 +1,847 @@ +package redis + +import ( + "context" + "fmt" + "reflect" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal/hashtag" + "github.com/redis/go-redis/v9/internal/routing" +) + +// slotResult represents the result of executing a command on a specific slot +type slotResult struct { + cmd Cmder + keys []string + err error +} + +// routeAndRun routes a command to the appropriate cluster nodes and executes it +func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { + policy := c.getCommandPolicy(ctx, cmd) + + switch { + case policy != nil && policy.Request == routing.ReqAllNodes: + return c.executeOnAllNodes(ctx, cmd, policy) + case policy != nil && policy.Request == routing.ReqAllShards: + return c.executeOnAllShards(ctx, cmd, policy) + case policy != nil && policy.Request == routing.ReqMultiShard: + return c.executeMultiShard(ctx, cmd, policy) + case policy != nil && policy.Request == routing.ReqSpecial: + return c.executeSpecialCommand(ctx, cmd, policy, node) + default: + return c.executeDefault(ctx, cmd, node) + } +} + +// getCommandPolicy retrieves the routing policy for a command +func (c *ClusterClient) getCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { + if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.Tips != nil { + return cmdInfo.Tips + } + return nil +} + +// executeDefault handles standard command routing based on keys +func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, node *clusterNode) error { + if c.hasKeys(cmd) { + // execute on key based shard + return node.Client.Process(ctx, cmd) + } + return c.executeOnArbitraryShard(ctx, cmd) +} + +// executeOnArbitraryShard routes command to an arbitrary shard +func (c *ClusterClient) executeOnArbitraryShard(ctx context.Context, cmd Cmder) error { + node := c.pickArbitraryShard(ctx) + if node == nil { + return errClusterNoNodes + } + return node.Client.Process(ctx, cmd) +} + +// executeOnAllNodes executes command on all nodes (masters and replicas) +func (c *ClusterClient) executeOnAllNodes(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { + state, err := c.state.Get(ctx) + if err != nil { + return err + } + + nodes := append(state.Masters, state.Slaves...) + if len(nodes) == 0 { + return errClusterNoNodes + } + + return c.executeParallel(ctx, cmd, nodes, policy) +} + +// executeOnAllShards executes command on all master shards +func (c *ClusterClient) executeOnAllShards(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { + state, err := c.state.Get(ctx) + if err != nil { + return err + } + + if len(state.Masters) == 0 { + return errClusterNoNodes + } + + return c.executeParallel(ctx, cmd, state.Masters, policy) +} + +// executeMultiShard handles commands that operate on multiple keys across shards +func (c *ClusterClient) executeMultiShard(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { + args := cmd.Args() + firstKeyPos := int(cmdFirstKeyPos(cmd)) + + if firstKeyPos == 0 || firstKeyPos >= len(args) { + return fmt.Errorf("redis: multi-shard command %s has no key arguments", cmd.Name()) + } + + // Group keys by slot + slotMap := make(map[int][]string) + keyOrder := make([]string, 0) + + for i := firstKeyPos; i < len(args); i++ { + key, ok := args[i].(string) + if !ok { + return fmt.Errorf("redis: non-string key at position %d: %v", i, args[i]) + } + + slot := hashtag.Slot(key) + slotMap[slot] = append(slotMap[slot], key) + keyOrder = append(keyOrder, key) + } + + return c.executeMultiSlot(ctx, cmd, slotMap, keyOrder, policy) +} + +// executeMultiSlot executes commands across multiple slots concurrently +func (c *ClusterClient) executeMultiSlot(ctx context.Context, cmd Cmder, slotMap map[int][]string, keyOrder []string, policy *routing.CommandPolicy) error { + results := make(chan slotResult, len(slotMap)) + var wg sync.WaitGroup + + // Execute on each slot concurrently + for slot, keys := range slotMap { + wg.Add(1) + go func(slot int, keys []string) { + defer wg.Done() + + node, err := c.cmdNode(ctx, cmd.Name(), slot) + if err != nil { + results <- slotResult{nil, keys, err} + return + } + + // Create a command for this specific slot's keys + subCmd := c.createSlotSpecificCommand(ctx, cmd, keys) + err = node.Client.Process(ctx, subCmd) + results <- slotResult{subCmd, keys, err} + }(slot, keys) + } + + go func() { + wg.Wait() + close(results) + }() + + return c.aggregateMultiSlotResults(ctx, cmd, results, keyOrder, policy) +} + +// createSlotSpecificCommand creates a new command for a specific slot's keys +func (c *ClusterClient) createSlotSpecificCommand(ctx context.Context, originalCmd Cmder, keys []string) Cmder { + originalArgs := originalCmd.Args() + firstKeyPos := int(cmdFirstKeyPos(originalCmd)) + + // Build new args with only the specified keys + newArgs := make([]interface{}, 0, firstKeyPos+len(keys)) + + // Copy command name and arguments before the keys + newArgs = append(newArgs, originalArgs[:firstKeyPos]...) + + // Add the slot-specific keys + for _, key := range keys { + newArgs = append(newArgs, key) + } + + // Create new command with the filtered keys + return NewCmd(ctx, newArgs...) +} + +// executeSpecialCommand handles commands with special routing requirements +func (c *ClusterClient) executeSpecialCommand(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy, node *clusterNode) error { + switch cmd.Name() { + case "ft.cursor": + return c.executeCursorCommand(ctx, cmd) + default: + return c.executeDefault(ctx, cmd, node) + } +} + +// executeCursorCommand handles FT.CURSOR commands with sticky routing +func (c *ClusterClient) executeCursorCommand(ctx context.Context, cmd Cmder) error { + args := cmd.Args() + if len(args) < 4 { + return fmt.Errorf("redis: FT.CURSOR command requires at least 3 arguments") + } + + cursorID, ok := args[3].(string) + if !ok { + return fmt.Errorf("redis: invalid cursor ID type") + } + + // Route based on cursor ID to maintain stickiness + slot := hashtag.Slot(cursorID) + node, err := c.cmdNode(ctx, cmd.Name(), slot) + if err != nil { + return err + } + + return node.Client.Process(ctx, cmd) +} + +// executeParallel executes a command on multiple nodes concurrently +func (c *ClusterClient) executeParallel(ctx context.Context, cmd Cmder, nodes []*clusterNode, policy *routing.CommandPolicy) error { + if len(nodes) == 0 { + return errClusterNoNodes + } + + if len(nodes) == 1 { + return nodes[0].Client.Process(ctx, cmd) + } + + type nodeResult struct { + cmd Cmder + err error + } + + results := make(chan nodeResult, len(nodes)) + var wg sync.WaitGroup + + for _, node := range nodes { + wg.Add(1) + go func(n *clusterNode) { + defer wg.Done() + cmdCopy := cmd.Clone() + err := n.Client.Process(ctx, cmdCopy) + results <- nodeResult{cmdCopy, err} + }(node) + } + + go func() { + wg.Wait() + close(results) + }() + + // Collect results + cmds := make([]Cmder, 0, len(nodes)) + for result := range results { + cmds = append(cmds, result.cmd) + } + + return c.aggregateResponses(cmd, cmds, policy) +} + +// aggregateMultiSlotResults aggregates results from multi-slot execution +func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder, results <-chan slotResult, keyOrder []string, policy *routing.CommandPolicy) error { + keyedResults := make(map[string]Cmder) + var firstErr error + + for result := range results { + if result.err != nil && firstErr == nil { + firstErr = result.err + } + if result.cmd != nil { + for _, key := range result.keys { + keyedResults[key] = result.cmd + } + } + } + + if firstErr != nil { + cmd.SetErr(firstErr) + return firstErr + } + + return c.aggregateKeyedResponses(ctx, cmd, keyedResults, keyOrder, policy) +} + +// aggregateKeyedResponses aggregates responses while preserving key order +func (c *ClusterClient) aggregateKeyedResponses(ctx context.Context, cmd Cmder, keyedResults map[string]Cmder, keyOrder []string, policy *routing.CommandPolicy) error { + if len(keyedResults) == 0 { + return fmt.Errorf("redis: no results to aggregate") + } + + aggregator := c.createAggregator(policy, cmd, true) + + // Set key order for keyed aggregators + if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { + keyedAgg.SetKeyOrder(keyOrder) + } + + // Add results with keys + for key, shardCmd := range keyedResults { + value := routing.ExtractCommandValue(shardCmd) + if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { + if err := keyedAgg.AddWithKey(key, value, shardCmd.Err()); err != nil { + return err + } + } else { + if err := aggregator.Add(value, shardCmd.Err()); err != nil { + return err + } + } + } + + return c.finishAggregation(cmd, aggregator) +} + +// aggregateResponses aggregates multiple shard responses +func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *routing.CommandPolicy) error { + if len(cmds) == 0 { + return fmt.Errorf("redis: no commands to aggregate") + } + + if len(cmds) == 1 { + shardCmd := cmds[0] + if err := shardCmd.Err(); err != nil { + cmd.SetErr(err) + return err + } + value := routing.ExtractCommandValue(shardCmd) + return c.setCommandValue(cmd, value) + } + + aggregator := c.createAggregator(policy, cmd, false) + + // Add all results to aggregator + for _, shardCmd := range cmds { + value := routing.ExtractCommandValue(shardCmd) + if err := aggregator.Add(value, shardCmd.Err()); err != nil { + return err + } + } + + return c.finishAggregation(cmd, aggregator) +} + +// createAggregator creates the appropriate response aggregator +func (c *ClusterClient) createAggregator(policy *routing.CommandPolicy, cmd Cmder, isKeyed bool) routing.ResponseAggregator { + if policy != nil { + return routing.NewResponseAggregator(policy.Response, cmd.Name()) + } + + if !isKeyed { + firstKeyPos := cmdFirstKeyPos(cmd) + isKeyed = firstKeyPos > 0 + } + + return routing.NewDefaultAggregator(isKeyed) +} + +// finishAggregation completes the aggregation process and sets the result +func (c *ClusterClient) finishAggregation(cmd Cmder, aggregator routing.ResponseAggregator) error { + finalValue, finalErr := aggregator.Finish() + if finalErr != nil { + cmd.SetErr(finalErr) + return finalErr + } + + return c.setCommandValue(cmd, finalValue) +} + +// pickArbitraryShard selects a master shard using the configured ShardPicker +func (c *ClusterClient) pickArbitraryShard(ctx context.Context) *clusterNode { + state, err := c.state.Get(ctx) + if err != nil || len(state.Masters) == 0 { + return nil + } + + idx := c.opt.ShardPicker.Next(len(state.Masters)) + return state.Masters[idx] +} + +// hasKeys checks if a command operates on keys +func (c *ClusterClient) hasKeys(cmd Cmder) bool { + firstKeyPos := cmdFirstKeyPos(cmd) + return firstKeyPos > 0 +} + +// setCommandValue sets the aggregated value on a command using the enum-based approach +func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { + // If value is nil, it might mean ExtractCommandValue couldn't extract the value + // but the command might have executed successfully. In this case, don't set an error. + if value == nil { + // Check if the original command has an error - if not, the nil value is not an error + if cmd.Err() == nil { + // Command executed successfully but value extraction failed + // This is common for complex commands like CLUSTER SLOTS + // The command already has its result set correctly, so just return + return nil + } + // If the command does have an error, set Nil error + cmd.SetErr(Nil) + return Nil + } + + switch cmd.GetCmdType() { + case CmdTypeGeneric: + if c, ok := cmd.(*Cmd); ok { + c.SetVal(value) + } + case CmdTypeString: + if c, ok := cmd.(*StringCmd); ok { + if v, ok := value.(string); ok { + c.SetVal(v) + } + } + case CmdTypeInt: + if c, ok := cmd.(*IntCmd); ok { + if v, ok := value.(int64); ok { + c.SetVal(v) + } + } + case CmdTypeBool: + if c, ok := cmd.(*BoolCmd); ok { + if v, ok := value.(bool); ok { + c.SetVal(v) + } + } + case CmdTypeFloat: + if c, ok := cmd.(*FloatCmd); ok { + if v, ok := value.(float64); ok { + c.SetVal(v) + } + } + case CmdTypeStringSlice: + if c, ok := cmd.(*StringSliceCmd); ok { + if v, ok := value.([]string); ok { + c.SetVal(v) + } + } + case CmdTypeIntSlice: + if c, ok := cmd.(*IntSliceCmd); ok { + if v, ok := value.([]int64); ok { + c.SetVal(v) + } + } + case CmdTypeFloatSlice: + if c, ok := cmd.(*FloatSliceCmd); ok { + if v, ok := value.([]float64); ok { + c.SetVal(v) + } + } + case CmdTypeBoolSlice: + if c, ok := cmd.(*BoolSliceCmd); ok { + if v, ok := value.([]bool); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringString: + if c, ok := cmd.(*MapStringStringCmd); ok { + if v, ok := value.(map[string]string); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringInt: + if c, ok := cmd.(*MapStringIntCmd); ok { + if v, ok := value.(map[string]int64); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringInterface: + if c, ok := cmd.(*MapStringInterfaceCmd); ok { + if v, ok := value.(map[string]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeSlice: + if c, ok := cmd.(*SliceCmd); ok { + if v, ok := value.([]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeStatus: + if c, ok := cmd.(*StatusCmd); ok { + if v, ok := value.(string); ok { + c.SetVal(v) + } + } + case CmdTypeDuration: + if c, ok := cmd.(*DurationCmd); ok { + if v, ok := value.(time.Duration); ok { + c.SetVal(v) + } + } + case CmdTypeTime: + if c, ok := cmd.(*TimeCmd); ok { + if v, ok := value.(time.Time); ok { + c.SetVal(v) + } + } + case CmdTypeKeyValueSlice: + if c, ok := cmd.(*KeyValueSliceCmd); ok { + if v, ok := value.([]KeyValue); ok { + c.SetVal(v) + } + } + case CmdTypeStringStructMap: + if c, ok := cmd.(*StringStructMapCmd); ok { + if v, ok := value.(map[string]struct{}); ok { + c.SetVal(v) + } + } + case CmdTypeXMessageSlice: + if c, ok := cmd.(*XMessageSliceCmd); ok { + if v, ok := value.([]XMessage); ok { + c.SetVal(v) + } + } + case CmdTypeXStreamSlice: + if c, ok := cmd.(*XStreamSliceCmd); ok { + if v, ok := value.([]XStream); ok { + c.SetVal(v) + } + } + case CmdTypeXPending: + if c, ok := cmd.(*XPendingCmd); ok { + if v, ok := value.(*XPending); ok { + c.SetVal(v) + } + } + case CmdTypeXPendingExt: + if c, ok := cmd.(*XPendingExtCmd); ok { + if v, ok := value.([]XPendingExt); ok { + c.SetVal(v) + } + } + case CmdTypeXAutoClaim: + if c, ok := cmd.(*XAutoClaimCmd); ok { + if v, ok := value.([]XMessage); ok { + c.SetVal(v, "") // Default start value + } + } + case CmdTypeXAutoClaimJustID: + if c, ok := cmd.(*XAutoClaimJustIDCmd); ok { + if v, ok := value.([]string); ok { + c.SetVal(v, "") // Default start value + } + } + case CmdTypeXInfoConsumers: + if c, ok := cmd.(*XInfoConsumersCmd); ok { + if v, ok := value.([]XInfoConsumer); ok { + c.SetVal(v) + } + } + case CmdTypeXInfoGroups: + if c, ok := cmd.(*XInfoGroupsCmd); ok { + if v, ok := value.([]XInfoGroup); ok { + c.SetVal(v) + } + } + case CmdTypeXInfoStream: + if c, ok := cmd.(*XInfoStreamCmd); ok { + if v, ok := value.(*XInfoStream); ok { + c.SetVal(v) + } + } + case CmdTypeXInfoStreamFull: + if c, ok := cmd.(*XInfoStreamFullCmd); ok { + if v, ok := value.(*XInfoStreamFull); ok { + c.SetVal(v) + } + } + case CmdTypeZSlice: + if c, ok := cmd.(*ZSliceCmd); ok { + if v, ok := value.([]Z); ok { + c.SetVal(v) + } + } + case CmdTypeZWithKey: + if c, ok := cmd.(*ZWithKeyCmd); ok { + if v, ok := value.(*ZWithKey); ok { + c.SetVal(v) + } + } + case CmdTypeScan: + if c, ok := cmd.(*ScanCmd); ok { + if v, ok := value.([]string); ok { + c.SetVal(v, uint64(0)) // Default cursor + } + } + case CmdTypeClusterSlots: + if c, ok := cmd.(*ClusterSlotsCmd); ok { + if v, ok := value.([]ClusterSlot); ok { + c.SetVal(v) + } + } + case CmdTypeGeoLocation: + if c, ok := cmd.(*GeoLocationCmd); ok { + if v, ok := value.([]GeoLocation); ok { + c.SetVal(v) + } + } + case CmdTypeGeoSearchLocation: + if c, ok := cmd.(*GeoSearchLocationCmd); ok { + if v, ok := value.([]GeoLocation); ok { + c.SetVal(v) + } + } + case CmdTypeGeoPos: + if c, ok := cmd.(*GeoPosCmd); ok { + if v, ok := value.([]*GeoPos); ok { + c.SetVal(v) + } + } + case CmdTypeCommandsInfo: + if c, ok := cmd.(*CommandsInfoCmd); ok { + if v, ok := value.(map[string]*CommandInfo); ok { + c.SetVal(v) + } + } + case CmdTypeSlowLog: + if c, ok := cmd.(*SlowLogCmd); ok { + if v, ok := value.([]SlowLog); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringStringSlice: + if c, ok := cmd.(*MapStringStringSliceCmd); ok { + if v, ok := value.([]map[string]string); ok { + c.SetVal(v) + } + } + case CmdTypeMapMapStringInterface: + if c, ok := cmd.(*MapMapStringInterfaceCmd); ok { + if v, ok := value.(map[string]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringInterfaceSlice: + if c, ok := cmd.(*MapStringInterfaceSliceCmd); ok { + if v, ok := value.([]map[string]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeKeyValues: + if c, ok := cmd.(*KeyValuesCmd); ok { + // KeyValuesCmd needs a key string and values slice + if key, ok := value.(string); ok { + c.SetVal(key, []string{}) // Default empty values + } + } + case CmdTypeZSliceWithKey: + if c, ok := cmd.(*ZSliceWithKeyCmd); ok { + // ZSliceWithKeyCmd needs a key string and Z slice + if key, ok := value.(string); ok { + c.SetVal(key, []Z{}) // Default empty Z slice + } + } + case CmdTypeFunctionList: + if c, ok := cmd.(*FunctionListCmd); ok { + if v, ok := value.([]Library); ok { + c.SetVal(v) + } + } + case CmdTypeFunctionStats: + if c, ok := cmd.(*FunctionStatsCmd); ok { + if v, ok := value.(FunctionStats); ok { + c.SetVal(v) + } + } + case CmdTypeLCS: + if c, ok := cmd.(*LCSCmd); ok { + if v, ok := value.(*LCSMatch); ok { + c.SetVal(v) + } + } + case CmdTypeKeyFlags: + if c, ok := cmd.(*KeyFlagsCmd); ok { + if v, ok := value.([]KeyFlags); ok { + c.SetVal(v) + } + } + case CmdTypeClusterLinks: + if c, ok := cmd.(*ClusterLinksCmd); ok { + if v, ok := value.([]ClusterLink); ok { + c.SetVal(v) + } + } + case CmdTypeClusterShards: + if c, ok := cmd.(*ClusterShardsCmd); ok { + if v, ok := value.([]ClusterShard); ok { + c.SetVal(v) + } + } + case CmdTypeRankWithScore: + if c, ok := cmd.(*RankWithScoreCmd); ok { + if v, ok := value.(RankScore); ok { + c.SetVal(v) + } + } + case CmdTypeClientInfo: + if c, ok := cmd.(*ClientInfoCmd); ok { + if v, ok := value.(*ClientInfo); ok { + c.SetVal(v) + } + } + case CmdTypeACLLog: + if c, ok := cmd.(*ACLLogCmd); ok { + if v, ok := value.([]*ACLLogEntry); ok { + c.SetVal(v) + } + } + case CmdTypeInfo: + if c, ok := cmd.(*InfoCmd); ok { + if v, ok := value.(map[string]map[string]string); ok { + c.SetVal(v) + } + } + case CmdTypeMonitor: + // MonitorCmd doesn't have SetVal method + // Skip setting value for MonitorCmd + case CmdTypeJSON: + if c, ok := cmd.(*JSONCmd); ok { + if v, ok := value.(string); ok { + c.SetVal(v) + } + } + case CmdTypeJSONSlice: + if c, ok := cmd.(*JSONSliceCmd); ok { + if v, ok := value.([]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeIntPointerSlice: + if c, ok := cmd.(*IntPointerSliceCmd); ok { + if v, ok := value.([]*int64); ok { + c.SetVal(v) + } + } + case CmdTypeScanDump: + if c, ok := cmd.(*ScanDumpCmd); ok { + if v, ok := value.(ScanDump); ok { + c.SetVal(v) + } + } + case CmdTypeBFInfo: + if c, ok := cmd.(*BFInfoCmd); ok { + if v, ok := value.(BFInfo); ok { + c.SetVal(v) + } + } + case CmdTypeCFInfo: + if c, ok := cmd.(*CFInfoCmd); ok { + if v, ok := value.(CFInfo); ok { + c.SetVal(v) + } + } + case CmdTypeCMSInfo: + if c, ok := cmd.(*CMSInfoCmd); ok { + if v, ok := value.(CMSInfo); ok { + c.SetVal(v) + } + } + case CmdTypeTopKInfo: + if c, ok := cmd.(*TopKInfoCmd); ok { + if v, ok := value.(TopKInfo); ok { + c.SetVal(v) + } + } + case CmdTypeTDigestInfo: + if c, ok := cmd.(*TDigestInfoCmd); ok { + if v, ok := value.(TDigestInfo); ok { + c.SetVal(v) + } + } + case CmdTypeFTSynDump: + if c, ok := cmd.(*FTSynDumpCmd); ok { + if v, ok := value.([]FTSynDumpResult); ok { + c.SetVal(v) + } + } + case CmdTypeAggregate: + if c, ok := cmd.(*AggregateCmd); ok { + if v, ok := value.(*FTAggregateResult); ok { + c.SetVal(v) + } + } + case CmdTypeFTInfo: + if c, ok := cmd.(*FTInfoCmd); ok { + if v, ok := value.(FTInfoResult); ok { + c.SetVal(v) + } + } + case CmdTypeFTSpellCheck: + if c, ok := cmd.(*FTSpellCheckCmd); ok { + if v, ok := value.([]SpellCheckResult); ok { + c.SetVal(v) + } + } + case CmdTypeFTSearch: + if c, ok := cmd.(*FTSearchCmd); ok { + if v, ok := value.(FTSearchResult); ok { + c.SetVal(v) + } + } + case CmdTypeTSTimestampValue: + if c, ok := cmd.(*TSTimestampValueCmd); ok { + if v, ok := value.(TSTimestampValue); ok { + c.SetVal(v) + } + } + case CmdTypeTSTimestampValueSlice: + if c, ok := cmd.(*TSTimestampValueSliceCmd); ok { + if v, ok := value.([]TSTimestampValue); ok { + c.SetVal(v) + } + } + default: + // Fallback to reflection for unknown types + return c.setCommandValueReflection(cmd, value) + } + + return nil +} + +// setCommandValueReflection is a fallback function that uses reflection +func (c *ClusterClient) setCommandValueReflection(cmd Cmder, value interface{}) error { + cmdValue := reflect.ValueOf(cmd) + if cmdValue.Kind() != reflect.Ptr || cmdValue.IsNil() { + return fmt.Errorf("redis: invalid command pointer") + } + + setValMethod := cmdValue.MethodByName("SetVal") + if !setValMethod.IsValid() { + return fmt.Errorf("redis: command %T does not have SetVal method", cmd) + } + + args := []reflect.Value{reflect.ValueOf(value)} + + switch cmd.(type) { + case *XAutoClaimCmd, *XAutoClaimJustIDCmd: + args = append(args, reflect.ValueOf("")) + case *ScanCmd: + args = append(args, reflect.ValueOf(uint64(0))) + case *KeyValuesCmd, *ZSliceWithKeyCmd: + if key, ok := value.(string); ok { + args = []reflect.Value{reflect.ValueOf(key)} + if _, ok := cmd.(*ZSliceWithKeyCmd); ok { + args = append(args, reflect.ValueOf([]Z{})) + } else { + args = append(args, reflect.ValueOf([]string{})) + } + } + } + + defer func() { + if r := recover(); r != nil { + cmd.SetErr(fmt.Errorf("redis: failed to set command value: %v", r)) + } + }() + + setValMethod.Call(args) + return nil +} diff --git a/osscluster_router_test.go b/osscluster_router_test.go new file mode 100644 index 0000000000..d2b3f94440 --- /dev/null +++ b/osscluster_router_test.go @@ -0,0 +1,379 @@ +package redis + +// import ( +// "context" +// "sync" +// "testing" +// "time" + +// . "github.com/bsm/ginkgo/v2" +// . "github.com/bsm/gomega" + +// "github.com/redis/go-redis/v9/internal/routing" +// ) + +// var _ = Describe("ExtractCommandValue", func() { +// It("should extract value from generic command", func() { +// cmd := NewCmd(nil, "test") +// cmd.SetVal("value") +// val := routing.ExtractCommandValue(cmd) +// Expect(val).To(Equal("value")) +// }) + +// It("should extract value from integer command", func() { +// intCmd := NewIntCmd(nil, "test") +// intCmd.SetVal(42) +// val := routing.ExtractCommandValue(intCmd) +// Expect(val).To(Equal(int64(42))) +// }) + +// It("should handle nil command", func() { +// val := routing.ExtractCommandValue(nil) +// Expect(val).To(BeNil()) +// }) +// }) + +// var _ = Describe("ClusterClient setCommandValue", func() { +// var client *ClusterClient + +// BeforeEach(func() { +// client = &ClusterClient{} +// }) + +// It("should set generic value", func() { +// cmd := NewCmd(nil, "test") +// err := client.setCommandValue(cmd, "new_value") +// Expect(err).NotTo(HaveOccurred()) +// Expect(cmd.Val()).To(Equal("new_value")) +// }) + +// It("should set integer value", func() { +// intCmd := NewIntCmd(nil, "test") +// err := client.setCommandValue(intCmd, int64(100)) +// Expect(err).NotTo(HaveOccurred()) +// Expect(intCmd.Val()).To(Equal(int64(100))) +// }) + +// It("should return error for type mismatch", func() { +// intCmd := NewIntCmd(nil, "test") +// err := client.setCommandValue(intCmd, "string_value") +// Expect(err).To(HaveOccurred()) +// Expect(err.Error()).To(ContainSubstring("cannot set IntCmd value from string")) +// }) +// }) + +// func TestConcurrentRouting(t *testing.T) { +// // This test ensures that concurrent execution doesn't cause response mismatches +// // or MOVED errors due to race conditions + +// // Mock cluster client for testing +// opt := &ClusterOptions{ +// Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"}, +// } + +// // Skip if no cluster available +// if testing.Short() { +// t.Skip("skipping cluster test in short mode") +// } + +// client := NewClusterClient(opt) +// defer client.Close() + +// // Test concurrent execution of commands with different policies +// var wg sync.WaitGroup +// numRoutines := 50 +// numCommands := 100 + +// // Channel to collect errors +// errors := make(chan error, numRoutines*numCommands) + +// for i := 0; i < numRoutines; i++ { +// wg.Add(1) +// go func(routineID int) { +// defer wg.Done() + +// for j := 0; j < numCommands; j++ { +// ctx := context.Background() + +// // Test different command types +// switch j % 4 { +// case 0: +// // Test keyless command (should use arbitrary shard) +// cmd := NewCmd(ctx, "PING") +// err := client.routeAndRun(ctx, cmd) +// if err != nil { +// errors <- err +// } +// case 1: +// // Test keyed command (should use slot-based routing) +// key := "test_key_" + string(rune(routineID)) + "_" + string(rune(j)) +// cmd := NewCmd(ctx, "GET", key) +// err := client.routeAndRun(ctx, cmd) +// if err != nil { +// errors <- err +// } +// case 2: +// // Test multi-shard command +// cmd := NewCmd(ctx, "MGET", "key1", "key2", "key3") +// err := client.routeAndRun(ctx, cmd) +// if err != nil { +// errors <- err +// } +// case 3: +// // Test all-shards command +// cmd := NewCmd(ctx, "DBSIZE") +// // Note: In actual implementation, the policy would come from COMMAND tips +// err := client.routeAndRun(ctx, cmd) +// if err != nil { +// errors <- err +// } +// } +// } +// }(i) +// } + +// // Wait for all routines to complete +// wg.Wait() +// close(errors) + +// // Check for errors +// var errorCount int +// for err := range errors { +// t.Errorf("Concurrent routing error: %v", err) +// errorCount++ +// if errorCount > 10 { // Limit error output +// break +// } +// } + +// if errorCount > 0 { +// t.Fatalf("Found %d errors in concurrent routing test", errorCount) +// } +// } + +// func TestResponseAggregation(t *testing.T) { +// // Test that response aggregation works correctly for different policies + +// if testing.Short() { +// t.Skip("skipping cluster test in short mode") +// } + +// // Test all_succeeded aggregation +// t.Run("AllSucceeded", func(t *testing.T) { +// aggregator := routing.NewResponseAggregator(routing.RespAllSucceeded, "TEST") + +// // Add successful results +// err := aggregator.Add("result1", nil) +// if err != nil { +// t.Errorf("Failed to add result: %v", err) +// } + +// err = aggregator.Add("result2", nil) +// if err != nil { +// t.Errorf("Failed to add result: %v", err) +// } + +// result, err := aggregator.Finish() +// if err != nil { +// t.Errorf("AllSucceeded aggregation failed: %v", err) +// } + +// if result != "result1" { +// t.Errorf("Expected 'result1', got %v", result) +// } +// }) + +// // Test agg_sum aggregation +// t.Run("AggSum", func(t *testing.T) { +// aggregator := routing.NewResponseAggregator(routing.RespAggSum, "TEST") + +// // Add numeric results +// err := aggregator.Add(int64(5), nil) +// if err != nil { +// t.Errorf("Failed to add result: %v", err) +// } + +// err = aggregator.Add(int64(10), nil) +// if err != nil { +// t.Errorf("Failed to add result: %v", err) +// } + +// result, err := aggregator.Finish() +// if err != nil { +// t.Errorf("AggSum aggregation failed: %v", err) +// } + +// if result != int64(15) { +// t.Errorf("Expected 15, got %v", result) +// } +// }) + +// // Test special aggregation for search commands +// t.Run("Special", func(t *testing.T) { +// aggregator := routing.NewResponseAggregator(routing.RespSpecial, "FT.SEARCH") + +// // Add search results +// searchResult := map[string]interface{}{ +// "total": 5, +// "docs": []interface{}{"doc1", "doc2"}, +// } + +// err := aggregator.Add(searchResult, nil) +// if err != nil { +// t.Errorf("Failed to add result: %v", err) +// } + +// result, err := aggregator.Finish() +// if err != nil { +// t.Errorf("Special aggregation failed: %v", err) +// } + +// if result == nil { +// t.Error("Expected non-nil result from special aggregation") +// } +// }) +// } + +// func TestShardPicking(t *testing.T) { +// // Test that arbitrary shard picking works correctly and doesn't always pick the first shard + +// opt := &ClusterOptions{ +// Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"}, +// } + +// if testing.Short() { +// t.Skip("skipping cluster test in short mode") +// } + +// client := NewClusterClient(opt) +// defer client.Close() + +// ctx := context.Background() + +// // Track which shards are picked +// shardCounts := make(map[string]int) +// var mu sync.Mutex + +// // Execute keyless commands multiple times +// var wg sync.WaitGroup +// numRequests := 100 + +// for i := 0; i < numRequests; i++ { +// wg.Add(1) +// go func() { +// defer wg.Done() + +// node := client.pickArbitraryShard(ctx) +// if node != nil { +// addr := node.Client.Options().Addr +// mu.Lock() +// shardCounts[addr]++ +// mu.Unlock() +// } +// }() +// } + +// wg.Wait() + +// // Verify that multiple shards were used (not just the first one) +// if len(shardCounts) < 2 { +// t.Error("Shard picking should distribute across multiple shards") +// } + +// // Verify reasonable distribution (no shard should have more than 80% of requests) +// for addr, count := range shardCounts { +// percentage := float64(count) / float64(numRequests) * 100 +// if percentage > 80 { +// t.Errorf("Shard %s got %d%% of requests, distribution should be more even", addr, int(percentage)) +// } +// t.Logf("Shard %s: %d requests (%.1f%%)", addr, count, percentage) +// } +// } + +// func TestCursorRouting(t *testing.T) { +// // Test that cursor commands are routed to the correct shard + +// opt := &ClusterOptions{ +// Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"}, +// } + +// if testing.Short() { +// t.Skip("skipping cluster test in short mode") +// } + +// client := NewClusterClient(opt) +// defer client.Close() + +// ctx := context.Background() + +// // Test FT.CURSOR command routing +// cmd := NewCmd(ctx, "FT.CURSOR", "READ", "myindex", "cursor123", "COUNT", "10") + +// // This should not panic or return an error due to incorrect routing +// err := client.executeSpecial(ctx, cmd, &routing.CommandPolicy{ +// Request: routing.ReqSpecial, +// Response: routing.RespSpecial, +// }) + +// // We expect this to fail with connection error in test environment, but not with routing error +// if err != nil && err.Error() != "redis: connection refused" { +// t.Logf("Cursor routing test completed with expected connection error: %v", err) +// } +// } + +// // Mock command methods for testing +// type testCmd struct { +// *Cmd +// requestPolicy routing.RequestPolicy +// responsePolicy routing.ResponsePolicy +// } + +// func (c *testCmd) setRequestPolicy(policy routing.RequestPolicy) { +// c.requestPolicy = policy +// } + +// func (c *testCmd) setResponsePolicy(policy routing.ResponsePolicy) { +// c.responsePolicy = policy +// } + +// func TestRaceConditionFree(t *testing.T) { +// // Test to ensure no race conditions in concurrent access + +// opt := &ClusterOptions{ +// Addrs: []string{"127.0.0.1:7000"}, +// } + +// if testing.Short() { +// t.Skip("skipping cluster test in short mode") +// } + +// client := NewClusterClient(opt) +// defer client.Close() + +// // Run with race detector enabled: go test -race +// var wg sync.WaitGroup +// numGoroutines := 100 + +// for i := 0; i < numGoroutines; i++ { +// wg.Add(1) +// go func(id int) { +// defer wg.Done() + +// ctx := context.Background() + +// // Simulate concurrent command execution +// for j := 0; j < 10; j++ { +// cmd := NewCmd(ctx, "PING") +// _ = client.routeAndRun(ctx, cmd) + +// // Small delay to increase chance of race conditions +// time.Sleep(time.Microsecond) +// } +// }(i) +// } + +// wg.Wait() + +// // If we reach here without race detector complaints, test passes +// t.Log("Race condition test completed successfully") +// } diff --git a/osscluster_test.go b/osscluster_test.go index 3659ec6547..4562d2f1c1 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -6,6 +6,9 @@ import ( "errors" "fmt" "net" + "os" + "runtime" + "runtime/pprof" "slices" "strconv" "strings" @@ -15,11 +18,19 @@ import ( . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" - + "github.com/fortytw2/leaktest" "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/internal/hashtag" ) +// leakCleanup holds the per-spec leak check function +var leakCleanup func() + +// sanitizeFilename converts spaces and slashes into underscores +func sanitizeFilename(s string) string { + return strings.NewReplacer(" ", "_", "/", "_").Replace(s) +} + type clusterScenario struct { ports []string nodeIDs []string @@ -257,7 +268,7 @@ func slotEqual(s1, s2 redis.ClusterSlot) bool { return true } -//------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ var _ = Describe("ClusterClient", func() { var failover bool @@ -1186,6 +1197,7 @@ var _ = Describe("ClusterClient", func() { }) Expect(err).NotTo(HaveOccurred()) + var mu sync.Mutex var stack []string clusterHook := &hook{ @@ -1198,12 +1210,16 @@ var _ = Describe("ClusterClient", func() { } Expect(cmd.String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "cluster.BeforeProcess") + mu.Unlock() err := hook(ctx, cmd) Expect(cmd.String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "cluster.AfterProcess") + mu.Unlock() return err } @@ -1221,12 +1237,16 @@ var _ = Describe("ClusterClient", func() { } Expect(cmd.String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "shard.BeforeProcess") + mu.Unlock() err := hook(ctx, cmd) Expect(cmd.String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "shard.AfterProcess") + mu.Unlock() return err } @@ -1240,7 +1260,13 @@ var _ = Describe("ClusterClient", func() { err = client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) - Expect(stack).To(Equal([]string{ + + mu.Lock() + finalStack := make([]string, len(stack)) + copy(finalStack, stack) + mu.Unlock() + + Expect(finalStack).To(ContainElements([]string{ "cluster.BeforeProcess", "shard.BeforeProcess", "shard.AfterProcess", @@ -1257,6 +1283,7 @@ var _ = Describe("ClusterClient", func() { }) Expect(err).NotTo(HaveOccurred()) + var mu sync.Mutex var stack []string client.AddHook(&hook{ @@ -1264,13 +1291,17 @@ var _ = Describe("ClusterClient", func() { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "cluster.BeforeProcessPipeline") + mu.Unlock() err := hook(ctx, cmds) Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "cluster.AfterProcessPipeline") + mu.Unlock() return err } @@ -1283,13 +1314,17 @@ var _ = Describe("ClusterClient", func() { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "shard.BeforeProcessPipeline") + mu.Unlock() err := hook(ctx, cmds) Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "shard.AfterProcessPipeline") + mu.Unlock() return err } @@ -1303,7 +1338,13 @@ var _ = Describe("ClusterClient", func() { return nil }) Expect(err).NotTo(HaveOccurred()) - Expect(stack).To(Equal([]string{ + + mu.Lock() + finalStack := make([]string, len(stack)) + copy(finalStack, stack) + mu.Unlock() + + Expect(finalStack).To(Equal([]string{ "cluster.BeforeProcessPipeline", "shard.BeforeProcessPipeline", "shard.AfterProcessPipeline", @@ -1320,6 +1361,7 @@ var _ = Describe("ClusterClient", func() { }) Expect(err).NotTo(HaveOccurred()) + var mu sync.Mutex var stack []string client.AddHook(&hook{ @@ -1327,13 +1369,17 @@ var _ = Describe("ClusterClient", func() { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "cluster.BeforeProcessPipeline") + mu.Unlock() err := hook(ctx, cmds) Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "cluster.AfterProcessPipeline") + mu.Unlock() return err } @@ -1346,13 +1392,17 @@ var _ = Describe("ClusterClient", func() { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "shard.BeforeProcessPipeline") + mu.Unlock() err := hook(ctx, cmds) Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "shard.AfterProcessPipeline") + mu.Unlock() return err } @@ -1366,7 +1416,13 @@ var _ = Describe("ClusterClient", func() { return nil }) Expect(err).NotTo(HaveOccurred()) - Expect(stack).To(Equal([]string{ + + mu.Lock() + finalStack := make([]string, len(stack)) + copy(finalStack, stack) + mu.Unlock() + + Expect(finalStack).To(Equal([]string{ "cluster.BeforeProcessPipeline", "shard.BeforeProcessPipeline", "shard.AfterProcessPipeline", @@ -1533,6 +1589,8 @@ var _ = Describe("ClusterClient", func() { Describe("ClusterClient with ClusterSlots with multiple nodes per slot", func() { BeforeEach(func() { + leakCleanup = leaktest.Check(GinkgoT()) + GinkgoWriter.Printf("[DEBUG] goroutines at start: %d\n", runtime.NumGoroutine()) failover = true opt = redisClusterOptions() @@ -1582,6 +1640,21 @@ var _ = Describe("ClusterClient", func() { }) AfterEach(func() { + leakCleanup() + + // on failure, write out a full goroutine dump + if CurrentSpecReport().Failed() { + fname := fmt.Sprintf("goroutines-%s.txt", sanitizeFilename(CurrentSpecReport().LeafNodeText)) + if f, err := os.Create(fname); err == nil { + pprof.Lookup("goroutine").WriteTo(f, 2) + f.Close() + GinkgoWriter.Printf("[DEBUG] wrote goroutine dump to %s\n", fname) + } else { + GinkgoWriter.Printf("[DEBUG] failed to write goroutine dump: %v\n", err) + } + } + + GinkgoWriter.Printf("[DEBUG] goroutines at end: %d\n", runtime.NumGoroutine()) failover = false err := client.Close() diff --git a/probabilistic.go b/probabilistic.go index c26e7cadbd..ee67911e69 100644 --- a/probabilistic.go +++ b/probabilistic.go @@ -225,8 +225,9 @@ type ScanDumpCmd struct { func newScanDumpCmd(ctx context.Context, args ...interface{}) *ScanDumpCmd { return &ScanDumpCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeScanDump, }, } } @@ -270,6 +271,13 @@ func (cmd *ScanDumpCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *ScanDumpCmd) Clone() Cmder { + return &ScanDumpCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // ScanDump is a simple struct, can be copied directly + } +} + // Returns information about a Bloom filter. // For more information - https://redis.io/commands/bf.info/ func (c cmdable) BFInfo(ctx context.Context, key string) *BFInfoCmd { @@ -296,8 +304,9 @@ type BFInfoCmd struct { func NewBFInfoCmd(ctx context.Context, args ...interface{}) *BFInfoCmd { return &BFInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeBFInfo, }, } } @@ -388,6 +397,13 @@ func (cmd *BFInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *BFInfoCmd) Clone() Cmder { + return &BFInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // BFInfo is a simple struct, can be copied directly + } +} + // BFInfoCapacity returns information about the capacity of a Bloom filter. // For more information - https://redis.io/commands/bf.info/ func (c cmdable) BFInfoCapacity(ctx context.Context, key string) *BFInfoCmd { @@ -625,8 +641,9 @@ type CFInfoCmd struct { func NewCFInfoCmd(ctx context.Context, args ...interface{}) *CFInfoCmd { return &CFInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeCFInfo, }, } } @@ -692,6 +709,13 @@ func (cmd *CFInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *CFInfoCmd) Clone() Cmder { + return &CFInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // CFInfo is a simple struct, can be copied directly + } +} + // CFInfo returns information about a Cuckoo filter. // For more information - https://redis.io/commands/cf.info/ func (c cmdable) CFInfo(ctx context.Context, key string) *CFInfoCmd { @@ -787,8 +811,9 @@ type CMSInfoCmd struct { func NewCMSInfoCmd(ctx context.Context, args ...interface{}) *CMSInfoCmd { return &CMSInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeCMSInfo, }, } } @@ -843,6 +868,13 @@ func (cmd *CMSInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *CMSInfoCmd) Clone() Cmder { + return &CMSInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // CMSInfo is a simple struct, can be copied directly + } +} + // CMSInfo returns information about a Count-Min Sketch filter. // For more information - https://redis.io/commands/cms.info/ func (c cmdable) CMSInfo(ctx context.Context, key string) *CMSInfoCmd { @@ -980,8 +1012,9 @@ type TopKInfoCmd struct { func NewTopKInfoCmd(ctx context.Context, args ...interface{}) *TopKInfoCmd { return &TopKInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTopKInfo, }, } } @@ -1038,6 +1071,13 @@ func (cmd *TopKInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TopKInfoCmd) Clone() Cmder { + return &TopKInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // TopKInfo is a simple struct, can be copied directly + } +} + // TopKInfo returns information about a Top-K filter. // For more information - https://redis.io/commands/topk.info/ func (c cmdable) TopKInfo(ctx context.Context, key string) *TopKInfoCmd { @@ -1227,8 +1267,9 @@ type TDigestInfoCmd struct { func NewTDigestInfoCmd(ctx context.Context, args ...interface{}) *TDigestInfoCmd { return &TDigestInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTDigestInfo, }, } } @@ -1295,6 +1336,13 @@ func (cmd *TDigestInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TDigestInfoCmd) Clone() Cmder { + return &TDigestInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // TDigestInfo is a simple struct, can be copied directly + } +} + // TDigestInfo returns information about a t-Digest data structure. // For more information - https://redis.io/commands/tdigest.info/ func (c cmdable) TDigestInfo(ctx context.Context, key string) *TDigestInfoCmd { diff --git a/search_commands.go b/search_commands.go index f0ca1bfede..e1f21c4749 100644 --- a/search_commands.go +++ b/search_commands.go @@ -672,8 +672,9 @@ func ProcessAggregateResult(data []interface{}) (*FTAggregateResult, error) { func NewAggregateCmd(ctx context.Context, args ...interface{}) *AggregateCmd { return &AggregateCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeAggregate, }, } } @@ -714,6 +715,31 @@ func (cmd *AggregateCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *AggregateCmd) Clone() Cmder { + var val *FTAggregateResult + if cmd.val != nil { + val = &FTAggregateResult{ + Total: cmd.val.Total, + } + if cmd.val.Rows != nil { + val.Rows = make([]AggregateRow, len(cmd.val.Rows)) + for i, row := range cmd.val.Rows { + val.Rows[i] = AggregateRow{} + if row.Fields != nil { + val.Rows[i].Fields = make(map[string]interface{}, len(row.Fields)) + for k, v := range row.Fields { + val.Rows[i].Fields[k] = v + } + } + } + } + } + return &AggregateCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // FTAggregateWithArgs - Performs a search query on an index and applies a series of aggregate transformations to the result. // The 'index' parameter specifies the index to search, and the 'query' parameter specifies the search query. // This function also allows for specifying additional options such as: Verbatim, LoadAll, Load, Timeout, GroupBy, SortBy, SortByMax, Apply, LimitOffset, Limit, Filter, WithCursor, Params, and DialectVersion. @@ -1464,8 +1490,9 @@ type FTInfoCmd struct { func newFTInfoCmd(ctx context.Context, args ...interface{}) *FTInfoCmd { return &FTInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTInfo, }, } } @@ -1527,6 +1554,68 @@ func (cmd *FTInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *FTInfoCmd) Clone() Cmder { + val := FTInfoResult{ + IndexErrors: cmd.val.IndexErrors, + BytesPerRecordAvg: cmd.val.BytesPerRecordAvg, + Cleaning: cmd.val.Cleaning, + CursorStats: cmd.val.CursorStats, + DocTableSizeMB: cmd.val.DocTableSizeMB, + GCStats: cmd.val.GCStats, + GeoshapesSzMB: cmd.val.GeoshapesSzMB, + HashIndexingFailures: cmd.val.HashIndexingFailures, + IndexDefinition: cmd.val.IndexDefinition, + IndexName: cmd.val.IndexName, + Indexing: cmd.val.Indexing, + InvertedSzMB: cmd.val.InvertedSzMB, + KeyTableSizeMB: cmd.val.KeyTableSizeMB, + MaxDocID: cmd.val.MaxDocID, + NumDocs: cmd.val.NumDocs, + NumRecords: cmd.val.NumRecords, + NumTerms: cmd.val.NumTerms, + NumberOfUses: cmd.val.NumberOfUses, + OffsetBitsPerRecordAvg: cmd.val.OffsetBitsPerRecordAvg, + OffsetVectorsSzMB: cmd.val.OffsetVectorsSzMB, + OffsetsPerTermAvg: cmd.val.OffsetsPerTermAvg, + PercentIndexed: cmd.val.PercentIndexed, + RecordsPerDocAvg: cmd.val.RecordsPerDocAvg, + SortableValuesSizeMB: cmd.val.SortableValuesSizeMB, + TagOverheadSzMB: cmd.val.TagOverheadSzMB, + TextOverheadSzMB: cmd.val.TextOverheadSzMB, + TotalIndexMemorySzMB: cmd.val.TotalIndexMemorySzMB, + TotalIndexingTime: cmd.val.TotalIndexingTime, + TotalInvertedIndexBlocks: cmd.val.TotalInvertedIndexBlocks, + VectorIndexSzMB: cmd.val.VectorIndexSzMB, + } + // Clone slices and maps + if cmd.val.Attributes != nil { + val.Attributes = make([]FTAttribute, len(cmd.val.Attributes)) + copy(val.Attributes, cmd.val.Attributes) + } + if cmd.val.DialectStats != nil { + val.DialectStats = make(map[string]int, len(cmd.val.DialectStats)) + for k, v := range cmd.val.DialectStats { + val.DialectStats[k] = v + } + } + if cmd.val.FieldStatistics != nil { + val.FieldStatistics = make([]FieldStatistic, len(cmd.val.FieldStatistics)) + copy(val.FieldStatistics, cmd.val.FieldStatistics) + } + if cmd.val.IndexOptions != nil { + val.IndexOptions = make([]string, len(cmd.val.IndexOptions)) + copy(val.IndexOptions, cmd.val.IndexOptions) + } + if cmd.val.IndexDefinition.Prefixes != nil { + val.IndexDefinition.Prefixes = make([]string, len(cmd.val.IndexDefinition.Prefixes)) + copy(val.IndexDefinition.Prefixes, cmd.val.IndexDefinition.Prefixes) + } + return &FTInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // FTInfo - Retrieves information about an index. // The 'index' parameter specifies the index to retrieve information about. // For more information, please refer to the Redis documentation: @@ -1583,8 +1672,9 @@ type FTSpellCheckCmd struct { func newFTSpellCheckCmd(ctx context.Context, args ...interface{}) *FTSpellCheckCmd { return &FTSpellCheckCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTSpellCheck, }, } } @@ -1680,6 +1770,26 @@ func parseFTSpellCheck(data []interface{}) ([]SpellCheckResult, error) { return results, nil } +func (cmd *FTSpellCheckCmd) Clone() Cmder { + var val []SpellCheckResult + if cmd.val != nil { + val = make([]SpellCheckResult, len(cmd.val)) + for i, result := range cmd.val { + val[i] = SpellCheckResult{ + Term: result.Term, + } + if result.Suggestions != nil { + val[i].Suggestions = make([]SpellCheckSuggestion, len(result.Suggestions)) + copy(val[i].Suggestions, result.Suggestions) + } + } + } + return &FTSpellCheckCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + func parseFTSearch(data []interface{}, noContent, withScores, withPayloads, withSortKeys bool) (FTSearchResult, error) { if len(data) < 1 { return FTSearchResult{}, fmt.Errorf("unexpected search result format") @@ -1776,8 +1886,9 @@ type FTSearchCmd struct { func newFTSearchCmd(ctx context.Context, options *FTSearchOptions, args ...interface{}) *FTSearchCmd { return &FTSearchCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTSearch, }, options: options, } @@ -1819,6 +1930,89 @@ func (cmd *FTSearchCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *FTSearchCmd) Clone() Cmder { + val := FTSearchResult{ + Total: cmd.val.Total, + } + if cmd.val.Docs != nil { + val.Docs = make([]Document, len(cmd.val.Docs)) + for i, doc := range cmd.val.Docs { + val.Docs[i] = Document{ + ID: doc.ID, + Score: doc.Score, + Payload: doc.Payload, + SortKey: doc.SortKey, + } + if doc.Fields != nil { + val.Docs[i].Fields = make(map[string]string, len(doc.Fields)) + for k, v := range doc.Fields { + val.Docs[i].Fields[k] = v + } + } + } + } + var options *FTSearchOptions + if cmd.options != nil { + options = &FTSearchOptions{ + NoContent: cmd.options.NoContent, + Verbatim: cmd.options.Verbatim, + NoStopWords: cmd.options.NoStopWords, + WithScores: cmd.options.WithScores, + WithPayloads: cmd.options.WithPayloads, + WithSortKeys: cmd.options.WithSortKeys, + Slop: cmd.options.Slop, + Timeout: cmd.options.Timeout, + InOrder: cmd.options.InOrder, + Language: cmd.options.Language, + Expander: cmd.options.Expander, + Scorer: cmd.options.Scorer, + ExplainScore: cmd.options.ExplainScore, + Payload: cmd.options.Payload, + SortByWithCount: cmd.options.SortByWithCount, + LimitOffset: cmd.options.LimitOffset, + Limit: cmd.options.Limit, + CountOnly: cmd.options.CountOnly, + DialectVersion: cmd.options.DialectVersion, + } + // Clone slices and maps + if cmd.options.Filters != nil { + options.Filters = make([]FTSearchFilter, len(cmd.options.Filters)) + copy(options.Filters, cmd.options.Filters) + } + if cmd.options.GeoFilter != nil { + options.GeoFilter = make([]FTSearchGeoFilter, len(cmd.options.GeoFilter)) + copy(options.GeoFilter, cmd.options.GeoFilter) + } + if cmd.options.InKeys != nil { + options.InKeys = make([]interface{}, len(cmd.options.InKeys)) + copy(options.InKeys, cmd.options.InKeys) + } + if cmd.options.InFields != nil { + options.InFields = make([]interface{}, len(cmd.options.InFields)) + copy(options.InFields, cmd.options.InFields) + } + if cmd.options.Return != nil { + options.Return = make([]FTSearchReturn, len(cmd.options.Return)) + copy(options.Return, cmd.options.Return) + } + if cmd.options.SortBy != nil { + options.SortBy = make([]FTSearchSortBy, len(cmd.options.SortBy)) + copy(options.SortBy, cmd.options.SortBy) + } + if cmd.options.Params != nil { + options.Params = make(map[string]interface{}, len(cmd.options.Params)) + for k, v := range cmd.options.Params { + options.Params[k] = v + } + } + } + return &FTSearchCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + options: options, + } +} + // FTSearch - Executes a search query on an index. // The 'index' parameter specifies the index to search, and the 'query' parameter specifies the search query. // For more information, please refer to the Redis documentation about [FT.SEARCH]. @@ -2078,8 +2272,9 @@ func (c cmdable) FTSearchWithArgs(ctx context.Context, index string, query strin func NewFTSynDumpCmd(ctx context.Context, args ...interface{}) *FTSynDumpCmd { return &FTSynDumpCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTSynDump, }, } } @@ -2145,6 +2340,26 @@ func (cmd *FTSynDumpCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *FTSynDumpCmd) Clone() Cmder { + var val []FTSynDumpResult + if cmd.val != nil { + val = make([]FTSynDumpResult, len(cmd.val)) + for i, result := range cmd.val { + val[i] = FTSynDumpResult{ + Term: result.Term, + } + if result.Synonyms != nil { + val[i].Synonyms = make([]string, len(result.Synonyms)) + copy(val[i].Synonyms, result.Synonyms) + } + } + } + return &FTSynDumpCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // FTSynDump - Dumps the contents of a synonym group. // The 'index' parameter specifies the index to dump. // For more information, please refer to the Redis documentation: diff --git a/timeseries_commands.go b/timeseries_commands.go index 82d8cdfcf5..71ed6af238 100644 --- a/timeseries_commands.go +++ b/timeseries_commands.go @@ -486,8 +486,9 @@ type TSTimestampValueCmd struct { func newTSTimestampValueCmd(ctx context.Context, args ...interface{}) *TSTimestampValueCmd { return &TSTimestampValueCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTSTimestampValue, }, } } @@ -533,6 +534,13 @@ func (cmd *TSTimestampValueCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TSTimestampValueCmd) Clone() Cmder { + return &TSTimestampValueCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // TSTimestampValue is a simple struct, can be copied directly + } +} + // TSInfo - Returns information about a time-series key. // For more information - https://redis.io/commands/ts.info/ func (c cmdable) TSInfo(ctx context.Context, key string) *MapStringInterfaceCmd { @@ -704,8 +712,9 @@ type TSTimestampValueSliceCmd struct { func newTSTimestampValueSliceCmd(ctx context.Context, args ...interface{}) *TSTimestampValueSliceCmd { return &TSTimestampValueSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTSTimestampValueSlice, }, } } @@ -752,6 +761,18 @@ func (cmd *TSTimestampValueSliceCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TSTimestampValueSliceCmd) Clone() Cmder { + var val []TSTimestampValue + if cmd.val != nil { + val = make([]TSTimestampValue, len(cmd.val)) + copy(val, cmd.val) + } + return &TSTimestampValueSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // TSMRange - Returns a range of samples from multiple time-series keys. // For more information - https://redis.io/commands/ts.mrange/ func (c cmdable) TSMRange(ctx context.Context, fromTimestamp int, toTimestamp int, filterExpr []string) *MapStringSliceInterfaceCmd { From 9cffa7941b73087e9401f9e8da01ab6a943458ec Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Mon, 30 Jun 2025 15:45:29 +0300 Subject: [PATCH 10/62] remove thread debugging code --- main_test.go | 12 -- osscluster_router_test.go | 379 -------------------------------------- 2 files changed, 391 deletions(-) delete mode 100644 osscluster_router_test.go diff --git a/main_test.go b/main_test.go index 81cc6d2aaf..4cd9bc6719 100644 --- a/main_test.go +++ b/main_test.go @@ -4,8 +4,6 @@ import ( "fmt" "net" "os" - "runtime" - "runtime/pprof" "strconv" "strings" "sync" @@ -151,22 +149,12 @@ var _ = BeforeSuite(func() { // populate cluster node information Expect(configureClusterTopology(ctx, cluster)).NotTo(HaveOccurred()) } - runtime.SetBlockProfileRate(1) - runtime.SetMutexProfileFraction(1) }) var _ = AfterSuite(func() { if !RECluster { Expect(cluster.Close()).NotTo(HaveOccurred()) } - if f, err := os.Create("block.pprof"); err == nil { - pprof.Lookup("block").WriteTo(f, 0) - f.Close() - } - if f, err := os.Create("mutex.pprof"); err == nil { - pprof.Lookup("mutex").WriteTo(f, 0) - f.Close() - } }) func TestGinkgoSuite(t *testing.T) { diff --git a/osscluster_router_test.go b/osscluster_router_test.go deleted file mode 100644 index d2b3f94440..0000000000 --- a/osscluster_router_test.go +++ /dev/null @@ -1,379 +0,0 @@ -package redis - -// import ( -// "context" -// "sync" -// "testing" -// "time" - -// . "github.com/bsm/ginkgo/v2" -// . "github.com/bsm/gomega" - -// "github.com/redis/go-redis/v9/internal/routing" -// ) - -// var _ = Describe("ExtractCommandValue", func() { -// It("should extract value from generic command", func() { -// cmd := NewCmd(nil, "test") -// cmd.SetVal("value") -// val := routing.ExtractCommandValue(cmd) -// Expect(val).To(Equal("value")) -// }) - -// It("should extract value from integer command", func() { -// intCmd := NewIntCmd(nil, "test") -// intCmd.SetVal(42) -// val := routing.ExtractCommandValue(intCmd) -// Expect(val).To(Equal(int64(42))) -// }) - -// It("should handle nil command", func() { -// val := routing.ExtractCommandValue(nil) -// Expect(val).To(BeNil()) -// }) -// }) - -// var _ = Describe("ClusterClient setCommandValue", func() { -// var client *ClusterClient - -// BeforeEach(func() { -// client = &ClusterClient{} -// }) - -// It("should set generic value", func() { -// cmd := NewCmd(nil, "test") -// err := client.setCommandValue(cmd, "new_value") -// Expect(err).NotTo(HaveOccurred()) -// Expect(cmd.Val()).To(Equal("new_value")) -// }) - -// It("should set integer value", func() { -// intCmd := NewIntCmd(nil, "test") -// err := client.setCommandValue(intCmd, int64(100)) -// Expect(err).NotTo(HaveOccurred()) -// Expect(intCmd.Val()).To(Equal(int64(100))) -// }) - -// It("should return error for type mismatch", func() { -// intCmd := NewIntCmd(nil, "test") -// err := client.setCommandValue(intCmd, "string_value") -// Expect(err).To(HaveOccurred()) -// Expect(err.Error()).To(ContainSubstring("cannot set IntCmd value from string")) -// }) -// }) - -// func TestConcurrentRouting(t *testing.T) { -// // This test ensures that concurrent execution doesn't cause response mismatches -// // or MOVED errors due to race conditions - -// // Mock cluster client for testing -// opt := &ClusterOptions{ -// Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"}, -// } - -// // Skip if no cluster available -// if testing.Short() { -// t.Skip("skipping cluster test in short mode") -// } - -// client := NewClusterClient(opt) -// defer client.Close() - -// // Test concurrent execution of commands with different policies -// var wg sync.WaitGroup -// numRoutines := 50 -// numCommands := 100 - -// // Channel to collect errors -// errors := make(chan error, numRoutines*numCommands) - -// for i := 0; i < numRoutines; i++ { -// wg.Add(1) -// go func(routineID int) { -// defer wg.Done() - -// for j := 0; j < numCommands; j++ { -// ctx := context.Background() - -// // Test different command types -// switch j % 4 { -// case 0: -// // Test keyless command (should use arbitrary shard) -// cmd := NewCmd(ctx, "PING") -// err := client.routeAndRun(ctx, cmd) -// if err != nil { -// errors <- err -// } -// case 1: -// // Test keyed command (should use slot-based routing) -// key := "test_key_" + string(rune(routineID)) + "_" + string(rune(j)) -// cmd := NewCmd(ctx, "GET", key) -// err := client.routeAndRun(ctx, cmd) -// if err != nil { -// errors <- err -// } -// case 2: -// // Test multi-shard command -// cmd := NewCmd(ctx, "MGET", "key1", "key2", "key3") -// err := client.routeAndRun(ctx, cmd) -// if err != nil { -// errors <- err -// } -// case 3: -// // Test all-shards command -// cmd := NewCmd(ctx, "DBSIZE") -// // Note: In actual implementation, the policy would come from COMMAND tips -// err := client.routeAndRun(ctx, cmd) -// if err != nil { -// errors <- err -// } -// } -// } -// }(i) -// } - -// // Wait for all routines to complete -// wg.Wait() -// close(errors) - -// // Check for errors -// var errorCount int -// for err := range errors { -// t.Errorf("Concurrent routing error: %v", err) -// errorCount++ -// if errorCount > 10 { // Limit error output -// break -// } -// } - -// if errorCount > 0 { -// t.Fatalf("Found %d errors in concurrent routing test", errorCount) -// } -// } - -// func TestResponseAggregation(t *testing.T) { -// // Test that response aggregation works correctly for different policies - -// if testing.Short() { -// t.Skip("skipping cluster test in short mode") -// } - -// // Test all_succeeded aggregation -// t.Run("AllSucceeded", func(t *testing.T) { -// aggregator := routing.NewResponseAggregator(routing.RespAllSucceeded, "TEST") - -// // Add successful results -// err := aggregator.Add("result1", nil) -// if err != nil { -// t.Errorf("Failed to add result: %v", err) -// } - -// err = aggregator.Add("result2", nil) -// if err != nil { -// t.Errorf("Failed to add result: %v", err) -// } - -// result, err := aggregator.Finish() -// if err != nil { -// t.Errorf("AllSucceeded aggregation failed: %v", err) -// } - -// if result != "result1" { -// t.Errorf("Expected 'result1', got %v", result) -// } -// }) - -// // Test agg_sum aggregation -// t.Run("AggSum", func(t *testing.T) { -// aggregator := routing.NewResponseAggregator(routing.RespAggSum, "TEST") - -// // Add numeric results -// err := aggregator.Add(int64(5), nil) -// if err != nil { -// t.Errorf("Failed to add result: %v", err) -// } - -// err = aggregator.Add(int64(10), nil) -// if err != nil { -// t.Errorf("Failed to add result: %v", err) -// } - -// result, err := aggregator.Finish() -// if err != nil { -// t.Errorf("AggSum aggregation failed: %v", err) -// } - -// if result != int64(15) { -// t.Errorf("Expected 15, got %v", result) -// } -// }) - -// // Test special aggregation for search commands -// t.Run("Special", func(t *testing.T) { -// aggregator := routing.NewResponseAggregator(routing.RespSpecial, "FT.SEARCH") - -// // Add search results -// searchResult := map[string]interface{}{ -// "total": 5, -// "docs": []interface{}{"doc1", "doc2"}, -// } - -// err := aggregator.Add(searchResult, nil) -// if err != nil { -// t.Errorf("Failed to add result: %v", err) -// } - -// result, err := aggregator.Finish() -// if err != nil { -// t.Errorf("Special aggregation failed: %v", err) -// } - -// if result == nil { -// t.Error("Expected non-nil result from special aggregation") -// } -// }) -// } - -// func TestShardPicking(t *testing.T) { -// // Test that arbitrary shard picking works correctly and doesn't always pick the first shard - -// opt := &ClusterOptions{ -// Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"}, -// } - -// if testing.Short() { -// t.Skip("skipping cluster test in short mode") -// } - -// client := NewClusterClient(opt) -// defer client.Close() - -// ctx := context.Background() - -// // Track which shards are picked -// shardCounts := make(map[string]int) -// var mu sync.Mutex - -// // Execute keyless commands multiple times -// var wg sync.WaitGroup -// numRequests := 100 - -// for i := 0; i < numRequests; i++ { -// wg.Add(1) -// go func() { -// defer wg.Done() - -// node := client.pickArbitraryShard(ctx) -// if node != nil { -// addr := node.Client.Options().Addr -// mu.Lock() -// shardCounts[addr]++ -// mu.Unlock() -// } -// }() -// } - -// wg.Wait() - -// // Verify that multiple shards were used (not just the first one) -// if len(shardCounts) < 2 { -// t.Error("Shard picking should distribute across multiple shards") -// } - -// // Verify reasonable distribution (no shard should have more than 80% of requests) -// for addr, count := range shardCounts { -// percentage := float64(count) / float64(numRequests) * 100 -// if percentage > 80 { -// t.Errorf("Shard %s got %d%% of requests, distribution should be more even", addr, int(percentage)) -// } -// t.Logf("Shard %s: %d requests (%.1f%%)", addr, count, percentage) -// } -// } - -// func TestCursorRouting(t *testing.T) { -// // Test that cursor commands are routed to the correct shard - -// opt := &ClusterOptions{ -// Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"}, -// } - -// if testing.Short() { -// t.Skip("skipping cluster test in short mode") -// } - -// client := NewClusterClient(opt) -// defer client.Close() - -// ctx := context.Background() - -// // Test FT.CURSOR command routing -// cmd := NewCmd(ctx, "FT.CURSOR", "READ", "myindex", "cursor123", "COUNT", "10") - -// // This should not panic or return an error due to incorrect routing -// err := client.executeSpecial(ctx, cmd, &routing.CommandPolicy{ -// Request: routing.ReqSpecial, -// Response: routing.RespSpecial, -// }) - -// // We expect this to fail with connection error in test environment, but not with routing error -// if err != nil && err.Error() != "redis: connection refused" { -// t.Logf("Cursor routing test completed with expected connection error: %v", err) -// } -// } - -// // Mock command methods for testing -// type testCmd struct { -// *Cmd -// requestPolicy routing.RequestPolicy -// responsePolicy routing.ResponsePolicy -// } - -// func (c *testCmd) setRequestPolicy(policy routing.RequestPolicy) { -// c.requestPolicy = policy -// } - -// func (c *testCmd) setResponsePolicy(policy routing.ResponsePolicy) { -// c.responsePolicy = policy -// } - -// func TestRaceConditionFree(t *testing.T) { -// // Test to ensure no race conditions in concurrent access - -// opt := &ClusterOptions{ -// Addrs: []string{"127.0.0.1:7000"}, -// } - -// if testing.Short() { -// t.Skip("skipping cluster test in short mode") -// } - -// client := NewClusterClient(opt) -// defer client.Close() - -// // Run with race detector enabled: go test -race -// var wg sync.WaitGroup -// numGoroutines := 100 - -// for i := 0; i < numGoroutines; i++ { -// wg.Add(1) -// go func(id int) { -// defer wg.Done() - -// ctx := context.Background() - -// // Simulate concurrent command execution -// for j := 0; j < 10; j++ { -// cmd := NewCmd(ctx, "PING") -// _ = client.routeAndRun(ctx, cmd) - -// // Small delay to increase chance of race conditions -// time.Sleep(time.Microsecond) -// } -// }(i) -// } - -// wg.Wait() - -// // If we reach here without race detector complaints, test passes -// t.Log("Race condition test completed successfully") -// } From 9f6f2c9165df78afedf5cb4175eb79744f3d7640 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Fri, 4 Jul 2025 16:05:08 +0300 Subject: [PATCH 11/62] remove thread debugging code && reject commands with policy that cannot be used in pipeline --- internal/routing/policy.go | 4 + osscluster.go | 193 +++++++++++++++++-------------------- osscluster_test.go | 31 +----- 3 files changed, 95 insertions(+), 133 deletions(-) diff --git a/internal/routing/policy.go b/internal/routing/policy.go index d65efb8aef..a76dfaf19b 100644 --- a/internal/routing/policy.go +++ b/internal/routing/policy.go @@ -129,3 +129,7 @@ type CommandPolicy struct { // e.g nondeterministic_output, nondeterministic_output_order. Tips map[string]string } + +func (p *CommandPolicy) CanBeUsedInPipeline() bool { + return p.Request != ReqAllNodes && p.Request != ReqAllShards && p.Request != ReqMultiShard +} diff --git a/osscluster.go b/osscluster.go index 598b9409af..61847ed799 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1025,9 +1025,6 @@ type ClusterClient struct { // NewClusterClient returns a Redis Cluster client as described in // http://redis.io/topics/cluster-spec. func NewClusterClient(opt *ClusterOptions) *ClusterClient { - if opt == nil { - panic("redis: NewClusterClient nil options") - } opt.init() c := &ClusterClient{ @@ -1378,23 +1375,11 @@ func (c *ClusterClient) Pipelined(ctx context.Context, fn func(Pipeliner) error) } func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { - // Separate commands into those that can be batched vs those that need individual routing - batchableCmds := make([]Cmder, 0) - individualCmds := make([]Cmder, 0) - - for _, cmd := range cmds { - policy := c.getCommandPolicy(ctx, cmd) + cmdsMap := newCmdsMap() - // Commands that need special routing should be handled individually - if policy != nil && (policy.Request == routing.ReqAllNodes || - policy.Request == routing.ReqAllShards || - policy.Request == routing.ReqMultiShard || - policy.Request == routing.ReqSpecial) { - individualCmds = append(individualCmds, cmd) - } else { - // Single-node commands can be batched - batchableCmds = append(batchableCmds, cmd) - } + if err := c.mapCmdsByNode(ctx, cmdsMap, cmds); err != nil { + setCmdsErr(cmds, err) + return err } for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { @@ -1405,68 +1390,74 @@ func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error } } - var allSucceeded = true - var failedBatchableCmds []Cmder - var failedIndividualCmds []Cmder + failedCmds := newCmdsMap() + var wg sync.WaitGroup - // Handle individual commands using existing router - for _, cmd := range individualCmds { - if err := c.routeAndRun(ctx, cmd, nil); err != nil { - allSucceeded = false - failedIndividualCmds = append(failedIndividualCmds, cmd) - } + for node, cmds := range cmdsMap.m { + wg.Add(1) + go func(node *clusterNode, cmds []Cmder) { + defer wg.Done() + c.processPipelineNode(ctx, node, cmds, failedCmds) + }(node, cmds) } - // Handle batchable commands using original pipeline logic - if len(batchableCmds) > 0 { - cmdsMap := newCmdsMap() + wg.Wait() + if len(failedCmds.m) == 0 { + break + } + cmdsMap = failedCmds + } - if err := c.mapCmdsByNode(ctx, cmdsMap, batchableCmds); err != nil { - setCmdsErr(batchableCmds, err) - allSucceeded = false - failedBatchableCmds = append(failedBatchableCmds, batchableCmds...) - } else { - batchFailedCmds := newCmdsMap() - var wg sync.WaitGroup - - for node, nodeCmds := range cmdsMap.m { - wg.Add(1) - go func(node *clusterNode, nodeCmds []Cmder) { - defer wg.Done() - c.processPipelineNode(ctx, node, nodeCmds, batchFailedCmds) - }(node, nodeCmds) - } + return cmdsFirstErr(cmds) +} - wg.Wait() +func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmds []Cmder) error { + state, err := c.state.Get(ctx) + if err != nil { + return err + } - if len(batchFailedCmds.m) > 0 { - allSucceeded = false - for _, nodeCmds := range batchFailedCmds.m { - failedBatchableCmds = append(failedBatchableCmds, nodeCmds...) - } - } + if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { + for _, cmd := range cmds { + policy := c.getCommandPolicy(ctx, cmd) + if policy != nil && !policy.CanBeUsedInPipeline() { + return fmt.Errorf("redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard", cmd.Name()) + } + slot := c.cmdSlot(ctx, cmd) + node, err := c.slotReadOnlyNode(state, slot) + if err != nil { + return err } + cmdsMap.Add(node, cmd) } + return nil + } - // If all commands succeeded, we're done - if allSucceeded { - break + for _, cmd := range cmds { + policy := c.getCommandPolicy(ctx, cmd) + if policy != nil && !policy.CanBeUsedInPipeline() { + return fmt.Errorf("redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard", cmd.Name()) } - - // If this was the last attempt, return the error - if attempt == c.opt.MaxRedirects { - break + slot := c.cmdSlot(ctx, cmd) + node, err := state.slotMasterNode(slot) + if err != nil { + return err } - - // Update command lists for retry - no reclassification needed - batchableCmds = failedBatchableCmds - individualCmds = failedIndividualCmds + cmdsMap.Add(node, cmd) } + return nil +} - return cmdsFirstErr(cmds) +func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool { + for _, cmd := range cmds { + cmdInfo := c.cmdInfo(ctx, cmd.Name()) + if cmdInfo == nil || !cmdInfo.ReadOnly { + return false + } + } + return true } -// processPipelineNode handles batched pipeline commands for a single node func (c *ClusterClient) processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { @@ -1476,8 +1467,7 @@ func (c *ClusterClient) processPipelineNode( if !isContextError(err) { node.MarkAsFailing() } - // Commands are already mapped to this node, just add them as failed - failedCmds.Add(node, cmds...) + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) setCmdsErr(cmds, err) return err } @@ -1502,8 +1492,7 @@ func (c *ClusterClient) processPipelineNodeConn( node.MarkAsFailing() } if shouldRetry(err, true) { - // Commands are already mapped to this node, just add them as failed - failedCmds.Add(node, cmds...) + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) } setCmdsErr(cmds, err) return err @@ -1539,8 +1528,7 @@ func (c *ClusterClient) pipelineReadCmds( if !isRedisError(err) { if shouldRetry(err, true) { - // Commands are already mapped to this node, just add them as failed - failedCmds.Add(node, cmds[i:]...) + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) } setCmdsErr(cmds[i+1:], err) return err @@ -1548,8 +1536,7 @@ func (c *ClusterClient) pipelineReadCmds( } if err := cmds[0].Err(); err != nil && shouldRetry(err, true) { - // Commands are already mapped to this node, just add them as failed - failedCmds.Add(node, cmds...) + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) return err } @@ -1630,35 +1617,6 @@ func (c *ClusterClient) checkMovedErr( panic("not reached") } -func (c *ClusterClient) cmdsMoved( - ctx context.Context, cmds []Cmder, - moved, ask bool, - addr string, - failedCmds *cmdsMap, -) error { - node, err := c.nodes.GetOrCreate(addr) - if err != nil { - return err - } - - if moved { - c.state.LazyReload() - for _, cmd := range cmds { - failedCmds.Add(node, cmd) - } - return nil - } - - if ask { - for _, cmd := range cmds { - failedCmds.Add(node, NewCmd(ctx, "asking"), cmd) - } - return nil - } - - return nil -} - // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. func (c *ClusterClient) TxPipeline() Pipeliner { pipe := Pipeline{ @@ -1882,6 +1840,35 @@ func (c *ClusterClient) txPipelineReadQueued( return nil } +func (c *ClusterClient) cmdsMoved( + ctx context.Context, cmds []Cmder, + moved, ask bool, + addr string, + failedCmds *cmdsMap, +) error { + node, err := c.nodes.GetOrCreate(addr) + if err != nil { + return err + } + + if moved { + c.state.LazyReload() + for _, cmd := range cmds { + failedCmds.Add(node, cmd) + } + return nil + } + + if ask { + for _, cmd := range cmds { + failedCmds.Add(node, NewCmd(ctx, "asking"), cmd) + } + return nil + } + + return nil +} + func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { if len(keys) == 0 { return fmt.Errorf("redis: Watch requires at least one key") diff --git a/osscluster_test.go b/osscluster_test.go index 4562d2f1c1..6860388bee 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -6,9 +6,6 @@ import ( "errors" "fmt" "net" - "os" - "runtime" - "runtime/pprof" "slices" "strconv" "strings" @@ -18,19 +15,10 @@ import ( . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" - "github.com/fortytw2/leaktest" "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/internal/hashtag" ) -// leakCleanup holds the per-spec leak check function -var leakCleanup func() - -// sanitizeFilename converts spaces and slashes into underscores -func sanitizeFilename(s string) string { - return strings.NewReplacer(" ", "_", "/", "_").Replace(s) -} - type clusterScenario struct { ports []string nodeIDs []string @@ -270,7 +258,7 @@ func slotEqual(s1, s2 redis.ClusterSlot) bool { // ------------------------------------------------------------------------------ -var _ = Describe("ClusterClient", func() { +var _ = FDescribe("ClusterClient", func() { var failover bool var opt *redis.ClusterOptions var client *redis.ClusterClient @@ -1589,8 +1577,6 @@ var _ = Describe("ClusterClient", func() { Describe("ClusterClient with ClusterSlots with multiple nodes per slot", func() { BeforeEach(func() { - leakCleanup = leaktest.Check(GinkgoT()) - GinkgoWriter.Printf("[DEBUG] goroutines at start: %d\n", runtime.NumGoroutine()) failover = true opt = redisClusterOptions() @@ -1640,21 +1626,6 @@ var _ = Describe("ClusterClient", func() { }) AfterEach(func() { - leakCleanup() - - // on failure, write out a full goroutine dump - if CurrentSpecReport().Failed() { - fname := fmt.Sprintf("goroutines-%s.txt", sanitizeFilename(CurrentSpecReport().LeafNodeText)) - if f, err := os.Create(fname); err == nil { - pprof.Lookup("goroutine").WriteTo(f, 2) - f.Close() - GinkgoWriter.Printf("[DEBUG] wrote goroutine dump to %s\n", fname) - } else { - GinkgoWriter.Printf("[DEBUG] failed to write goroutine dump: %v\n", err) - } - } - - GinkgoWriter.Printf("[DEBUG] goroutines at end: %d\n", runtime.NumGoroutine()) failover = false err := client.Close() From 9087c217e29b7813246f72395df66016e3c61b7b Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Fri, 4 Jul 2025 18:02:39 +0300 Subject: [PATCH 12/62] refactor processPipline and cmdType enum --- command.go | 435 +++++++++++++++++++++++----- internal/routing/aggregator.go | 368 +---------------------- internal/routing/aggregator_test.go | 427 --------------------------- osscluster.go | 8 +- osscluster_router.go | 87 +++++- osscluster_test.go | 71 +++-- 6 files changed, 506 insertions(+), 890 deletions(-) delete mode 100644 internal/routing/aggregator_test.go diff --git a/command.go b/command.go index c7acf2237c..2ce3a11328 100644 --- a/command.go +++ b/command.go @@ -18,7 +18,6 @@ import ( "github.com/redis/go-redis/v9/internal/util" ) -<<<<<<< HEAD // keylessCommands contains Redis commands that have empty key specifications (9th slot empty) // Only includes core Redis commands, excludes FT.*, ts.*, timeseries.*, search.* and subcommands var keylessCommands = map[string]struct{}{ @@ -67,79 +66,90 @@ var keylessCommands = map[string]struct{}{ "unsubscribe": {}, "unwatch": {}, } -======= type CmdType = routing.CmdType +// CmdTyper interface for getting command type +type CmdTyper interface { + GetCmdType() CmdType +} + +// CmdTypeGetter interface for getting command type without circular imports +type CmdTypeGetter interface { + GetCmdType() CmdType +} + +type CmdType uint8 + const ( - CmdTypeGeneric = routing.CmdTypeGeneric - CmdTypeString = routing.CmdTypeString - CmdTypeInt = routing.CmdTypeInt - CmdTypeBool = routing.CmdTypeBool - CmdTypeFloat = routing.CmdTypeFloat - CmdTypeStringSlice = routing.CmdTypeStringSlice - CmdTypeIntSlice = routing.CmdTypeIntSlice - CmdTypeFloatSlice = routing.CmdTypeFloatSlice - CmdTypeBoolSlice = routing.CmdTypeBoolSlice - CmdTypeMapStringString = routing.CmdTypeMapStringString - CmdTypeMapStringInt = routing.CmdTypeMapStringInt - CmdTypeMapStringInterface = routing.CmdTypeMapStringInterface - CmdTypeMapStringInterfaceSlice = routing.CmdTypeMapStringInterfaceSlice - CmdTypeSlice = routing.CmdTypeSlice - CmdTypeStatus = routing.CmdTypeStatus - CmdTypeDuration = routing.CmdTypeDuration - CmdTypeTime = routing.CmdTypeTime - CmdTypeKeyValueSlice = routing.CmdTypeKeyValueSlice - CmdTypeStringStructMap = routing.CmdTypeStringStructMap - CmdTypeXMessageSlice = routing.CmdTypeXMessageSlice - CmdTypeXStreamSlice = routing.CmdTypeXStreamSlice - CmdTypeXPending = routing.CmdTypeXPending - CmdTypeXPendingExt = routing.CmdTypeXPendingExt - CmdTypeXAutoClaim = routing.CmdTypeXAutoClaim - CmdTypeXAutoClaimJustID = routing.CmdTypeXAutoClaimJustID - CmdTypeXInfoConsumers = routing.CmdTypeXInfoConsumers - CmdTypeXInfoGroups = routing.CmdTypeXInfoGroups - CmdTypeXInfoStream = routing.CmdTypeXInfoStream - CmdTypeXInfoStreamFull = routing.CmdTypeXInfoStreamFull - CmdTypeZSlice = routing.CmdTypeZSlice - CmdTypeZWithKey = routing.CmdTypeZWithKey - CmdTypeScan = routing.CmdTypeScan - CmdTypeClusterSlots = routing.CmdTypeClusterSlots - CmdTypeGeoLocation = routing.CmdTypeGeoLocation - CmdTypeGeoSearchLocation = routing.CmdTypeGeoSearchLocation - CmdTypeGeoPos = routing.CmdTypeGeoPos - CmdTypeCommandsInfo = routing.CmdTypeCommandsInfo - CmdTypeSlowLog = routing.CmdTypeSlowLog - CmdTypeMapStringStringSlice = routing.CmdTypeMapStringStringSlice - CmdTypeMapMapStringInterface = routing.CmdTypeMapMapStringInterface - CmdTypeKeyValues = routing.CmdTypeKeyValues - CmdTypeZSliceWithKey = routing.CmdTypeZSliceWithKey - CmdTypeFunctionList = routing.CmdTypeFunctionList - CmdTypeFunctionStats = routing.CmdTypeFunctionStats - CmdTypeLCS = routing.CmdTypeLCS - CmdTypeKeyFlags = routing.CmdTypeKeyFlags - CmdTypeClusterLinks = routing.CmdTypeClusterLinks - CmdTypeClusterShards = routing.CmdTypeClusterShards - CmdTypeRankWithScore = routing.CmdTypeRankWithScore - CmdTypeClientInfo = routing.CmdTypeClientInfo - CmdTypeACLLog = routing.CmdTypeACLLog - CmdTypeInfo = routing.CmdTypeInfo - CmdTypeMonitor = routing.CmdTypeMonitor - CmdTypeJSON = routing.CmdTypeJSON - CmdTypeJSONSlice = routing.CmdTypeJSONSlice - CmdTypeIntPointerSlice = routing.CmdTypeIntPointerSlice - CmdTypeScanDump = routing.CmdTypeScanDump - CmdTypeBFInfo = routing.CmdTypeBFInfo - CmdTypeCFInfo = routing.CmdTypeCFInfo - CmdTypeCMSInfo = routing.CmdTypeCMSInfo - CmdTypeTopKInfo = routing.CmdTypeTopKInfo - CmdTypeTDigestInfo = routing.CmdTypeTDigestInfo - CmdTypeFTSynDump = routing.CmdTypeFTSynDump - CmdTypeAggregate = routing.CmdTypeAggregate - CmdTypeFTInfo = routing.CmdTypeFTInfo - CmdTypeFTSpellCheck = routing.CmdTypeFTSpellCheck - CmdTypeFTSearch = routing.CmdTypeFTSearch - CmdTypeTSTimestampValue = routing.CmdTypeTSTimestampValue - CmdTypeTSTimestampValueSlice = routing.CmdTypeTSTimestampValueSlice + CmdTypeGeneric CmdType = iota + CmdTypeString + CmdTypeInt + CmdTypeBool + CmdTypeFloat + CmdTypeStringSlice + CmdTypeIntSlice + CmdTypeFloatSlice + CmdTypeBoolSlice + CmdTypeMapStringString + CmdTypeMapStringInt + CmdTypeMapStringInterface + CmdTypeMapStringInterfaceSlice + CmdTypeSlice + CmdTypeStatus + CmdTypeDuration + CmdTypeTime + CmdTypeKeyValueSlice + CmdTypeStringStructMap + CmdTypeXMessageSlice + CmdTypeXStreamSlice + CmdTypeXPending + CmdTypeXPendingExt + CmdTypeXAutoClaim + CmdTypeXAutoClaimJustID + CmdTypeXInfoConsumers + CmdTypeXInfoGroups + CmdTypeXInfoStream + CmdTypeXInfoStreamFull + CmdTypeZSlice + CmdTypeZWithKey + CmdTypeScan + CmdTypeClusterSlots + CmdTypeGeoLocation + CmdTypeGeoSearchLocation + CmdTypeGeoPos + CmdTypeCommandsInfo + CmdTypeSlowLog + CmdTypeMapStringStringSlice + CmdTypeMapMapStringInterface + CmdTypeKeyValues + CmdTypeZSliceWithKey + CmdTypeFunctionList + CmdTypeFunctionStats + CmdTypeLCS + CmdTypeKeyFlags + CmdTypeClusterLinks + CmdTypeClusterShards + CmdTypeRankWithScore + CmdTypeClientInfo + CmdTypeACLLog + CmdTypeInfo + CmdTypeMonitor + CmdTypeJSON + CmdTypeJSONSlice + CmdTypeIntPointerSlice + CmdTypeScanDump + CmdTypeBFInfo + CmdTypeCFInfo + CmdTypeCMSInfo + CmdTypeTopKInfo + CmdTypeTDigestInfo + CmdTypeFTSynDump + CmdTypeAggregate + CmdTypeFTInfo + CmdTypeFTSpellCheck + CmdTypeFTSearch + CmdTypeTSTimestampValue + CmdTypeTSTimestampValueSlice ) >>>>>>> b6633bf9 (centralize cluster command routing in osscluster_router.go and refactor osscluster.go (#6)) @@ -6943,6 +6953,289 @@ func (cmd *VectorScoreSliceCmd) readReply(rd *proto.Reader) error { } cmd.val[i].Score = score } + + return nil +} +// ExtractCommandValue extracts the value from a command result using the fast enum-based approach +func ExtractCommandValue(cmd interface{}) interface{} { + // First try to get the command type using the interface + if cmdTypeGetter, ok := cmd.(CmdTypeGetter); ok { + cmdType := cmdTypeGetter.GetCmdType() + + // Use fast type-based extraction + switch cmdType { + case CmdTypeString: + if stringCmd, ok := cmd.(interface{ Val() string }); ok { + return stringCmd.Val() + } + case CmdTypeInt: + if intCmd, ok := cmd.(interface{ Val() int64 }); ok { + return intCmd.Val() + } + case CmdTypeBool: + if boolCmd, ok := cmd.(interface{ Val() bool }); ok { + return boolCmd.Val() + } + case CmdTypeFloat: + if floatCmd, ok := cmd.(interface{ Val() float64 }); ok { + return floatCmd.Val() + } + case CmdTypeStatus: + if statusCmd, ok := cmd.(interface{ Val() string }); ok { + return statusCmd.Val() + } + case CmdTypeDuration: + if durationCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return durationCmd.Val() + } + case CmdTypeTime: + if timeCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return timeCmd.Val() + } + case CmdTypeStringSlice: + if stringSliceCmd, ok := cmd.(interface{ Val() []string }); ok { + return stringSliceCmd.Val() + } + case CmdTypeIntSlice: + if intSliceCmd, ok := cmd.(interface{ Val() []int64 }); ok { + return intSliceCmd.Val() + } + case CmdTypeBoolSlice: + if boolSliceCmd, ok := cmd.(interface{ Val() []bool }); ok { + return boolSliceCmd.Val() + } + case CmdTypeFloatSlice: + if floatSliceCmd, ok := cmd.(interface{ Val() []float64 }); ok { + return floatSliceCmd.Val() + } + case CmdTypeMapStringString: + if mapCmd, ok := cmd.(interface{ Val() map[string]string }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringInt: + if mapCmd, ok := cmd.(interface{ Val() map[string]int64 }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringInterfaceSlice: + if mapCmd, ok := cmd.(interface { + Val() map[string][]interface{} + }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringInterface: + if mapCmd, ok := cmd.(interface{ Val() map[string]interface{} }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringStringSlice: + if mapCmd, ok := cmd.(interface{ Val() map[string][]string }); ok { + return mapCmd.Val() + } + case CmdTypeMapMapStringInterface: + if mapCmd, ok := cmd.(interface { + Val() map[string][]interface{} + }); ok { + return mapCmd.Val() + } + case CmdTypeStringStructMap: + if mapCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return mapCmd.Val() + } + case CmdTypeXMessageSlice: + if xMsgCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xMsgCmd.Val() + } + case CmdTypeXStreamSlice: + if xStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xStreamCmd.Val() + } + case CmdTypeXPending: + if xPendingCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xPendingCmd.Val() + } + case CmdTypeXPendingExt: + if xPendingExtCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xPendingExtCmd.Val() + } + case CmdTypeXAutoClaim: + if xAutoClaimCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xAutoClaimCmd.Val() + } + case CmdTypeXAutoClaimJustID: + if xAutoClaimJustIDCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xAutoClaimJustIDCmd.Val() + } + case CmdTypeXInfoConsumers: + if xInfoConsumersCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoConsumersCmd.Val() + } + case CmdTypeXInfoGroups: + if xInfoGroupsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoGroupsCmd.Val() + } + case CmdTypeXInfoStream: + if xInfoStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoStreamCmd.Val() + } + case CmdTypeXInfoStreamFull: + if xInfoStreamFullCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoStreamFullCmd.Val() + } + case CmdTypeZSlice: + if zSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return zSliceCmd.Val() + } + case CmdTypeZWithKey: + if zWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return zWithKeyCmd.Val() + } + case CmdTypeScan: + if scanCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return scanCmd.Val() + } + case CmdTypeClusterSlots: + if clusterSlotsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clusterSlotsCmd.Val() + } + case CmdTypeGeoSearchLocation: + if geoSearchLocationCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return geoSearchLocationCmd.Val() + } + case CmdTypeGeoPos: + if geoPosCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return geoPosCmd.Val() + } + case CmdTypeCommandsInfo: + if commandsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return commandsInfoCmd.Val() + } + case CmdTypeSlowLog: + if slowLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return slowLogCmd.Val() + } + + case CmdTypeKeyValues: + if keyValuesCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return keyValuesCmd.Val() + } + case CmdTypeZSliceWithKey: + if zSliceWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return zSliceWithKeyCmd.Val() + } + case CmdTypeFunctionList: + if functionListCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return functionListCmd.Val() + } + case CmdTypeFunctionStats: + if functionStatsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return functionStatsCmd.Val() + } + case CmdTypeLCS: + if lcsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return lcsCmd.Val() + } + case CmdTypeKeyFlags: + if keyFlagsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return keyFlagsCmd.Val() + } + case CmdTypeClusterLinks: + if clusterLinksCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clusterLinksCmd.Val() + } + case CmdTypeClusterShards: + if clusterShardsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clusterShardsCmd.Val() + } + case CmdTypeRankWithScore: + if rankWithScoreCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return rankWithScoreCmd.Val() + } + case CmdTypeClientInfo: + if clientInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clientInfoCmd.Val() + } + case CmdTypeACLLog: + if aclLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return aclLogCmd.Val() + } + case CmdTypeInfo: + if infoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return infoCmd.Val() + } + case CmdTypeMonitor: + if monitorCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return monitorCmd.Val() + } + case CmdTypeJSON: + if jsonCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return jsonCmd.Val() + } + case CmdTypeJSONSlice: + if jsonSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return jsonSliceCmd.Val() + } + case CmdTypeIntPointerSlice: + if intPointerSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return intPointerSliceCmd.Val() + } + case CmdTypeScanDump: + if scanDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return scanDumpCmd.Val() + } + case CmdTypeBFInfo: + if bfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return bfInfoCmd.Val() + } + case CmdTypeCFInfo: + if cfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return cfInfoCmd.Val() + } + case CmdTypeCMSInfo: + if cmsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return cmsInfoCmd.Val() + } + case CmdTypeTopKInfo: + if topKInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return topKInfoCmd.Val() + } + case CmdTypeTDigestInfo: + if tDigestInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return tDigestInfoCmd.Val() + } + case CmdTypeFTSearch: + if ftSearchCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftSearchCmd.Val() + } + case CmdTypeFTInfo: + if ftInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftInfoCmd.Val() + } + case CmdTypeFTSpellCheck: + if ftSpellCheckCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftSpellCheckCmd.Val() + } + case CmdTypeFTSynDump: + if ftSynDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftSynDumpCmd.Val() + } + case CmdTypeAggregate: + if aggregateCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return aggregateCmd.Val() + } + case CmdTypeTSTimestampValue: + if tsTimestampValueCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return tsTimestampValueCmd.Val() + } + case CmdTypeTSTimestampValueSlice: + if tsTimestampValueSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return tsTimestampValueSliceCmd.Val() + } + default: + // For unknown command types, return nil + return nil + } + } + + // If we can't get the command type, return nil return nil } diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index f065415f60..3c2e072622 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -6,84 +6,6 @@ import ( "sync" ) -type CmdTyper interface { - GetCmdType() CmdType -} - -type CmdType uint8 - -const ( - CmdTypeGeneric CmdType = iota - CmdTypeString - CmdTypeInt - CmdTypeBool - CmdTypeFloat - CmdTypeStringSlice - CmdTypeIntSlice - CmdTypeFloatSlice - CmdTypeBoolSlice - CmdTypeMapStringString - CmdTypeMapStringInt - CmdTypeMapStringInterface - CmdTypeMapStringInterfaceSlice - CmdTypeSlice - CmdTypeStatus - CmdTypeDuration - CmdTypeTime - CmdTypeKeyValueSlice - CmdTypeStringStructMap - CmdTypeXMessageSlice - CmdTypeXStreamSlice - CmdTypeXPending - CmdTypeXPendingExt - CmdTypeXAutoClaim - CmdTypeXAutoClaimJustID - CmdTypeXInfoConsumers - CmdTypeXInfoGroups - CmdTypeXInfoStream - CmdTypeXInfoStreamFull - CmdTypeZSlice - CmdTypeZWithKey - CmdTypeScan - CmdTypeClusterSlots - CmdTypeGeoLocation - CmdTypeGeoSearchLocation - CmdTypeGeoPos - CmdTypeCommandsInfo - CmdTypeSlowLog - CmdTypeMapStringStringSlice - CmdTypeMapMapStringInterface - CmdTypeKeyValues - CmdTypeZSliceWithKey - CmdTypeFunctionList - CmdTypeFunctionStats - CmdTypeLCS - CmdTypeKeyFlags - CmdTypeClusterLinks - CmdTypeClusterShards - CmdTypeRankWithScore - CmdTypeClientInfo - CmdTypeACLLog - CmdTypeInfo - CmdTypeMonitor - CmdTypeJSON - CmdTypeJSONSlice - CmdTypeIntPointerSlice - CmdTypeScanDump - CmdTypeBFInfo - CmdTypeCFInfo - CmdTypeCMSInfo - CmdTypeTopKInfo - CmdTypeTDigestInfo - CmdTypeFTSynDump - CmdTypeAggregate - CmdTypeFTInfo - CmdTypeFTSpellCheck - CmdTypeFTSearch - CmdTypeTSTimestampValue - CmdTypeTSTimestampValueSlice -) - // ResponseAggregator defines the interface for aggregating responses from multiple shards. type ResponseAggregator interface { // Add processes a single shard response. @@ -442,6 +364,9 @@ func (a *AggLogicalOrAggregator) Finish() (interface{}, error) { } func toInt64(val interface{}) (int64, error) { + if val == nil { + return 0, nil + } switch v := val.(type) { case int64: return v, nil @@ -460,6 +385,9 @@ func toInt64(val interface{}) (int64, error) { } func toBool(val interface{}) (bool, error) { + if val == nil { + return false, nil + } switch v := val.(type) { case bool: return v, nil @@ -647,287 +575,3 @@ func NewSpecialAggregator(cmdName string) *SpecialAggregator { } return agg } - -// CmdTypeGetter interface for getting command type without circular imports -type CmdTypeGetter interface { - GetCmdType() CmdType -} - -// ExtractCommandValue extracts the value from a command result using the fast enum-based approach -func ExtractCommandValue(cmd interface{}) interface{} { - // First try to get the command type using the interface - if cmdTypeGetter, ok := cmd.(CmdTypeGetter); ok { - cmdType := cmdTypeGetter.GetCmdType() - - // Use fast type-based extraction - switch cmdType { - case CmdTypeString: - if stringCmd, ok := cmd.(interface{ Val() string }); ok { - return stringCmd.Val() - } - case CmdTypeInt: - if intCmd, ok := cmd.(interface{ Val() int64 }); ok { - return intCmd.Val() - } - case CmdTypeBool: - if boolCmd, ok := cmd.(interface{ Val() bool }); ok { - return boolCmd.Val() - } - case CmdTypeFloat: - if floatCmd, ok := cmd.(interface{ Val() float64 }); ok { - return floatCmd.Val() - } - case CmdTypeDuration: - if durationCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return durationCmd.Val() - } - case CmdTypeTime: - if timeCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return timeCmd.Val() - } - case CmdTypeStringSlice: - if stringSliceCmd, ok := cmd.(interface{ Val() []string }); ok { - return stringSliceCmd.Val() - } - case CmdTypeIntSlice: - if intSliceCmd, ok := cmd.(interface{ Val() []int64 }); ok { - return intSliceCmd.Val() - } - case CmdTypeBoolSlice: - if boolSliceCmd, ok := cmd.(interface{ Val() []bool }); ok { - return boolSliceCmd.Val() - } - case CmdTypeFloatSlice: - if floatSliceCmd, ok := cmd.(interface{ Val() []float64 }); ok { - return floatSliceCmd.Val() - } - case CmdTypeMapStringString: - if mapCmd, ok := cmd.(interface{ Val() map[string]string }); ok { - return mapCmd.Val() - } - case CmdTypeMapStringInt: - if mapCmd, ok := cmd.(interface{ Val() map[string]int64 }); ok { - return mapCmd.Val() - } - case CmdTypeMapStringInterfaceSlice: - if mapCmd, ok := cmd.(interface { - Val() map[string][]interface{} - }); ok { - return mapCmd.Val() - } - case CmdTypeMapStringInterface: - if mapCmd, ok := cmd.(interface{ Val() map[string]interface{} }); ok { - return mapCmd.Val() - } - case CmdTypeMapStringStringSlice: - if mapCmd, ok := cmd.(interface{ Val() map[string][]string }); ok { - return mapCmd.Val() - } - case CmdTypeMapMapStringInterface: - if mapCmd, ok := cmd.(interface { - Val() map[string][]interface{} - }); ok { - return mapCmd.Val() - } - case CmdTypeStringStructMap: - if mapCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return mapCmd.Val() - } - case CmdTypeXMessageSlice: - if xMsgCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xMsgCmd.Val() - } - case CmdTypeXStreamSlice: - if xStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xStreamCmd.Val() - } - case CmdTypeXPending: - if xPendingCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xPendingCmd.Val() - } - case CmdTypeXPendingExt: - if xPendingExtCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xPendingExtCmd.Val() - } - case CmdTypeXAutoClaim: - if xAutoClaimCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xAutoClaimCmd.Val() - } - case CmdTypeXAutoClaimJustID: - if xAutoClaimJustIDCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xAutoClaimJustIDCmd.Val() - } - case CmdTypeXInfoConsumers: - if xInfoConsumersCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xInfoConsumersCmd.Val() - } - case CmdTypeXInfoGroups: - if xInfoGroupsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xInfoGroupsCmd.Val() - } - case CmdTypeXInfoStream: - if xInfoStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xInfoStreamCmd.Val() - } - case CmdTypeXInfoStreamFull: - if xInfoStreamFullCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xInfoStreamFullCmd.Val() - } - case CmdTypeZSlice: - if zSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return zSliceCmd.Val() - } - case CmdTypeZWithKey: - if zWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return zWithKeyCmd.Val() - } - case CmdTypeScan: - if scanCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return scanCmd.Val() - } - case CmdTypeClusterSlots: - if clusterSlotsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return clusterSlotsCmd.Val() - } - case CmdTypeGeoSearchLocation: - if geoSearchLocationCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return geoSearchLocationCmd.Val() - } - case CmdTypeGeoPos: - if geoPosCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return geoPosCmd.Val() - } - case CmdTypeCommandsInfo: - if commandsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return commandsInfoCmd.Val() - } - case CmdTypeSlowLog: - if slowLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return slowLogCmd.Val() - } - - case CmdTypeKeyValues: - if keyValuesCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return keyValuesCmd.Val() - } - case CmdTypeZSliceWithKey: - if zSliceWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return zSliceWithKeyCmd.Val() - } - case CmdTypeFunctionList: - if functionListCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return functionListCmd.Val() - } - case CmdTypeFunctionStats: - if functionStatsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return functionStatsCmd.Val() - } - case CmdTypeLCS: - if lcsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return lcsCmd.Val() - } - case CmdTypeKeyFlags: - if keyFlagsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return keyFlagsCmd.Val() - } - case CmdTypeClusterLinks: - if clusterLinksCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return clusterLinksCmd.Val() - } - case CmdTypeClusterShards: - if clusterShardsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return clusterShardsCmd.Val() - } - case CmdTypeRankWithScore: - if rankWithScoreCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return rankWithScoreCmd.Val() - } - case CmdTypeClientInfo: - if clientInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return clientInfoCmd.Val() - } - case CmdTypeACLLog: - if aclLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return aclLogCmd.Val() - } - case CmdTypeInfo: - if infoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return infoCmd.Val() - } - case CmdTypeMonitor: - if monitorCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return monitorCmd.Val() - } - case CmdTypeJSON: - if jsonCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return jsonCmd.Val() - } - case CmdTypeJSONSlice: - if jsonSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return jsonSliceCmd.Val() - } - case CmdTypeIntPointerSlice: - if intPointerSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return intPointerSliceCmd.Val() - } - case CmdTypeScanDump: - if scanDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return scanDumpCmd.Val() - } - case CmdTypeBFInfo: - if bfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return bfInfoCmd.Val() - } - case CmdTypeCFInfo: - if cfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return cfInfoCmd.Val() - } - case CmdTypeCMSInfo: - if cmsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return cmsInfoCmd.Val() - } - case CmdTypeTopKInfo: - if topKInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return topKInfoCmd.Val() - } - case CmdTypeTDigestInfo: - if tDigestInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return tDigestInfoCmd.Val() - } - case CmdTypeFTSearch: - if ftSearchCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return ftSearchCmd.Val() - } - case CmdTypeFTInfo: - if ftInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return ftInfoCmd.Val() - } - case CmdTypeFTSpellCheck: - if ftSpellCheckCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return ftSpellCheckCmd.Val() - } - case CmdTypeFTSynDump: - if ftSynDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return ftSynDumpCmd.Val() - } - case CmdTypeAggregate: - if aggregateCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return aggregateCmd.Val() - } - case CmdTypeTSTimestampValue: - if tsTimestampValueCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return tsTimestampValueCmd.Val() - } - case CmdTypeTSTimestampValueSlice: - if tsTimestampValueSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return tsTimestampValueSliceCmd.Val() - } - default: - // For unknown command types, return nil - return nil - } - } - - // If we can't get the command type, return nil - return nil -} diff --git a/internal/routing/aggregator_test.go b/internal/routing/aggregator_test.go deleted file mode 100644 index 4de29396df..0000000000 --- a/internal/routing/aggregator_test.go +++ /dev/null @@ -1,427 +0,0 @@ -package routing - -import ( - "errors" - "testing" -) - -// Mock command types for testing -type MockStringCmd struct { - cmdType CmdType - val string -} - -func (cmd *MockStringCmd) GetCmdType() CmdType { - return cmd.cmdType -} - -func (cmd *MockStringCmd) Val() string { - return cmd.val -} - -type MockIntCmd struct { - cmdType CmdType - val int64 -} - -func (cmd *MockIntCmd) GetCmdType() CmdType { - return cmd.cmdType -} - -func (cmd *MockIntCmd) Val() int64 { - return cmd.val -} - -type MockBoolCmd struct { - cmdType CmdType - val bool -} - -func (cmd *MockBoolCmd) GetCmdType() CmdType { - return cmd.cmdType -} - -func (cmd *MockBoolCmd) Val() bool { - return cmd.val -} - -// Legacy command without GetCmdType for comparison -type LegacyStringCmd struct { - val string -} - -func (cmd *LegacyStringCmd) Val() string { - return cmd.val -} - -func BenchmarkExtractCommandValueOptimized(b *testing.B) { - commands := []interface{}{ - &MockStringCmd{cmdType: CmdTypeString, val: "test-value"}, - &MockIntCmd{cmdType: CmdTypeInt, val: 42}, - &MockBoolCmd{cmdType: CmdTypeBool, val: true}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - for _, cmd := range commands { - ExtractCommandValue(cmd) - } - } -} - -func BenchmarkExtractCommandValueLegacy(b *testing.B) { - commands := []interface{}{ - &LegacyStringCmd{val: "test-value"}, - &MockIntCmd{cmdType: CmdTypeInt, val: 42}, - &MockBoolCmd{cmdType: CmdTypeBool, val: true}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - for _, cmd := range commands { - ExtractCommandValue(cmd) - } - } -} - -func TestExtractCommandValue(t *testing.T) { - tests := []struct { - name string - cmd interface{} - expected interface{} - }{ - { - name: "string command", - cmd: &MockStringCmd{cmdType: CmdTypeString, val: "hello"}, - expected: "hello", - }, - { - name: "int command", - cmd: &MockIntCmd{cmdType: CmdTypeInt, val: 123}, - expected: int64(123), - }, - { - name: "bool command", - cmd: &MockBoolCmd{cmdType: CmdTypeBool, val: true}, - expected: true, - }, - { - name: "unsupported command", - cmd: &LegacyStringCmd{val: "test"}, - expected: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := ExtractCommandValue(tt.cmd) - if result != tt.expected { - t.Errorf("ExtractCommandValue() = %v, want %v", result, tt.expected) - } - }) - } -} - -func TestExtractCommandValueIntegration(t *testing.T) { - tests := []struct { - name string - cmd interface{} - expected interface{} - }{ - { - name: "optimized string command", - cmd: &MockStringCmd{cmdType: CmdTypeString, val: "hello"}, - expected: "hello", - }, - { - name: "legacy string command returns nil (no GetCmdType)", - cmd: &LegacyStringCmd{val: "legacy"}, - expected: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := ExtractCommandValue(tt.cmd) - if result != tt.expected { - t.Errorf("ExtractCommandValue() = %v, want %v", result, tt.expected) - } - }) - } -} - -func TestAllSucceededAggregator(t *testing.T) { - agg := &AllSucceededAggregator{} - - err := agg.Add("result1", nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add("result2", nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err := agg.Finish() - if err != nil { - t.Errorf("Finish failed: %v", err) - } - if result != "result1" { - t.Errorf("Expected 'result1', got %v", result) - } - - agg = &AllSucceededAggregator{} - testErr := errors.New("test error") - err = agg.Add("result1", nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add("result2", testErr) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err = agg.Finish() - if err != testErr { - t.Errorf("Expected test error, got %v", err) - } -} - -func TestOneSucceededAggregator(t *testing.T) { - agg := &OneSucceededAggregator{} - - testErr := errors.New("test error") - err := agg.Add("result1", testErr) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add("result2", nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err := agg.Finish() - if err != nil { - t.Errorf("Finish failed: %v", err) - } - if result != "result2" { - t.Errorf("Expected 'result2', got %v", result) - } - - agg = &OneSucceededAggregator{} - err = agg.Add("result1", testErr) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add("result2", testErr) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err = agg.Finish() - if err != testErr { - t.Errorf("Expected test error, got %v", err) - } -} - -func TestAggSumAggregator(t *testing.T) { - agg := &AggSumAggregator{} - - err := agg.Add(int64(10), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(int64(20), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(int64(30), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err := agg.Finish() - if err != nil { - t.Errorf("Finish failed: %v", err) - } - if result != int64(60) { - t.Errorf("Expected 60, got %v", result) - } - - agg = &AggSumAggregator{} - testErr := errors.New("test error") - err = agg.Add(int64(10), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(int64(20), testErr) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err = agg.Finish() - if err != testErr { - t.Errorf("Expected test error, got %v", err) - } -} - -func TestAggMinAggregator(t *testing.T) { - agg := &AggMinAggregator{} - - err := agg.Add(int64(30), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(int64(10), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(int64(20), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err := agg.Finish() - if err != nil { - t.Errorf("Finish failed: %v", err) - } - if result != int64(10) { - t.Errorf("Expected 10, got %v", result) - } -} - -func TestAggMaxAggregator(t *testing.T) { - agg := &AggMaxAggregator{} - - err := agg.Add(int64(10), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(int64(30), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(int64(20), nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err := agg.Finish() - if err != nil { - t.Errorf("Finish failed: %v", err) - } - if result != int64(30) { - t.Errorf("Expected 30, got %v", result) - } -} - -func TestAggLogicalAndAggregator(t *testing.T) { - agg := &AggLogicalAndAggregator{} - - err := agg.Add(true, nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(true, nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(false, nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err := agg.Finish() - if err != nil { - t.Errorf("Finish failed: %v", err) - } - if result != false { - t.Errorf("Expected false, got %v", result) - } -} - -func TestAggLogicalOrAggregator(t *testing.T) { - agg := &AggLogicalOrAggregator{} - - err := agg.Add(false, nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(true, nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add(false, nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err := agg.Finish() - if err != nil { - t.Errorf("Finish failed: %v", err) - } - if result != true { - t.Errorf("Expected true, got %v", result) - } -} - -func TestDefaultKeylessAggregator(t *testing.T) { - agg := &DefaultKeylessAggregator{} - - err := agg.Add("result1", nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add("result2", nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - err = agg.Add("result3", nil) - if err != nil { - t.Errorf("Add failed: %v", err) - } - - result, err := agg.Finish() - if err != nil { - t.Errorf("Finish failed: %v", err) - } - - results, ok := result.([]interface{}) - if !ok { - t.Errorf("Expected []interface{}, got %T", result) - } - if len(results) != 3 { - t.Errorf("Expected 3 results, got %d", len(results)) - } - if results[0] != "result1" || results[1] != "result2" || results[2] != "result3" { - t.Errorf("Unexpected results: %v", results) - } -} - -func TestNewResponseAggregator(t *testing.T) { - tests := []struct { - policy ResponsePolicy - cmdName string - expected string - }{ - {RespAllSucceeded, "test", "*routing.AllSucceededAggregator"}, - {RespOneSucceeded, "test", "*routing.OneSucceededAggregator"}, - {RespAggSum, "test", "*routing.AggSumAggregator"}, - {RespAggMin, "test", "*routing.AggMinAggregator"}, - {RespAggMax, "test", "*routing.AggMaxAggregator"}, - {RespAggLogicalAnd, "test", "*routing.AggLogicalAndAggregator"}, - {RespAggLogicalOr, "test", "*routing.AggLogicalOrAggregator"}, - {RespSpecial, "test", "*routing.SpecialAggregator"}, - } - - for _, test := range tests { - agg := NewResponseAggregator(test.policy, test.cmdName) - if agg == nil { - t.Errorf("NewResponseAggregator returned nil for policy %v", test.policy) - } - _, ok := agg.(ResponseAggregator) - if !ok { - t.Errorf("Aggregator does not implement ResponseAggregator interface") - } - } -} diff --git a/osscluster.go b/osscluster.go index 61847ed799..72fca6fdb3 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1421,7 +1421,9 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd for _, cmd := range cmds { policy := c.getCommandPolicy(ctx, cmd) if policy != nil && !policy.CanBeUsedInPipeline() { - return fmt.Errorf("redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard", cmd.Name()) + return fmt.Errorf( + "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), + ) } slot := c.cmdSlot(ctx, cmd) node, err := c.slotReadOnlyNode(state, slot) @@ -1436,7 +1438,9 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd for _, cmd := range cmds { policy := c.getCommandPolicy(ctx, cmd) if policy != nil && !policy.CanBeUsedInPipeline() { - return fmt.Errorf("redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard", cmd.Name()) + return fmt.Errorf( + "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), + ) } slot := c.cmdSlot(ctx, cmd) node, err := state.slotMasterNode(slot) diff --git a/osscluster_router.go b/osscluster_router.go index a1fe669736..99bc598a35 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -166,8 +166,85 @@ func (c *ClusterClient) createSlotSpecificCommand(ctx context.Context, originalC newArgs = append(newArgs, key) } - // Create new command with the filtered keys - return NewCmd(ctx, newArgs...) + // Create a new command of the same type using the helper function + return createCommandByType(ctx, originalCmd.GetCmdType(), newArgs...) +} + +// createCommandByType creates a new command of the specified type with the given arguments +func createCommandByType(ctx context.Context, cmdType CmdType, args ...interface{}) Cmder { + switch cmdType { + case CmdTypeString: + return NewStringCmd(ctx, args...) + case CmdTypeInt: + return NewIntCmd(ctx, args...) + case CmdTypeBool: + return NewBoolCmd(ctx, args...) + case CmdTypeFloat: + return NewFloatCmd(ctx, args...) + case CmdTypeStringSlice: + return NewStringSliceCmd(ctx, args...) + case CmdTypeIntSlice: + return NewIntSliceCmd(ctx, args...) + case CmdTypeFloatSlice: + return NewFloatSliceCmd(ctx, args...) + case CmdTypeBoolSlice: + return NewBoolSliceCmd(ctx, args...) + case CmdTypeStatus: + return NewStatusCmd(ctx, args...) + case CmdTypeTime: + return NewTimeCmd(ctx, args...) + case CmdTypeMapStringString: + return NewMapStringStringCmd(ctx, args...) + case CmdTypeMapStringInt: + return NewMapStringIntCmd(ctx, args...) + case CmdTypeMapStringInterface: + return NewMapStringInterfaceCmd(ctx, args...) + case CmdTypeMapStringInterfaceSlice: + return NewMapStringInterfaceSliceCmd(ctx, args...) + case CmdTypeSlice: + return NewSliceCmd(ctx, args...) + case CmdTypeStringStructMap: + return NewStringStructMapCmd(ctx, args...) + case CmdTypeXMessageSlice: + return NewXMessageSliceCmd(ctx, args...) + case CmdTypeXStreamSlice: + return NewXStreamSliceCmd(ctx, args...) + case CmdTypeXPending: + return NewXPendingCmd(ctx, args...) + case CmdTypeXPendingExt: + return NewXPendingExtCmd(ctx, args...) + case CmdTypeXAutoClaim: + return NewXAutoClaimCmd(ctx, args...) + case CmdTypeXAutoClaimJustID: + return NewXAutoClaimJustIDCmd(ctx, args...) + case CmdTypeXInfoStreamFull: + return NewXInfoStreamFullCmd(ctx, args...) + case CmdTypeZSlice: + return NewZSliceCmd(ctx, args...) + case CmdTypeZWithKey: + return NewZWithKeyCmd(ctx, args...) + case CmdTypeClusterSlots: + return NewClusterSlotsCmd(ctx, args...) + case CmdTypeGeoPos: + return NewGeoPosCmd(ctx, args...) + case CmdTypeCommandsInfo: + return NewCommandsInfoCmd(ctx, args...) + case CmdTypeSlowLog: + return NewSlowLogCmd(ctx, args...) + case CmdTypeKeyValues: + return NewKeyValuesCmd(ctx, args...) + case CmdTypeZSliceWithKey: + return NewZSliceWithKeyCmd(ctx, args...) + case CmdTypeFunctionList: + return NewFunctionListCmd(ctx, args...) + case CmdTypeFunctionStats: + return NewFunctionStatsCmd(ctx, args...) + case CmdTypeKeyFlags: + return NewKeyFlagsCmd(ctx, args...) + case CmdTypeDuration: + return NewDurationCmd(ctx, time.Second, args...) + } + return NewCmd(ctx, args...) } // executeSpecialCommand handles commands with special routing requirements @@ -283,7 +360,7 @@ func (c *ClusterClient) aggregateKeyedResponses(ctx context.Context, cmd Cmder, // Add results with keys for key, shardCmd := range keyedResults { - value := routing.ExtractCommandValue(shardCmd) + value := ExtractCommandValue(shardCmd) if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { if err := keyedAgg.AddWithKey(key, value, shardCmd.Err()); err != nil { return err @@ -310,7 +387,7 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout cmd.SetErr(err) return err } - value := routing.ExtractCommandValue(shardCmd) + value := ExtractCommandValue(shardCmd) return c.setCommandValue(cmd, value) } @@ -318,7 +395,7 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout // Add all results to aggregator for _, shardCmd := range cmds { - value := routing.ExtractCommandValue(shardCmd) + value := ExtractCommandValue(shardCmd) if err := aggregator.Add(value, shardCmd.Err()); err != nil { return err } diff --git a/osscluster_test.go b/osscluster_test.go index 6860388bee..3ce21d1e52 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -1278,20 +1278,28 @@ var _ = FDescribe("ClusterClient", func() { processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - mu.Lock() - stack = append(stack, "cluster.BeforeProcessPipeline") - mu.Unlock() + cmdStr := cmds[0].String() - err := hook(ctx, cmds) + // Handle SET command (should succeed) + if cmdStr == "set pipeline_test_key pipeline_test_value: " { + mu.Lock() + stack = append(stack, "cluster.BeforeProcessPipeline") + mu.Unlock() - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - mu.Lock() - stack = append(stack, "cluster.AfterProcessPipeline") - mu.Unlock() + err := hook(ctx, cmds) - return err + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("set pipeline_test_key pipeline_test_value: OK")) + mu.Lock() + stack = append(stack, "cluster.AfterProcessPipeline") + mu.Unlock() + + return err + } + + // For other commands (like ping), just pass through without expectations + // since they might fail before reaching this point + return hook(ctx, cmds) } }, }) @@ -1301,20 +1309,27 @@ var _ = FDescribe("ClusterClient", func() { processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - mu.Lock() - stack = append(stack, "shard.BeforeProcessPipeline") - mu.Unlock() + cmdStr := cmds[0].String() - err := hook(ctx, cmds) + // Handle SET command (should succeed) + if cmdStr == "set pipeline_test_key pipeline_test_value: " { + mu.Lock() + stack = append(stack, "shard.BeforeProcessPipeline") + mu.Unlock() - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - mu.Lock() - stack = append(stack, "shard.AfterProcessPipeline") - mu.Unlock() + err := hook(ctx, cmds) - return err + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("set pipeline_test_key pipeline_test_value: OK")) + mu.Lock() + stack = append(stack, "shard.AfterProcessPipeline") + mu.Unlock() + + return err + } + + // For other commands (like ping), just pass through without expectations + return hook(ctx, cmds) } }, }) @@ -1322,7 +1337,7 @@ var _ = FDescribe("ClusterClient", func() { }) _, err = client.Pipelined(ctx, func(pipe redis.Pipeliner) error { - pipe.Ping(ctx) + pipe.Set(ctx, "pipeline_test_key", "pipeline_test_value", 0) return nil }) Expect(err).NotTo(HaveOccurred()) @@ -1340,6 +1355,16 @@ var _ = FDescribe("ClusterClient", func() { })) }) + It("rejects ping command in pipeline", func() { + // Test that ping command fails in pipeline as expected + _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Ping(ctx) + return nil + }) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("redis: cannot pipeline command \"ping\" with request policy ReqAllNodes/ReqAllShards/ReqMultiShard")) + }) + It("supports TxPipeline hook", func() { err := client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) From 7ce5f7842cedffee5c285270e0a464ea0be1e99d Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Fri, 4 Jul 2025 18:11:44 +0300 Subject: [PATCH 13/62] remove FDescribe from cluster tests --- go.mod | 1 - go.sum | 2 -- osscluster_test.go | 2 +- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 9da8e58a64..3bbb8ac4d8 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/bsm/gomega v1.27.10 github.com/cespare/xxhash/v2 v2.3.0 github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f - github.com/fortytw2/leaktest v1.3.0 ) retract ( diff --git a/go.sum b/go.sum index a60f6d5880..4db68f6d4f 100644 --- a/go.sum +++ b/go.sum @@ -6,5 +6,3 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= -github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= diff --git a/osscluster_test.go b/osscluster_test.go index 3ce21d1e52..6475d9859e 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -258,7 +258,7 @@ func slotEqual(s1, s2 redis.ClusterSlot) bool { // ------------------------------------------------------------------------------ -var _ = FDescribe("ClusterClient", func() { +var _ = Describe("ClusterClient", func() { var failover bool var opt *redis.ClusterOptions var client *redis.ClusterClient From 4780dd89a0c03191c1256b48b16a0f3d69aeabe4 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Sun, 6 Jul 2025 13:21:22 +0300 Subject: [PATCH 14/62] Add tests --- internal/routing/aggregator.go | 12 + osscluster_router.go | 17 +- osscluster_test.go | 1599 +++++++++++++++++++++++++------- 3 files changed, 1298 insertions(+), 330 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 3c2e072622..962e592647 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -218,6 +218,9 @@ func (a *AggMinAggregator) Finish() (interface{}, error) { if a.firstErr != nil { return nil, a.firstErr } + if !a.hasResult { + return nil, fmt.Errorf("redis: no valid results to aggregate for min operation") + } return a.min, nil } @@ -264,6 +267,9 @@ func (a *AggMaxAggregator) Finish() (interface{}, error) { if a.firstErr != nil { return nil, a.firstErr } + if !a.hasResult { + return nil, fmt.Errorf("redis: no valid results to aggregate for max operation") + } return a.max, nil } @@ -312,6 +318,9 @@ func (a *AggLogicalAndAggregator) Finish() (interface{}, error) { if a.firstErr != nil { return nil, a.firstErr } + if !a.hasResult { + return nil, fmt.Errorf("redis: no valid results to aggregate for logical AND operation") + } return a.result, nil } @@ -360,6 +369,9 @@ func (a *AggLogicalOrAggregator) Finish() (interface{}, error) { if a.firstErr != nil { return nil, a.firstErr } + if !a.hasResult { + return nil, fmt.Errorf("redis: no valid results to aggregate for logical OR operation") + } return a.result, nil } diff --git a/osscluster_router.go b/osscluster_router.go index 99bc598a35..7b765964b8 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -312,12 +312,23 @@ func (c *ClusterClient) executeParallel(ctx context.Context, cmd Cmder, nodes [] close(results) }() - // Collect results + // Collect results and check for errors cmds := make([]Cmder, 0, len(nodes)) + var firstErr error + for result := range results { + if result.err != nil && firstErr == nil { + firstErr = result.err + } cmds = append(cmds, result.cmd) } + // If there was an error and no policy specified, fail fast + if firstErr != nil && (policy == nil || policy.Response == routing.RespDefaultKeyless) { + cmd.SetErr(firstErr) + return firstErr + } + return c.aggregateResponses(cmd, cmds, policy) } @@ -342,11 +353,11 @@ func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder return firstErr } - return c.aggregateKeyedResponses(ctx, cmd, keyedResults, keyOrder, policy) + return c.aggregateKeyedResponses(cmd, keyedResults, keyOrder, policy) } // aggregateKeyedResponses aggregates responses while preserving key order -func (c *ClusterClient) aggregateKeyedResponses(ctx context.Context, cmd Cmder, keyedResults map[string]Cmder, keyOrder []string, policy *routing.CommandPolicy) error { +func (c *ClusterClient) aggregateKeyedResponses(cmd Cmder, keyedResults map[string]Cmder, keyOrder []string, policy *routing.CommandPolicy) error { if len(keyedResults) == 0 { return fmt.Errorf("redis: no results to aggregate") } diff --git a/osscluster_test.go b/osscluster_test.go index 6475d9859e..de91c4d602 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -10,13 +10,13 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/internal/hashtag" + "github.com/redis/go-redis/v9/internal/routing" ) type clusterScenario struct { @@ -286,7 +286,7 @@ var _ = Describe("ClusterClient", func() { Expect(cnt).To(Equal(int64(1))) }) - It("GET follows redirects", func() { + It("should follow redirects for GET", func() { err := client.Set(ctx, "A", "VALUE", 0).Err() Expect(err).NotTo(HaveOccurred()) @@ -309,7 +309,7 @@ var _ = Describe("ClusterClient", func() { Expect(v).To(Equal("VALUE")) }) - It("SET follows redirects", func() { + It("should follow redirects for SET", func() { if !failover { Eventually(func() error { return client.SwapNodes(ctx, "A") @@ -324,7 +324,7 @@ var _ = Describe("ClusterClient", func() { Expect(v).To(Equal("VALUE")) }) - It("distributes keys", func() { + It("should distribute keys", func() { for i := 0; i < 100; i++ { err := client.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err() Expect(err).NotTo(HaveOccurred()) @@ -345,7 +345,7 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) }) - It("distributes keys when using EVAL", func() { + It("should distribute keys when using EVAL", func() { script := redis.NewScript(` local r = redis.call('SET', KEYS[1], ARGV[1]) return r @@ -373,7 +373,7 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) }) - It("distributes scripts when using Script Load", func() { + It("should distribute scripts when using Script Load", func() { client.ScriptFlush(ctx) script := redis.NewScript(`return 'Unique script'`) @@ -390,7 +390,7 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) }) - It("checks all shards when using Script Exists", func() { + It("should check all shards when using Script Exists", func() { client.ScriptFlush(ctx) script := redis.NewScript(`return 'First script'`) @@ -405,7 +405,7 @@ var _ = Describe("ClusterClient", func() { Expect(val).To(Equal([]bool{true, false})) }) - It("flushes scripts from all shards when using ScriptFlush", func() { + It("should flush scripts from all shards when using ScriptFlush", func() { script := redis.NewScript(`return 'Unnecessary script'`) script.Load(ctx, client) @@ -418,7 +418,7 @@ var _ = Describe("ClusterClient", func() { Expect(val).To(Equal([]bool{false})) }) - It("supports Watch", func() { + It("should support Watch", func() { var incr func(string) error // Transactionally increments key using GET and SET commands. @@ -464,7 +464,7 @@ var _ = Describe("ClusterClient", func() { assertPipeline := func(keys []string) { - It("follows redirects", func() { + It("should follow redirects", func() { if !failover { for _, key := range keys { Eventually(func() error { @@ -514,9 +514,9 @@ var _ = Describe("ClusterClient", func() { } }) - It("works with missing keys", func() { - pipe.Set(ctx, "A{s}", "A_value", 0) - pipe.Set(ctx, "C{s}", "C_value", 0) + It("should work with missing keys", func() { + pipe.Set(ctx, "A", "A_value", 0) + pipe.Set(ctx, "C", "C_value", 0) _, err := pipe.Exec(ctx) Expect(err).NotTo(HaveOccurred()) @@ -548,7 +548,7 @@ var _ = Describe("ClusterClient", func() { keys := []string{"A", "B", "C", "D", "E", "F", "G"} assertPipeline(keys) - It("doesn't fail node with context.Canceled error", func() { + It("should not fail node with context.Canceled error", func() { ctx, cancel := context.WithCancel(context.Background()) cancel() pipe.Set(ctx, "A", "A_value", 0) @@ -564,7 +564,7 @@ var _ = Describe("ClusterClient", func() { } }) - It("doesn't fail node with context.DeadlineExceeded error", func() { + It("should not fail node with context.DeadlineExceeded error", func() { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() @@ -620,91 +620,10 @@ var _ = Describe("ClusterClient", func() { }) }) - It("supports PubSub", func() { - pubsub := client.Subscribe(ctx, "mychannel") - defer pubsub.Close() - - Eventually(func() error { - _, err := client.Publish(ctx, "mychannel", "hello").Result() - if err != nil { - return err - } - - msg, err := pubsub.ReceiveTimeout(ctx, time.Second) - if err != nil { - return err - } - - _, ok := msg.(*redis.Message) - if !ok { - return fmt.Errorf("got %T, wanted *redis.Message", msg) - } - - return nil - }, 30*time.Second).ShouldNot(HaveOccurred()) - }) - - It("supports PubSub with ReadOnly option", func() { - opt = redisClusterOptions() - opt.ReadOnly = true - client = cluster.newClusterClient(ctx, opt) - + It("should support PubSub", func() { pubsub := client.Subscribe(ctx, "mychannel") defer pubsub.Close() - Eventually(func() error { - var masterPubsubChannels atomic.Int64 - var slavePubsubChannels atomic.Int64 - - err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { - info := master.InfoMap(ctx, "stats") - if info.Err() != nil { - return info.Err() - } - - pc, err := strconv.Atoi(info.Item("Stats", "pubsub_channels")) - if err != nil { - return err - } - - masterPubsubChannels.Add(int64(pc)) - - return nil - }) - if err != nil { - return err - } - - err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error { - info := slave.InfoMap(ctx, "stats") - if info.Err() != nil { - return info.Err() - } - - pc, err := strconv.Atoi(info.Item("Stats", "pubsub_channels")) - if err != nil { - return err - } - - slavePubsubChannels.Add(int64(pc)) - - return nil - }) - if err != nil { - return err - } - - if c := masterPubsubChannels.Load(); c != int64(0) { - return fmt.Errorf("total master pubsub_channels is %d; expected 0", c) - } - - if c := slavePubsubChannels.Load(); c != int64(1) { - return fmt.Errorf("total slave pubsub_channels is %d; expected 1", c) - } - - return nil - }, 30*time.Second).ShouldNot(HaveOccurred()) - Eventually(func() error { _, err := client.Publish(ctx, "mychannel", "hello").Result() if err != nil { @@ -725,91 +644,10 @@ var _ = Describe("ClusterClient", func() { }, 30*time.Second).ShouldNot(HaveOccurred()) }) - It("supports sharded PubSub", func() { - pubsub := client.SSubscribe(ctx, "mychannel") - defer pubsub.Close() - - Eventually(func() error { - _, err := client.SPublish(ctx, "mychannel", "hello").Result() - if err != nil { - return err - } - - msg, err := pubsub.ReceiveTimeout(ctx, time.Second) - if err != nil { - return err - } - - _, ok := msg.(*redis.Message) - if !ok { - return fmt.Errorf("got %T, wanted *redis.Message", msg) - } - - return nil - }, 30*time.Second).ShouldNot(HaveOccurred()) - }) - - It("supports sharded PubSub with ReadOnly option", func() { - opt = redisClusterOptions() - opt.ReadOnly = true - client = cluster.newClusterClient(ctx, opt) - + It("should support sharded PubSub", func() { pubsub := client.SSubscribe(ctx, "mychannel") defer pubsub.Close() - Eventually(func() error { - var masterPubsubShardChannels atomic.Int64 - var slavePubsubShardChannels atomic.Int64 - - err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { - info := master.InfoMap(ctx, "stats") - if info.Err() != nil { - return info.Err() - } - - pc, err := strconv.Atoi(info.Item("Stats", "pubsubshard_channels")) - if err != nil { - return err - } - - masterPubsubShardChannels.Add(int64(pc)) - - return nil - }) - if err != nil { - return err - } - - err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error { - info := slave.InfoMap(ctx, "stats") - if info.Err() != nil { - return info.Err() - } - - pc, err := strconv.Atoi(info.Item("Stats", "pubsubshard_channels")) - if err != nil { - return err - } - - slavePubsubShardChannels.Add(int64(pc)) - - return nil - }) - if err != nil { - return err - } - - if c := masterPubsubShardChannels.Load(); c != int64(0) { - return fmt.Errorf("total master pubsubshard_channels is %d; expected 0", c) - } - - if c := slavePubsubShardChannels.Load(); c != int64(1) { - return fmt.Errorf("total slave pubsubshard_channels is %d; expected 1", c) - } - - return nil - }, 30*time.Second).ShouldNot(HaveOccurred()) - Eventually(func() error { _, err := client.SPublish(ctx, "mychannel", "hello").Result() if err != nil { @@ -830,7 +668,7 @@ var _ = Describe("ClusterClient", func() { }, 30*time.Second).ShouldNot(HaveOccurred()) }) - It("supports PubSub.Ping without channels", func() { + It("should support PubSub.Ping without channels", func() { pubsub := client.Subscribe(ctx) defer pubsub.Close() @@ -887,12 +725,12 @@ var _ = Describe("ClusterClient", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("returns pool stats", func() { + It("should return pool stats", func() { stats := client.PoolStats() Expect(stats).To(BeAssignableToTypeOf(&redis.PoolStats{})) }) - It("returns an error when there are no attempts left", func() { + It("should return an error when there are no attempts left", func() { opt := redisClusterOptions() opt.MaxRedirects = -1 client := cluster.newClusterClient(ctx, opt) @@ -908,7 +746,7 @@ var _ = Describe("ClusterClient", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("determines hash slots correctly for generic commands", func() { + It("should determine hash slots correctly for generic commands", func() { opt := redisClusterOptions() opt.MaxRedirects = -1 client := cluster.newClusterClient(ctx, opt) @@ -934,7 +772,7 @@ var _ = Describe("ClusterClient", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("follows node redirection immediately", func() { + It("should follow node redirection immediately", func() { // Configure retry backoffs far in excess of the expected duration of redirection opt := redisClusterOptions() opt.MinRetryBackoff = 10 * time.Minute @@ -960,7 +798,7 @@ var _ = Describe("ClusterClient", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("calls fn for every master node", func() { + It("should call fn for every master node", func() { for i := 0; i < 10; i++ { Expect(client.Set(ctx, strconv.Itoa(i), "", 0).Err()).NotTo(HaveOccurred()) } @@ -1173,7 +1011,7 @@ var _ = Describe("ClusterClient", func() { Expect(len(keys)).To(BeNumerically("~", nkeys, nkeys/10)) }) - It("supports Process hook", func() { + It("should support Process hook", func() { testCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -1262,7 +1100,7 @@ var _ = Describe("ClusterClient", func() { })) }) - It("supports Pipeline hook", func() { + It("should support Pipeline hook", func() { err := client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) @@ -1355,7 +1193,7 @@ var _ = Describe("ClusterClient", func() { })) }) - It("rejects ping command in pipeline", func() { + It("should reject ping command in pipeline", func() { // Test that ping command fails in pipeline as expected _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) @@ -1365,7 +1203,7 @@ var _ = Describe("ClusterClient", func() { Expect(err.Error()).To(ContainSubstring("redis: cannot pipeline command \"ping\" with request policy ReqAllNodes/ReqAllShards/ReqMultiShard")) }) - It("supports TxPipeline hook", func() { + It("should support TxPipeline hook", func() { err := client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) @@ -1672,12 +1510,12 @@ var _ = Describe("ClusterClient without nodes", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("Ping returns an error", func() { + It("should return an error for Ping", func() { err := client.Ping(ctx).Err() Expect(err).To(MatchError("redis: cluster has no nodes")) }) - It("pipeline returns an error", func() { + It("should return an error for pipeline", func() { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil @@ -1699,12 +1537,12 @@ var _ = Describe("ClusterClient without valid nodes", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("returns an error", func() { + It("should return an error when cluster support is disabled", func() { err := client.Ping(ctx).Err() Expect(err).To(MatchError("ERR This instance has cluster support disabled")) }) - It("pipeline returns an error", func() { + It("should return an error for pipeline when cluster support is disabled", func() { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil @@ -1734,7 +1572,7 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("recovers when Cluster recovers", func() { + It("should recover when Cluster recovers", func() { err := client.Ping(ctx).Err() Expect(err).To(HaveOccurred()) @@ -1752,13 +1590,13 @@ var _ = Describe("ClusterClient timeout", func() { }) testTimeout := func() { - It("Ping timeouts", func() { + It("should timeout Ping", func() { err := client.Ping(ctx).Err() Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) }) - It("Pipeline timeouts", func() { + It("should timeout Pipeline", func() { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil @@ -1767,7 +1605,7 @@ var _ = Describe("ClusterClient timeout", func() { Expect(err.(net.Error).Timeout()).To(BeTrue()) }) - It("Tx timeouts", func() { + It("should timeout Tx", func() { err := client.Watch(ctx, func(tx *redis.Tx) error { return tx.Ping(ctx).Err() }, "foo") @@ -1775,7 +1613,7 @@ var _ = Describe("ClusterClient timeout", func() { Expect(err.(net.Error).Timeout()).To(BeTrue()) }) - It("Tx Pipeline timeouts", func() { + It("should timeout Tx Pipeline", func() { err := client.Watch(ctx, func(tx *redis.Tx) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) @@ -1834,140 +1672,1247 @@ var _ = Describe("ClusterClient timeout", func() { }) }) -var _ = Describe("ClusterClient ParseURL", func() { - cases := []struct { - test string - url string - o *redis.ClusterOptions // expected value - err error - }{ - { - test: "ParseRedisURL", - url: "redis://localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}}, - }, { - test: "ParseRedissURL", - url: "rediss://localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, - }, { - test: "MissingRedisPort", - url: "redis://localhost", - o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}}, - }, { - test: "MissingRedissPort", - url: "rediss://localhost", - o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, - }, { - test: "MultipleRedisURLs", - url: "redis://localhost:123?addr=localhost:1234&addr=localhost:12345", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}}, - }, { - test: "MultipleRedissURLs", - url: "rediss://localhost:123?addr=localhost:1234&addr=localhost:12345", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, - }, { - test: "OnlyPassword", - url: "redis://:bar@localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Password: "bar"}, - }, { - test: "OnlyUser", - url: "redis://foo@localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo"}, - }, { - test: "RedisUsernamePassword", - url: "redis://foo:bar@localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo", Password: "bar"}, - }, { - test: "RedissUsernamePassword", - url: "rediss://foo:bar@localhost:123?addr=localhost:1234", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, Username: "foo", Password: "bar", TLSConfig: &tls.Config{ServerName: "localhost"}}, - }, { - test: "QueryParameters", - url: "redis://localhost:123?read_timeout=2&pool_fifo=true&addr=localhost:1234", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, ReadTimeout: 2 * time.Second, PoolFIFO: true}, - }, { - test: "DisabledTimeout", - url: "redis://localhost:123?conn_max_idle_time=0", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, - }, { - test: "DisabledTimeoutNeg", - url: "redis://localhost:123?conn_max_idle_time=-1", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, - }, { - test: "UseDefault", - url: "redis://localhost:123?conn_max_idle_time=", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, - }, { - test: "FailingTimeoutSeconds", - url: "redis://localhost:123?failing_timeout_seconds=25", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, FailingTimeoutSeconds: 25}, - }, { - test: "Protocol", - url: "redis://localhost:123?protocol=2", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Protocol: 2}, - }, { - test: "ClientName", - url: "redis://localhost:123?client_name=cluster_hi", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ClientName: "cluster_hi"}, - }, { - test: "UseDefaultMissing=", - url: "redis://localhost:123?conn_max_idle_time", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, - }, { - test: "InvalidQueryAddr", - url: "rediss://foo:bar@localhost:123?addr=rediss://foo:barr@localhost:1234", - err: errors.New(`redis: unable to parse addr param: rediss://foo:barr@localhost:1234`), - }, { - test: "InvalidInt", - url: "redis://localhost?pool_size=five", - err: errors.New(`redis: invalid pool_size number: strconv.Atoi: parsing "five": invalid syntax`), - }, { - test: "InvalidBool", - url: "redis://localhost?pool_fifo=yes", - err: errors.New(`redis: invalid pool_fifo boolean: expected true/false/1/0 or an empty string, got "yes"`), - }, { - test: "UnknownParam", - url: "redis://localhost?abc=123", - err: errors.New("redis: unexpected option: abc"), - }, { - test: "InvalidScheme", - url: "https://google.com", - err: errors.New("redis: invalid URL scheme: https"), - }, - } +var _ = Describe("Command Tips tests", func() { + var client *redis.ClusterClient + + BeforeEach(func() { + opt := redisClusterOptions() + client = cluster.newClusterClient(ctx, opt) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should verify COMMAND tips match router policy types", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + expectedPolicies := map[string]struct { + RequestPolicy string + ResponsePolicy string + }{ + "touch": { + RequestPolicy: "multi_shard", + ResponsePolicy: "agg_sum", + }, + "flushall": { + RequestPolicy: "all_shards", + ResponsePolicy: "all_succeeded", + }, + } + + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + for cmdName, expected := range expectedPolicies { + actualCmd := cmds[cmdName] + + Expect(actualCmd.Tips).NotTo(BeNil()) + + // Verify request_policy from COMMAND matches router policy + actualRequestPolicy := actualCmd.Tips.Request.String() + Expect(actualRequestPolicy).To(Equal(expected.RequestPolicy)) + + // Verify response_policy from COMMAND matches router policy + actualResponsePolicy := actualCmd.Tips.Response.String() + Expect(actualResponsePolicy).To(Equal(expected.ResponsePolicy)) + } + }) + + Describe("Explicit Routing Policy Tests", func() { + It("should test explicit routing policy for TOUCH", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify TOUCH command has multi_shard policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + touchCmd := cmds["touch"] + + Expect(touchCmd.Tips).NotTo(BeNil()) + Expect(touchCmd.Tips.Request.String()).To(Equal("multi_shard")) + Expect(touchCmd.Tips.Response.String()).To(Equal("agg_sum")) - It("match ParseClusterURL", func() { - for i := range cases { - tc := cases[i] - actual, err := redis.ParseClusterURL(tc.url) - if tc.err != nil { - Expect(err).Should(MatchError(tc.err)) - } else { + keys := []string{"key1", "key2", "key3", "key4", "key5"} + for _, key := range keys { + err := client.Set(ctx, key, "value", 0).Err() Expect(err).NotTo(HaveOccurred()) } - if err == nil { - Expect(tc.o).NotTo(BeNil()) - - Expect(tc.o.Addrs).To(Equal(actual.Addrs)) - Expect(tc.o.TLSConfig).To(Equal(actual.TLSConfig)) - Expect(tc.o.Username).To(Equal(actual.Username)) - Expect(tc.o.Password).To(Equal(actual.Password)) - Expect(tc.o.MaxRetries).To(Equal(actual.MaxRetries)) - Expect(tc.o.MinRetryBackoff).To(Equal(actual.MinRetryBackoff)) - Expect(tc.o.MaxRetryBackoff).To(Equal(actual.MaxRetryBackoff)) - Expect(tc.o.DialTimeout).To(Equal(actual.DialTimeout)) - Expect(tc.o.ReadTimeout).To(Equal(actual.ReadTimeout)) - Expect(tc.o.WriteTimeout).To(Equal(actual.WriteTimeout)) - Expect(tc.o.PoolFIFO).To(Equal(actual.PoolFIFO)) - Expect(tc.o.PoolSize).To(Equal(actual.PoolSize)) - Expect(tc.o.MinIdleConns).To(Equal(actual.MinIdleConns)) - Expect(tc.o.ConnMaxLifetime).To(Equal(actual.ConnMaxLifetime)) - Expect(tc.o.ConnMaxIdleTime).To(Equal(actual.ConnMaxIdleTime)) - Expect(tc.o.PoolTimeout).To(Equal(actual.PoolTimeout)) - Expect(tc.o.FailingTimeoutSeconds).To(Equal(actual.FailingTimeoutSeconds)) + result := client.Touch(ctx, keys...) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal(int64(len(keys)))) + }) + + It("should test explicit routing policy for FLUSHALL", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify FLUSHALL command has all_shards policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + flushallCmd := cmds["flushall"] + + Expect(flushallCmd.Tips).NotTo(BeNil()) + Expect(flushallCmd.Tips.Request.String()).To(Equal("all_shards")) + Expect(flushallCmd.Tips.Response.String()).To(Equal("all_succeeded")) + + testKeys := []string{"test1", "test2", "test3"} + for _, key := range testKeys { + err := client.Set(ctx, key, "value", 0).Err() + Expect(err).NotTo(HaveOccurred()) } - } + + err = client.FlushAll(ctx).Err() + Expect(err).NotTo(HaveOccurred()) + + for _, key := range testKeys { + exists := client.Exists(ctx, key) + Expect(exists.Val()).To(Equal(int64(0))) + } + }) + + It("should test explicit routing policy for PING", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify PING command has all_shards policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + pingCmd := cmds["ping"] + Expect(pingCmd.Tips).NotTo(BeNil()) + Expect(pingCmd.Tips.Request.String()).To(Equal("all_shards")) + Expect(pingCmd.Tips.Response.String()).To(Equal("all_succeeded")) + + result := client.Ping(ctx) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("PONG")) + }) + + It("should test explicit routing policy for DBSIZE", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify DBSIZE command has all_shards policy with agg_sum response + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + dbsizeCmd := cmds["dbsize"] + Expect(dbsizeCmd.Tips).NotTo(BeNil()) + Expect(dbsizeCmd.Tips.Request.String()).To(Equal("all_shards")) + Expect(dbsizeCmd.Tips.Response.String()).To(Equal("agg_sum")) + + testKeys := []string{"dbsize_test1", "dbsize_test2", "dbsize_test3"} + for _, key := range testKeys { + err := client.Set(ctx, key, "value", 0).Err() + Expect(err).NotTo(HaveOccurred()) + } + + size := client.DBSize(ctx) + Expect(size.Err()).NotTo(HaveOccurred()) + Expect(size.Val()).To(BeNumerically(">=", int64(len(testKeys)))) + }) + }) + + Describe("DDL Commands Routing Policy Tests", func() { + BeforeEach(func() { + info := client.Info(ctx, "modules") + if info.Err() != nil || !strings.Contains(info.Val(), "search") { + Skip("Search module not available") + } + }) + + It("should test DDL commands routing policy for FT.CREATE", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify FT.CREATE command routing policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + ftCreateCmd, exists := cmds["ft.create"] + if !exists || ftCreateCmd.Tips == nil { + Skip("FT.CREATE command or tips not available") + } + + // DDL commands should NOT be broadcasted - they should go to coordinator only + Expect(ftCreateCmd.Tips).NotTo(BeNil()) + requestPolicy := ftCreateCmd.Tips.Request.String() + Expect(requestPolicy).NotTo(Equal("all_shards")) + Expect(requestPolicy).NotTo(Equal("all_nodes")) + + indexName := "test_index_create" + client.FTDropIndex(ctx, indexName) + + result := client.FTCreate(ctx, indexName, + &redis.FTCreateOptions{ + OnHash: true, + Prefix: []interface{}{"doc:"}, + }, + &redis.FieldSchema{ + FieldName: "title", + FieldType: redis.SearchFieldTypeText, + }) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("OK")) + + infoResult := client.FTInfo(ctx, indexName) + Expect(infoResult.Err()).NotTo(HaveOccurred()) + Expect(infoResult.Val().IndexName).To(Equal(indexName)) + client.FTDropIndex(ctx, indexName) + }) + + It("should test DDL commands routing policy for FT.ALTER", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify FT.ALTER command routing policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + ftAlterCmd, exists := cmds["ft.alter"] + if !exists || ftAlterCmd.Tips == nil { + Skip("FT.ALTER command or tips not available") + } + + Expect(ftAlterCmd.Tips).NotTo(BeNil()) + requestPolicy := ftAlterCmd.Tips.Request.String() + Expect(requestPolicy).NotTo(Equal("all_shards")) + Expect(requestPolicy).NotTo(Equal("all_nodes")) + + indexName := "test_index_alter" + client.FTDropIndex(ctx, indexName) + + result := client.FTCreate(ctx, indexName, + &redis.FTCreateOptions{ + OnHash: true, + Prefix: []interface{}{"doc:"}, + }, + &redis.FieldSchema{ + FieldName: "title", + FieldType: redis.SearchFieldTypeText, + }) + Expect(result.Err()).NotTo(HaveOccurred()) + + alterResult := client.FTAlter(ctx, indexName, false, + []interface{}{"description", redis.SearchFieldTypeText.String()}) + Expect(alterResult.Err()).NotTo(HaveOccurred()) + Expect(alterResult.Val()).To(Equal("OK")) + client.FTDropIndex(ctx, indexName) + }) + + It("should route keyed commands to correct shard based on hash slot", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 1)) + + // Single keyed command should go to exactly one shard - determined by hash slot + testKey := "test_key_12345" + testValue := "test_value" + + result := client.Set(ctx, testKey, testValue, 0) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("OK")) + + time.Sleep(200 * time.Millisecond) + + var targetNodeAddr string + foundNodes := 0 + + for _, node := range masterNodes { + getResult := node.client.Get(ctx, testKey) + if getResult.Err() == nil && getResult.Val() == testValue { + foundNodes++ + targetNodeAddr = node.addr + } else { + } + } + + Expect(foundNodes).To(Equal(1)) + Expect(targetNodeAddr).NotTo(BeEmpty()) + + // Multiple commands with same key should go to same shard + finalValue := "" + for i := 0; i < 5; i++ { + finalValue = fmt.Sprintf("value_%d", i) + result := client.Set(ctx, testKey, finalValue, 0) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("OK")) + } + + time.Sleep(200 * time.Millisecond) + + var currentTargetNode string + foundNodesAfterUpdate := 0 + + for _, node := range masterNodes { + getResult := node.client.Get(ctx, testKey) + if getResult.Err() == nil && getResult.Val() == finalValue { + foundNodesAfterUpdate++ + currentTargetNode = node.addr + } else { + } + } + + // All commands with same key should go to same shard + Expect(foundNodesAfterUpdate).To(Equal(1)) + Expect(currentTargetNode).To(Equal(targetNodeAddr)) + }) + + It("should aggregate responses according to explicit aggregation policies", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 1)) + + // verify TOUCH command has agg_sum policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + touchCmd, exists := cmds["touch"] + if !exists || touchCmd.Tips == nil { + Skip("TOUCH command or tips not available") + } + + Expect(touchCmd.Tips.Response.String()).To(Equal("agg_sum")) + + testKeys := []string{ + "touch_test_key_1111", // These keys should map to different hash slots + "touch_test_key_2222", + "touch_test_key_3333", + "touch_test_key_4444", + "touch_test_key_5555", + } + + // Set keys on different shards + keysPerShard := make(map[string][]string) + for _, key := range testKeys { + result := client.Set(ctx, key, "test_value", 0) + Expect(result.Err()).NotTo(HaveOccurred()) + + // Find which shard contains this key + for _, node := range masterNodes { + getResult := node.client.Get(ctx, key) + if getResult.Err() == nil { + keysPerShard[node.addr] = append(keysPerShard[node.addr], key) + break + } + } + } + + // Verify keys are distributed across multiple shards + shardsWithKeys := len(keysPerShard) + Expect(shardsWithKeys).To(BeNumerically(">", 1)) + + // Execute TOUCH command on all keys - this should aggregate results using agg_sum + touchResult := client.Touch(ctx, testKeys...) + Expect(touchResult.Err()).NotTo(HaveOccurred()) + + totalTouched := touchResult.Val() + Expect(totalTouched).To(Equal(int64(len(testKeys)))) + + totalKeysOnShards := 0 + for _, keys := range keysPerShard { + totalKeysOnShards += len(keys) + } + + Expect(totalKeysOnShards).To(Equal(len(testKeys))) + + // FLUSHALL command with all_succeeded aggregation policy + flushallCmd, exists := cmds["flushall"] + if !exists || flushallCmd.Tips == nil { + Skip("FLUSHALL command or tips not available") + } + + Expect(flushallCmd.Tips.Response.String()).To(Equal("all_succeeded")) + + for i := 0; i < len(masterNodes); i++ { + testKey := fmt.Sprintf("flush_test_key_%d_%d", i, time.Now().UnixNano()) + result := client.Set(ctx, testKey, "test_data", 0) + Expect(result.Err()).NotTo(HaveOccurred()) + } + + flushResult := client.FlushAll(ctx) + Expect(flushResult.Err()).NotTo(HaveOccurred()) + Expect(flushResult.Val()).To(Equal("OK")) + + for _, node := range masterNodes { + dbSizeResult := node.client.DBSize(ctx) + Expect(dbSizeResult.Err()).NotTo(HaveOccurred()) + Expect(dbSizeResult.Val()).To(Equal(int64(0))) + } + + // PFCOUNT command aggregation policy - verify agg_min policy + pfcountCmd, exists := cmds["pfcount"] + if !exists || pfcountCmd.Tips == nil { + Skip("PFCOUNT command or tips not available") + } + + actualPfcountPolicy := pfcountCmd.Tips.Response.String() + + if actualPfcountPolicy != "agg_min" { + Skip("PFCOUNT does not have agg_min policy in this Redis version") + } + + // Create HyperLogLog keys on different shards with different cardinalities + hllKeys := []string{ + "hll_test_key_1111", + "hll_test_key_2222", + "hll_test_key_3333", + } + + hllData := map[string][]string{ + "hll_test_key_1111": {"elem1", "elem2", "elem3", "elem4", "elem5"}, + "hll_test_key_2222": {"elem6", "elem7", "elem8"}, + "hll_test_key_3333": {"elem9", "elem10", "elem11", "elem12", "elem13", "elem14", "elem15"}, + } + + hllKeysPerShard := make(map[string][]string) + expectedCounts := make(map[string]int64) + + for key, elements := range hllData { + + interfaceElements := make([]interface{}, len(elements)) + for i, elem := range elements { + interfaceElements[i] = elem + } + result := client.PFAdd(ctx, key, interfaceElements...) + Expect(result.Err()).NotTo(HaveOccurred()) + + countResult := client.PFCount(ctx, key) + Expect(countResult.Err()).NotTo(HaveOccurred()) + expectedCounts[key] = countResult.Val() + + for _, node := range masterNodes { + // Check if key exists on this shard by trying to get its count + shardCountResult := node.client.PFCount(ctx, key) + if shardCountResult.Err() == nil && shardCountResult.Val() > 0 { + hllKeysPerShard[node.addr] = append(hllKeysPerShard[node.addr], key) + break + } + } + } + + // Verify keys are distributed across multiple shards + shardsWithHLLKeys := len(hllKeysPerShard) + Expect(shardsWithHLLKeys).To(BeNumerically(">", 1)) + + // Execute PFCOUNT command on all keys - should aggregate using agg_min + pfcountResult := client.PFCount(ctx, hllKeys...) + Expect(pfcountResult.Err()).NotTo(HaveOccurred()) + + aggregatedCount := pfcountResult.Val() + + // Verify the aggregation by manually getting counts from each shard + var shardCounts []int64 + for shardAddr, keys := range hllKeysPerShard { + if len(keys) == 0 { + continue + } + + // Find the node for this shard + var shardNode *masterNode + for i := range masterNodes { + if masterNodes[i].addr == shardAddr { + shardNode = &masterNodes[i] + break + } + } + Expect(shardNode).NotTo(BeNil()) + + // Get count for keys on this specific shard + shardResult := shardNode.client.PFCount(ctx, keys...) + Expect(shardResult.Err()).NotTo(HaveOccurred()) + + shardCount := shardResult.Val() + shardCounts = append(shardCounts, shardCount) + } + + // Find the minimum count from all shards + expectedMin := shardCounts[0] + for _, count := range shardCounts[1:] { + if count < expectedMin { + expectedMin = count + } + } + + // Verify agg_min aggregation worked correctly + Expect(aggregatedCount).To(Equal(expectedMin)) + + // EXISTS command aggregation policy - verify agg_logical_and policy + existsCmd, exists := cmds["exists"] + if !exists || existsCmd.Tips == nil { + Skip("EXISTS command or tips not available") + } + + actualExistsPolicy := existsCmd.Tips.Response.String() + if actualExistsPolicy != "agg_logical_and" { + Skip("EXISTS does not have agg_logical_and policy in this Redis version") + } + + existsTestKeys := []string{ + "exists_test_key_1111", + "exists_test_key_2222", + "exists_test_key_3333", + } + + for _, key := range existsTestKeys { + result := client.Set(ctx, key, "exists_test_value", 0) + Expect(result.Err()).NotTo(HaveOccurred()) + } + + //All keys exist - should return true + existsResult := client.Exists(ctx, existsTestKeys...) + Expect(existsResult.Err()).NotTo(HaveOccurred()) + + allExistCount := existsResult.Val() + Expect(allExistCount).To(Equal(int64(len(existsTestKeys)))) + + // Delete one key and test again - logical AND should handle mixed results + deletedKey := existsTestKeys[0] + delResult := client.Del(ctx, deletedKey) + Expect(delResult.Err()).NotTo(HaveOccurred()) + + // Check EXISTS again - now one key is missing + existsResult2 := client.Exists(ctx, existsTestKeys...) + Expect(existsResult2.Err()).NotTo(HaveOccurred()) + + partialExistCount := existsResult2.Val() + + // Should return count of existing keys (2 out of 3) + Expect(partialExistCount).To(Equal(int64(len(existsTestKeys) - 1))) + }) + + It("should verify command aggregation policies", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + commandPolicies := map[string]string{ + "touch": "agg_sum", + "flushall": "all_succeeded", + "pfcount": "agg_min", + "exists": "agg_logical_and", + } + + for cmdName, expectedPolicy := range commandPolicies { + cmd, exists := cmds[cmdName] + if !exists { + continue + } + + if cmd.Tips == nil { + continue + } + + actualPolicy := cmd.Tips.Response.String() + Expect(actualPolicy).To(Equal(expectedPolicy)) + } + }) + + It("should properly aggregate responses from keyless commands executed on multiple shards", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 1)) + + // PING command with all_shards policy - should aggregate responses + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + pingCmd, exists := cmds["ping"] + if exists && pingCmd.Tips != nil { + } + + pingResult := client.Ping(ctx) + Expect(pingResult.Err()).NotTo(HaveOccurred()) + Expect(pingResult.Val()).To(Equal("PONG")) + + // Verify PING was executed on all shards by checking individual nodes + for _, node := range masterNodes { + nodePingResult := node.client.Ping(ctx) + Expect(nodePingResult.Err()).NotTo(HaveOccurred()) + Expect(nodePingResult.Val()).To(Equal("PONG")) + } + + // Test 2: DBSIZE command aggregation across shards - verify agg_sum policy + testKeys := []string{ + "dbsize_test_key_1111", + "dbsize_test_key_2222", + "dbsize_test_key_3333", + "dbsize_test_key_4444", + } + + for _, key := range testKeys { + result := client.Set(ctx, key, "test_value", 0) + Expect(result.Err()).NotTo(HaveOccurred()) + } + + dbSizeResult := client.DBSize(ctx) + Expect(dbSizeResult.Err()).NotTo(HaveOccurred()) + + totalSize := dbSizeResult.Val() + Expect(totalSize).To(BeNumerically(">=", int64(len(testKeys)))) + + // Verify aggregation by manually getting sizes from each shard + totalManualSize := int64(0) + + for _, node := range masterNodes { + nodeDbSizeResult := node.client.DBSize(ctx) + Expect(nodeDbSizeResult.Err()).NotTo(HaveOccurred()) + + nodeSize := nodeDbSizeResult.Val() + totalManualSize += nodeSize + } + + // Verify aggregation worked correctly + Expect(totalSize).To(Equal(totalManualSize)) + }) + + It("should properly aggregate responses from keyed commands executed on multiple shards", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 1)) + + // MGET command aggregation across multiple keys on different shards - verify agg_sum policy + testData := map[string]string{ + "mget_test_key_1111": "value1", + "mget_test_key_2222": "value2", + "mget_test_key_3333": "value3", + "mget_test_key_4444": "value4", + "mget_test_key_5555": "value5", + } + + keyLocations := make(map[string]string) + for key, value := range testData { + + result := client.Set(ctx, key, value, 0) + Expect(result.Err()).NotTo(HaveOccurred()) + + for _, node := range masterNodes { + getResult := node.client.Get(ctx, key) + if getResult.Err() == nil && getResult.Val() == value { + keyLocations[key] = node.addr + break + } + } + } + + shardsUsed := make(map[string]bool) + for _, shardAddr := range keyLocations { + shardsUsed[shardAddr] = true + } + Expect(len(shardsUsed)).To(BeNumerically(">", 1)) + + keys := make([]string, 0, len(testData)) + expectedValues := make([]interface{}, 0, len(testData)) + + for key, value := range testData { + keys = append(keys, key) + expectedValues = append(expectedValues, value) + } + + mgetResult := client.MGet(ctx, keys...) + Expect(mgetResult.Err()).NotTo(HaveOccurred()) + + actualValues := mgetResult.Val() + Expect(len(actualValues)).To(Equal(len(keys))) + Expect(actualValues).To(ConsistOf(expectedValues)) + + // Verify all values are correctly aggregated + for i, key := range keys { + expectedValue := testData[key] + actualValue := actualValues[i] + Expect(actualValue).To(Equal(expectedValue)) + } + + // DEL command aggregation across multiple keys on different shards + delResult := client.Del(ctx, keys...) + Expect(delResult.Err()).NotTo(HaveOccurred()) + + deletedCount := delResult.Val() + Expect(deletedCount).To(Equal(int64(len(keys)))) + + // Verify keys are actually deleted from their respective shards + for key, shardAddr := range keyLocations { + var targetNode *masterNode + for i := range masterNodes { + if masterNodes[i].addr == shardAddr { + targetNode = &masterNodes[i] + break + } + } + Expect(targetNode).NotTo(BeNil()) + + getResult := targetNode.client.Get(ctx, key) + Expect(getResult.Err()).To(HaveOccurred()) + } + + // EXISTS command aggregation across multiple keys + existsTestData := map[string]string{ + "exists_agg_key_1111": "value1", + "exists_agg_key_2222": "value2", + "exists_agg_key_3333": "value3", + } + + existsKeys := make([]string, 0, len(existsTestData)) + for key, value := range existsTestData { + result := client.Set(ctx, key, value, 0) + Expect(result.Err()).NotTo(HaveOccurred()) + existsKeys = append(existsKeys, key) + } + + // Add a non-existent key to the list + nonExistentKey := "non_existent_key_9999" + existsKeys = append(existsKeys, nonExistentKey) + + existsResult := client.Exists(ctx, existsKeys...) + Expect(existsResult.Err()).NotTo(HaveOccurred()) + + existsCount := existsResult.Val() + Expect(existsCount).To(Equal(int64(len(existsTestData)))) + }) + + It("should propagate coordinator errors to client without modification", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 0)) + + invalidSlotResult := client.ClusterAddSlotsRange(ctx, 99999, 100000) + coordinatorErr := invalidSlotResult.Err() + + if coordinatorErr != nil { + // Verify the error is a Redis error + var redisErr redis.Error + Expect(errors.As(coordinatorErr, &redisErr)).To(BeTrue()) + + // Verify error message is preserved exactly as returned by coordinator + errorMsg := coordinatorErr.Error() + Expect(errorMsg).To(SatisfyAny( + ContainSubstring("slot"), + ContainSubstring("ERR"), + ContainSubstring("Invalid"), + )) + + // Test that the same error occurs when calling coordinator directly + coordinatorNode := masterNodes[0] + directResult := coordinatorNode.client.ClusterAddSlotsRange(ctx, 99999, 100000) + directErr := directResult.Err() + + if directErr != nil { + Expect(coordinatorErr.Error()).To(Equal(directErr.Error())) + } + } + + // Try cluster forget with invalid node ID + invalidNodeID := "invalid_node_id_12345" + forgetResult := client.ClusterForget(ctx, invalidNodeID) + forgetErr := forgetResult.Err() + + if forgetErr != nil { + var redisErr redis.Error + Expect(errors.As(forgetErr, &redisErr)).To(BeTrue()) + + errorMsg := forgetErr.Error() + Expect(errorMsg).To(SatisfyAny( + ContainSubstring("Unknown node"), + ContainSubstring("Invalid node"), + ContainSubstring("ERR"), + )) + + coordinatorNode := masterNodes[0] + directForgetResult := coordinatorNode.client.ClusterForget(ctx, invalidNodeID) + directForgetErr := directForgetResult.Err() + + if directForgetErr != nil { + Expect(forgetErr.Error()).To(Equal(directForgetErr.Error())) + } + } + + // Test error type preservation and format + keySlotResult := client.ClusterKeySlot(ctx, "") + keySlotErr := keySlotResult.Err() + + if keySlotErr != nil { + var redisErr redis.Error + Expect(errors.As(keySlotErr, &redisErr)).To(BeTrue()) + + errorMsg := keySlotErr.Error() + Expect(len(errorMsg)).To(BeNumerically(">", 0)) + Expect(errorMsg).NotTo(ContainSubstring("wrapped")) + Expect(errorMsg).NotTo(ContainSubstring("context")) + } + + // Verify error propagation consistency + clusterInfoResult := client.ClusterInfo(ctx) + clusterInfoErr := clusterInfoResult.Err() + + if clusterInfoErr != nil { + var redisErr redis.Error + Expect(errors.As(clusterInfoErr, &redisErr)).To(BeTrue()) + + coordinatorNode := masterNodes[0] + directInfoResult := coordinatorNode.client.ClusterInfo(ctx) + directInfoErr := directInfoResult.Err() + + if directInfoErr != nil { + Expect(clusterInfoErr.Error()).To(Equal(directInfoErr.Error())) + } + } + + // Verify no error modification in router + invalidReplicateResult := client.ClusterReplicate(ctx, "00000000000000000000000000000000invalid00") + invalidReplicateErr := invalidReplicateResult.Err() + + if invalidReplicateErr != nil { + var redisErr redis.Error + Expect(errors.As(invalidReplicateErr, &redisErr)).To(BeTrue()) + + errorMsg := invalidReplicateErr.Error() + Expect(errorMsg).NotTo(ContainSubstring("router")) + Expect(errorMsg).NotTo(ContainSubstring("cluster client")) + Expect(errorMsg).NotTo(ContainSubstring("failed to execute")) + + Expect(errorMsg).To(SatisfyAny( + HavePrefix("ERR"), + ContainSubstring("Invalid"), + ContainSubstring("Unknown"), + )) + } + }) + + It("should route keyless commands to arbitrary shards using round robin", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + var numMasters int + var numMastersMu sync.Mutex + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + numMastersMu.Lock() + numMasters++ + numMastersMu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(numMasters).To(BeNumerically(">", 1)) + + err = client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + return master.ConfigResetStat(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + + // Helper function to get ECHO command counts from all nodes + getEchoCounts := func() map[string]int { + echoCounts := make(map[string]int) + var echoCountsMu sync.Mutex + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + info := master.Info(ctx, "server") + Expect(info.Err()).NotTo(HaveOccurred()) + + serverInfo := info.Val() + portStart := strings.Index(serverInfo, "tcp_port:") + portLine := serverInfo[portStart:] + portEnd := strings.Index(portLine, "\r\n") + if portEnd == -1 { + portEnd = len(portLine) + } + port := strings.TrimPrefix(portLine[:portEnd], "tcp_port:") + + commandStats := master.Info(ctx, "commandstats") + count := 0 + if commandStats.Err() == nil { + stats := commandStats.Val() + cmdStatKey := "cmdstat_echo:" + if strings.Contains(stats, cmdStatKey) { + statStart := strings.Index(stats, cmdStatKey) + statLine := stats[statStart:] + statEnd := strings.Index(statLine, "\r\n") + if statEnd == -1 { + statEnd = len(statLine) + } + statLine = statLine[:statEnd] + + callsStart := strings.Index(statLine, "calls=") + if callsStart != -1 { + callsStr := statLine[callsStart+6:] + callsEnd := strings.Index(callsStr, ",") + if callsEnd == -1 { + callsEnd = strings.Index(callsStr, "\r") + if callsEnd == -1 { + callsEnd = len(callsStr) + } + } + if callsCount, err := strconv.Atoi(callsStr[:callsEnd]); err == nil { + count = callsCount + } + } + } + } + + echoCountsMu.Lock() + echoCounts[port] = count + echoCountsMu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + return echoCounts + } + + // Single ECHO command should go to exactly one shard + result := client.Echo(ctx, "single_test") + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("single_test")) + + time.Sleep(200 * time.Millisecond) + + // Verify single command went to exactly one shard + echoCounts := getEchoCounts() + shardsWithEcho := 0 + for _, count := range echoCounts { + if count > 0 { + shardsWithEcho++ + Expect(count).To(Equal(1)) + } + } + Expect(shardsWithEcho).To(Equal(1)) + + // Test Multiple ECHO commands should distribute across all shards using round robin + numCommands := numMasters * 3 + + for i := 0; i < numCommands; i++ { + result := client.Echo(ctx, fmt.Sprintf("multi_test_%d", i)) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal(fmt.Sprintf("multi_test_%d", i))) + } + + time.Sleep(200 * time.Millisecond) + + echoCounts = getEchoCounts() + totalEchos := 0 + shardsWithEchos := 0 + for _, count := range echoCounts { + if count > 0 { + shardsWithEchos++ + } + totalEchos += count + } + + // All shards should now have some ECHO commands + Expect(shardsWithEchos).To(Equal(numMasters)) + + expectedTotal := 1 + numCommands + Expect(totalEchos).To(Equal(expectedTotal)) + }) + }) + + var _ = Describe("ClusterClient ParseURL", func() { + cases := []struct { + test string + url string + o *redis.ClusterOptions // expected value + err error + }{ + { + test: "ParseRedisURL", + url: "redis://localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}}, + }, { + test: "ParseRedissURL", + url: "rediss://localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "MissingRedisPort", + url: "redis://localhost", + o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}}, + }, { + test: "MissingRedissPort", + url: "rediss://localhost", + o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "MultipleRedisURLs", + url: "redis://localhost:123?addr=localhost:1234&addr=localhost:12345", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}}, + }, { + test: "MultipleRedissURLs", + url: "rediss://localhost:123?addr=localhost:1234&addr=localhost:12345", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "OnlyPassword", + url: "redis://:bar@localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Password: "bar"}, + }, { + test: "OnlyUser", + url: "redis://foo@localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo"}, + }, { + test: "RedisUsernamePassword", + url: "redis://foo:bar@localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo", Password: "bar"}, + }, { + test: "RedissUsernamePassword", + url: "rediss://foo:bar@localhost:123?addr=localhost:1234", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, Username: "foo", Password: "bar", TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "QueryParameters", + url: "redis://localhost:123?read_timeout=2&pool_fifo=true&addr=localhost:1234", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, ReadTimeout: 2 * time.Second, PoolFIFO: true}, + }, { + test: "DisabledTimeout", + url: "redis://localhost:123?conn_max_idle_time=0", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, + }, { + test: "DisabledTimeoutNeg", + url: "redis://localhost:123?conn_max_idle_time=-1", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, + }, { + test: "UseDefault", + url: "redis://localhost:123?conn_max_idle_time=", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, + }, { + test: "Protocol", + url: "redis://localhost:123?protocol=2", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Protocol: 2}, + }, { + test: "ClientName", + url: "redis://localhost:123?client_name=cluster_hi", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ClientName: "cluster_hi"}, + }, { + test: "UseDefaultMissing=", + url: "redis://localhost:123?conn_max_idle_time", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, + }, { + test: "InvalidQueryAddr", + url: "rediss://foo:bar@localhost:123?addr=rediss://foo:barr@localhost:1234", + err: errors.New(`redis: unable to parse addr param: rediss://foo:barr@localhost:1234`), + }, { + test: "InvalidInt", + url: "redis://localhost?pool_size=five", + err: errors.New(`redis: invalid pool_size number: strconv.Atoi: parsing "five": invalid syntax`), + }, { + test: "InvalidBool", + url: "redis://localhost?pool_fifo=yes", + err: errors.New(`redis: invalid pool_fifo boolean: expected true/false/1/0 or an empty string, got "yes"`), + }, { + test: "UnknownParam", + url: "redis://localhost?abc=123", + err: errors.New("redis: unexpected option: abc"), + }, { + test: "InvalidScheme", + url: "https://google.com", + err: errors.New("redis: invalid URL scheme: https"), + }, + } + + It("should match ParseClusterURL", func() { + for i := range cases { + tc := cases[i] + actual, err := redis.ParseClusterURL(tc.url) + if tc.err != nil { + Expect(err).Should(MatchError(tc.err)) + } else { + Expect(err).NotTo(HaveOccurred()) + } + + if err == nil { + Expect(tc.o).NotTo(BeNil()) + + Expect(tc.o.Addrs).To(Equal(actual.Addrs)) + Expect(tc.o.TLSConfig).To(Equal(actual.TLSConfig)) + Expect(tc.o.Username).To(Equal(actual.Username)) + Expect(tc.o.Password).To(Equal(actual.Password)) + Expect(tc.o.MaxRetries).To(Equal(actual.MaxRetries)) + Expect(tc.o.MinRetryBackoff).To(Equal(actual.MinRetryBackoff)) + Expect(tc.o.MaxRetryBackoff).To(Equal(actual.MaxRetryBackoff)) + Expect(tc.o.DialTimeout).To(Equal(actual.DialTimeout)) + Expect(tc.o.ReadTimeout).To(Equal(actual.ReadTimeout)) + Expect(tc.o.WriteTimeout).To(Equal(actual.WriteTimeout)) + Expect(tc.o.PoolFIFO).To(Equal(actual.PoolFIFO)) + Expect(tc.o.PoolSize).To(Equal(actual.PoolSize)) + Expect(tc.o.MinIdleConns).To(Equal(actual.MinIdleConns)) + Expect(tc.o.ConnMaxLifetime).To(Equal(actual.ConnMaxLifetime)) + Expect(tc.o.ConnMaxIdleTime).To(Equal(actual.ConnMaxIdleTime)) + Expect(tc.o.PoolTimeout).To(Equal(actual.PoolTimeout)) + } + } + }) + + It("should distribute keyless commands randomly across shards using random shard picker", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Create a cluster client with random shard picker + opt := redisClusterOptions() + opt.ShardPicker = &routing.RandomPicker{} + randomClient := cluster.newClusterClient(ctx, opt) + defer randomClient.Close() + + Eventually(func() error { + return randomClient.Ping(ctx).Err() + }, 30*time.Second).ShouldNot(HaveOccurred()) + + var numMasters int + var numMastersMu sync.Mutex + err := randomClient.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + numMastersMu.Lock() + numMasters++ + numMastersMu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(numMasters).To(BeNumerically(">", 1)) + + // Reset command statistics on all masters + err = randomClient.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + return master.ConfigResetStat(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + + // Helper function to get ECHO command counts from all nodes + getEchoCounts := func() map[string]int { + echoCounts := make(map[string]int) + var echoCountsMu sync.Mutex + err := randomClient.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + port := addr[strings.LastIndex(addr, ":")+1:] + + info, err := master.Info(ctx, "commandstats").Result() + if err != nil { + return err + } + + count := 0 + if strings.Contains(info, "cmdstat_echo:") { + lines := strings.Split(info, "\n") + for _, line := range lines { + if strings.HasPrefix(line, "cmdstat_echo:") { + parts := strings.Split(line, ",") + if len(parts) > 0 { + callsPart := strings.Split(parts[0], "=") + if len(callsPart) > 1 { + if parsedCount, parseErr := strconv.Atoi(callsPart[1]); parseErr == nil { + count = parsedCount + } + } + } + break + } + } + } + + echoCountsMu.Lock() + echoCounts[port] = count + echoCountsMu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + return echoCounts + } + + // Execute multiple ECHO commands and measure distribution + numCommands := 100 + for i := 0; i < numCommands; i++ { + result := randomClient.Echo(ctx, fmt.Sprintf("random_test_%d", i)) + Expect(result.Err()).NotTo(HaveOccurred()) + } + + echoCounts := getEchoCounts() + + totalEchos := 0 + shardsWithEchos := 0 + + for _, count := range echoCounts { + if count > 0 { + shardsWithEchos++ + } + totalEchos += count + } + + Expect(totalEchos).To(Equal(numCommands)) + Expect(shardsWithEchos).To(BeNumerically(">=", 2)) + }) }) }) From 7d80b8a09696a9eb69fd1734fcf5db2e4896b22a Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Sun, 6 Jul 2025 14:44:18 +0300 Subject: [PATCH 15/62] fix aggregation test --- osscluster_router.go | 5 ++ osscluster_test.go | 182 +++++++++++++------------------------------ 2 files changed, 57 insertions(+), 130 deletions(-) diff --git a/osscluster_router.go b/osscluster_router.go index 7b765964b8..39a217cca0 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -418,6 +418,11 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout // createAggregator creates the appropriate response aggregator func (c *ClusterClient) createAggregator(policy *routing.CommandPolicy, cmd Cmder, isKeyed bool) routing.ResponseAggregator { if policy != nil { + // For multi-shard commands that operate on multiple keys (like MGET), + // use keyed aggregator even if policy says all_succeeded + if policy.Request == routing.ReqMultiShard && isKeyed { + return routing.NewDefaultAggregator(true) + } return routing.NewResponseAggregator(policy.Response, cmd.Name()) } diff --git a/osscluster_test.go b/osscluster_test.go index de91c4d602..16ab536bc4 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -2073,145 +2073,65 @@ var _ = Describe("Command Tips tests", func() { Expect(dbSizeResult.Val()).To(Equal(int64(0))) } - // PFCOUNT command aggregation policy - verify agg_min policy - pfcountCmd, exists := cmds["pfcount"] - if !exists || pfcountCmd.Tips == nil { - Skip("PFCOUNT command or tips not available") + // WAIT command aggregation policy - verify agg_min policy + waitCmd, exists := cmds["wait"] + if !exists || waitCmd.Tips == nil { + Skip("WAIT command or tips not available") } - actualPfcountPolicy := pfcountCmd.Tips.Response.String() + Expect(waitCmd.Tips.Response.String()).To(Equal("agg_min")) - if actualPfcountPolicy != "agg_min" { - Skip("PFCOUNT does not have agg_min policy in this Redis version") - } - - // Create HyperLogLog keys on different shards with different cardinalities - hllKeys := []string{ - "hll_test_key_1111", - "hll_test_key_2222", - "hll_test_key_3333", - } - - hllData := map[string][]string{ - "hll_test_key_1111": {"elem1", "elem2", "elem3", "elem4", "elem5"}, - "hll_test_key_2222": {"elem6", "elem7", "elem8"}, - "hll_test_key_3333": {"elem9", "elem10", "elem11", "elem12", "elem13", "elem14", "elem15"}, - } - - hllKeysPerShard := make(map[string][]string) - expectedCounts := make(map[string]int64) - - for key, elements := range hllData { - - interfaceElements := make([]interface{}, len(elements)) - for i, elem := range elements { - interfaceElements[i] = elem - } - result := client.PFAdd(ctx, key, interfaceElements...) - Expect(result.Err()).NotTo(HaveOccurred()) - - countResult := client.PFCount(ctx, key) - Expect(countResult.Err()).NotTo(HaveOccurred()) - expectedCounts[key] = countResult.Val() - - for _, node := range masterNodes { - // Check if key exists on this shard by trying to get its count - shardCountResult := node.client.PFCount(ctx, key) - if shardCountResult.Err() == nil && shardCountResult.Val() > 0 { - hllKeysPerShard[node.addr] = append(hllKeysPerShard[node.addr], key) - break - } - } - } - - // Verify keys are distributed across multiple shards - shardsWithHLLKeys := len(hllKeysPerShard) - Expect(shardsWithHLLKeys).To(BeNumerically(">", 1)) - - // Execute PFCOUNT command on all keys - should aggregate using agg_min - pfcountResult := client.PFCount(ctx, hllKeys...) - Expect(pfcountResult.Err()).NotTo(HaveOccurred()) - - aggregatedCount := pfcountResult.Val() - - // Verify the aggregation by manually getting counts from each shard - var shardCounts []int64 - for shardAddr, keys := range hllKeysPerShard { - if len(keys) == 0 { - continue - } - - // Find the node for this shard - var shardNode *masterNode - for i := range masterNodes { - if masterNodes[i].addr == shardAddr { - shardNode = &masterNodes[i] - break - } - } - Expect(shardNode).NotTo(BeNil()) - - // Get count for keys on this specific shard - shardResult := shardNode.client.PFCount(ctx, keys...) - Expect(shardResult.Err()).NotTo(HaveOccurred()) - - shardCount := shardResult.Val() - shardCounts = append(shardCounts, shardCount) - } + // Set up some data to replicate + testKey := "wait_test_key_1111" + result := client.Set(ctx, testKey, "test_value", 0) + Expect(result.Err()).NotTo(HaveOccurred()) - // Find the minimum count from all shards - expectedMin := shardCounts[0] - for _, count := range shardCounts[1:] { - if count < expectedMin { - expectedMin = count - } - } + // Execute WAIT command - should aggregate using agg_min across all shards + // WAIT waits for a given number of replicas to acknowledge writes + // With agg_min policy, it returns the minimum number of replicas that acknowledged + waitResult := client.Wait(ctx, 0, 1000) // Wait for 0 replicas with 1 second timeout + Expect(waitResult.Err()).NotTo(HaveOccurred()) - // Verify agg_min aggregation worked correctly - Expect(aggregatedCount).To(Equal(expectedMin)) + // The result should be the minimum number of replicas across all shards + // Since we're asking for 0 replicas, all shards should return 0, so min is 0 + minReplicas := waitResult.Val() + Expect(minReplicas).To(BeNumerically(">=", 0)) - // EXISTS command aggregation policy - verify agg_logical_and policy - existsCmd, exists := cmds["exists"] - if !exists || existsCmd.Tips == nil { - Skip("EXISTS command or tips not available") + // SCRIPT EXISTS command aggregation policy - verify agg_logical_and policy + scriptExistsCmd, exists := cmds["script exists"] + if !exists || scriptExistsCmd.Tips == nil { + Skip("SCRIPT EXISTS command or tips not available") } - actualExistsPolicy := existsCmd.Tips.Response.String() - if actualExistsPolicy != "agg_logical_and" { - Skip("EXISTS does not have agg_logical_and policy in this Redis version") - } + Expect(scriptExistsCmd.Tips.Response.String()).To(Equal("agg_logical_and")) - existsTestKeys := []string{ - "exists_test_key_1111", - "exists_test_key_2222", - "exists_test_key_3333", - } + // Load a script on all shards + testScript := "return 'hello'" + scriptLoadResult := client.ScriptLoad(ctx, testScript) + Expect(scriptLoadResult.Err()).NotTo(HaveOccurred()) + scriptSHA := scriptLoadResult.Val() - for _, key := range existsTestKeys { - result := client.Set(ctx, key, "exists_test_value", 0) - Expect(result.Err()).NotTo(HaveOccurred()) - } - - //All keys exist - should return true - existsResult := client.Exists(ctx, existsTestKeys...) - Expect(existsResult.Err()).NotTo(HaveOccurred()) + // Verify script exists on all shards using SCRIPT EXISTS + // With agg_logical_and policy, it should return true only if script exists on ALL shards + scriptExistsResult := client.ScriptExists(ctx, scriptSHA) + Expect(scriptExistsResult.Err()).NotTo(HaveOccurred()) - allExistCount := existsResult.Val() - Expect(allExistCount).To(Equal(int64(len(existsTestKeys)))) - - // Delete one key and test again - logical AND should handle mixed results - deletedKey := existsTestKeys[0] - delResult := client.Del(ctx, deletedKey) - Expect(delResult.Err()).NotTo(HaveOccurred()) + existsResults := scriptExistsResult.Val() + Expect(len(existsResults)).To(Equal(1)) + Expect(existsResults[0]).To(BeTrue()) // Script should exist on all shards - // Check EXISTS again - now one key is missing - existsResult2 := client.Exists(ctx, existsTestKeys...) - Expect(existsResult2.Err()).NotTo(HaveOccurred()) + // Test with a non-existent script SHA + nonExistentSHA := "0000000000000000000000000000000000000000" + scriptExistsResult2 := client.ScriptExists(ctx, nonExistentSHA) + Expect(scriptExistsResult2.Err()).NotTo(HaveOccurred()) - partialExistCount := existsResult2.Val() + existsResults2 := scriptExistsResult2.Val() + Expect(len(existsResults2)).To(Equal(1)) + Expect(existsResults2[0]).To(BeFalse()) // Script should not exist on any shard - // Should return count of existing keys (2 out of 3) - Expect(partialExistCount).To(Equal(int64(len(existsTestKeys) - 1))) + // Test with mixed scenario - flush scripts from one shard manually + // This is harder to test in practice since SCRIPT FLUSH affects all shards + // So we'll just verify the basic functionality works }) It("should verify command aggregation policies", func() { @@ -2221,10 +2141,12 @@ var _ = Describe("Command Tips tests", func() { Expect(err).NotTo(HaveOccurred()) commandPolicies := map[string]string{ - "touch": "agg_sum", - "flushall": "all_succeeded", - "pfcount": "agg_min", - "exists": "agg_logical_and", + "touch": "agg_sum", + "flushall": "all_succeeded", + "pfcount": "all_succeeded", + "exists": "agg_sum", + "script exists": "agg_logical_and", + "wait": "agg_min", } for cmdName, expectedPolicy := range commandPolicies { @@ -2341,7 +2263,7 @@ var _ = Describe("Command Tips tests", func() { Expect(err).NotTo(HaveOccurred()) Expect(len(masterNodes)).To(BeNumerically(">", 1)) - // MGET command aggregation across multiple keys on different shards - verify agg_sum policy + // MGET command aggregation across multiple keys on different shards - verify all_succeeded policy with keyed aggregation testData := map[string]string{ "mget_test_key_1111": "value1", "mget_test_key_2222": "value2", From d70bf764358e3e585ef8abdfc5673109e4bc6554 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Sun, 6 Jul 2025 14:59:30 +0300 Subject: [PATCH 16/62] fix mget test --- osscluster_router.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/osscluster_router.go b/osscluster_router.go index 39a217cca0..a7647f9040 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "strings" "sync" "time" @@ -418,10 +419,13 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout // createAggregator creates the appropriate response aggregator func (c *ClusterClient) createAggregator(policy *routing.CommandPolicy, cmd Cmder, isKeyed bool) routing.ResponseAggregator { if policy != nil { - // For multi-shard commands that operate on multiple keys (like MGET), - // use keyed aggregator even if policy says all_succeeded - if policy.Request == routing.ReqMultiShard && isKeyed { - return routing.NewDefaultAggregator(true) + // For specific multi-shard commands that need keyed aggregation despite having + // all_succeeded policy (like MGET which needs to preserve key order) + if policy.Request == routing.ReqMultiShard && policy.Response == routing.RespAllSucceeded && isKeyed { + cmdName := strings.ToLower(cmd.Name()) + if cmdName == "mget" { + return routing.NewDefaultAggregator(true) + } } return routing.NewResponseAggregator(policy.Response, cmd.Name()) } From 9a64190719ea9aac05c2252f8b18fe3a4a13b6d2 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Sun, 6 Jul 2025 15:20:53 +0300 Subject: [PATCH 17/62] fix mget test --- osscluster_router.go | 78 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/osscluster_router.go b/osscluster_router.go index a7647f9040..ea64ccd55b 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -335,16 +335,41 @@ func (c *ClusterClient) executeParallel(ctx context.Context, cmd Cmder, nodes [] // aggregateMultiSlotResults aggregates results from multi-slot execution func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder, results <-chan slotResult, keyOrder []string, policy *routing.CommandPolicy) error { - keyedResults := make(map[string]Cmder) + keyedResults := make(map[string]interface{}) var firstErr error for result := range results { if result.err != nil && firstErr == nil { firstErr = result.err } - if result.cmd != nil { - for _, key := range result.keys { - keyedResults[key] = result.cmd + if result.cmd != nil && result.err == nil { + // For MGET, extract individual values from the array result + if strings.ToLower(cmd.Name()) == "mget" { + if sliceCmd, ok := result.cmd.(*SliceCmd); ok { + values := sliceCmd.Val() + if len(values) == len(result.keys) { + for i, key := range result.keys { + keyedResults[key] = values[i] + } + } else { + // Fallback: map all keys to the entire result + for _, key := range result.keys { + keyedResults[key] = values + } + } + } else { + // Fallback for non-SliceCmd results + value := ExtractCommandValue(result.cmd) + for _, key := range result.keys { + keyedResults[key] = value + } + } + } else { + // For other commands, map each key to the entire result + value := ExtractCommandValue(result.cmd) + for _, key := range result.keys { + keyedResults[key] = value + } } } } @@ -354,7 +379,36 @@ func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder return firstErr } - return c.aggregateKeyedResponses(cmd, keyedResults, keyOrder, policy) + return c.aggregateKeyedValues(cmd, keyedResults, keyOrder, policy) +} + +// aggregateKeyedValues aggregates individual key-value pairs while preserving key order +func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string]interface{}, keyOrder []string, policy *routing.CommandPolicy) error { + if len(keyedResults) == 0 { + return fmt.Errorf("redis: no results to aggregate") + } + + aggregator := c.createAggregator(policy, cmd, true) + + // Set key order for keyed aggregators + if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { + keyedAgg.SetKeyOrder(keyOrder) + } + + // Add results with keys + for key, value := range keyedResults { + if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { + if err := keyedAgg.AddWithKey(key, value, nil); err != nil { + return err + } + } else { + if err := aggregator.Add(value, nil); err != nil { + return err + } + } + } + + return c.finishAggregation(cmd, aggregator) } // aggregateKeyedResponses aggregates responses while preserving key order @@ -418,15 +472,13 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout // createAggregator creates the appropriate response aggregator func (c *ClusterClient) createAggregator(policy *routing.CommandPolicy, cmd Cmder, isKeyed bool) routing.ResponseAggregator { + cmdName := strings.ToLower(cmd.Name()) + // For MGET without policy, use keyed aggregator + if cmdName == "mget" { + return routing.NewDefaultAggregator(true) + } + if policy != nil { - // For specific multi-shard commands that need keyed aggregation despite having - // all_succeeded policy (like MGET which needs to preserve key order) - if policy.Request == routing.ReqMultiShard && policy.Response == routing.RespAllSucceeded && isKeyed { - cmdName := strings.ToLower(cmd.Name()) - if cmdName == "mget" { - return routing.NewDefaultAggregator(true) - } - } return routing.NewResponseAggregator(policy.Response, cmd.Name()) } From c45ed9b4b9c71e23e55e8adab700d9b6b6e1c08b Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Sun, 6 Jul 2025 15:49:19 +0300 Subject: [PATCH 18/62] remove aggregateKeyedResponses --- osscluster_router.go | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/osscluster_router.go b/osscluster_router.go index ea64ccd55b..a4ddbbf4ab 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -411,36 +411,6 @@ func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string] return c.finishAggregation(cmd, aggregator) } -// aggregateKeyedResponses aggregates responses while preserving key order -func (c *ClusterClient) aggregateKeyedResponses(cmd Cmder, keyedResults map[string]Cmder, keyOrder []string, policy *routing.CommandPolicy) error { - if len(keyedResults) == 0 { - return fmt.Errorf("redis: no results to aggregate") - } - - aggregator := c.createAggregator(policy, cmd, true) - - // Set key order for keyed aggregators - if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { - keyedAgg.SetKeyOrder(keyOrder) - } - - // Add results with keys - for key, shardCmd := range keyedResults { - value := ExtractCommandValue(shardCmd) - if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { - if err := keyedAgg.AddWithKey(key, value, shardCmd.Err()); err != nil { - return err - } - } else { - if err := aggregator.Add(value, shardCmd.Err()); err != nil { - return err - } - } - } - - return c.finishAggregation(cmd, aggregator) -} - // aggregateResponses aggregates multiple shard responses func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *routing.CommandPolicy) error { if len(cmds) == 0 { From ae268a6334baee40320b74e736e40571daeb0adf Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Wed, 8 Oct 2025 09:24:06 +0300 Subject: [PATCH 19/62] added scaffolding for the req-resp manager --- command_policy_manager.go | 45 +++++++++++++++++++++++++++++++++++++++ osscluster.go | 14 ++++++------ osscluster_router.go | 14 ++++++------ 3 files changed, 61 insertions(+), 12 deletions(-) create mode 100644 command_policy_manager.go diff --git a/command_policy_manager.go b/command_policy_manager.go new file mode 100644 index 0000000000..c456ce244a --- /dev/null +++ b/command_policy_manager.go @@ -0,0 +1,45 @@ +package redis + +import ( + "sync" + + "github.com/redis/go-redis/v9/internal/routing" +) + +var defaultPolicies = map[string]*routing.CommandPolicy{} + +type commandPolicyManager struct { + rwmutex *sync.RWMutex + clientPolicies map[string]*routing.CommandPolicy + overwrittenPolicies map[string]*routing.CommandPolicy +} + +func newCommandPolicyManager(overwrites interface{}) *commandPolicyManager { + return &commandPolicyManager{} +} + +func (cpm *commandPolicyManager) updateClientPolicies(policies interface{}) { + cpm.rwmutex.Lock() + defer cpm.rwmutex.Unlock() +} + +func (cpm *commandPolicyManager) getCmdPolicy(cmd Cmder) *routing.CommandPolicy { + cpm.rwmutex.RLock() + defer cpm.rwmutex.RUnlock() + + cmdName := cmd.Name() + + if policy, ok := cpm.overwrittenPolicies[cmdName]; ok { + return policy + } + + if policy, ok := cpm.clientPolicies[cmdName]; ok { + return policy + } + + if policy, ok := defaultPolicies[cmdName]; ok { + return policy + } + + return nil +} diff --git a/osscluster.go b/osscluster.go index 72fca6fdb3..3ec91f2ad2 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1014,10 +1014,11 @@ func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, er // or more underlying connections. It's safe for concurrent use by // multiple goroutines. type ClusterClient struct { - opt *ClusterOptions - nodes *clusterNodes - state *clusterStateHolder - cmdsInfoCache *cmdsInfoCache + opt *ClusterOptions + nodes *clusterNodes + state *clusterStateHolder + cmdsInfoCache *cmdsInfoCache + cmdPolicyManager *commandPolicyManager cmdable hooksMixin } @@ -1035,6 +1036,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdable = c.Process + c.cmdPolicyManager = newCommandPolicyManager(nil) c.initHooks(hooks{ dial: nil, process: c.process, @@ -1419,7 +1421,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { for _, cmd := range cmds { - policy := c.getCommandPolicy(ctx, cmd) + policy := c.cmdPolicyManager.getCmdPolicy(cmd) if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), @@ -1436,7 +1438,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd } for _, cmd := range cmds { - policy := c.getCommandPolicy(ctx, cmd) + policy := c.cmdPolicyManager.getCmdPolicy(cmd) if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), diff --git a/osscluster_router.go b/osscluster_router.go index a4ddbbf4ab..4f4acd4e1e 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -21,7 +21,7 @@ type slotResult struct { // routeAndRun routes a command to the appropriate cluster nodes and executes it func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { - policy := c.getCommandPolicy(ctx, cmd) + policy := c.cmdPolicyManager.getCmdPolicy(cmd) switch { case policy != nil && policy.Request == routing.ReqAllNodes: @@ -38,11 +38,13 @@ func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *cluste } // getCommandPolicy retrieves the routing policy for a command -func (c *ClusterClient) getCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { - if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.Tips != nil { - return cmdInfo.Tips - } - return nil +func (c *ClusterClient) getCommandPolicy(cmd Cmder) *routing.CommandPolicy { + + return c.cmdPolicyManager.getCmdPolicy(cmd) + // if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.Tips != nil { + // return cmdInfo.Tips + // } + // return nil } // executeDefault handles standard command routing based on keys From 39e578aae22fd15ebb5758ae08c15b8156093002 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Wed, 8 Oct 2025 10:59:27 +0300 Subject: [PATCH 20/62] added default policies for the search commands --- command_policy_manager.go | 113 ++++++++++++++++++++++++++++++++++++-- osscluster_router.go | 10 ---- 2 files changed, 109 insertions(+), 14 deletions(-) diff --git a/command_policy_manager.go b/command_policy_manager.go index c456ce244a..451b668f13 100644 --- a/command_policy_manager.go +++ b/command_policy_manager.go @@ -1,12 +1,114 @@ package redis import ( + "strings" "sync" "github.com/redis/go-redis/v9/internal/routing" ) -var defaultPolicies = map[string]*routing.CommandPolicy{} +var defaultPolicies = map[string]*routing.CommandPolicy{ + "ft.create": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.search": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.aggregate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.dictadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.dictdump": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.dictdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.suglen": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "ft.cursor": { + Request: routing.ReqSpecial, + Response: routing.RespDefaultKeyless, + }, + "ft.sugadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "ft.sugget": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "ft.sugdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "ft.spellcheck": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.explain": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.explaincli": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.aliasadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.aliasupdate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.aliasdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.info": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.tagvals": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.syndump": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.synupdate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.profile": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.alter": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.dropindex": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "ft.drop": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, +} type commandPolicyManager struct { rwmutex *sync.RWMutex @@ -15,10 +117,14 @@ type commandPolicyManager struct { } func newCommandPolicyManager(overwrites interface{}) *commandPolicyManager { - return &commandPolicyManager{} + // TODO: To be implemented in the next req-resp development stage + return &commandPolicyManager{ + rwmutex: &sync.RWMutex{}, + } } func (cpm *commandPolicyManager) updateClientPolicies(policies interface{}) { + // TODO: To be implemented in the next req-resp development stage cpm.rwmutex.Lock() defer cpm.rwmutex.Unlock() } @@ -27,8 +133,7 @@ func (cpm *commandPolicyManager) getCmdPolicy(cmd Cmder) *routing.CommandPolicy cpm.rwmutex.RLock() defer cpm.rwmutex.RUnlock() - cmdName := cmd.Name() - + cmdName := strings.ToLower(cmd.Name()) if policy, ok := cpm.overwrittenPolicies[cmdName]; ok { return policy } diff --git a/osscluster_router.go b/osscluster_router.go index 4f4acd4e1e..ceae4daa58 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -37,16 +37,6 @@ func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *cluste } } -// getCommandPolicy retrieves the routing policy for a command -func (c *ClusterClient) getCommandPolicy(cmd Cmder) *routing.CommandPolicy { - - return c.cmdPolicyManager.getCmdPolicy(cmd) - // if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.Tips != nil { - // return cmdInfo.Tips - // } - // return nil -} - // executeDefault handles standard command routing based on keys func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, node *clusterNode) error { if c.hasKeys(cmd) { From c4dd3d584ed38b579a3991f4d38fa37a5c39d92a Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Wed, 8 Oct 2025 14:50:50 +0300 Subject: [PATCH 21/62] split command map into module->command --- command_policy_manager.go | 231 +++++++++++++++++++++----------------- 1 file changed, 125 insertions(+), 106 deletions(-) diff --git a/command_policy_manager.go b/command_policy_manager.go index 451b668f13..0893566ede 100644 --- a/command_policy_manager.go +++ b/command_policy_manager.go @@ -7,113 +7,120 @@ import ( "github.com/redis/go-redis/v9/internal/routing" ) -var defaultPolicies = map[string]*routing.CommandPolicy{ - "ft.create": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.search": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.aggregate": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.dictadd": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.dictdump": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.dictdel": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.suglen": { - Request: routing.ReqDefault, - Response: routing.RespDefaultHashSlot, - }, - "ft.cursor": { - Request: routing.ReqSpecial, - Response: routing.RespDefaultKeyless, - }, - "ft.sugadd": { - Request: routing.ReqDefault, - Response: routing.RespDefaultHashSlot, - }, - "ft.sugget": { - Request: routing.ReqDefault, - Response: routing.RespDefaultHashSlot, - }, - "ft.sugdel": { - Request: routing.ReqDefault, - Response: routing.RespDefaultHashSlot, - }, - "ft.spellcheck": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.explain": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.explaincli": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.aliasadd": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.aliasupdate": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.aliasdel": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.info": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.tagvals": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.syndump": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.synupdate": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.profile": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.alter": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.dropindex": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "ft.drop": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, +type ( + module = string + commandName = string +) + +var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ + "ft": { + "create": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "search": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aggregate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dictadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dictdump": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dictdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "suglen": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "cursor": { + Request: routing.ReqSpecial, + Response: routing.RespDefaultKeyless, + }, + "sugadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "sugget": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "sugdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "spellcheck": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "explain": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "explaincli": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aliasadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aliasupdate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aliasdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "info": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "tagvals": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "syndump": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "synupdate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "profile": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "alter": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dropindex": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "drop": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, }, } type commandPolicyManager struct { rwmutex *sync.RWMutex - clientPolicies map[string]*routing.CommandPolicy - overwrittenPolicies map[string]*routing.CommandPolicy + clientPolicies map[module]map[commandName]*routing.CommandPolicy + overwrittenPolicies map[module]map[commandName]*routing.CommandPolicy } func newCommandPolicyManager(overwrites interface{}) *commandPolicyManager { @@ -134,17 +141,29 @@ func (cpm *commandPolicyManager) getCmdPolicy(cmd Cmder) *routing.CommandPolicy defer cpm.rwmutex.RUnlock() cmdName := strings.ToLower(cmd.Name()) - if policy, ok := cpm.overwrittenPolicies[cmdName]; ok { + + module := "code" + command := cmdName + cmdParts := strings.Split(cmdName, ".") + if len(cmdParts) == 2 { + module = cmdParts[0] + command = cmdParts[1] + } + + if policy, ok := cpm.overwrittenPolicies[module][command]; ok { return policy } - if policy, ok := cpm.clientPolicies[cmdName]; ok { + if policy, ok := cpm.clientPolicies[module][command]; ok { return policy } - if policy, ok := defaultPolicies[cmdName]; ok { + if policy, ok := defaultPolicies[module][command]; ok { return policy } - return nil + return &routing.CommandPolicy{ + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + } } From ded8eb225a62b5ef88f6bacdd9bb0f2a0df42262 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 9 Oct 2025 11:55:40 +0300 Subject: [PATCH 22/62] cleanup, added logic to refresh the cache --- command.go | 16 +++- command_policy_manager.go | 169 -------------------------------------- osscluster.go | 16 ++-- osscluster_router.go | 10 ++- osscluster_test.go | 2 + 5 files changed, 33 insertions(+), 180 deletions(-) delete mode 100644 command_policy_manager.go diff --git a/command.go b/command.go index 2ce3a11328..f74a3315c1 100644 --- a/command.go +++ b/command.go @@ -4409,8 +4409,9 @@ func (cmd *CommandsInfoCmd) Clone() Cmder { type cmdsInfoCache struct { fn func(ctx context.Context) (map[string]*CommandInfo, error) - once internal.Once - cmds map[string]*CommandInfo + once internal.Once + refreshLock sync.Mutex + cmds map[string]*CommandInfo } func newCmdsInfoCache(fn func(ctx context.Context) (map[string]*CommandInfo, error)) *cmdsInfoCache { @@ -4420,6 +4421,9 @@ func newCmdsInfoCache(fn func(ctx context.Context) (map[string]*CommandInfo, err } func (c *cmdsInfoCache) Get(ctx context.Context) (map[string]*CommandInfo, error) { + c.refreshLock.Lock() + defer c.refreshLock.Unlock() + err := c.once.Do(func() error { cmds, err := c.fn(ctx) if err != nil { @@ -4439,6 +4443,14 @@ func (c *cmdsInfoCache) Get(ctx context.Context) (map[string]*CommandInfo, error return c.cmds, err } +// TODO: Call it on client reconnect +func (c *cmdsInfoCache) Refresh() { + c.refreshLock.Lock() + defer c.refreshLock.Unlock() + + c.once = internal.Once{} +} + // ------------------------------------------------------------------------------ const requestPolicy = "request_policy" const responsePolicy = "response_policy" diff --git a/command_policy_manager.go b/command_policy_manager.go deleted file mode 100644 index 0893566ede..0000000000 --- a/command_policy_manager.go +++ /dev/null @@ -1,169 +0,0 @@ -package redis - -import ( - "strings" - "sync" - - "github.com/redis/go-redis/v9/internal/routing" -) - -type ( - module = string - commandName = string -) - -var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ - "ft": { - "create": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "search": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "aggregate": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "dictadd": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "dictdump": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "dictdel": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "suglen": { - Request: routing.ReqDefault, - Response: routing.RespDefaultHashSlot, - }, - "cursor": { - Request: routing.ReqSpecial, - Response: routing.RespDefaultKeyless, - }, - "sugadd": { - Request: routing.ReqDefault, - Response: routing.RespDefaultHashSlot, - }, - "sugget": { - Request: routing.ReqDefault, - Response: routing.RespDefaultHashSlot, - }, - "sugdel": { - Request: routing.ReqDefault, - Response: routing.RespDefaultHashSlot, - }, - "spellcheck": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "explain": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "explaincli": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "aliasadd": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "aliasupdate": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "aliasdel": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "info": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "tagvals": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "syndump": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "synupdate": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "profile": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "alter": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "dropindex": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - "drop": { - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - }, - }, -} - -type commandPolicyManager struct { - rwmutex *sync.RWMutex - clientPolicies map[module]map[commandName]*routing.CommandPolicy - overwrittenPolicies map[module]map[commandName]*routing.CommandPolicy -} - -func newCommandPolicyManager(overwrites interface{}) *commandPolicyManager { - // TODO: To be implemented in the next req-resp development stage - return &commandPolicyManager{ - rwmutex: &sync.RWMutex{}, - } -} - -func (cpm *commandPolicyManager) updateClientPolicies(policies interface{}) { - // TODO: To be implemented in the next req-resp development stage - cpm.rwmutex.Lock() - defer cpm.rwmutex.Unlock() -} - -func (cpm *commandPolicyManager) getCmdPolicy(cmd Cmder) *routing.CommandPolicy { - cpm.rwmutex.RLock() - defer cpm.rwmutex.RUnlock() - - cmdName := strings.ToLower(cmd.Name()) - - module := "code" - command := cmdName - cmdParts := strings.Split(cmdName, ".") - if len(cmdParts) == 2 { - module = cmdParts[0] - command = cmdParts[1] - } - - if policy, ok := cpm.overwrittenPolicies[module][command]; ok { - return policy - } - - if policy, ok := cpm.clientPolicies[module][command]; ok { - return policy - } - - if policy, ok := defaultPolicies[module][command]; ok { - return policy - } - - return &routing.CommandPolicy{ - Request: routing.ReqDefault, - Response: routing.RespDefaultKeyless, - } -} diff --git a/osscluster.go b/osscluster.go index 3ec91f2ad2..75c1400a07 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1014,11 +1014,10 @@ func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, er // or more underlying connections. It's safe for concurrent use by // multiple goroutines. type ClusterClient struct { - opt *ClusterOptions - nodes *clusterNodes - state *clusterStateHolder - cmdsInfoCache *cmdsInfoCache - cmdPolicyManager *commandPolicyManager + opt *ClusterOptions + nodes *clusterNodes + state *clusterStateHolder + cmdsInfoCache *cmdsInfoCache cmdable hooksMixin } @@ -1034,9 +1033,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { } c.state = newClusterStateHolder(c.loadState) + // TODO: execute on handshake, should be called again on reconnect c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdable = c.Process - c.cmdPolicyManager = newCommandPolicyManager(nil) c.initHooks(hooks{ dial: nil, process: c.process, @@ -1421,7 +1420,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { for _, cmd := range cmds { - policy := c.cmdPolicyManager.getCmdPolicy(cmd) + policy := c.getCommandPolicy(ctx, cmd) if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), @@ -1438,7 +1437,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd } for _, cmd := range cmds { - policy := c.cmdPolicyManager.getCmdPolicy(cmd) + policy := c.getCommandPolicy(ctx, cmd) if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), @@ -2077,6 +2076,7 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, return nil, firstErr } +// cmdInfo will fetch and cache the command policies after the first execution func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { // Use a separate context that won't be canceled to ensure command info lookup // doesn't fail due to original context cancellation diff --git a/osscluster_router.go b/osscluster_router.go index ceae4daa58..a4ddbbf4ab 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -21,7 +21,7 @@ type slotResult struct { // routeAndRun routes a command to the appropriate cluster nodes and executes it func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { - policy := c.cmdPolicyManager.getCmdPolicy(cmd) + policy := c.getCommandPolicy(ctx, cmd) switch { case policy != nil && policy.Request == routing.ReqAllNodes: @@ -37,6 +37,14 @@ func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *cluste } } +// getCommandPolicy retrieves the routing policy for a command +func (c *ClusterClient) getCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { + if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.Tips != nil { + return cmdInfo.Tips + } + return nil +} + // executeDefault handles standard command routing based on keys func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, node *clusterNode) error { if c.hasKeys(cmd) { diff --git a/osscluster_test.go b/osscluster_test.go index 16ab536bc4..b089cb73d0 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "reflect" "slices" "strconv" "strings" @@ -1602,6 +1603,7 @@ var _ = Describe("ClusterClient timeout", func() { return nil }) Expect(err).To(HaveOccurred()) + fmt.Println("qko greshki male", reflect.TypeOf(err).String(), reflect.TypeOf(err).Kind().String()) Expect(err.(net.Error).Timeout()).To(BeTrue()) }) From b2a0f1b9a0d0da7e20559140e884cf02af9a0ce5 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 9 Oct 2025 12:23:51 +0300 Subject: [PATCH 23/62] added reactive cache refresh --- command.go | 1 - osscluster.go | 21 +++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/command.go b/command.go index f74a3315c1..c8840bd328 100644 --- a/command.go +++ b/command.go @@ -4443,7 +4443,6 @@ func (c *cmdsInfoCache) Get(ctx context.Context) (map[string]*CommandInfo, error return c.cmds, err } -// TODO: Call it on client reconnect func (c *cmdsInfoCache) Refresh() { c.refreshLock.Lock() defer c.refreshLock.Unlock() diff --git a/osscluster.go b/osscluster.go index 75c1400a07..5e0850a11d 100644 --- a/osscluster.go +++ b/osscluster.go @@ -951,15 +951,19 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode { //------------------------------------------------------------------------------ type clusterStateHolder struct { - load func(ctx context.Context) (*clusterState, error) - - state atomic.Value - reloading uint32 // atomic + load func(ctx context.Context) (*clusterState, error) + commandsCacheRefresh func() + state atomic.Value + reloading uint32 // atomic } -func newClusterStateHolder(fn func(ctx context.Context) (*clusterState, error)) *clusterStateHolder { +func newClusterStateHolder( + load func(ctx context.Context) (*clusterState, error), + commandsCacheRefresh func(), +) *clusterStateHolder { return &clusterStateHolder{ - load: fn, + load: load, + commandsCacheRefresh: commandsCacheRefresh, } } @@ -968,6 +972,7 @@ func (c *clusterStateHolder) Reload(ctx context.Context) (*clusterState, error) if err != nil { return nil, err } + c.commandsCacheRefresh() c.state.Store(state) return state, nil } @@ -1032,9 +1037,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { nodes: newClusterNodes(opt), } - c.state = newClusterStateHolder(c.loadState) - // TODO: execute on handshake, should be called again on reconnect c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) + + c.state = newClusterStateHolder(c.loadState, c.cmdsInfoCache.Refresh) c.cmdable = c.Process c.initHooks(hooks{ dial: nil, From b7e0bf646a07ca41a1a01ebf56e2250295ac5d41 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 9 Oct 2025 12:42:58 +0300 Subject: [PATCH 24/62] revert cluster refresh --- osscluster.go | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/osscluster.go b/osscluster.go index 5e0850a11d..268cd7b8ae 100644 --- a/osscluster.go +++ b/osscluster.go @@ -957,13 +957,9 @@ type clusterStateHolder struct { reloading uint32 // atomic } -func newClusterStateHolder( - load func(ctx context.Context) (*clusterState, error), - commandsCacheRefresh func(), -) *clusterStateHolder { +func newClusterStateHolder(load func(ctx context.Context) (*clusterState, error)) *clusterStateHolder { return &clusterStateHolder{ - load: load, - commandsCacheRefresh: commandsCacheRefresh, + load: load, } } @@ -972,7 +968,6 @@ func (c *clusterStateHolder) Reload(ctx context.Context) (*clusterState, error) if err != nil { return nil, err } - c.commandsCacheRefresh() c.state.Store(state) return state, nil } @@ -1039,7 +1034,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) - c.state = newClusterStateHolder(c.loadState, c.cmdsInfoCache.Refresh) + c.state = newClusterStateHolder(c.loadState) c.cmdable = c.Process c.initHooks(hooks{ dial: nil, From 2f08a82259195c21f640efc9e617c928a195ccc0 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 9 Oct 2025 12:59:19 +0300 Subject: [PATCH 25/62] fixed lint --- osscluster.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/osscluster.go b/osscluster.go index 268cd7b8ae..d1fd059ea7 100644 --- a/osscluster.go +++ b/osscluster.go @@ -951,10 +951,9 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode { //------------------------------------------------------------------------------ type clusterStateHolder struct { - load func(ctx context.Context) (*clusterState, error) - commandsCacheRefresh func() - state atomic.Value - reloading uint32 // atomic + load func(ctx context.Context) (*clusterState, error) + state atomic.Value + reloading uint32 // atomic } func newClusterStateHolder(load func(ctx context.Context) (*clusterState, error)) *clusterStateHolder { From 5b0e0b18179ba7c704cd6aa04b63b20f84b24221 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Fri, 10 Oct 2025 14:14:13 +0300 Subject: [PATCH 26/62] addresed first batch of comments --- command.go | 244 ++++----------------------------- commands_test.go | 8 +- internal/routing/aggregator.go | 24 ++-- osscluster_router.go | 22 +-- osscluster_test.go | 64 ++++----- 5 files changed, 89 insertions(+), 273 deletions(-) diff --git a/command.go b/command.go index c8840bd328..7e2375a127 100644 --- a/command.go +++ b/command.go @@ -4193,15 +4193,15 @@ func (cmd *GeoPosCmd) Clone() Cmder { //------------------------------------------------------------------------------ type CommandInfo struct { - Name string - Arity int8 - Flags []string - ACLFlags []string - FirstKeyPos int8 - LastKeyPos int8 - StepCount int8 - ReadOnly bool - Tips *routing.CommandPolicy + Name string + Arity int8 + Flags []string + ACLFlags []string + FirstKeyPos int8 + LastKeyPos int8 + StepCount int8 + ReadOnly bool + CommandPolicy *routing.CommandPolicy } type CommandsInfoCmd struct { @@ -4355,7 +4355,7 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { } rawTips[k] = v } - cmdInfo.Tips = parseCommandPolicies(rawTips) + cmdInfo.CommandPolicy = parseCommandPolicies(rawTips) if err := rd.DiscardNext(); err != nil { return err @@ -4378,13 +4378,13 @@ func (cmd *CommandsInfoCmd) Clone() Cmder { for k, v := range cmd.val { if v != nil { newInfo := &CommandInfo{ - Name: v.Name, - Arity: v.Arity, - FirstKeyPos: v.FirstKeyPos, - LastKeyPos: v.LastKeyPos, - StepCount: v.StepCount, - ReadOnly: v.ReadOnly, - Tips: v.Tips, // CommandPolicy can be shared as it's immutable + Name: v.Name, + Arity: v.Arity, + FirstKeyPos: v.FirstKeyPos, + LastKeyPos: v.LastKeyPos, + StepCount: v.StepCount, + ReadOnly: v.ReadOnly, + CommandPolicy: v.CommandPolicy, // CommandPolicy can be shared as it's immutable } if v.Flags != nil { newInfo.Flags = make([]string, len(v.Flags)) @@ -6995,14 +6995,21 @@ func ExtractCommandValue(cmd interface{}) interface{} { if statusCmd, ok := cmd.(interface{ Val() string }); ok { return statusCmd.Val() } - case CmdTypeDuration: + case CmdTypeDuration, CmdTypeTime, CmdTypeStringStructMap, CmdTypeXMessageSlice, + CmdTypeXStreamSlice, CmdTypeXPending, CmdTypeXPendingExt, CmdTypeXAutoClaim, + CmdTypeXAutoClaimJustID, CmdTypeXInfoConsumers, CmdTypeXInfoGroups, CmdTypeXInfoStream, + CmdTypeXInfoStreamFull, CmdTypeZSlice, CmdTypeZWithKey, CmdTypeScan, CmdTypeClusterSlots, + CmdTypeGeoSearchLocation, CmdTypeGeoPos, CmdTypeCommandsInfo, CmdTypeSlowLog, + CmdTypeKeyValues, CmdTypeZSliceWithKey, CmdTypeFunctionList, CmdTypeFunctionStats, + CmdTypeLCS, CmdTypeKeyFlags, CmdTypeClusterLinks, CmdTypeClusterShards, + CmdTypeRankWithScore, CmdTypeClientInfo, CmdTypeACLLog, CmdTypeInfo, CmdTypeMonitor, + CmdTypeJSON, CmdTypeJSONSlice, CmdTypeIntPointerSlice, CmdTypeScanDump, CmdTypeBFInfo, + CmdTypeCFInfo, CmdTypeCMSInfo, CmdTypeTopKInfo, CmdTypeTDigestInfo, CmdTypeFTSearch, + CmdTypeFTInfo, CmdTypeFTSpellCheck, CmdTypeFTSynDump, CmdTypeAggregate, + CmdTypeTSTimestampValue, CmdTypeTSTimestampValueSlice: if durationCmd, ok := cmd.(interface{ Val() interface{} }); ok { return durationCmd.Val() } - case CmdTypeTime: - if timeCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return timeCmd.Val() - } case CmdTypeStringSlice: if stringSliceCmd, ok := cmd.(interface{ Val() []string }); ok { return stringSliceCmd.Val() @@ -7047,199 +7054,6 @@ func ExtractCommandValue(cmd interface{}) interface{} { }); ok { return mapCmd.Val() } - case CmdTypeStringStructMap: - if mapCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return mapCmd.Val() - } - case CmdTypeXMessageSlice: - if xMsgCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xMsgCmd.Val() - } - case CmdTypeXStreamSlice: - if xStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xStreamCmd.Val() - } - case CmdTypeXPending: - if xPendingCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xPendingCmd.Val() - } - case CmdTypeXPendingExt: - if xPendingExtCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xPendingExtCmd.Val() - } - case CmdTypeXAutoClaim: - if xAutoClaimCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xAutoClaimCmd.Val() - } - case CmdTypeXAutoClaimJustID: - if xAutoClaimJustIDCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xAutoClaimJustIDCmd.Val() - } - case CmdTypeXInfoConsumers: - if xInfoConsumersCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xInfoConsumersCmd.Val() - } - case CmdTypeXInfoGroups: - if xInfoGroupsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xInfoGroupsCmd.Val() - } - case CmdTypeXInfoStream: - if xInfoStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xInfoStreamCmd.Val() - } - case CmdTypeXInfoStreamFull: - if xInfoStreamFullCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return xInfoStreamFullCmd.Val() - } - case CmdTypeZSlice: - if zSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return zSliceCmd.Val() - } - case CmdTypeZWithKey: - if zWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return zWithKeyCmd.Val() - } - case CmdTypeScan: - if scanCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return scanCmd.Val() - } - case CmdTypeClusterSlots: - if clusterSlotsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return clusterSlotsCmd.Val() - } - case CmdTypeGeoSearchLocation: - if geoSearchLocationCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return geoSearchLocationCmd.Val() - } - case CmdTypeGeoPos: - if geoPosCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return geoPosCmd.Val() - } - case CmdTypeCommandsInfo: - if commandsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return commandsInfoCmd.Val() - } - case CmdTypeSlowLog: - if slowLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return slowLogCmd.Val() - } - - case CmdTypeKeyValues: - if keyValuesCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return keyValuesCmd.Val() - } - case CmdTypeZSliceWithKey: - if zSliceWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return zSliceWithKeyCmd.Val() - } - case CmdTypeFunctionList: - if functionListCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return functionListCmd.Val() - } - case CmdTypeFunctionStats: - if functionStatsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return functionStatsCmd.Val() - } - case CmdTypeLCS: - if lcsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return lcsCmd.Val() - } - case CmdTypeKeyFlags: - if keyFlagsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return keyFlagsCmd.Val() - } - case CmdTypeClusterLinks: - if clusterLinksCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return clusterLinksCmd.Val() - } - case CmdTypeClusterShards: - if clusterShardsCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return clusterShardsCmd.Val() - } - case CmdTypeRankWithScore: - if rankWithScoreCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return rankWithScoreCmd.Val() - } - case CmdTypeClientInfo: - if clientInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return clientInfoCmd.Val() - } - case CmdTypeACLLog: - if aclLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return aclLogCmd.Val() - } - case CmdTypeInfo: - if infoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return infoCmd.Val() - } - case CmdTypeMonitor: - if monitorCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return monitorCmd.Val() - } - case CmdTypeJSON: - if jsonCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return jsonCmd.Val() - } - case CmdTypeJSONSlice: - if jsonSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return jsonSliceCmd.Val() - } - case CmdTypeIntPointerSlice: - if intPointerSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return intPointerSliceCmd.Val() - } - case CmdTypeScanDump: - if scanDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return scanDumpCmd.Val() - } - case CmdTypeBFInfo: - if bfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return bfInfoCmd.Val() - } - case CmdTypeCFInfo: - if cfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return cfInfoCmd.Val() - } - case CmdTypeCMSInfo: - if cmsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return cmsInfoCmd.Val() - } - case CmdTypeTopKInfo: - if topKInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return topKInfoCmd.Val() - } - case CmdTypeTDigestInfo: - if tDigestInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return tDigestInfoCmd.Val() - } - case CmdTypeFTSearch: - if ftSearchCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return ftSearchCmd.Val() - } - case CmdTypeFTInfo: - if ftInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return ftInfoCmd.Val() - } - case CmdTypeFTSpellCheck: - if ftSpellCheckCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return ftSpellCheckCmd.Val() - } - case CmdTypeFTSynDump: - if ftSynDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return ftSynDumpCmd.Val() - } - case CmdTypeAggregate: - if aggregateCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return aggregateCmd.Val() - } - case CmdTypeTSTimestampValue: - if tsTimestampValueCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return tsTimestampValueCmd.Val() - } - case CmdTypeTSTimestampValueSlice: - if tsTimestampValueSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return tsTimestampValueSliceCmd.Val() - } default: // For unknown command types, return nil return nil diff --git a/commands_test.go b/commands_test.go index 476a08a302..66bade4ead 100644 --- a/commands_test.go +++ b/commands_test.go @@ -665,13 +665,13 @@ var _ = Describe("Commands", func() { cmd := cmds["touch"] Expect(cmd.Name).To(Equal("touch")) - Expect(cmd.Tips.Request).To(Equal(routing.ReqMultiShard)) - Expect(cmd.Tips.Response).To(Equal(routing.RespAggSum)) + Expect(cmd.CommandPolicy.Request).To(Equal(routing.ReqMultiShard)) + Expect(cmd.CommandPolicy.Response).To(Equal(routing.RespAggSum)) cmd = cmds["flushall"] Expect(cmd.Name).To(Equal("flushall")) - Expect(cmd.Tips.Request).To(Equal(routing.ReqAllShards)) - Expect(cmd.Tips.Response).To(Equal(routing.RespAllSucceeded)) + Expect(cmd.CommandPolicy.Request).To(Equal(routing.ReqAllShards)) + Expect(cmd.CommandPolicy.Response).To(Equal(routing.RespAllSucceeded)) }) It("should return all command names", func() { diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 962e592647..fd69153799 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -14,8 +14,8 @@ type ResponseAggregator interface { // AddWithKey processes a single shard response for a specific key (used by keyed aggregators). AddWithKey(key string, result interface{}, err error) error - // Finish returns the final aggregated result and any error. - Finish() (interface{}, error) + // Result returns the final aggregated result and any error. + Result() (interface{}, error) } // NewResponseAggregator creates an aggregator based on the response policy. @@ -83,7 +83,7 @@ func (a *AllSucceededAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *AllSucceededAggregator) Finish() (interface{}, error) { +func (a *AllSucceededAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -121,7 +121,7 @@ func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *OneSucceededAggregator) Finish() (interface{}, error) { +func (a *OneSucceededAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -165,7 +165,7 @@ func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } -func (a *AggSumAggregator) Finish() (interface{}, error) { +func (a *AggSumAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -211,7 +211,7 @@ func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } -func (a *AggMinAggregator) Finish() (interface{}, error) { +func (a *AggMinAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -260,7 +260,7 @@ func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } -func (a *AggMaxAggregator) Finish() (interface{}, error) { +func (a *AggMaxAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -311,7 +311,7 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *AggLogicalAndAggregator) Finish() (interface{}, error) { +func (a *AggLogicalAndAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -362,7 +362,7 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *AggLogicalOrAggregator) Finish() (interface{}, error) { +func (a *AggLogicalOrAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -437,7 +437,7 @@ func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, er return a.Add(result, err) } -func (a *DefaultKeylessAggregator) Finish() (interface{}, error) { +func (a *DefaultKeylessAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -497,7 +497,7 @@ func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) { a.keyOrder = keyOrder } -func (a *DefaultKeyedAggregator) Finish() (interface{}, error) { +func (a *DefaultKeyedAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -545,7 +545,7 @@ func (a *SpecialAggregator) AddWithKey(key string, result interface{}, err error return a.Add(result, err) } -func (a *SpecialAggregator) Finish() (interface{}, error) { +func (a *SpecialAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() diff --git a/osscluster_router.go b/osscluster_router.go index a4ddbbf4ab..669bfea66f 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -22,15 +22,17 @@ type slotResult struct { // routeAndRun routes a command to the appropriate cluster nodes and executes it func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { policy := c.getCommandPolicy(ctx, cmd) - - switch { - case policy != nil && policy.Request == routing.ReqAllNodes: + if policy == nil { + return c.executeDefault(ctx, cmd, node) + } + switch policy.Request { + case routing.ReqAllNodes: return c.executeOnAllNodes(ctx, cmd, policy) - case policy != nil && policy.Request == routing.ReqAllShards: + case routing.ReqAllShards: return c.executeOnAllShards(ctx, cmd, policy) - case policy != nil && policy.Request == routing.ReqMultiShard: + case routing.ReqMultiShard: return c.executeMultiShard(ctx, cmd, policy) - case policy != nil && policy.Request == routing.ReqSpecial: + case routing.ReqSpecial: return c.executeSpecialCommand(ctx, cmd, policy, node) default: return c.executeDefault(ctx, cmd, node) @@ -39,8 +41,8 @@ func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *cluste // getCommandPolicy retrieves the routing policy for a command func (c *ClusterClient) getCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { - if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.Tips != nil { - return cmdInfo.Tips + if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.CommandPolicy != nil { + return cmdInfo.CommandPolicy } return nil } @@ -243,7 +245,7 @@ func createCommandByType(ctx context.Context, cmdType CmdType, args ...interface case CmdTypeKeyFlags: return NewKeyFlagsCmd(ctx, args...) case CmdTypeDuration: - return NewDurationCmd(ctx, time.Second, args...) + return NewDurationCmd(ctx, time.Millisecond, args...) } return NewCmd(ctx, args...) } @@ -462,7 +464,7 @@ func (c *ClusterClient) createAggregator(policy *routing.CommandPolicy, cmd Cmde // finishAggregation completes the aggregation process and sets the result func (c *ClusterClient) finishAggregation(cmd Cmder, aggregator routing.ResponseAggregator) error { - finalValue, finalErr := aggregator.Finish() + finalValue, finalErr := aggregator.Result() if finalErr != nil { cmd.SetErr(finalErr) return finalErr diff --git a/osscluster_test.go b/osscluster_test.go index b089cb73d0..fc2a3be429 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -1708,14 +1708,14 @@ var _ = Describe("Command Tips tests", func() { for cmdName, expected := range expectedPolicies { actualCmd := cmds[cmdName] - Expect(actualCmd.Tips).NotTo(BeNil()) + Expect(actualCmd.CommandPolicy).NotTo(BeNil()) // Verify request_policy from COMMAND matches router policy - actualRequestPolicy := actualCmd.Tips.Request.String() + actualRequestPolicy := actualCmd.CommandPolicy.Request.String() Expect(actualRequestPolicy).To(Equal(expected.RequestPolicy)) // Verify response_policy from COMMAND matches router policy - actualResponsePolicy := actualCmd.Tips.Response.String() + actualResponsePolicy := actualCmd.CommandPolicy.Response.String() Expect(actualResponsePolicy).To(Equal(expected.ResponsePolicy)) } }) @@ -1730,9 +1730,9 @@ var _ = Describe("Command Tips tests", func() { touchCmd := cmds["touch"] - Expect(touchCmd.Tips).NotTo(BeNil()) - Expect(touchCmd.Tips.Request.String()).To(Equal("multi_shard")) - Expect(touchCmd.Tips.Response.String()).To(Equal("agg_sum")) + Expect(touchCmd.CommandPolicy).NotTo(BeNil()) + Expect(touchCmd.CommandPolicy.Request.String()).To(Equal("multi_shard")) + Expect(touchCmd.CommandPolicy.Response.String()).To(Equal("agg_sum")) keys := []string{"key1", "key2", "key3", "key4", "key5"} for _, key := range keys { @@ -1754,9 +1754,9 @@ var _ = Describe("Command Tips tests", func() { flushallCmd := cmds["flushall"] - Expect(flushallCmd.Tips).NotTo(BeNil()) - Expect(flushallCmd.Tips.Request.String()).To(Equal("all_shards")) - Expect(flushallCmd.Tips.Response.String()).To(Equal("all_succeeded")) + Expect(flushallCmd.CommandPolicy).NotTo(BeNil()) + Expect(flushallCmd.CommandPolicy.Request.String()).To(Equal("all_shards")) + Expect(flushallCmd.CommandPolicy.Response.String()).To(Equal("all_succeeded")) testKeys := []string{"test1", "test2", "test3"} for _, key := range testKeys { @@ -1781,9 +1781,9 @@ var _ = Describe("Command Tips tests", func() { Expect(err).NotTo(HaveOccurred()) pingCmd := cmds["ping"] - Expect(pingCmd.Tips).NotTo(BeNil()) - Expect(pingCmd.Tips.Request.String()).To(Equal("all_shards")) - Expect(pingCmd.Tips.Response.String()).To(Equal("all_succeeded")) + Expect(pingCmd.CommandPolicy).NotTo(BeNil()) + Expect(pingCmd.CommandPolicy.Request.String()).To(Equal("all_shards")) + Expect(pingCmd.CommandPolicy.Response.String()).To(Equal("all_succeeded")) result := client.Ping(ctx) Expect(result.Err()).NotTo(HaveOccurred()) @@ -1798,9 +1798,9 @@ var _ = Describe("Command Tips tests", func() { Expect(err).NotTo(HaveOccurred()) dbsizeCmd := cmds["dbsize"] - Expect(dbsizeCmd.Tips).NotTo(BeNil()) - Expect(dbsizeCmd.Tips.Request.String()).To(Equal("all_shards")) - Expect(dbsizeCmd.Tips.Response.String()).To(Equal("agg_sum")) + Expect(dbsizeCmd.CommandPolicy).NotTo(BeNil()) + Expect(dbsizeCmd.CommandPolicy.Request.String()).To(Equal("all_shards")) + Expect(dbsizeCmd.CommandPolicy.Response.String()).To(Equal("agg_sum")) testKeys := []string{"dbsize_test1", "dbsize_test2", "dbsize_test3"} for _, key := range testKeys { @@ -1830,13 +1830,13 @@ var _ = Describe("Command Tips tests", func() { Expect(err).NotTo(HaveOccurred()) ftCreateCmd, exists := cmds["ft.create"] - if !exists || ftCreateCmd.Tips == nil { + if !exists || ftCreateCmd.CommandPolicy == nil { Skip("FT.CREATE command or tips not available") } // DDL commands should NOT be broadcasted - they should go to coordinator only - Expect(ftCreateCmd.Tips).NotTo(BeNil()) - requestPolicy := ftCreateCmd.Tips.Request.String() + Expect(ftCreateCmd.CommandPolicy).NotTo(BeNil()) + requestPolicy := ftCreateCmd.CommandPolicy.Request.String() Expect(requestPolicy).NotTo(Equal("all_shards")) Expect(requestPolicy).NotTo(Equal("all_nodes")) @@ -1869,12 +1869,12 @@ var _ = Describe("Command Tips tests", func() { Expect(err).NotTo(HaveOccurred()) ftAlterCmd, exists := cmds["ft.alter"] - if !exists || ftAlterCmd.Tips == nil { + if !exists || ftAlterCmd.CommandPolicy == nil { Skip("FT.ALTER command or tips not available") } - Expect(ftAlterCmd.Tips).NotTo(BeNil()) - requestPolicy := ftAlterCmd.Tips.Request.String() + Expect(ftAlterCmd.CommandPolicy).NotTo(BeNil()) + requestPolicy := ftAlterCmd.CommandPolicy.Request.String() Expect(requestPolicy).NotTo(Equal("all_shards")) Expect(requestPolicy).NotTo(Equal("all_nodes")) @@ -2003,11 +2003,11 @@ var _ = Describe("Command Tips tests", func() { Expect(err).NotTo(HaveOccurred()) touchCmd, exists := cmds["touch"] - if !exists || touchCmd.Tips == nil { + if !exists || touchCmd.CommandPolicy == nil { Skip("TOUCH command or tips not available") } - Expect(touchCmd.Tips.Response.String()).To(Equal("agg_sum")) + Expect(touchCmd.CommandPolicy.Response.String()).To(Equal("agg_sum")) testKeys := []string{ "touch_test_key_1111", // These keys should map to different hash slots @@ -2053,11 +2053,11 @@ var _ = Describe("Command Tips tests", func() { // FLUSHALL command with all_succeeded aggregation policy flushallCmd, exists := cmds["flushall"] - if !exists || flushallCmd.Tips == nil { + if !exists || flushallCmd.CommandPolicy == nil { Skip("FLUSHALL command or tips not available") } - Expect(flushallCmd.Tips.Response.String()).To(Equal("all_succeeded")) + Expect(flushallCmd.CommandPolicy.Response.String()).To(Equal("all_succeeded")) for i := 0; i < len(masterNodes); i++ { testKey := fmt.Sprintf("flush_test_key_%d_%d", i, time.Now().UnixNano()) @@ -2077,11 +2077,11 @@ var _ = Describe("Command Tips tests", func() { // WAIT command aggregation policy - verify agg_min policy waitCmd, exists := cmds["wait"] - if !exists || waitCmd.Tips == nil { + if !exists || waitCmd.CommandPolicy == nil { Skip("WAIT command or tips not available") } - Expect(waitCmd.Tips.Response.String()).To(Equal("agg_min")) + Expect(waitCmd.CommandPolicy.Response.String()).To(Equal("agg_min")) // Set up some data to replicate testKey := "wait_test_key_1111" @@ -2101,11 +2101,11 @@ var _ = Describe("Command Tips tests", func() { // SCRIPT EXISTS command aggregation policy - verify agg_logical_and policy scriptExistsCmd, exists := cmds["script exists"] - if !exists || scriptExistsCmd.Tips == nil { + if !exists || scriptExistsCmd.CommandPolicy == nil { Skip("SCRIPT EXISTS command or tips not available") } - Expect(scriptExistsCmd.Tips.Response.String()).To(Equal("agg_logical_and")) + Expect(scriptExistsCmd.CommandPolicy.Response.String()).To(Equal("agg_logical_and")) // Load a script on all shards testScript := "return 'hello'" @@ -2157,11 +2157,11 @@ var _ = Describe("Command Tips tests", func() { continue } - if cmd.Tips == nil { + if cmd.CommandPolicy == nil { continue } - actualPolicy := cmd.Tips.Response.String() + actualPolicy := cmd.CommandPolicy.Response.String() Expect(actualPolicy).To(Equal(expectedPolicy)) } }) @@ -2194,7 +2194,7 @@ var _ = Describe("Command Tips tests", func() { Expect(err).NotTo(HaveOccurred()) pingCmd, exists := cmds["ping"] - if exists && pingCmd.Tips != nil { + if exists && pingCmd.CommandPolicy != nil { } pingResult := client.Ping(ctx) From 4e25c3e5915a9607d71a0e338bfcbb2f8beb2b3f Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Sun, 12 Oct 2025 23:36:22 +0300 Subject: [PATCH 27/62] rewrote aggregator implementations with atomic for native or nearnative primitives --- internal/routing/aggregator.go | 323 +++++++++++++++------------------ internal/util/atomic_max.go | 96 ++++++++++ internal/util/atomic_min.go | 95 ++++++++++ 3 files changed, 337 insertions(+), 177 deletions(-) create mode 100644 internal/util/atomic_max.go create mode 100644 internal/util/atomic_min.go diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index fd69153799..b7d8260f92 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -4,6 +4,9 @@ import ( "fmt" "math" "sync" + "sync/atomic" + + "github.com/redis/go-redis/v9/internal/util" ) // ResponseAggregator defines the interface for aggregating responses from multiple shards. @@ -32,11 +35,18 @@ func NewResponseAggregator(policy ResponsePolicy, cmdName string) ResponseAggreg case RespAggSum: return &AggSumAggregator{} case RespAggMin: - return &AggMinAggregator{} + return &AggMinAggregator{ + res: util.NewAtomicMin(), + } case RespAggMax: - return &AggMaxAggregator{} + return &AggMaxAggregator{ + res: util.NewAtomicMax(), + } case RespAggLogicalAnd: - return &AggLogicalAndAggregator{} + andAgg := &AggLogicalAndAggregator{} + andAgg.res.Add(1) + + return andAgg case RespAggLogicalOr: return &AggLogicalOrAggregator{} case RespSpecial: @@ -58,62 +68,54 @@ func NewDefaultAggregator(isKeyed bool) ResponseAggregator { // AllSucceededAggregator returns one non-error reply if every shard succeeded, // propagates the first error otherwise. type AllSucceededAggregator struct { - mu sync.Mutex - result interface{} - firstErr error - hasResult bool + err atomic.Value + res atomic.Value } func (a *AllSucceededAggregator) Add(result interface{}, err error) error { - a.mu.Lock() - defer a.mu.Unlock() - - if err != nil && a.firstErr == nil { - a.firstErr = err + if err != nil { + a.err.CompareAndSwap(nil, err) return nil } - if err == nil && !a.hasResult { - a.result = result - a.hasResult = true + + if result != nil { + a.res.CompareAndSwap(nil, result) } - return nil -} -func (a *AllSucceededAggregator) AddWithKey(key string, result interface{}, err error) error { - return a.Add(result, err) + return nil } func (a *AllSucceededAggregator) Result() (interface{}, error) { - a.mu.Lock() - defer a.mu.Unlock() - - if a.firstErr != nil { - return nil, a.firstErr + var err error + res, e := a.res.Load(), a.err.Load() + if e != nil { + err = e.(error) } - return a.result, nil + + return res, err +} + +func (a *AllSucceededAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) } // OneSucceededAggregator returns the first non-error reply, // if all shards errored, returns any one of those errors. type OneSucceededAggregator struct { - mu sync.Mutex - result interface{} - firstErr error - hasResult bool + err atomic.Value + res atomic.Value } func (a *OneSucceededAggregator) Add(result interface{}, err error) error { - a.mu.Lock() - defer a.mu.Unlock() - - if err != nil && a.firstErr == nil { - a.firstErr = err + if err != nil { + a.err.CompareAndSwap(nil, err) return nil } - if err == nil && !a.hasResult { - a.result = result - a.hasResult = true + + if result != nil { + a.res.CompareAndSwap(nil, result) } + return nil } @@ -122,42 +124,33 @@ func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err } func (a *OneSucceededAggregator) Result() (interface{}, error) { - a.mu.Lock() - defer a.mu.Unlock() - - if a.hasResult { - return a.result, nil + res, e := a.res.Load(), a.err.Load() + if res != nil { + return nil, e.(error) } - return nil, a.firstErr + + return res, nil } // AggSumAggregator sums numeric replies from all shards. type AggSumAggregator struct { - mu sync.Mutex - sum int64 - hasResult bool - firstErr error + err atomic.Value + res *int64 } func (a *AggSumAggregator) Add(result interface{}, err error) error { - a.mu.Lock() - defer a.mu.Unlock() - - if err != nil && a.firstErr == nil { - a.firstErr = err - return nil + if err != nil { + a.err.CompareAndSwap(nil, err) } - if err == nil { + + if result != nil { val, err := toInt64(result) - if err != nil && a.firstErr == nil { - a.firstErr = err - return nil - } - if err == nil { - a.sum += val - a.hasResult = true + if err != nil { + return err } + atomic.AddInt64(a.res, val) } + return nil } @@ -166,17 +159,19 @@ func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) } func (a *AggSumAggregator) Result() (interface{}, error) { - a.mu.Lock() - defer a.mu.Unlock() - - if a.firstErr != nil { - return nil, a.firstErr + res, err := atomic.LoadInt64(a.res), a.err.Load() + if err != nil { + return nil, err.(error) } - return a.sum, nil + + return res, nil } // AggMinAggregator returns the minimum numeric value from all shards. type AggMinAggregator struct { + err atomic.Value + res *util.AtomicMin + mu sync.Mutex min int64 hasResult bool @@ -184,26 +179,19 @@ type AggMinAggregator struct { } func (a *AggMinAggregator) Add(result interface{}, err error) error { - a.mu.Lock() - defer a.mu.Unlock() - - if err != nil && a.firstErr == nil { - a.firstErr = err + if err != nil { + a.err.CompareAndSwap(nil, err) return nil } - if err == nil { - val, err := toInt64(result) - if err != nil && a.firstErr == nil { - a.firstErr = err - return nil - } - if err == nil { - if !a.hasResult || val < a.min { - a.min = val - a.hasResult = true - } - } + + intVal, e := toInt64(result) + if e != nil { + a.err.CompareAndSwap(nil, err) + return nil } + + a.res.Value(intVal) + return nil } @@ -212,47 +200,38 @@ func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) } func (a *AggMinAggregator) Result() (interface{}, error) { - a.mu.Lock() - defer a.mu.Unlock() - - if a.firstErr != nil { - return nil, a.firstErr + err := a.err.Load() + if err != nil { + return nil, err.(error) } - if !a.hasResult { + + val, hasVal := a.res.Min() + if !hasVal { return nil, fmt.Errorf("redis: no valid results to aggregate for min operation") } - return a.min, nil + return val, nil } // AggMaxAggregator returns the maximum numeric value from all shards. type AggMaxAggregator struct { - mu sync.Mutex - max int64 - hasResult bool - firstErr error + err atomic.Value + res *util.AtomicMax } func (a *AggMaxAggregator) Add(result interface{}, err error) error { - a.mu.Lock() - defer a.mu.Unlock() - - if err != nil && a.firstErr == nil { - a.firstErr = err + if err != nil { + a.err.CompareAndSwap(nil, err) return nil } - if err == nil { - val, err := toInt64(result) - if err != nil && a.firstErr == nil { - a.firstErr = err - return nil - } - if err == nil { - if !a.hasResult || val > a.max { - a.max = val - a.hasResult = true - } - } + + intVal, e := toInt64(result) + if e != nil { + a.err.CompareAndSwap(nil, err) + return nil } + + a.res.Value(intVal) + return nil } @@ -261,49 +240,45 @@ func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) } func (a *AggMaxAggregator) Result() (interface{}, error) { - a.mu.Lock() - defer a.mu.Unlock() - - if a.firstErr != nil { - return nil, a.firstErr + err := a.err.Load() + if err != nil { + return nil, err.(error) } - if !a.hasResult { + + val, hasVal := a.res.Max() + if !hasVal { return nil, fmt.Errorf("redis: no valid results to aggregate for max operation") } - return a.max, nil + return val, nil } // AggLogicalAndAggregator performs logical AND on boolean values. type AggLogicalAndAggregator struct { - mu sync.Mutex - result bool - hasResult bool - firstErr error + err atomic.Value + res atomic.Int64 + hasResult atomic.Bool } func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { - a.mu.Lock() - defer a.mu.Unlock() + if err != nil { + a.err.CompareAndSwap(nil, err) + return nil + } - if err != nil && a.firstErr == nil { - a.firstErr = err + val, e := toBool(result) + if e != nil { + a.err.CompareAndSwap(nil, e) return nil } - if err == nil { - val, err := toBool(result) - if err != nil && a.firstErr == nil { - a.firstErr = err - return nil - } - if err == nil { - if !a.hasResult { - a.result = val - a.hasResult = true - } else { - a.result = a.result && val - } - } + + if val { + a.res.And(1) + } else { + a.res.And(0) } + + a.hasResult.Store(true) + return nil } @@ -312,49 +287,44 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err } func (a *AggLogicalAndAggregator) Result() (interface{}, error) { - a.mu.Lock() - defer a.mu.Unlock() - - if a.firstErr != nil { - return nil, a.firstErr + err := a.err.Load() + if err != nil { + return nil, err.(error) } - if !a.hasResult { + + if !a.hasResult.Load() { return nil, fmt.Errorf("redis: no valid results to aggregate for logical AND operation") } - return a.result, nil + return a.res.Load() != 0, nil } // AggLogicalOrAggregator performs logical OR on boolean values. type AggLogicalOrAggregator struct { - mu sync.Mutex - result bool - hasResult bool - firstErr error + err atomic.Value + res atomic.Int64 + hasResult atomic.Bool } func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { - a.mu.Lock() - defer a.mu.Unlock() + if err != nil { + a.err.CompareAndSwap(nil, err) + return nil + } - if err != nil && a.firstErr == nil { - a.firstErr = err + val, e := toBool(result) + if e != nil { + a.err.CompareAndSwap(nil, e) return nil } - if err == nil { - val, err := toBool(result) - if err != nil && a.firstErr == nil { - a.firstErr = err - return nil - } - if err == nil { - if !a.hasResult { - a.result = val - a.hasResult = true - } else { - a.result = a.result || val - } - } + + if val { + a.res.Or(1) + } else { + a.res.Or(0) } + + a.hasResult.Store(true) + return nil } @@ -363,16 +333,15 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err } func (a *AggLogicalOrAggregator) Result() (interface{}, error) { - a.mu.Lock() - defer a.mu.Unlock() - - if a.firstErr != nil { - return nil, a.firstErr + err := a.err.Load() + if err != nil { + return nil, err.(error) } - if !a.hasResult { + + if !a.hasResult.Load() { return nil, fmt.Errorf("redis: no valid results to aggregate for logical OR operation") } - return a.result, nil + return a.res.Load() != 0, nil } func toInt64(val interface{}) (int64, error) { diff --git a/internal/util/atomic_max.go b/internal/util/atomic_max.go new file mode 100644 index 0000000000..ccee0e4c88 --- /dev/null +++ b/internal/util/atomic_max.go @@ -0,0 +1,96 @@ +/* +© 2023–present Harald Rudell (https://haraldrudell.github.io/haraldrudell/) +ISC License + +Modified by htemelski-redis +Removed the treshold, adapted it to work with int64 +*/ + +package util + +import ( + "math" + "sync/atomic" +) + +// AtomicMax is a thread-safe max container +// - hasValue indicator true if a value was equal to or greater than threshold +// - optional threshold for minimum accepted max value +// - if threshold is not used, initialization-free +// - — +// - wait-free CompareAndSwap mechanic +type AtomicMax struct { + + // value is current max + value atomic.Int64 + // whether [AtomicMax.Value] has been invoked + // with value equal or greater to threshold + hasValue atomic.Bool +} + +// NewAtomicMax returns a thread-safe max container +// - if threshold is not used, AtomicMax is initialization-free +func NewAtomicMax() (atomicMax *AtomicMax) { + m := AtomicMax{} + m.value.Store(math.MinInt64) + return &m +} + +// Value updates the container with a possible max value +// - isNewMax is true if: +// - — value is equal to or greater than any threshold and +// - — invocation recorded the first 0 or +// - — a new max +// - upon return, Max and Max1 are guaranteed to reflect the invocation +// - the return order of concurrent Value invocations is not guaranteed +// - Thread-safe +func (m *AtomicMax) Value(value int64) (isNewMax bool) { + // math.MinInt64 as max case + var hasValue0 = m.hasValue.Load() + if value == math.MinInt64 { + if !hasValue0 { + isNewMax = m.hasValue.CompareAndSwap(false, true) + } + return // math.MinInt64 as max: isNewMax true for first 0 writer + } + + // check against present value + var current = m.value.Load() + if isNewMax = value > current; !isNewMax { + return // not a new max return: isNewMax false + } + + // store the new max + for { + + // try to write value to *max + if isNewMax = m.value.CompareAndSwap(current, value); isNewMax { + if !hasValue0 { + // may be rarely written multiple times + // still faster than CompareAndSwap + m.hasValue.Store(true) + } + return // new max written return: isNewMax true + } + if current = m.value.Load(); current >= value { + return // no longer a need to write return: isNewMax false + } + } +} + +// Max returns current max and value-present flag +// - hasValue true indicates that value reflects a Value invocation +// - hasValue false: value is zero-value +// - Thread-safe +func (m *AtomicMax) Max() (value int64, hasValue bool) { + if hasValue = m.hasValue.Load(); !hasValue { + return + } + value = m.value.Load() + return +} + +// Max1 returns current maximum whether zero-value or set by Value +// - threshold is ignored +// - Thread-safe +func (m *AtomicMax) Max1() (value int64) { return m.value.Load() } diff --git a/internal/util/atomic_min.go b/internal/util/atomic_min.go new file mode 100644 index 0000000000..962d2a8070 --- /dev/null +++ b/internal/util/atomic_min.go @@ -0,0 +1,95 @@ +package util + +/* +© 2023–present Harald Rudell (https://haraldrudell.github.io/haraldrudell/) +ISC License + +Modified by htemelski-redis +Adapted from the modified atomic_max, but with inverted logic +*/ + +import ( + "math" + "sync/atomic" +) + +// AtomicMin is a thread-safe Min container +// - hasValue indicator true if a value was equal to or greater than threshold +// - optional threshold for minimum accepted Min value +// - — +// - wait-free CompareAndSwap mechanic +type AtomicMin struct { + + // value is current Min + value atomic.Int64 + // whether [AtomicMin.Value] has been invoked + // with value equal or greater to threshold + hasValue atomic.Bool +} + +// NewAtomicMin returns a thread-safe Min container +// - if threshold is not used, AtomicMin is initialization-free +func NewAtomicMin() (atomicMin *AtomicMin) { + m := AtomicMin{} + m.value.Store(math.MaxInt64) + return &m +} + +// Value updates the container with a possible Min value +// - isNewMin is true if: +// - — value is equal to or greater than any threshold and +// - — invocation recorded the first 0 or +// - — a new Min +// - upon return, Min and Min1 are guaranteed to reflect the invocation +// - the return order of concurrent Value invocations is not guaranteed +// - Thread-safe +func (m *AtomicMin) Value(value int64) (isNewMin bool) { + // math.MaxInt64 as Min case + var hasValue0 = m.hasValue.Load() + if value == math.MaxInt64 { + if !hasValue0 { + isNewMin = m.hasValue.CompareAndSwap(false, true) + } + return // math.MaxInt64 as Min: isNewMin true for first 0 writer + } + + // check against present value + var current = m.value.Load() + if isNewMin = value < current; !isNewMin { + return // not a new Min return: isNewMin false + } + + // store the new Min + for { + + // try to write value to *Min + if isNewMin = m.value.CompareAndSwap(current, value); isNewMin { + if !hasValue0 { + // may be rarely written multiple times + // still faster than CompareAndSwap + m.hasValue.Store(true) + } + return // new Min written return: isNewMin true + } + if current = m.value.Load(); current <= value { + return // no longer a need to write return: isNewMin false + } + } +} + +// Min returns current min and value-present flag +// - hasValue true indicates that value reflects a Value invocation +// - hasValue false: value is zero-value +// - Thread-safe +func (m *AtomicMin) Min() (value int64, hasValue bool) { + if hasValue = m.hasValue.Load(); !hasValue { + return + } + value = m.value.Load() + return +} + +// Min1 returns current Minimum whether zero-value or set by Value +// - threshold is ignored +// - Thread-safe +func (m *AtomicMin) Min1() (value int64) { return m.value.Load() } From b945fbd8ddb22e14924c50fe955eac4b81ecf183 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Mon, 13 Oct 2025 10:37:07 +0300 Subject: [PATCH 28/62] addressed more comments, fixed lint --- command.go | 12 ++++++------ internal/routing/aggregator.go | 30 +++++++++++++----------------- osscluster.go | 6 +++--- 3 files changed, 22 insertions(+), 26 deletions(-) diff --git a/command.go b/command.go index 7e2375a127..49db484920 100644 --- a/command.go +++ b/command.go @@ -6967,6 +6967,12 @@ func (cmd *VectorScoreSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MonitorCmd) Clone() Cmder { + // MonitorCmd cannot be safely cloned due to channels and goroutines + // Return a new MonitorCmd with the same channel + return newMonitorCmd(cmd.ctx, cmd.ch) +} + // ExtractCommandValue extracts the value from a command result using the fast enum-based approach func ExtractCommandValue(cmd interface{}) interface{} { // First try to get the command type using the interface @@ -7063,9 +7069,3 @@ func ExtractCommandValue(cmd interface{}) interface{} { // If we can't get the command type, return nil return nil } - -func (cmd *MonitorCmd) Clone() Cmder { - // MonitorCmd cannot be safely cloned due to channels and goroutines - // Return a new MonitorCmd with the same channel - return newMonitorCmd(cmd.ctx, cmd.ch) -} diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index b7d8260f92..5c2455c194 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -1,6 +1,7 @@ package routing import ( + "errors" "fmt" "math" "sync" @@ -9,6 +10,13 @@ import ( "github.com/redis/go-redis/v9/internal/util" ) +var ( + ErrMaxAggregation = errors.New("redis: no valid results to aggregate for max operation") + ErrMinAggregation = errors.New("redis: no valid results to aggregate for min operation") + ErrAndAggregation = errors.New("redis: no valid results to aggregate for logical AND operation") + ErrOrAggregation = errors.New("redis: no valid results to aggregate for logical OR operation") +) + // ResponseAggregator defines the interface for aggregating responses from multiple shards. type ResponseAggregator interface { // Add processes a single shard response. @@ -171,11 +179,6 @@ func (a *AggSumAggregator) Result() (interface{}, error) { type AggMinAggregator struct { err atomic.Value res *util.AtomicMin - - mu sync.Mutex - min int64 - hasResult bool - firstErr error } func (a *AggMinAggregator) Add(result interface{}, err error) error { @@ -207,7 +210,7 @@ func (a *AggMinAggregator) Result() (interface{}, error) { val, hasVal := a.res.Min() if !hasVal { - return nil, fmt.Errorf("redis: no valid results to aggregate for min operation") + return nil, ErrMinAggregation } return val, nil } @@ -247,7 +250,7 @@ func (a *AggMaxAggregator) Result() (interface{}, error) { val, hasVal := a.res.Max() if !hasVal { - return nil, fmt.Errorf("redis: no valid results to aggregate for max operation") + return nil, ErrMaxAggregation } return val, nil } @@ -293,7 +296,7 @@ func (a *AggLogicalAndAggregator) Result() (interface{}, error) { } if !a.hasResult.Load() { - return nil, fmt.Errorf("redis: no valid results to aggregate for logical AND operation") + return nil, ErrAndAggregation } return a.res.Load() != 0, nil } @@ -339,7 +342,7 @@ func (a *AggLogicalOrAggregator) Result() (interface{}, error) { } if !a.hasResult.Load() { - return nil, fmt.Errorf("redis: no valid results to aggregate for logical OR operation") + return nil, ErrOrAggregation } return a.res.Load() != 0, nil } @@ -533,13 +536,6 @@ func (a *SpecialAggregator) Result() (interface{}, error) { return nil, nil } -// SetAggregatorFunc allows setting custom aggregation logic for special commands. -func (a *SpecialAggregator) SetAggregatorFunc(fn func([]interface{}, []error) (interface{}, error)) { - a.mu.Lock() - defer a.mu.Unlock() - a.aggregatorFunc = fn -} - // SpecialAggregatorRegistry holds custom aggregation functions for specific commands. var SpecialAggregatorRegistry = make(map[string]func([]interface{}, []error) (interface{}, error)) @@ -552,7 +548,7 @@ func RegisterSpecialAggregator(cmdName string, fn func([]interface{}, []error) ( func NewSpecialAggregator(cmdName string) *SpecialAggregator { agg := &SpecialAggregator{} if fn, exists := SpecialAggregatorRegistry[cmdName]; exists { - agg.SetAggregatorFunc(fn) + agg.aggregatorFunc = fn } return agg } diff --git a/osscluster.go b/osscluster.go index d1fd059ea7..2ceedad045 100644 --- a/osscluster.go +++ b/osscluster.go @@ -2079,7 +2079,7 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { // Use a separate context that won't be canceled to ensure command info lookup // doesn't fail due to original context cancellation - cmdInfoCtx := context.Background() + cmdInfoCtx := c.context(ctx) if c.opt.ContextTimeoutEnabled && ctx != nil { // If context timeout is enabled, still use a reasonable timeout var cancel context.CancelFunc @@ -2089,13 +2089,13 @@ func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { cmdsInfo, err := c.cmdsInfoCache.Get(cmdInfoCtx) if err != nil { - internal.Logger.Printf(context.TODO(), "getting command info: %s", err) + internal.Logger.Printf(cmdInfoCtx, "getting command info: %s", err) return nil } info := cmdsInfo[name] if info == nil { - internal.Logger.Printf(context.TODO(), "info for cmd=%s not found", name) + internal.Logger.Printf(cmdInfoCtx, "info for cmd=%s not found", name) } return info } From 06dfd2ce9f8a461a9cd9b006d32d31a01ab65f8a Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Tue, 14 Oct 2025 10:11:38 +0300 Subject: [PATCH 29/62] added batch aggregator operations --- internal/routing/aggregator.go | 133 ++++++++++++++++++++++++++++++--- osscluster_router.go | 22 +++--- 2 files changed, 131 insertions(+), 24 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 5c2455c194..305b52a77f 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -25,6 +25,8 @@ type ResponseAggregator interface { // AddWithKey processes a single shard response for a specific key (used by keyed aggregators). AddWithKey(key string, result interface{}, err error) error + BatchAdd(map[string]interface{}, error) error + // Result returns the final aggregated result and any error. Result() (interface{}, error) } @@ -93,6 +95,14 @@ func (a *AllSucceededAggregator) Add(result interface{}, err error) error { return nil } +func (a *AllSucceededAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.Add(res, err) + } + + return nil +} + func (a *AllSucceededAggregator) Result() (interface{}, error) { var err error res, e := a.res.Load(), a.err.Load() @@ -127,6 +137,14 @@ func (a *OneSucceededAggregator) Add(result interface{}, err error) error { return nil } +func (a *OneSucceededAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.Add(res, err) + } + + return nil +} + func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } @@ -162,6 +180,14 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error { return nil } +func (a *AggSumAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.Add(res, err) + } + + return nil +} + func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } @@ -198,6 +224,14 @@ func (a *AggMinAggregator) Add(result interface{}, err error) error { return nil } +func (a *AggMinAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.Add(res, err) + } + + return nil +} + func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } @@ -238,6 +272,14 @@ func (a *AggMaxAggregator) Add(result interface{}, err error) error { return nil } +func (a *AggMaxAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.Add(res, err) + } + + return nil +} + func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } @@ -285,6 +327,14 @@ func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { return nil } +func (a *AggLogicalAndAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.Add(res, err) + } + + return nil +} + func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } @@ -331,6 +381,14 @@ func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { return nil } +func (a *AggLogicalOrAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.Add(res, err) + } + + return nil +} + func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } @@ -391,10 +449,7 @@ type DefaultKeylessAggregator struct { firstErr error } -func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { - a.mu.Lock() - defer a.mu.Unlock() - +func (a *DefaultKeylessAggregator) add(result interface{}, err error) error { if err != nil && a.firstErr == nil { a.firstErr = err return nil @@ -405,6 +460,21 @@ func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { return nil } +func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + return a.add(result, err) +} + +func (a *DefaultKeylessAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.add(res, err) + } + + return nil +} + func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } @@ -434,10 +504,7 @@ func NewDefaultKeyedAggregator(keyOrder []string) *DefaultKeyedAggregator { } } -func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error { - a.mu.Lock() - defer a.mu.Unlock() - +func (a *DefaultKeyedAggregator) add(result interface{}, err error) error { if err != nil && a.firstErr == nil { a.firstErr = err return nil @@ -449,10 +516,22 @@ func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error { return nil } -func (a *DefaultKeyedAggregator) AddWithKey(key string, result interface{}, err error) error { +func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error { a.mu.Lock() defer a.mu.Unlock() + return a.add(result, err) +} + +func (a *DefaultKeyedAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.add(res, err) + } + + return nil +} + +func (a *DefaultKeyedAggregator) addWithKey(key string, result interface{}, err error) error { if err != nil && a.firstErr == nil { a.firstErr = err return nil @@ -463,6 +542,26 @@ func (a *DefaultKeyedAggregator) AddWithKey(key string, result interface{}, err return nil } +func (a *DefaultKeyedAggregator) AddWithKey(key string, result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + a.addWithKey(key, result, err) + return nil +} + +func (a *DefaultKeyedAggregator) BatchAddWithKeyOrder(results map[string]interface{}, keyOrder []string) error { + a.mu.Lock() + defer a.mu.Unlock() + + a.keyOrder = keyOrder + for key, val := range results { + _ = a.addWithKey(key, val, nil) + } + + return nil +} + func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) { a.mu.Lock() defer a.mu.Unlock() @@ -504,12 +603,24 @@ type SpecialAggregator struct { errors []error } +func (a *SpecialAggregator) add(result interface{}, err error) error { + a.results = append(a.results, result) + a.errors = append(a.errors, err) + return nil +} + func (a *SpecialAggregator) Add(result interface{}, err error) error { a.mu.Lock() defer a.mu.Unlock() - a.results = append(a.results, result) - a.errors = append(a.errors, err) + return a.add(result, err) +} + +func (a *SpecialAggregator) BatchAdd(results map[string]interface{}, err error) error { + for _, res := range results { + a.add(res, err) + } + return nil } diff --git a/osscluster_router.go b/osscluster_router.go index 669bfea66f..013c363325 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -393,21 +393,17 @@ func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string] aggregator := c.createAggregator(policy, cmd, true) // Set key order for keyed aggregators - if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { - keyedAgg.SetKeyOrder(keyOrder) + var keyedAgg *routing.DefaultKeyedAggregator + var isKeyedAgg bool + var err error + if keyedAgg, isKeyedAgg = aggregator.(*routing.DefaultKeyedAggregator); isKeyedAgg { + err = keyedAgg.BatchAddWithKeyOrder(keyedResults, keyOrder) + } else { + err = aggregator.BatchAdd(keyedResults, nil) } - // Add results with keys - for key, value := range keyedResults { - if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { - if err := keyedAgg.AddWithKey(key, value, nil); err != nil { - return err - } - } else { - if err := aggregator.Add(value, nil); err != nil { - return err - } - } + if err != nil { + return err } return c.finishAggregation(cmd, aggregator) From bd5386fcd3d245c5f8f847dfecf6ce9d1e43e3a8 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Tue, 14 Oct 2025 10:26:04 +0300 Subject: [PATCH 30/62] fixed lint --- internal/routing/aggregator.go | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 305b52a77f..fe59f628c1 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -97,7 +97,7 @@ func (a *AllSucceededAggregator) Add(result interface{}, err error) error { func (a *AllSucceededAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.Add(res, err) + _ = a.Add(res, err) } return nil @@ -139,7 +139,7 @@ func (a *OneSucceededAggregator) Add(result interface{}, err error) error { func (a *OneSucceededAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.Add(res, err) + _ = a.Add(res, err) } return nil @@ -182,7 +182,7 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error { func (a *AggSumAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.Add(res, err) + _ = a.Add(res, err) } return nil @@ -226,7 +226,7 @@ func (a *AggMinAggregator) Add(result interface{}, err error) error { func (a *AggMinAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.Add(res, err) + _ = a.Add(res, err) } return nil @@ -274,7 +274,7 @@ func (a *AggMaxAggregator) Add(result interface{}, err error) error { func (a *AggMaxAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.Add(res, err) + _ = a.Add(res, err) } return nil @@ -329,7 +329,7 @@ func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { func (a *AggLogicalAndAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.Add(res, err) + _ = a.Add(res, err) } return nil @@ -383,7 +383,7 @@ func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { func (a *AggLogicalOrAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.Add(res, err) + _ = a.Add(res, err) } return nil @@ -469,7 +469,7 @@ func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { func (a *DefaultKeylessAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.add(res, err) + _ = a.add(res, err) } return nil @@ -525,7 +525,7 @@ func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error { func (a *DefaultKeyedAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.add(res, err) + _ = a.add(res, err) } return nil @@ -546,8 +546,7 @@ func (a *DefaultKeyedAggregator) AddWithKey(key string, result interface{}, err a.mu.Lock() defer a.mu.Unlock() - a.addWithKey(key, result, err) - return nil + return a.addWithKey(key, result, err) } func (a *DefaultKeyedAggregator) BatchAddWithKeyOrder(results map[string]interface{}, keyOrder []string) error { @@ -618,7 +617,7 @@ func (a *SpecialAggregator) Add(result interface{}, err error) error { func (a *SpecialAggregator) BatchAdd(results map[string]interface{}, err error) error { for _, res := range results { - a.add(res, err) + _ = a.add(res, err) } return nil From a402e47282e43d5ff5bef1292965b01bcc2316bf Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Tue, 14 Oct 2025 14:05:56 +0300 Subject: [PATCH 31/62] updated batch aggregator, fixed extractcommandvalue --- command.go | 267 +++++++++++++++++++++++++++++++-- internal/routing/aggregator.go | 144 +++++++++++++++--- osscluster_router.go | 23 +-- 3 files changed, 386 insertions(+), 48 deletions(-) diff --git a/command.go b/command.go index 49db484920..27743081e8 100644 --- a/command.go +++ b/command.go @@ -153,6 +153,33 @@ const ( ) >>>>>>> b6633bf9 (centralize cluster command routing in osscluster_router.go and refactor osscluster.go (#6)) +type ( + CmdTypeXAutoClaimValue struct { + messages []XMessage + start string + } + + CmdTypeXAutoClaimJustIDValue struct { + ids []string + start string + } + + CmdTypeScanValue struct { + keys []string + cursor uint64 + } + + CmdTypeKeyValuesValue struct { + key string + values []string + } + + CmdTypeZSliceWithKeyValue struct { + key string + zSlice []Z + } +) + type Cmder interface { // command name. // e.g. "set k v ex 10" -> "set", "cluster info" -> "cluster". @@ -6981,6 +7008,10 @@ func ExtractCommandValue(cmd interface{}) interface{} { // Use fast type-based extraction switch cmdType { + case CmdTypeGeneric: + if genericCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return genericCmd.Val() + } case CmdTypeString: if stringCmd, ok := cmd.(interface{ Val() string }); ok { return stringCmd.Val() @@ -7001,21 +7032,217 @@ func ExtractCommandValue(cmd interface{}) interface{} { if statusCmd, ok := cmd.(interface{ Val() string }); ok { return statusCmd.Val() } - case CmdTypeDuration, CmdTypeTime, CmdTypeStringStructMap, CmdTypeXMessageSlice, - CmdTypeXStreamSlice, CmdTypeXPending, CmdTypeXPendingExt, CmdTypeXAutoClaim, - CmdTypeXAutoClaimJustID, CmdTypeXInfoConsumers, CmdTypeXInfoGroups, CmdTypeXInfoStream, - CmdTypeXInfoStreamFull, CmdTypeZSlice, CmdTypeZWithKey, CmdTypeScan, CmdTypeClusterSlots, - CmdTypeGeoSearchLocation, CmdTypeGeoPos, CmdTypeCommandsInfo, CmdTypeSlowLog, - CmdTypeKeyValues, CmdTypeZSliceWithKey, CmdTypeFunctionList, CmdTypeFunctionStats, - CmdTypeLCS, CmdTypeKeyFlags, CmdTypeClusterLinks, CmdTypeClusterShards, - CmdTypeRankWithScore, CmdTypeClientInfo, CmdTypeACLLog, CmdTypeInfo, CmdTypeMonitor, - CmdTypeJSON, CmdTypeJSONSlice, CmdTypeIntPointerSlice, CmdTypeScanDump, CmdTypeBFInfo, - CmdTypeCFInfo, CmdTypeCMSInfo, CmdTypeTopKInfo, CmdTypeTDigestInfo, CmdTypeFTSearch, - CmdTypeFTInfo, CmdTypeFTSpellCheck, CmdTypeFTSynDump, CmdTypeAggregate, - CmdTypeTSTimestampValue, CmdTypeTSTimestampValueSlice: - if durationCmd, ok := cmd.(interface{ Val() interface{} }); ok { + case CmdTypeDuration: + if durationCmd, ok := cmd.(interface{ Val() time.Duration }); ok { return durationCmd.Val() } + case CmdTypeTime: + if timeCmd, ok := cmd.(interface{ Val() time.Time }); ok { + return timeCmd.Val() + } + case CmdTypeStringStructMap: + if structMapCmd, ok := cmd.(interface{ Val() map[string]struct{} }); ok { + return structMapCmd.Val() + } + case CmdTypeXMessageSlice: + if xMessageSliceCmd, ok := cmd.(interface{ Val() []XMessage }); ok { + return xMessageSliceCmd.Val() + } + case CmdTypeXStreamSlice: + if xStreamSliceCmd, ok := cmd.(interface{ Val() []XStream }); ok { + return xStreamSliceCmd.Val() + } + case CmdTypeXPending: + if xPendingCmd, ok := cmd.(interface{ Val() *XPending }); ok { + return xPendingCmd.Val() + } + case CmdTypeXPendingExt: + if xPendingExtCmd, ok := cmd.(interface{ Val() []XPendingExt }); ok { + return xPendingExtCmd.Val() + } + case CmdTypeXAutoClaim: + if xAutoClaimCmd, ok := cmd.(interface{ Val() ([]XMessage, string) }); ok { + messages, start := xAutoClaimCmd.Val() + return CmdTypeXAutoClaimValue{messages: messages, start: start} + } + case CmdTypeXAutoClaimJustID: + if xAutoClaimJustIDCmd, ok := cmd.(interface{ Val() ([]string, string) }); ok { + ids, start := xAutoClaimJustIDCmd.Val() + return CmdTypeXAutoClaimJustIDValue{ids: ids, start: start} + } + case CmdTypeXInfoConsumers: + if xInfoConsumersCmd, ok := cmd.(interface{ Val() []XInfoConsumer }); ok { + return xInfoConsumersCmd.Val() + } + case CmdTypeXInfoGroups: + if xInfoGroupsCmd, ok := cmd.(interface{ Val() []XInfoGroup }); ok { + return xInfoGroupsCmd.Val() + } + case CmdTypeXInfoStream: + if xInfoStreamCmd, ok := cmd.(interface{ Val() *XInfoStream }); ok { + return xInfoStreamCmd.Val() + } + case CmdTypeXInfoStreamFull: + if xInfoStreamFullCmd, ok := cmd.(interface{ Val() *XInfoStreamFull }); ok { + return xInfoStreamFullCmd.Val() + } + case CmdTypeZSlice: + if zSliceCmd, ok := cmd.(interface{ Val() []Z }); ok { + return zSliceCmd.Val() + } + case CmdTypeZWithKey: + if zWithKeyCmd, ok := cmd.(interface{ Val() *ZWithKey }); ok { + return zWithKeyCmd.Val() + } + case CmdTypeScan: + if scanCmd, ok := cmd.(interface{ Val() ([]string, uint64) }); ok { + keys, cursor := scanCmd.Val() + return CmdTypeScanValue{keys: keys, cursor: cursor} + } + case CmdTypeClusterSlots: + if clusterSlotsCmd, ok := cmd.(interface{ Val() []ClusterSlot }); ok { + return clusterSlotsCmd.Val() + } + case CmdTypeGeoLocation: + if geoLocationCmd, ok := cmd.(interface{ Val() []GeoLocation }); ok { + return geoLocationCmd.Val() + } + case CmdTypeGeoSearchLocation: + if geoSearchLocationCmd, ok := cmd.(interface{ Val() []GeoLocation }); ok { + return geoSearchLocationCmd.Val() + } + case CmdTypeGeoPos: + if geoPosCmd, ok := cmd.(interface{ Val() []*GeoPos }); ok { + return geoPosCmd.Val() + } + case CmdTypeCommandsInfo: + if commandsInfoCmd, ok := cmd.(interface { + Val() map[string]*CommandInfo + }); ok { + return commandsInfoCmd.Val() + } + case CmdTypeSlowLog: + if slowLogCmd, ok := cmd.(interface{ Val() []SlowLog }); ok { + return slowLogCmd.Val() + } + case CmdTypeKeyValues: + if keyValuesCmd, ok := cmd.(interface{ Val() (string, []string) }); ok { + key, values := keyValuesCmd.Val() + return CmdTypeKeyValuesValue{key: key, values: values} + } + case CmdTypeZSliceWithKey: + if zSliceWithKeyCmd, ok := cmd.(interface{ Val() (string, []Z) }); ok { + key, zSlice := zSliceWithKeyCmd.Val() + return CmdTypeZSliceWithKeyValue{key: key, zSlice: zSlice} + } + case CmdTypeFunctionList: + if functionListCmd, ok := cmd.(interface{ Val() []Library }); ok { + return functionListCmd.Val() + } + case CmdTypeFunctionStats: + if functionStatsCmd, ok := cmd.(interface{ Val() FunctionStats }); ok { + return functionStatsCmd.Val() + } + case CmdTypeLCS: + if lcsCmd, ok := cmd.(interface{ Val() *LCSMatch }); ok { + return lcsCmd.Val() + } + case CmdTypeKeyFlags: + if keyFlagsCmd, ok := cmd.(interface{ Val() []KeyFlags }); ok { + return keyFlagsCmd.Val() + } + case CmdTypeClusterLinks: + if clusterLinksCmd, ok := cmd.(interface{ Val() []ClusterLink }); ok { + return clusterLinksCmd.Val() + } + case CmdTypeClusterShards: + if clusterShardsCmd, ok := cmd.(interface{ Val() []ClusterShard }); ok { + return clusterShardsCmd.Val() + } + case CmdTypeRankWithScore: + if rankWithScoreCmd, ok := cmd.(interface{ Val() RankScore }); ok { + return rankWithScoreCmd.Val() + } + case CmdTypeClientInfo: + if clientInfoCmd, ok := cmd.(interface{ Val() *ClientInfo }); ok { + return clientInfoCmd.Val() + } + case CmdTypeACLLog: + if aclLogCmd, ok := cmd.(interface{ Val() []*ACLLogEntry }); ok { + return aclLogCmd.Val() + } + case CmdTypeInfo: + if infoCmd, ok := cmd.(interface{ Val() string }); ok { + return infoCmd.Val() + } + case CmdTypeMonitor: + if monitorCmd, ok := cmd.(interface{ Val() string }); ok { + return monitorCmd.Val() + } + case CmdTypeJSON: + if jsonCmd, ok := cmd.(interface{ Val() string }); ok { + return jsonCmd.Val() + } + case CmdTypeJSONSlice: + if jsonSliceCmd, ok := cmd.(interface{ Val() []interface{} }); ok { + return jsonSliceCmd.Val() + } + case CmdTypeIntPointerSlice: + if intPointerSliceCmd, ok := cmd.(interface{ Val() []*int64 }); ok { + return intPointerSliceCmd.Val() + } + case CmdTypeScanDump: + if scanDumpCmd, ok := cmd.(interface{ Val() ScanDump }); ok { + return scanDumpCmd.Val() + } + case CmdTypeBFInfo: + if bfInfoCmd, ok := cmd.(interface{ Val() BFInfo }); ok { + return bfInfoCmd.Val() + } + case CmdTypeCFInfo: + if cfInfoCmd, ok := cmd.(interface{ Val() CFInfo }); ok { + return cfInfoCmd.Val() + } + case CmdTypeCMSInfo: + if cmsInfoCmd, ok := cmd.(interface{ Val() CMSInfo }); ok { + return cmsInfoCmd.Val() + } + case CmdTypeTopKInfo: + if topKInfoCmd, ok := cmd.(interface{ Val() TopKInfo }); ok { + return topKInfoCmd.Val() + } + case CmdTypeTDigestInfo: + if tDigestInfoCmd, ok := cmd.(interface{ Val() TDigestInfo }); ok { + return tDigestInfoCmd.Val() + } + case CmdTypeFTSearch: + if ftSearchCmd, ok := cmd.(interface{ Val() FTSearchResult }); ok { + return ftSearchCmd.Val() + } + case CmdTypeFTInfo: + if ftInfoCmd, ok := cmd.(interface{ Val() FTInfoResult }); ok { + return ftInfoCmd.Val() + } + case CmdTypeFTSpellCheck: + if ftSpellCheckCmd, ok := cmd.(interface{ Val() []SpellCheckResult }); ok { + return ftSpellCheckCmd.Val() + } + case CmdTypeFTSynDump: + if ftSynDumpCmd, ok := cmd.(interface{ Val() []FTSynDumpResult }); ok { + return ftSynDumpCmd.Val() + } + case CmdTypeAggregate: + if aggregateCmd, ok := cmd.(interface{ Val() *FTAggregateResult }); ok { + return aggregateCmd.Val() + } + case CmdTypeTSTimestampValue: + if tsTimestampValueCmd, ok := cmd.(interface{ Val() TSTimestampValue }); ok { + return tsTimestampValueCmd.Val() + } + case CmdTypeTSTimestampValueSlice: + if tsTimestampValueSliceCmd, ok := cmd.(interface{ Val() []TSTimestampValue }); ok { + return tsTimestampValueSliceCmd.Val() + } case CmdTypeStringSlice: if stringSliceCmd, ok := cmd.(interface{ Val() []string }); ok { return stringSliceCmd.Val() @@ -7032,6 +7259,14 @@ func ExtractCommandValue(cmd interface{}) interface{} { if floatSliceCmd, ok := cmd.(interface{ Val() []float64 }); ok { return floatSliceCmd.Val() } + case CmdTypeSlice: + if sliceCmd, ok := cmd.(interface{ Val() []interface{} }); ok { + return sliceCmd.Val() + } + case CmdTypeKeyValueSlice: + if keyValueSliceCmd, ok := cmd.(interface{ Val() []KeyValue }); ok { + return keyValueSliceCmd.Val() + } case CmdTypeMapStringString: if mapCmd, ok := cmd.(interface{ Val() map[string]string }); ok { return mapCmd.Val() @@ -7042,7 +7277,7 @@ func ExtractCommandValue(cmd interface{}) interface{} { } case CmdTypeMapStringInterfaceSlice: if mapCmd, ok := cmd.(interface { - Val() map[string][]interface{} + Val() []map[string]interface{} }); ok { return mapCmd.Val() } @@ -7051,12 +7286,12 @@ func ExtractCommandValue(cmd interface{}) interface{} { return mapCmd.Val() } case CmdTypeMapStringStringSlice: - if mapCmd, ok := cmd.(interface{ Val() map[string][]string }); ok { + if mapCmd, ok := cmd.(interface{ Val() []map[string]string }); ok { return mapCmd.Val() } case CmdTypeMapMapStringInterface: if mapCmd, ok := cmd.(interface { - Val() map[string][]interface{} + Val() map[string]interface{} }); ok { return mapCmd.Val() } diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index fe59f628c1..d44c8e5ca6 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -25,12 +25,19 @@ type ResponseAggregator interface { // AddWithKey processes a single shard response for a specific key (used by keyed aggregators). AddWithKey(key string, result interface{}, err error) error - BatchAdd(map[string]interface{}, error) error + BatchAdd(map[string]interface{}) error + + BatchWithErrs([]AggregatorResErr) error // Result returns the final aggregated result and any error. Result() (interface{}, error) } +type AggregatorResErr struct { + result interface{} + err error +} + // NewResponseAggregator creates an aggregator based on the response policy. func NewResponseAggregator(policy ResponsePolicy, cmdName string) ResponseAggregator { switch policy { @@ -95,9 +102,17 @@ func (a *AllSucceededAggregator) Add(result interface{}, err error) error { return nil } -func (a *AllSucceededAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *AllSucceededAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, err) + _ = a.Add(res, nil) + } + + return nil +} + +func (a *AllSucceededAggregator) BatchWithErrs(values []AggregatorResErr) error { + for _, val := range values { + a.Add(val.result, val.err) } return nil @@ -137,9 +152,9 @@ func (a *OneSucceededAggregator) Add(result interface{}, err error) error { return nil } -func (a *OneSucceededAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *OneSucceededAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, err) + _ = a.Add(res, nil) } return nil @@ -149,6 +164,14 @@ func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } +func (a *OneSucceededAggregator) BatchWithErrs(values []AggregatorResErr) error { + for _, val := range values { + a.Add(val.result, val.err) + } + + return nil +} + func (a *OneSucceededAggregator) Result() (interface{}, error) { res, e := a.res.Load(), a.err.Load() if res != nil { @@ -180,9 +203,9 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggSumAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *AggSumAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, err) + _ = a.Add(res, nil) } return nil @@ -192,6 +215,14 @@ func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } +func (a *AggSumAggregator) BatchWithErrs(values []AggregatorResErr) error { + for _, val := range values { + a.Add(val.result, val.err) + } + + return nil +} + func (a *AggSumAggregator) Result() (interface{}, error) { res, err := atomic.LoadInt64(a.res), a.err.Load() if err != nil { @@ -224,9 +255,9 @@ func (a *AggMinAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggMinAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *AggMinAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, err) + _ = a.Add(res, nil) } return nil @@ -236,6 +267,14 @@ func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } +func (a *AggMinAggregator) BatchWithErrs(values []AggregatorResErr) error { + for _, val := range values { + a.Add(val.result, val.err) + } + + return nil +} + func (a *AggMinAggregator) Result() (interface{}, error) { err := a.err.Load() if err != nil { @@ -272,9 +311,9 @@ func (a *AggMaxAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggMaxAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *AggMaxAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, err) + _ = a.Add(res, nil) } return nil @@ -284,6 +323,14 @@ func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } +func (a *AggMaxAggregator) BatchWithErrs(values []AggregatorResErr) error { + for _, val := range values { + a.Add(val.result, val.err) + } + + return nil +} + func (a *AggMaxAggregator) Result() (interface{}, error) { err := a.err.Load() if err != nil { @@ -327,9 +374,9 @@ func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggLogicalAndAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *AggLogicalAndAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, err) + _ = a.Add(res, nil) } return nil @@ -339,6 +386,14 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } +func (a *AggLogicalAndAggregator) BatchWithErrs(values []AggregatorResErr) error { + for _, val := range values { + a.Add(val.result, val.err) + } + + return nil +} + func (a *AggLogicalAndAggregator) Result() (interface{}, error) { err := a.err.Load() if err != nil { @@ -381,9 +436,9 @@ func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggLogicalOrAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *AggLogicalOrAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, err) + _ = a.Add(res, nil) } return nil @@ -393,6 +448,14 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } +func (a *AggLogicalOrAggregator) BatchWithErrs(values []AggregatorResErr) error { + for _, val := range values { + a.Add(val.result, val.err) + } + + return nil +} + func (a *AggLogicalOrAggregator) Result() (interface{}, error) { err := a.err.Load() if err != nil { @@ -467,9 +530,9 @@ func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { return a.add(result, err) } -func (a *DefaultKeylessAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *DefaultKeylessAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.add(res, err) + _ = a.add(res, nil) } return nil @@ -479,6 +542,17 @@ func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, er return a.Add(result, err) } +func (a *DefaultKeylessAggregator) BatchWithErrs(values []AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + + for _, val := range values { + a.add(val.result, val.err) + } + + return nil +} + func (a *DefaultKeylessAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -523,9 +597,12 @@ func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error { return a.add(result, err) } -func (a *DefaultKeyedAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *DefaultKeyedAggregator) BatchAdd(results map[string]interface{}) error { + a.mu.Lock() + defer a.mu.Unlock() + for _, res := range results { - _ = a.add(res, err) + _ = a.add(res, nil) } return nil @@ -567,6 +644,17 @@ func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) { a.keyOrder = keyOrder } +func (a *DefaultKeyedAggregator) BatchWithErrs(values []AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + + for _, val := range values { + a.add(val.result, val.err) + } + + return nil +} + func (a *DefaultKeyedAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() @@ -615,9 +703,12 @@ func (a *SpecialAggregator) Add(result interface{}, err error) error { return a.add(result, err) } -func (a *SpecialAggregator) BatchAdd(results map[string]interface{}, err error) error { +func (a *SpecialAggregator) BatchAdd(results map[string]interface{}) error { + a.mu.Lock() + defer a.mu.Unlock() + for _, res := range results { - _ = a.add(res, err) + _ = a.add(res, nil) } return nil @@ -627,6 +718,17 @@ func (a *SpecialAggregator) AddWithKey(key string, result interface{}, err error return a.Add(result, err) } +func (a *SpecialAggregator) BatchWithErrs(values []AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + + for _, val := range values { + a.add(val.result, val.err) + } + + return nil +} + func (a *SpecialAggregator) Result() (interface{}, error) { a.mu.Lock() defer a.mu.Unlock() diff --git a/osscluster_router.go b/osscluster_router.go index 013c363325..b5ba608634 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -399,7 +399,7 @@ func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string] if keyedAgg, isKeyedAgg = aggregator.(*routing.DefaultKeyedAggregator); isKeyedAgg { err = keyedAgg.BatchAddWithKeyOrder(keyedResults, keyOrder) } else { - err = aggregator.BatchAdd(keyedResults, nil) + err = aggregator.BatchAdd(keyedResults) } if err != nil { @@ -430,6 +430,7 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout // Add all results to aggregator for _, shardCmd := range cmds { value := ExtractCommandValue(shardCmd) + //TODO: Rewrite as batch if err := aggregator.Add(value, shardCmd.Err()); err != nil { return err } @@ -636,14 +637,14 @@ func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { } case CmdTypeXAutoClaim: if c, ok := cmd.(*XAutoClaimCmd); ok { - if v, ok := value.([]XMessage); ok { - c.SetVal(v, "") // Default start value + if v, ok := value.(CmdTypeXAutoClaimValue); ok { + c.SetVal(v.messages, v.start) } } case CmdTypeXAutoClaimJustID: if c, ok := cmd.(*XAutoClaimJustIDCmd); ok { - if v, ok := value.([]string); ok { - c.SetVal(v, "") // Default start value + if v, ok := value.(CmdTypeXAutoClaimJustIDValue); ok { + c.SetVal(v.ids, v.start) } } case CmdTypeXInfoConsumers: @@ -684,8 +685,8 @@ func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { } case CmdTypeScan: if c, ok := cmd.(*ScanCmd); ok { - if v, ok := value.([]string); ok { - c.SetVal(v, uint64(0)) // Default cursor + if v, ok := value.(CmdTypeScanValue); ok { + c.SetVal(v.keys, v.cursor) } } case CmdTypeClusterSlots: @@ -745,15 +746,15 @@ func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { case CmdTypeKeyValues: if c, ok := cmd.(*KeyValuesCmd); ok { // KeyValuesCmd needs a key string and values slice - if key, ok := value.(string); ok { - c.SetVal(key, []string{}) // Default empty values + if v, ok := value.(CmdTypeKeyValuesValue); ok { + c.SetVal(v.key, v.values) } } case CmdTypeZSliceWithKey: if c, ok := cmd.(*ZSliceWithKeyCmd); ok { // ZSliceWithKeyCmd needs a key string and Z slice - if key, ok := value.(string); ok { - c.SetVal(key, []Z{}) // Default empty Z slice + if v, ok := value.(CmdTypeZSliceWithKeyValue); ok { + c.SetVal(v.key, v.zSlice) } } case CmdTypeFunctionList: From 4267f7c607df8ffe4f9fc4b2dcfb91bcf38b60d0 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Tue, 14 Oct 2025 14:08:32 +0300 Subject: [PATCH 32/62] fixed lint --- internal/routing/aggregator.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index d44c8e5ca6..aa6deb669c 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -112,7 +112,7 @@ func (a *AllSucceededAggregator) BatchAdd(results map[string]interface{}) error func (a *AllSucceededAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - a.Add(val.result, val.err) + _ = a.Add(val.result, val.err) } return nil @@ -166,7 +166,7 @@ func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err func (a *OneSucceededAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - a.Add(val.result, val.err) + _ = a.Add(val.result, val.err) } return nil @@ -217,7 +217,7 @@ func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) func (a *AggSumAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - a.Add(val.result, val.err) + _ = a.Add(val.result, val.err) } return nil @@ -269,7 +269,7 @@ func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) func (a *AggMinAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - a.Add(val.result, val.err) + _ = a.Add(val.result, val.err) } return nil @@ -325,7 +325,7 @@ func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) func (a *AggMaxAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - a.Add(val.result, val.err) + _ = a.Add(val.result, val.err) } return nil @@ -388,7 +388,7 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err func (a *AggLogicalAndAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - a.Add(val.result, val.err) + _ = a.Add(val.result, val.err) } return nil @@ -450,7 +450,7 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err func (a *AggLogicalOrAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - a.Add(val.result, val.err) + _ = a.Add(val.result, val.err) } return nil @@ -547,7 +547,7 @@ func (a *DefaultKeylessAggregator) BatchWithErrs(values []AggregatorResErr) erro defer a.mu.Unlock() for _, val := range values { - a.add(val.result, val.err) + _ = a.add(val.result, val.err) } return nil @@ -649,7 +649,7 @@ func (a *DefaultKeyedAggregator) BatchWithErrs(values []AggregatorResErr) error defer a.mu.Unlock() for _, val := range values { - a.add(val.result, val.err) + _ = a.add(val.result, val.err) } return nil @@ -723,7 +723,7 @@ func (a *SpecialAggregator) BatchWithErrs(values []AggregatorResErr) error { defer a.mu.Unlock() for _, val := range values { - a.add(val.result, val.err) + _ = a.add(val.result, val.err) } return nil From 41c4a43a8f453a835d8de3f564e53cf175b4f6c0 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Tue, 14 Oct 2025 15:39:12 +0300 Subject: [PATCH 33/62] added batching to aggregateResponses --- internal/routing/aggregator.go | 104 ++++++++++++++++++++++++++------- osscluster_router.go | 14 +++-- 2 files changed, 92 insertions(+), 26 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index aa6deb669c..9e2943f7e4 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -34,8 +34,8 @@ type ResponseAggregator interface { } type AggregatorResErr struct { - result interface{} - err error + Result interface{} + Err error } // NewResponseAggregator creates an aggregator based on the response policy. @@ -104,7 +104,10 @@ func (a *AllSucceededAggregator) Add(result interface{}, err error) error { func (a *AllSucceededAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -112,7 +115,10 @@ func (a *AllSucceededAggregator) BatchAdd(results map[string]interface{}) error func (a *AllSucceededAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - _ = a.Add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil @@ -154,7 +160,10 @@ func (a *OneSucceededAggregator) Add(result interface{}, err error) error { func (a *OneSucceededAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -166,7 +175,10 @@ func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err func (a *OneSucceededAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - _ = a.Add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil @@ -205,7 +217,10 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error { func (a *AggSumAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -217,7 +232,10 @@ func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) func (a *AggSumAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - _ = a.Add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil @@ -257,7 +275,10 @@ func (a *AggMinAggregator) Add(result interface{}, err error) error { func (a *AggMinAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -269,7 +290,10 @@ func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) func (a *AggMinAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - _ = a.Add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil @@ -313,7 +337,10 @@ func (a *AggMaxAggregator) Add(result interface{}, err error) error { func (a *AggMaxAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -325,7 +352,10 @@ func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) func (a *AggMaxAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - _ = a.Add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil @@ -376,7 +406,10 @@ func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { func (a *AggLogicalAndAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -388,7 +421,10 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err func (a *AggLogicalAndAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - _ = a.Add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil @@ -438,7 +474,10 @@ func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { func (a *AggLogicalOrAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.Add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -450,7 +489,10 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err func (a *AggLogicalOrAggregator) BatchWithErrs(values []AggregatorResErr) error { for _, val := range values { - _ = a.Add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil @@ -532,7 +574,10 @@ func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { func (a *DefaultKeylessAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - _ = a.add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -547,7 +592,10 @@ func (a *DefaultKeylessAggregator) BatchWithErrs(values []AggregatorResErr) erro defer a.mu.Unlock() for _, val := range values { - _ = a.add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil @@ -602,7 +650,10 @@ func (a *DefaultKeyedAggregator) BatchAdd(results map[string]interface{}) error defer a.mu.Unlock() for _, res := range results { - _ = a.add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -649,7 +700,10 @@ func (a *DefaultKeyedAggregator) BatchWithErrs(values []AggregatorResErr) error defer a.mu.Unlock() for _, val := range values { - _ = a.add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil @@ -708,7 +762,10 @@ func (a *SpecialAggregator) BatchAdd(results map[string]interface{}) error { defer a.mu.Unlock() for _, res := range results { - _ = a.add(res, nil) + err := a.Add(res, nil) + if err != nil { + return err + } } return nil @@ -723,7 +780,10 @@ func (a *SpecialAggregator) BatchWithErrs(values []AggregatorResErr) error { defer a.mu.Unlock() for _, val := range values { - _ = a.add(val.result, val.err) + err := a.Add(val.Result, val.Err) + if err != nil { + return err + } } return nil diff --git a/osscluster_router.go b/osscluster_router.go index b5ba608634..fe426125aa 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -427,13 +427,19 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout aggregator := c.createAggregator(policy, cmd, false) + batchWithErrs := []routing.AggregatorResErr{} // Add all results to aggregator for _, shardCmd := range cmds { value := ExtractCommandValue(shardCmd) - //TODO: Rewrite as batch - if err := aggregator.Add(value, shardCmd.Err()); err != nil { - return err - } + batchWithErrs = append(batchWithErrs, routing.AggregatorResErr{ + Result: value, + Err: shardCmd.Err(), + }) + } + + err := aggregator.BatchWithErrs(batchWithErrs) + if err != nil { + return err } return c.finishAggregation(cmd, aggregator) From 77e25b6156d5305fb08d268ed99286ca598b2e2a Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Tue, 14 Oct 2025 16:48:57 +0300 Subject: [PATCH 34/62] fixed deadlocks --- internal/routing/aggregator.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 9e2943f7e4..7d7e9c1be0 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -574,7 +574,7 @@ func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { func (a *DefaultKeylessAggregator) BatchAdd(results map[string]interface{}) error { for _, res := range results { - err := a.Add(res, nil) + err := a.add(res, nil) if err != nil { return err } @@ -592,7 +592,7 @@ func (a *DefaultKeylessAggregator) BatchWithErrs(values []AggregatorResErr) erro defer a.mu.Unlock() for _, val := range values { - err := a.Add(val.Result, val.Err) + err := a.add(val.Result, val.Err) if err != nil { return err } @@ -650,7 +650,7 @@ func (a *DefaultKeyedAggregator) BatchAdd(results map[string]interface{}) error defer a.mu.Unlock() for _, res := range results { - err := a.Add(res, nil) + err := a.add(res, nil) if err != nil { return err } @@ -700,7 +700,7 @@ func (a *DefaultKeyedAggregator) BatchWithErrs(values []AggregatorResErr) error defer a.mu.Unlock() for _, val := range values { - err := a.Add(val.Result, val.Err) + err := a.add(val.Result, val.Err) if err != nil { return err } @@ -762,7 +762,7 @@ func (a *SpecialAggregator) BatchAdd(results map[string]interface{}) error { defer a.mu.Unlock() for _, res := range results { - err := a.Add(res, nil) + err := a.add(res, nil) if err != nil { return err } @@ -780,7 +780,7 @@ func (a *SpecialAggregator) BatchWithErrs(values []AggregatorResErr) error { defer a.mu.Unlock() for _, val := range values { - err := a.Add(val.Result, val.Err) + err := a.add(val.Result, val.Err) if err != nil { return err } From 23b35bed5e9c0281746ee9acaa86dbc2802bd42a Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Wed, 15 Oct 2025 11:10:06 +0300 Subject: [PATCH 35/62] changed aggregator logic, added error params --- command.go | 477 +++++++++++++++++++++++---------- internal/routing/aggregator.go | 71 ++--- osscluster_router.go | 30 +-- 3 files changed, 389 insertions(+), 189 deletions(-) diff --git a/command.go b/command.go index 27743081e8..ce693004cb 100644 --- a/command.go +++ b/command.go @@ -7001,7 +7001,7 @@ func (cmd *MonitorCmd) Clone() Cmder { } // ExtractCommandValue extracts the value from a command result using the fast enum-based approach -func ExtractCommandValue(cmd interface{}) interface{} { +func ExtractCommandValue(cmd interface{}) (interface{}, error) { // First try to get the command type using the interface if cmdTypeGetter, ok := cmd.(CmdTypeGetter); ok { cmdType := cmdTypeGetter.GetCmdType() @@ -7009,298 +7009,499 @@ func ExtractCommandValue(cmd interface{}) interface{} { // Use fast type-based extraction switch cmdType { case CmdTypeGeneric: - if genericCmd, ok := cmd.(interface{ Val() interface{} }); ok { - return genericCmd.Val() + if genericCmd, ok := cmd.(interface { + Val() interface{} + Err() error + }); ok { + return genericCmd.Val(), genericCmd.Err() } case CmdTypeString: - if stringCmd, ok := cmd.(interface{ Val() string }); ok { - return stringCmd.Val() + if stringCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return stringCmd.Val(), stringCmd.Err() } case CmdTypeInt: - if intCmd, ok := cmd.(interface{ Val() int64 }); ok { - return intCmd.Val() + if intCmd, ok := cmd.(interface { + Val() int64 + Err() error + }); ok { + return intCmd.Val(), intCmd.Err() } case CmdTypeBool: - if boolCmd, ok := cmd.(interface{ Val() bool }); ok { - return boolCmd.Val() + if boolCmd, ok := cmd.(interface { + Val() bool + Err() error + }); ok { + return boolCmd.Val(), boolCmd.Err() } case CmdTypeFloat: - if floatCmd, ok := cmd.(interface{ Val() float64 }); ok { - return floatCmd.Val() + if floatCmd, ok := cmd.(interface { + Val() float64 + Err() error + }); ok { + return floatCmd.Val(), floatCmd.Err() } case CmdTypeStatus: - if statusCmd, ok := cmd.(interface{ Val() string }); ok { - return statusCmd.Val() + if statusCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return statusCmd.Val(), statusCmd.Err() } case CmdTypeDuration: - if durationCmd, ok := cmd.(interface{ Val() time.Duration }); ok { - return durationCmd.Val() + if durationCmd, ok := cmd.(interface { + Val() time.Duration + Err() error + }); ok { + return durationCmd.Val(), durationCmd.Err() } case CmdTypeTime: - if timeCmd, ok := cmd.(interface{ Val() time.Time }); ok { - return timeCmd.Val() + if timeCmd, ok := cmd.(interface { + Val() time.Time + Err() error + }); ok { + return timeCmd.Val(), timeCmd.Err() } case CmdTypeStringStructMap: - if structMapCmd, ok := cmd.(interface{ Val() map[string]struct{} }); ok { - return structMapCmd.Val() + if structMapCmd, ok := cmd.(interface { + Val() map[string]struct{} + Err() error + }); ok { + return structMapCmd.Val(), structMapCmd.Err() } case CmdTypeXMessageSlice: - if xMessageSliceCmd, ok := cmd.(interface{ Val() []XMessage }); ok { - return xMessageSliceCmd.Val() + if xMessageSliceCmd, ok := cmd.(interface { + Val() []XMessage + Err() error + }); ok { + return xMessageSliceCmd.Val(), xMessageSliceCmd.Err() } case CmdTypeXStreamSlice: - if xStreamSliceCmd, ok := cmd.(interface{ Val() []XStream }); ok { - return xStreamSliceCmd.Val() + if xStreamSliceCmd, ok := cmd.(interface { + Val() []XStream + Err() error + }); ok { + return xStreamSliceCmd.Val(), xStreamSliceCmd.Err() } case CmdTypeXPending: - if xPendingCmd, ok := cmd.(interface{ Val() *XPending }); ok { - return xPendingCmd.Val() + if xPendingCmd, ok := cmd.(interface { + Val() *XPending + Err() error + }); ok { + return xPendingCmd.Val(), xPendingCmd.Err() } case CmdTypeXPendingExt: - if xPendingExtCmd, ok := cmd.(interface{ Val() []XPendingExt }); ok { - return xPendingExtCmd.Val() + if xPendingExtCmd, ok := cmd.(interface { + Val() []XPendingExt + Err() error + }); ok { + return xPendingExtCmd.Val(), xPendingExtCmd.Err() } case CmdTypeXAutoClaim: - if xAutoClaimCmd, ok := cmd.(interface{ Val() ([]XMessage, string) }); ok { + if xAutoClaimCmd, ok := cmd.(interface { + Val() ([]XMessage, string) + Err() error + }); ok { messages, start := xAutoClaimCmd.Val() - return CmdTypeXAutoClaimValue{messages: messages, start: start} + return CmdTypeXAutoClaimValue{messages: messages, start: start}, xAutoClaimCmd.Err() } case CmdTypeXAutoClaimJustID: - if xAutoClaimJustIDCmd, ok := cmd.(interface{ Val() ([]string, string) }); ok { + if xAutoClaimJustIDCmd, ok := cmd.(interface { + Val() ([]string, string) + Err() error + }); ok { ids, start := xAutoClaimJustIDCmd.Val() - return CmdTypeXAutoClaimJustIDValue{ids: ids, start: start} + return CmdTypeXAutoClaimJustIDValue{ids: ids, start: start}, xAutoClaimJustIDCmd.Err() } case CmdTypeXInfoConsumers: - if xInfoConsumersCmd, ok := cmd.(interface{ Val() []XInfoConsumer }); ok { - return xInfoConsumersCmd.Val() + if xInfoConsumersCmd, ok := cmd.(interface { + Val() []XInfoConsumer + Err() error + }); ok { + return xInfoConsumersCmd.Val(), xInfoConsumersCmd.Err() } case CmdTypeXInfoGroups: - if xInfoGroupsCmd, ok := cmd.(interface{ Val() []XInfoGroup }); ok { - return xInfoGroupsCmd.Val() + if xInfoGroupsCmd, ok := cmd.(interface { + Val() []XInfoGroup + Err() error + }); ok { + return xInfoGroupsCmd.Val(), xInfoGroupsCmd.Err() } case CmdTypeXInfoStream: - if xInfoStreamCmd, ok := cmd.(interface{ Val() *XInfoStream }); ok { - return xInfoStreamCmd.Val() + if xInfoStreamCmd, ok := cmd.(interface { + Val() *XInfoStream + Err() error + }); ok { + return xInfoStreamCmd.Val(), xInfoStreamCmd.Err() } case CmdTypeXInfoStreamFull: - if xInfoStreamFullCmd, ok := cmd.(interface{ Val() *XInfoStreamFull }); ok { - return xInfoStreamFullCmd.Val() + if xInfoStreamFullCmd, ok := cmd.(interface { + Val() *XInfoStreamFull + Err() error + }); ok { + return xInfoStreamFullCmd.Val(), xInfoStreamFullCmd.Err() } case CmdTypeZSlice: - if zSliceCmd, ok := cmd.(interface{ Val() []Z }); ok { - return zSliceCmd.Val() + if zSliceCmd, ok := cmd.(interface { + Val() []Z + Err() error + }); ok { + return zSliceCmd.Val(), zSliceCmd.Err() } case CmdTypeZWithKey: - if zWithKeyCmd, ok := cmd.(interface{ Val() *ZWithKey }); ok { - return zWithKeyCmd.Val() + if zWithKeyCmd, ok := cmd.(interface { + Val() *ZWithKey + Err() error + }); ok { + return zWithKeyCmd.Val(), zWithKeyCmd.Err() } case CmdTypeScan: - if scanCmd, ok := cmd.(interface{ Val() ([]string, uint64) }); ok { + if scanCmd, ok := cmd.(interface { + Val() ([]string, uint64) + Err() error + }); ok { keys, cursor := scanCmd.Val() - return CmdTypeScanValue{keys: keys, cursor: cursor} + return CmdTypeScanValue{keys: keys, cursor: cursor}, scanCmd.Err() } case CmdTypeClusterSlots: - if clusterSlotsCmd, ok := cmd.(interface{ Val() []ClusterSlot }); ok { - return clusterSlotsCmd.Val() + if clusterSlotsCmd, ok := cmd.(interface { + Val() []ClusterSlot + Err() error + }); ok { + return clusterSlotsCmd.Val(), clusterSlotsCmd.Err() } case CmdTypeGeoLocation: - if geoLocationCmd, ok := cmd.(interface{ Val() []GeoLocation }); ok { - return geoLocationCmd.Val() + if geoLocationCmd, ok := cmd.(interface { + Val() []GeoLocation + Err() error + }); ok { + return geoLocationCmd.Val(), geoLocationCmd.Err() } case CmdTypeGeoSearchLocation: - if geoSearchLocationCmd, ok := cmd.(interface{ Val() []GeoLocation }); ok { - return geoSearchLocationCmd.Val() + if geoSearchLocationCmd, ok := cmd.(interface { + Val() []GeoLocation + Err() error + }); ok { + return geoSearchLocationCmd.Val(), geoSearchLocationCmd.Err() } case CmdTypeGeoPos: - if geoPosCmd, ok := cmd.(interface{ Val() []*GeoPos }); ok { - return geoPosCmd.Val() + if geoPosCmd, ok := cmd.(interface { + Val() []*GeoPos + Err() error + }); ok { + return geoPosCmd.Val(), geoPosCmd.Err() } case CmdTypeCommandsInfo: if commandsInfoCmd, ok := cmd.(interface { Val() map[string]*CommandInfo + Err() error }); ok { - return commandsInfoCmd.Val() + return commandsInfoCmd.Val(), commandsInfoCmd.Err() } case CmdTypeSlowLog: - if slowLogCmd, ok := cmd.(interface{ Val() []SlowLog }); ok { - return slowLogCmd.Val() + if slowLogCmd, ok := cmd.(interface { + Val() []SlowLog + Err() error + }); ok { + return slowLogCmd.Val(), slowLogCmd.Err() } case CmdTypeKeyValues: - if keyValuesCmd, ok := cmd.(interface{ Val() (string, []string) }); ok { + if keyValuesCmd, ok := cmd.(interface { + Val() (string, []string) + Err() error + }); ok { key, values := keyValuesCmd.Val() - return CmdTypeKeyValuesValue{key: key, values: values} + return CmdTypeKeyValuesValue{key: key, values: values}, keyValuesCmd.Err() } case CmdTypeZSliceWithKey: - if zSliceWithKeyCmd, ok := cmd.(interface{ Val() (string, []Z) }); ok { + if zSliceWithKeyCmd, ok := cmd.(interface { + Val() (string, []Z) + Err() error + }); ok { key, zSlice := zSliceWithKeyCmd.Val() - return CmdTypeZSliceWithKeyValue{key: key, zSlice: zSlice} + return CmdTypeZSliceWithKeyValue{key: key, zSlice: zSlice}, zSliceWithKeyCmd.Err() } case CmdTypeFunctionList: - if functionListCmd, ok := cmd.(interface{ Val() []Library }); ok { - return functionListCmd.Val() + if functionListCmd, ok := cmd.(interface { + Val() []Library + Err() error + }); ok { + return functionListCmd.Val(), functionListCmd.Err() } case CmdTypeFunctionStats: - if functionStatsCmd, ok := cmd.(interface{ Val() FunctionStats }); ok { - return functionStatsCmd.Val() + if functionStatsCmd, ok := cmd.(interface { + Val() FunctionStats + Err() error + }); ok { + return functionStatsCmd.Val(), functionStatsCmd.Err() } case CmdTypeLCS: - if lcsCmd, ok := cmd.(interface{ Val() *LCSMatch }); ok { - return lcsCmd.Val() + if lcsCmd, ok := cmd.(interface { + Val() *LCSMatch + Err() error + }); ok { + return lcsCmd.Val(), lcsCmd.Err() } case CmdTypeKeyFlags: - if keyFlagsCmd, ok := cmd.(interface{ Val() []KeyFlags }); ok { - return keyFlagsCmd.Val() + if keyFlagsCmd, ok := cmd.(interface { + Val() []KeyFlags + Err() error + }); ok { + return keyFlagsCmd.Val(), keyFlagsCmd.Err() } case CmdTypeClusterLinks: - if clusterLinksCmd, ok := cmd.(interface{ Val() []ClusterLink }); ok { - return clusterLinksCmd.Val() + if clusterLinksCmd, ok := cmd.(interface { + Val() []ClusterLink + Err() error + }); ok { + return clusterLinksCmd.Val(), clusterLinksCmd.Err() } case CmdTypeClusterShards: - if clusterShardsCmd, ok := cmd.(interface{ Val() []ClusterShard }); ok { - return clusterShardsCmd.Val() + if clusterShardsCmd, ok := cmd.(interface { + Val() []ClusterShard + Err() error + }); ok { + return clusterShardsCmd.Val(), clusterShardsCmd.Err() } case CmdTypeRankWithScore: - if rankWithScoreCmd, ok := cmd.(interface{ Val() RankScore }); ok { - return rankWithScoreCmd.Val() + if rankWithScoreCmd, ok := cmd.(interface { + Val() RankScore + Err() error + }); ok { + return rankWithScoreCmd.Val(), rankWithScoreCmd.Err() } case CmdTypeClientInfo: - if clientInfoCmd, ok := cmd.(interface{ Val() *ClientInfo }); ok { - return clientInfoCmd.Val() + if clientInfoCmd, ok := cmd.(interface { + Val() *ClientInfo + Err() error + }); ok { + return clientInfoCmd.Val(), clientInfoCmd.Err() } case CmdTypeACLLog: - if aclLogCmd, ok := cmd.(interface{ Val() []*ACLLogEntry }); ok { - return aclLogCmd.Val() + if aclLogCmd, ok := cmd.(interface { + Val() []*ACLLogEntry + Err() error + }); ok { + return aclLogCmd.Val(), aclLogCmd.Err() } case CmdTypeInfo: - if infoCmd, ok := cmd.(interface{ Val() string }); ok { - return infoCmd.Val() + if infoCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return infoCmd.Val(), infoCmd.Err() } case CmdTypeMonitor: - if monitorCmd, ok := cmd.(interface{ Val() string }); ok { - return monitorCmd.Val() + if monitorCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return monitorCmd.Val(), monitorCmd.Err() } case CmdTypeJSON: - if jsonCmd, ok := cmd.(interface{ Val() string }); ok { - return jsonCmd.Val() + if jsonCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return jsonCmd.Val(), jsonCmd.Err() } case CmdTypeJSONSlice: - if jsonSliceCmd, ok := cmd.(interface{ Val() []interface{} }); ok { - return jsonSliceCmd.Val() + if jsonSliceCmd, ok := cmd.(interface { + Val() []interface{} + Err() error + }); ok { + return jsonSliceCmd.Val(), jsonSliceCmd.Err() } case CmdTypeIntPointerSlice: - if intPointerSliceCmd, ok := cmd.(interface{ Val() []*int64 }); ok { - return intPointerSliceCmd.Val() + if intPointerSliceCmd, ok := cmd.(interface { + Val() []*int64 + Err() error + }); ok { + return intPointerSliceCmd.Val(), intPointerSliceCmd.Err() } case CmdTypeScanDump: - if scanDumpCmd, ok := cmd.(interface{ Val() ScanDump }); ok { - return scanDumpCmd.Val() + if scanDumpCmd, ok := cmd.(interface { + Val() ScanDump + Err() error + }); ok { + return scanDumpCmd.Val(), scanDumpCmd.Err() } case CmdTypeBFInfo: - if bfInfoCmd, ok := cmd.(interface{ Val() BFInfo }); ok { - return bfInfoCmd.Val() + if bfInfoCmd, ok := cmd.(interface { + Val() BFInfo + Err() error + }); ok { + return bfInfoCmd.Val(), bfInfoCmd.Err() } case CmdTypeCFInfo: - if cfInfoCmd, ok := cmd.(interface{ Val() CFInfo }); ok { - return cfInfoCmd.Val() + if cfInfoCmd, ok := cmd.(interface { + Val() CFInfo + Err() error + }); ok { + return cfInfoCmd.Val(), cfInfoCmd.Err() } case CmdTypeCMSInfo: - if cmsInfoCmd, ok := cmd.(interface{ Val() CMSInfo }); ok { - return cmsInfoCmd.Val() + if cmsInfoCmd, ok := cmd.(interface { + Val() CMSInfo + Err() error + }); ok { + return cmsInfoCmd.Val(), cmsInfoCmd.Err() } case CmdTypeTopKInfo: - if topKInfoCmd, ok := cmd.(interface{ Val() TopKInfo }); ok { - return topKInfoCmd.Val() + if topKInfoCmd, ok := cmd.(interface { + Val() TopKInfo + Err() error + }); ok { + return topKInfoCmd.Val(), topKInfoCmd.Err() } case CmdTypeTDigestInfo: - if tDigestInfoCmd, ok := cmd.(interface{ Val() TDigestInfo }); ok { - return tDigestInfoCmd.Val() + if tDigestInfoCmd, ok := cmd.(interface { + Val() TDigestInfo + Err() error + }); ok { + return tDigestInfoCmd.Val(), tDigestInfoCmd.Err() } case CmdTypeFTSearch: - if ftSearchCmd, ok := cmd.(interface{ Val() FTSearchResult }); ok { - return ftSearchCmd.Val() + if ftSearchCmd, ok := cmd.(interface { + Val() FTSearchResult + Err() error + }); ok { + return ftSearchCmd.Val(), ftSearchCmd.Err() } case CmdTypeFTInfo: - if ftInfoCmd, ok := cmd.(interface{ Val() FTInfoResult }); ok { - return ftInfoCmd.Val() + if ftInfoCmd, ok := cmd.(interface { + Val() FTInfoResult + Err() error + }); ok { + return ftInfoCmd.Val(), ftInfoCmd.Err() } case CmdTypeFTSpellCheck: - if ftSpellCheckCmd, ok := cmd.(interface{ Val() []SpellCheckResult }); ok { - return ftSpellCheckCmd.Val() + if ftSpellCheckCmd, ok := cmd.(interface { + Val() []SpellCheckResult + Err() error + }); ok { + return ftSpellCheckCmd.Val(), ftSpellCheckCmd.Err() } case CmdTypeFTSynDump: - if ftSynDumpCmd, ok := cmd.(interface{ Val() []FTSynDumpResult }); ok { - return ftSynDumpCmd.Val() + if ftSynDumpCmd, ok := cmd.(interface { + Val() []FTSynDumpResult + Err() error + }); ok { + return ftSynDumpCmd.Val(), ftSynDumpCmd.Err() } case CmdTypeAggregate: - if aggregateCmd, ok := cmd.(interface{ Val() *FTAggregateResult }); ok { - return aggregateCmd.Val() + if aggregateCmd, ok := cmd.(interface { + Val() *FTAggregateResult + Err() error + }); ok { + return aggregateCmd.Val(), aggregateCmd.Err() } case CmdTypeTSTimestampValue: - if tsTimestampValueCmd, ok := cmd.(interface{ Val() TSTimestampValue }); ok { - return tsTimestampValueCmd.Val() + if tsTimestampValueCmd, ok := cmd.(interface { + Val() TSTimestampValue + Err() error + }); ok { + return tsTimestampValueCmd.Val(), tsTimestampValueCmd.Err() } case CmdTypeTSTimestampValueSlice: - if tsTimestampValueSliceCmd, ok := cmd.(interface{ Val() []TSTimestampValue }); ok { - return tsTimestampValueSliceCmd.Val() + if tsTimestampValueSliceCmd, ok := cmd.(interface { + Val() []TSTimestampValue + Err() error + }); ok { + return tsTimestampValueSliceCmd.Val(), tsTimestampValueSliceCmd.Err() } case CmdTypeStringSlice: - if stringSliceCmd, ok := cmd.(interface{ Val() []string }); ok { - return stringSliceCmd.Val() + if stringSliceCmd, ok := cmd.(interface { + Val() []string + Err() error + }); ok { + return stringSliceCmd.Val(), stringSliceCmd.Err() } case CmdTypeIntSlice: - if intSliceCmd, ok := cmd.(interface{ Val() []int64 }); ok { - return intSliceCmd.Val() + if intSliceCmd, ok := cmd.(interface { + Val() []int64 + Err() error + }); ok { + return intSliceCmd.Val(), intSliceCmd.Err() } case CmdTypeBoolSlice: - if boolSliceCmd, ok := cmd.(interface{ Val() []bool }); ok { - return boolSliceCmd.Val() + if boolSliceCmd, ok := cmd.(interface { + Val() []bool + Err() error + }); ok { + return boolSliceCmd.Val(), boolSliceCmd.Err() } case CmdTypeFloatSlice: - if floatSliceCmd, ok := cmd.(interface{ Val() []float64 }); ok { - return floatSliceCmd.Val() + if floatSliceCmd, ok := cmd.(interface { + Val() []float64 + Err() error + }); ok { + return floatSliceCmd.Val(), floatSliceCmd.Err() } case CmdTypeSlice: - if sliceCmd, ok := cmd.(interface{ Val() []interface{} }); ok { - return sliceCmd.Val() + if sliceCmd, ok := cmd.(interface { + Val() []interface{} + Err() error + }); ok { + return sliceCmd.Val(), sliceCmd.Err() } case CmdTypeKeyValueSlice: - if keyValueSliceCmd, ok := cmd.(interface{ Val() []KeyValue }); ok { - return keyValueSliceCmd.Val() + if keyValueSliceCmd, ok := cmd.(interface { + Val() []KeyValue + Err() error + }); ok { + return keyValueSliceCmd.Val(), keyValueSliceCmd.Err() } case CmdTypeMapStringString: - if mapCmd, ok := cmd.(interface{ Val() map[string]string }); ok { - return mapCmd.Val() + if mapCmd, ok := cmd.(interface { + Val() map[string]string + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() } case CmdTypeMapStringInt: - if mapCmd, ok := cmd.(interface{ Val() map[string]int64 }); ok { - return mapCmd.Val() + if mapCmd, ok := cmd.(interface { + Val() map[string]int64 + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() } case CmdTypeMapStringInterfaceSlice: if mapCmd, ok := cmd.(interface { Val() []map[string]interface{} + Err() error }); ok { - return mapCmd.Val() + return mapCmd.Val(), mapCmd.Err() } case CmdTypeMapStringInterface: - if mapCmd, ok := cmd.(interface{ Val() map[string]interface{} }); ok { - return mapCmd.Val() + if mapCmd, ok := cmd.(interface { + Val() map[string]interface{} + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() } case CmdTypeMapStringStringSlice: - if mapCmd, ok := cmd.(interface{ Val() []map[string]string }); ok { - return mapCmd.Val() + if mapCmd, ok := cmd.(interface { + Val() []map[string]string + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() } case CmdTypeMapMapStringInterface: if mapCmd, ok := cmd.(interface { Val() map[string]interface{} + Err() error }); ok { - return mapCmd.Val() + return mapCmd.Val(), mapCmd.Err() } default: // For unknown command types, return nil - return nil + return nil, nil } } // If we can't get the command type, return nil - return nil + return nil, nil } diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 7d7e9c1be0..8327cd8896 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -25,9 +25,9 @@ type ResponseAggregator interface { // AddWithKey processes a single shard response for a specific key (used by keyed aggregators). AddWithKey(key string, result interface{}, err error) error - BatchAdd(map[string]interface{}) error + BatchAdd(map[string]AggregatorResErr) error - BatchWithErrs([]AggregatorResErr) error + BatchSlice([]AggregatorResErr) error // Result returns the final aggregated result and any error. Result() (interface{}, error) @@ -102,9 +102,9 @@ func (a *AllSucceededAggregator) Add(result interface{}, err error) error { return nil } -func (a *AllSucceededAggregator) BatchAdd(results map[string]interface{}) error { +func (a *AllSucceededAggregator) BatchAdd(results map[string]AggregatorResErr) error { for _, res := range results { - err := a.Add(res, nil) + err := a.Add(res.Result, res.Err) if err != nil { return err } @@ -113,7 +113,7 @@ func (a *AllSucceededAggregator) BatchAdd(results map[string]interface{}) error return nil } -func (a *AllSucceededAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *AllSucceededAggregator) BatchSlice(values []AggregatorResErr) error { for _, val := range values { err := a.Add(val.Result, val.Err) if err != nil { @@ -158,9 +158,9 @@ func (a *OneSucceededAggregator) Add(result interface{}, err error) error { return nil } -func (a *OneSucceededAggregator) BatchAdd(results map[string]interface{}) error { +func (a *OneSucceededAggregator) BatchAdd(results map[string]AggregatorResErr) error { for _, res := range results { - err := a.Add(res, nil) + err := a.Add(res.Result, res.Err) if err != nil { return err } @@ -173,7 +173,7 @@ func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *OneSucceededAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *OneSucceededAggregator) BatchSlice(values []AggregatorResErr) error { for _, val := range values { err := a.Add(val.Result, val.Err) if err != nil { @@ -215,9 +215,9 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggSumAggregator) BatchAdd(results map[string]interface{}) error { +func (a *AggSumAggregator) BatchAdd(results map[string]AggregatorResErr) error { for _, res := range results { - err := a.Add(res, nil) + err := a.Add(res.Result, res.Err) if err != nil { return err } @@ -230,7 +230,7 @@ func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } -func (a *AggSumAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *AggSumAggregator) BatchSlice(values []AggregatorResErr) error { for _, val := range values { err := a.Add(val.Result, val.Err) if err != nil { @@ -273,9 +273,9 @@ func (a *AggMinAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggMinAggregator) BatchAdd(results map[string]interface{}) error { +func (a *AggMinAggregator) BatchAdd(results map[string]AggregatorResErr) error { for _, res := range results { - err := a.Add(res, nil) + err := a.Add(res.Result, res.Err) if err != nil { return err } @@ -288,7 +288,7 @@ func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } -func (a *AggMinAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *AggMinAggregator) BatchSlice(values []AggregatorResErr) error { for _, val := range values { err := a.Add(val.Result, val.Err) if err != nil { @@ -335,9 +335,9 @@ func (a *AggMaxAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggMaxAggregator) BatchAdd(results map[string]interface{}) error { +func (a *AggMaxAggregator) BatchAdd(results map[string]AggregatorResErr) error { for _, res := range results { - err := a.Add(res, nil) + err := a.Add(res.Result, res.Err) if err != nil { return err } @@ -350,7 +350,7 @@ func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } -func (a *AggMaxAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *AggMaxAggregator) BatchSlice(values []AggregatorResErr) error { for _, val := range values { err := a.Add(val.Result, val.Err) if err != nil { @@ -404,9 +404,9 @@ func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggLogicalAndAggregator) BatchAdd(results map[string]interface{}) error { +func (a *AggLogicalAndAggregator) BatchAdd(results map[string]AggregatorResErr) error { for _, res := range results { - err := a.Add(res, nil) + err := a.Add(res.Result, res.Err) if err != nil { return err } @@ -419,7 +419,7 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *AggLogicalAndAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *AggLogicalAndAggregator) BatchSlice(values []AggregatorResErr) error { for _, val := range values { err := a.Add(val.Result, val.Err) if err != nil { @@ -472,9 +472,9 @@ func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { return nil } -func (a *AggLogicalOrAggregator) BatchAdd(results map[string]interface{}) error { +func (a *AggLogicalOrAggregator) BatchAdd(results map[string]AggregatorResErr) error { for _, res := range results { - err := a.Add(res, nil) + err := a.Add(res.Result, res.Err) if err != nil { return err } @@ -487,7 +487,7 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *AggLogicalOrAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *AggLogicalOrAggregator) BatchSlice(values []AggregatorResErr) error { for _, val := range values { err := a.Add(val.Result, val.Err) if err != nil { @@ -572,9 +572,12 @@ func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { return a.add(result, err) } -func (a *DefaultKeylessAggregator) BatchAdd(results map[string]interface{}) error { +func (a *DefaultKeylessAggregator) BatchAdd(results map[string]AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + for _, res := range results { - err := a.add(res, nil) + err := a.add(res.Result, res.Err) if err != nil { return err } @@ -587,7 +590,7 @@ func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, er return a.Add(result, err) } -func (a *DefaultKeylessAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *DefaultKeylessAggregator) BatchSlice(values []AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() @@ -645,12 +648,12 @@ func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error { return a.add(result, err) } -func (a *DefaultKeyedAggregator) BatchAdd(results map[string]interface{}) error { +func (a *DefaultKeyedAggregator) BatchAdd(results map[string]AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() for _, res := range results { - err := a.add(res, nil) + err := a.add(res.Result, res.Err) if err != nil { return err } @@ -677,13 +680,13 @@ func (a *DefaultKeyedAggregator) AddWithKey(key string, result interface{}, err return a.addWithKey(key, result, err) } -func (a *DefaultKeyedAggregator) BatchAddWithKeyOrder(results map[string]interface{}, keyOrder []string) error { +func (a *DefaultKeyedAggregator) BatchAddWithKeyOrder(results map[string]AggregatorResErr, keyOrder []string) error { a.mu.Lock() defer a.mu.Unlock() a.keyOrder = keyOrder for key, val := range results { - _ = a.addWithKey(key, val, nil) + _ = a.addWithKey(key, val.Result, val.Err) } return nil @@ -695,7 +698,7 @@ func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) { a.keyOrder = keyOrder } -func (a *DefaultKeyedAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *DefaultKeyedAggregator) BatchSlice(values []AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() @@ -757,12 +760,12 @@ func (a *SpecialAggregator) Add(result interface{}, err error) error { return a.add(result, err) } -func (a *SpecialAggregator) BatchAdd(results map[string]interface{}) error { +func (a *SpecialAggregator) BatchAdd(results map[string]AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() for _, res := range results { - err := a.add(res, nil) + err := a.add(res.Result, res.Err) if err != nil { return err } @@ -775,7 +778,7 @@ func (a *SpecialAggregator) AddWithKey(key string, result interface{}, err error return a.Add(result, err) } -func (a *SpecialAggregator) BatchWithErrs(values []AggregatorResErr) error { +func (a *SpecialAggregator) BatchSlice(values []AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() diff --git a/osscluster_router.go b/osscluster_router.go index fe426125aa..295767d3b7 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -337,7 +337,7 @@ func (c *ClusterClient) executeParallel(ctx context.Context, cmd Cmder, nodes [] // aggregateMultiSlotResults aggregates results from multi-slot execution func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder, results <-chan slotResult, keyOrder []string, policy *routing.CommandPolicy) error { - keyedResults := make(map[string]interface{}) + keyedResults := make(map[string]routing.AggregatorResErr) var firstErr error for result := range results { @@ -349,43 +349,39 @@ func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder if strings.ToLower(cmd.Name()) == "mget" { if sliceCmd, ok := result.cmd.(*SliceCmd); ok { values := sliceCmd.Val() + err := sliceCmd.Err() if len(values) == len(result.keys) { for i, key := range result.keys { - keyedResults[key] = values[i] + keyedResults[key] = routing.AggregatorResErr{Result: values[i], Err: err} } } else { // Fallback: map all keys to the entire result for _, key := range result.keys { - keyedResults[key] = values + keyedResults[key] = routing.AggregatorResErr{Result: values, Err: err} } } } else { // Fallback for non-SliceCmd results - value := ExtractCommandValue(result.cmd) + value, err := ExtractCommandValue(result.cmd) for _, key := range result.keys { - keyedResults[key] = value + keyedResults[key] = routing.AggregatorResErr{Result: value, Err: err} } } } else { // For other commands, map each key to the entire result - value := ExtractCommandValue(result.cmd) + value, err := ExtractCommandValue(result.cmd) for _, key := range result.keys { - keyedResults[key] = value + keyedResults[key] = routing.AggregatorResErr{Result: value, Err: err} } } } } - if firstErr != nil { - cmd.SetErr(firstErr) - return firstErr - } - return c.aggregateKeyedValues(cmd, keyedResults, keyOrder, policy) } // aggregateKeyedValues aggregates individual key-value pairs while preserving key order -func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string]interface{}, keyOrder []string, policy *routing.CommandPolicy) error { +func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string]routing.AggregatorResErr, keyOrder []string, policy *routing.CommandPolicy) error { if len(keyedResults) == 0 { return fmt.Errorf("redis: no results to aggregate") } @@ -421,7 +417,7 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout cmd.SetErr(err) return err } - value := ExtractCommandValue(shardCmd) + value, _ := ExtractCommandValue(shardCmd) return c.setCommandValue(cmd, value) } @@ -430,14 +426,14 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout batchWithErrs := []routing.AggregatorResErr{} // Add all results to aggregator for _, shardCmd := range cmds { - value := ExtractCommandValue(shardCmd) + value, err := ExtractCommandValue(shardCmd) batchWithErrs = append(batchWithErrs, routing.AggregatorResErr{ Result: value, - Err: shardCmd.Err(), + Err: err, }) } - err := aggregator.BatchWithErrs(batchWithErrs) + err := aggregator.BatchSlice(batchWithErrs) if err != nil { return err } From 4801633ee2725903a99b1a1d9ea695e6d391e241 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Wed, 15 Oct 2025 12:55:38 +0300 Subject: [PATCH 36/62] added preemptive return to the aggregators --- internal/routing/aggregator.go | 151 ++++++++++++++++++++++++++------- 1 file changed, 119 insertions(+), 32 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 8327cd8896..19a6244fa0 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -108,17 +108,25 @@ func (a *AllSucceededAggregator) BatchAdd(results map[string]AggregatorResErr) e if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil } -func (a *AllSucceededAggregator) BatchSlice(values []AggregatorResErr) error { - for _, val := range values { - err := a.Add(val.Result, val.Err) +func (a *AllSucceededAggregator) BatchSlice(result []AggregatorResErr) error { + for _, res := range result { + err := a.Add(res.Result, res.Err) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -164,6 +172,10 @@ func (a *OneSucceededAggregator) BatchAdd(results map[string]AggregatorResErr) e if err != nil { return err } + + if res.Err == nil { + return nil + } } return nil @@ -173,12 +185,16 @@ func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *OneSucceededAggregator) BatchSlice(values []AggregatorResErr) error { - for _, val := range values { - err := a.Add(val.Result, val.Err) +func (a *OneSucceededAggregator) BatchSlice(result []AggregatorResErr) error { + for _, res := range result { + err := a.Add(res.Result, res.Err) if err != nil { return err } + + if res.Err == nil { + return nil + } } return nil @@ -221,6 +237,10 @@ func (a *AggSumAggregator) BatchAdd(results map[string]AggregatorResErr) error { if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -230,12 +250,16 @@ func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } -func (a *AggSumAggregator) BatchSlice(values []AggregatorResErr) error { - for _, val := range values { - err := a.Add(val.Result, val.Err) +func (a *AggSumAggregator) BatchSlice(result []AggregatorResErr) error { + for _, res := range result { + err := a.Add(res.Result, res.Err) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -279,6 +303,10 @@ func (a *AggMinAggregator) BatchAdd(results map[string]AggregatorResErr) error { if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -288,12 +316,16 @@ func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } -func (a *AggMinAggregator) BatchSlice(values []AggregatorResErr) error { - for _, val := range values { - err := a.Add(val.Result, val.Err) +func (a *AggMinAggregator) BatchSlice(result []AggregatorResErr) error { + for _, res := range result { + err := a.Add(res.Result, res.Err) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -341,6 +373,10 @@ func (a *AggMaxAggregator) BatchAdd(results map[string]AggregatorResErr) error { if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -350,12 +386,16 @@ func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) return a.Add(result, err) } -func (a *AggMaxAggregator) BatchSlice(values []AggregatorResErr) error { - for _, val := range values { - err := a.Add(val.Result, val.Err) +func (a *AggMaxAggregator) BatchSlice(result []AggregatorResErr) error { + for _, res := range result { + err := a.Add(res.Result, res.Err) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -410,6 +450,10 @@ func (a *AggLogicalAndAggregator) BatchAdd(results map[string]AggregatorResErr) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -419,12 +463,16 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *AggLogicalAndAggregator) BatchSlice(values []AggregatorResErr) error { - for _, val := range values { - err := a.Add(val.Result, val.Err) +func (a *AggLogicalAndAggregator) BatchSlice(result []AggregatorResErr) error { + for _, res := range result { + err := a.Add(res.Result, res.Err) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -478,6 +526,10 @@ func (a *AggLogicalOrAggregator) BatchAdd(results map[string]AggregatorResErr) e if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -487,12 +539,16 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *AggLogicalOrAggregator) BatchSlice(values []AggregatorResErr) error { - for _, val := range values { - err := a.Add(val.Result, val.Err) +func (a *AggLogicalOrAggregator) BatchSlice(result []AggregatorResErr) error { + for _, res := range result { + err := a.Add(res.Result, res.Err) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -581,6 +637,10 @@ func (a *DefaultKeylessAggregator) BatchAdd(results map[string]AggregatorResErr) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -590,15 +650,19 @@ func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, er return a.Add(result, err) } -func (a *DefaultKeylessAggregator) BatchSlice(values []AggregatorResErr) error { +func (a *DefaultKeylessAggregator) BatchSlice(result []AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() - for _, val := range values { - err := a.add(val.Result, val.Err) + for _, res := range result { + err := a.add(res.Result, res.Err) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -657,6 +721,10 @@ func (a *DefaultKeyedAggregator) BatchAdd(results map[string]AggregatorResErr) e if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -685,8 +753,15 @@ func (a *DefaultKeyedAggregator) BatchAddWithKeyOrder(results map[string]Aggrega defer a.mu.Unlock() a.keyOrder = keyOrder - for key, val := range results { - _ = a.addWithKey(key, val.Result, val.Err) + for key, res := range results { + err := a.addWithKey(key, res.Result, res.Err) + if err != nil { + return nil + } + + if res.Err != nil { + return nil + } } return nil @@ -698,15 +773,19 @@ func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) { a.keyOrder = keyOrder } -func (a *DefaultKeyedAggregator) BatchSlice(values []AggregatorResErr) error { +func (a *DefaultKeyedAggregator) BatchSlice(result []AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() - for _, val := range values { - err := a.add(val.Result, val.Err) + for _, res := range result { + err := a.add(res.Result, res.Err) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -769,6 +848,10 @@ func (a *SpecialAggregator) BatchAdd(results map[string]AggregatorResErr) error if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil @@ -778,15 +861,19 @@ func (a *SpecialAggregator) AddWithKey(key string, result interface{}, err error return a.Add(result, err) } -func (a *SpecialAggregator) BatchSlice(values []AggregatorResErr) error { +func (a *SpecialAggregator) BatchSlice(result []AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() - for _, val := range values { - err := a.add(val.Result, val.Err) + for _, res := range result { + err := a.add(res.Result, res.Err) if err != nil { return err } + + if res.Err != nil { + return nil + } } return nil From 4143e5342d45b183e8c878b8420d4cefa9c8caae Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Wed, 15 Oct 2025 14:31:32 +0300 Subject: [PATCH 37/62] more work on the aggregators --- internal/routing/aggregator.go | 151 ++++++++++++++++++++++----------- 1 file changed, 100 insertions(+), 51 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 19a6244fa0..a11a8cd30c 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -117,8 +117,8 @@ func (a *AllSucceededAggregator) BatchAdd(results map[string]AggregatorResErr) e return nil } -func (a *AllSucceededAggregator) BatchSlice(result []AggregatorResErr) error { - for _, res := range result { +func (a *AllSucceededAggregator) BatchSlice(results []AggregatorResErr) error { + for _, res := range results { err := a.Add(res.Result, res.Err) if err != nil { return err @@ -185,8 +185,8 @@ func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *OneSucceededAggregator) BatchSlice(result []AggregatorResErr) error { - for _, res := range result { +func (a *OneSucceededAggregator) BatchSlice(results []AggregatorResErr) error { + for _, res := range results { err := a.Add(res.Result, res.Err) if err != nil { return err @@ -202,7 +202,7 @@ func (a *OneSucceededAggregator) BatchSlice(result []AggregatorResErr) error { func (a *OneSucceededAggregator) Result() (interface{}, error) { res, e := a.res.Load(), a.err.Load() - if res != nil { + if res == nil { return nil, e.(error) } @@ -223,6 +223,7 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error { if result != nil { val, err := toInt64(result) if err != nil { + a.err.CompareAndSwap(nil, err) return err } atomic.AddInt64(a.res, val) @@ -232,37 +233,49 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error { } func (a *AggSumAggregator) BatchAdd(results map[string]AggregatorResErr) error { + var sum int64 + for _, res := range results { - err := a.Add(res.Result, res.Err) - if err != nil { - return err + if res.Err != nil { + a.Add(res.Result, res.Err) + return nil } - if res.Err != nil { + intRes, err := toInt64(res) + if err != nil { + a.Add(nil, err) return nil } + + sum += intRes } - return nil + return a.Add(sum, nil) } func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } -func (a *AggSumAggregator) BatchSlice(result []AggregatorResErr) error { - for _, res := range result { - err := a.Add(res.Result, res.Err) - if err != nil { - return err - } +func (a *AggSumAggregator) BatchSlice(results []AggregatorResErr) error { + var sum int64 + for _, res := range results { if res.Err != nil { + a.Add(res.Result, res.Err) return nil } + + intRes, err := toInt64(res) + if err != nil { + a.Add(nil, err) + return nil + } + + sum += intRes } - return nil + return a.Add(sum, nil) } func (a *AggSumAggregator) Result() (interface{}, error) { @@ -298,37 +311,55 @@ func (a *AggMinAggregator) Add(result interface{}, err error) error { } func (a *AggMinAggregator) BatchAdd(results map[string]AggregatorResErr) error { + min := int64(math.MaxInt64) + for _, res := range results { - err := a.Add(res.Result, res.Err) - if err != nil { - return err + if res.Err != nil { + _ = a.Add(nil, res.Err) + return nil } - if res.Err != nil { + resInt, err := toInt64(res.Result) + if err != nil { + _ = a.Add(nil, res.Err) return nil } + + if resInt < min { + min = resInt + } + } - return nil + return a.Add(min, nil) } func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } -func (a *AggMinAggregator) BatchSlice(result []AggregatorResErr) error { - for _, res := range result { - err := a.Add(res.Result, res.Err) - if err != nil { - return err - } +func (a *AggMinAggregator) BatchSlice(results []AggregatorResErr) error { + min := int64(math.MaxInt64) + for _, res := range results { if res.Err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + resInt, err := toInt64(res.Result) + if err != nil { + _ = a.Add(nil, res.Err) return nil } + + if resInt < min { + min = resInt + } + } - return nil + return a.Add(min, nil) } func (a *AggMinAggregator) Result() (interface{}, error) { @@ -368,37 +399,55 @@ func (a *AggMaxAggregator) Add(result interface{}, err error) error { } func (a *AggMaxAggregator) BatchAdd(results map[string]AggregatorResErr) error { + max := int64(math.MinInt64) + for _, res := range results { - err := a.Add(res.Result, res.Err) - if err != nil { - return err + if res.Err != nil { + _ = a.Add(nil, res.Err) + return nil } - if res.Err != nil { + resInt, err := toInt64(res.Result) + if err != nil { + _ = a.Add(nil, res.Err) return nil } + + if resInt > max { + max = resInt + } + } - return nil + return a.Add(max, nil) } func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) error { return a.Add(result, err) } -func (a *AggMaxAggregator) BatchSlice(result []AggregatorResErr) error { - for _, res := range result { - err := a.Add(res.Result, res.Err) - if err != nil { - return err - } +func (a *AggMaxAggregator) BatchSlice(results []AggregatorResErr) error { + max := int64(math.MinInt64) + for _, res := range results { if res.Err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + resInt, err := toInt64(res.Result) + if err != nil { + _ = a.Add(nil, res.Err) return nil } + + if resInt > max { + max = resInt + } + } - return nil + return a.Add(max, nil) } func (a *AggMaxAggregator) Result() (interface{}, error) { @@ -463,8 +512,8 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *AggLogicalAndAggregator) BatchSlice(result []AggregatorResErr) error { - for _, res := range result { +func (a *AggLogicalAndAggregator) BatchSlice(results []AggregatorResErr) error { + for _, res := range results { err := a.Add(res.Result, res.Err) if err != nil { return err @@ -539,8 +588,8 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err return a.Add(result, err) } -func (a *AggLogicalOrAggregator) BatchSlice(result []AggregatorResErr) error { - for _, res := range result { +func (a *AggLogicalOrAggregator) BatchSlice(results []AggregatorResErr) error { + for _, res := range results { err := a.Add(res.Result, res.Err) if err != nil { return err @@ -650,11 +699,11 @@ func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, er return a.Add(result, err) } -func (a *DefaultKeylessAggregator) BatchSlice(result []AggregatorResErr) error { +func (a *DefaultKeylessAggregator) BatchSlice(results []AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() - for _, res := range result { + for _, res := range results { err := a.add(res.Result, res.Err) if err != nil { return err @@ -773,11 +822,11 @@ func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) { a.keyOrder = keyOrder } -func (a *DefaultKeyedAggregator) BatchSlice(result []AggregatorResErr) error { +func (a *DefaultKeyedAggregator) BatchSlice(results []AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() - for _, res := range result { + for _, res := range results { err := a.add(res.Result, res.Err) if err != nil { return err @@ -861,11 +910,11 @@ func (a *SpecialAggregator) AddWithKey(key string, result interface{}, err error return a.Add(result, err) } -func (a *SpecialAggregator) BatchSlice(result []AggregatorResErr) error { +func (a *SpecialAggregator) BatchSlice(results []AggregatorResErr) error { a.mu.Lock() defer a.mu.Unlock() - for _, res := range result { + for _, res := range results { err := a.add(res.Result, res.Err) if err != nil { return err From 212619b015bfe316ac2d1820f72de50c93817306 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 16 Oct 2025 13:02:10 +0300 Subject: [PATCH 38/62] updated and and or aggregators --- internal/routing/aggregator.go | 64 +++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index a11a8cd30c..948f66952d 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -494,18 +494,22 @@ func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { } func (a *AggLogicalAndAggregator) BatchAdd(results map[string]AggregatorResErr) error { + var result bool = true + for _, res := range results { - err := a.Add(res.Result, res.Err) - if err != nil { - return err + if res.Err != nil { + return a.Add(nil, res.Err) } - if res.Err != nil { - return nil + boolRes, err := toBool(res.Result) + if err != nil { + return a.Add(nil, err) } + + result = result && boolRes } - return nil + return a.Add(result, nil) } func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err error) error { @@ -513,18 +517,22 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err } func (a *AggLogicalAndAggregator) BatchSlice(results []AggregatorResErr) error { + var result bool = true + for _, res := range results { - err := a.Add(res.Result, res.Err) - if err != nil { - return err + if res.Err != nil { + return a.Add(nil, res.Err) } - if res.Err != nil { - return nil + boolRes, err := toBool(res.Result) + if err != nil { + return a.Add(nil, err) } + + result = result && boolRes } - return nil + return a.Add(result, nil) } func (a *AggLogicalAndAggregator) Result() (interface{}, error) { @@ -570,18 +578,22 @@ func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { } func (a *AggLogicalOrAggregator) BatchAdd(results map[string]AggregatorResErr) error { + var result bool = false + for _, res := range results { - err := a.Add(res.Result, res.Err) - if err != nil { - return err + if res.Err != nil { + return a.Add(nil, res.Err) } - if res.Err != nil { - return nil + boolRes, err := toBool(res.Result) + if err != nil { + return a.Add(nil, err) } + + result = result || boolRes } - return nil + return a.Add(result, nil) } func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err error) error { @@ -589,18 +601,22 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err } func (a *AggLogicalOrAggregator) BatchSlice(results []AggregatorResErr) error { + var result bool = false + for _, res := range results { - err := a.Add(res.Result, res.Err) - if err != nil { - return err + if res.Err != nil { + return a.Add(nil, res.Err) } - if res.Err != nil { - return nil + boolRes, err := toBool(res.Result) + if err != nil { + return a.Add(nil, err) } + + result = result || boolRes } - return nil + return a.Add(result, nil) } func (a *AggLogicalOrAggregator) Result() (interface{}, error) { From 024339a51a016dd61c288cf71d9849cfa37f51cf Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 16 Oct 2025 13:59:05 +0300 Subject: [PATCH 39/62] fixed lint --- internal/routing/aggregator.go | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 948f66952d..ac0f1eb9c0 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -237,14 +237,12 @@ func (a *AggSumAggregator) BatchAdd(results map[string]AggregatorResErr) error { for _, res := range results { if res.Err != nil { - a.Add(res.Result, res.Err) - return nil + return a.Add(res.Result, res.Err) } intRes, err := toInt64(res) if err != nil { - a.Add(nil, err) - return nil + return a.Add(nil, err) } sum += intRes @@ -262,14 +260,12 @@ func (a *AggSumAggregator) BatchSlice(results []AggregatorResErr) error { for _, res := range results { if res.Err != nil { - a.Add(res.Result, res.Err) - return nil + return a.Add(res.Result, res.Err) } intRes, err := toInt64(res) if err != nil { - a.Add(nil, err) - return nil + return a.Add(nil, err) } sum += intRes @@ -494,7 +490,7 @@ func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { } func (a *AggLogicalAndAggregator) BatchAdd(results map[string]AggregatorResErr) error { - var result bool = true + result := true for _, res := range results { if res.Err != nil { @@ -517,7 +513,7 @@ func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err } func (a *AggLogicalAndAggregator) BatchSlice(results []AggregatorResErr) error { - var result bool = true + result := true for _, res := range results { if res.Err != nil { @@ -578,7 +574,7 @@ func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { } func (a *AggLogicalOrAggregator) BatchAdd(results map[string]AggregatorResErr) error { - var result bool = false + result := false for _, res := range results { if res.Err != nil { @@ -601,7 +597,7 @@ func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err } func (a *AggLogicalOrAggregator) BatchSlice(results []AggregatorResErr) error { - var result bool = false + result := false for _, res := range results { if res.Err != nil { From f4cb0f58e61a8355802d0a4773da657322b8965b Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Fri, 17 Oct 2025 13:26:37 +0300 Subject: [PATCH 40/62] added configurable policy resolvers --- command_policy_resolver.go | 165 +++++++++++++++++++++++++++++++++ internal/routing/aggregator.go | 14 +-- osscluster.go | 48 ++++++++-- osscluster_router.go | 10 +- osscluster_test.go | 3 +- 5 files changed, 217 insertions(+), 23 deletions(-) create mode 100644 command_policy_resolver.go diff --git a/command_policy_resolver.go b/command_policy_resolver.go new file mode 100644 index 0000000000..0e945a14e3 --- /dev/null +++ b/command_policy_resolver.go @@ -0,0 +1,165 @@ +package redis + +import ( + "context" + "strings" + + "github.com/redis/go-redis/v9/internal/routing" +) + +type ( + module = string + commandName = string +) + +var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ + "ft": { + "create": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "search": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aggregate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dictadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dictdump": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dictdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "suglen": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "cursor": { + Request: routing.ReqSpecial, + Response: routing.RespDefaultKeyless, + }, + "sugadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "sugget": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "sugdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "spellcheck": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "explain": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "explaincli": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aliasadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aliasupdate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aliasdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "info": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "tagvals": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "syndump": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "synupdate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "profile": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "alter": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dropindex": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "drop": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + }, +} + +type CommandInfoResolver interface { + getCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy + setFallbackResolver(fallback CommandInfoResolver) +} + +type resolver struct { + resolve func(ctx context.Context, cmdName string) *routing.CommandPolicy + fallBackResolver CommandInfoResolver +} + +func NewDefaultCommandPolicyResolver() *resolver { + return &resolver{ + resolve: func(ctx context.Context, cmdName string) *routing.CommandPolicy { + module := "core" + command := cmdName + cmdParts := strings.Split(cmdName, ".") + if len(cmdParts) == 2 { + module = cmdParts[0] + command = cmdParts[1] + } + + if policy, ok := defaultPolicies[module][command]; ok { + return policy + } + + return nil + }, + } +} + +func (r *resolver) getCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy { + policy := r.resolve(ctx, cmdName) + if policy != nil { + return policy + } + + if r.fallBackResolver != nil { + return r.fallBackResolver.getCommandPolicy(ctx, cmdName) + } + + return nil +} + +func (r *resolver) setFallbackResolver(fallbackResolver CommandInfoResolver) { + r.fallBackResolver = fallbackResolver +} diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index ac0f1eb9c0..5022e8a707 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -50,7 +50,9 @@ func NewResponseAggregator(policy ResponsePolicy, cmdName string) ResponseAggreg case RespOneSucceeded: return &OneSucceededAggregator{} case RespAggSum: - return &AggSumAggregator{} + return &AggSumAggregator{ + // res: + } case RespAggMin: return &AggMinAggregator{ res: util.NewAtomicMin(), @@ -212,7 +214,7 @@ func (a *OneSucceededAggregator) Result() (interface{}, error) { // AggSumAggregator sums numeric replies from all shards. type AggSumAggregator struct { err atomic.Value - res *int64 + res int64 } func (a *AggSumAggregator) Add(result interface{}, err error) error { @@ -226,7 +228,7 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error { a.err.CompareAndSwap(nil, err) return err } - atomic.AddInt64(a.res, val) + atomic.AddInt64(&a.res, val) } return nil @@ -240,7 +242,7 @@ func (a *AggSumAggregator) BatchAdd(results map[string]AggregatorResErr) error { return a.Add(res.Result, res.Err) } - intRes, err := toInt64(res) + intRes, err := toInt64(res.Result) if err != nil { return a.Add(nil, err) } @@ -263,7 +265,7 @@ func (a *AggSumAggregator) BatchSlice(results []AggregatorResErr) error { return a.Add(res.Result, res.Err) } - intRes, err := toInt64(res) + intRes, err := toInt64(res.Result) if err != nil { return a.Add(nil, err) } @@ -275,7 +277,7 @@ func (a *AggSumAggregator) BatchSlice(results []AggregatorResErr) error { } func (a *AggSumAggregator) Result() (interface{}, error) { - res, err := atomic.LoadInt64(a.res), a.err.Load() + res, err := atomic.LoadInt64(&a.res), a.err.Load() if err != nil { return nil, err.(error) } diff --git a/osscluster.go b/osscluster.go index 2ceedad045..1ec9c471c8 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1013,10 +1013,11 @@ func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, er // or more underlying connections. It's safe for concurrent use by // multiple goroutines. type ClusterClient struct { - opt *ClusterOptions - nodes *clusterNodes - state *clusterStateHolder - cmdsInfoCache *cmdsInfoCache + opt *ClusterOptions + nodes *clusterNodes + state *clusterStateHolder + cmdsInfoCache *cmdsInfoCache + cmdInfoResolver CommandInfoResolver cmdable hooksMixin } @@ -1034,6 +1035,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.state = newClusterStateHolder(c.loadState) + + c.cmdInfoResolver = c.NewDynamicResolver() + c.cmdable = c.Process c.initHooks(hooks{ dial: nil, @@ -1419,7 +1423,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { for _, cmd := range cmds { - policy := c.getCommandPolicy(ctx, cmd) + policy := c.extractCommandInfo(ctx, cmd.Name()) if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), @@ -1436,7 +1440,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd } for _, cmd := range cmds { - policy := c.getCommandPolicy(ctx, cmd) + policy := c.extractCommandInfo(ctx, cmd.Name()) if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), @@ -2193,6 +2197,38 @@ func (c *ClusterClient) context(ctx context.Context) context.Context { return context.Background() } +func (c *ClusterClient) GetResolver() CommandInfoResolver { + return c.cmdInfoResolver +} + +func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver CommandInfoResolver) { + c.cmdInfoResolver = cmdInfoResolver +} + +// extractCommandInfo retrieves the routing policy for a command +func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmdName string) *routing.CommandPolicy { + if cmdInfo := c.cmdInfo(ctx, cmdName); cmdInfo != nil && cmdInfo.CommandPolicy != nil { + return cmdInfo.CommandPolicy + } + + return nil +} + +func (c *ClusterClient) NewDynamicResolver() CommandInfoResolver { + return &resolver{ + resolve: c.extractCommandInfo, + } +} + +func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { + for _, n := range nodes { + if n == node { + return nodes + } + } + return append(nodes, node) +} + func appendIfNotExist[T comparable](vals []T, newVal T) []T { for _, v := range vals { if v == newVal { diff --git a/osscluster_router.go b/osscluster_router.go index 295767d3b7..342f904c03 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -21,7 +21,7 @@ type slotResult struct { // routeAndRun routes a command to the appropriate cluster nodes and executes it func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { - policy := c.getCommandPolicy(ctx, cmd) + policy := c.cmdInfoResolver.getCommandPolicy(ctx, cmd.Name()) if policy == nil { return c.executeDefault(ctx, cmd, node) } @@ -39,14 +39,6 @@ func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *cluste } } -// getCommandPolicy retrieves the routing policy for a command -func (c *ClusterClient) getCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { - if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.CommandPolicy != nil { - return cmdInfo.CommandPolicy - } - return nil -} - // executeDefault handles standard command routing based on keys func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, node *clusterNode) error { if c.hasKeys(cmd) { diff --git a/osscluster_test.go b/osscluster_test.go index fc2a3be429..62425463a5 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -63,7 +63,7 @@ func (s *clusterScenario) newClusterClient( ctx context.Context, opt *redis.ClusterOptions, ) *redis.ClusterClient { client := s.newClusterClientUnstable(opt) - + client.SetCommandInfoResolver(client.NewDynamicResolver()) err := eventually(func() error { if opt.ClusterSlots != nil { return nil @@ -1360,7 +1360,6 @@ var _ = Describe("ClusterClient", func() { return slots, nil } client = cluster.newClusterClient(ctx, opt) - err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { return master.FlushDB(ctx).Err() }) From 20392b339b773776ed8a5e531cf0f29775336185 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Fri, 17 Oct 2025 14:25:33 +0300 Subject: [PATCH 41/62] slight refactor --- command_policy_resolver.go | 16 ++++++++-------- osscluster.go | 2 +- osscluster_router.go | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/command_policy_resolver.go b/command_policy_resolver.go index 0e945a14e3..733de92b06 100644 --- a/command_policy_resolver.go +++ b/command_policy_resolver.go @@ -118,17 +118,17 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ } type CommandInfoResolver interface { - getCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy - setFallbackResolver(fallback CommandInfoResolver) + GetCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy + SetFallbackResolver(fallback CommandInfoResolver) } -type resolver struct { +type internalCommandInfoResolver struct { resolve func(ctx context.Context, cmdName string) *routing.CommandPolicy fallBackResolver CommandInfoResolver } -func NewDefaultCommandPolicyResolver() *resolver { - return &resolver{ +func NewDefaultCommandPolicyResolver() *internalCommandInfoResolver { + return &internalCommandInfoResolver{ resolve: func(ctx context.Context, cmdName string) *routing.CommandPolicy { module := "core" command := cmdName @@ -147,19 +147,19 @@ func NewDefaultCommandPolicyResolver() *resolver { } } -func (r *resolver) getCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy { +func (r *internalCommandInfoResolver) GetCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy { policy := r.resolve(ctx, cmdName) if policy != nil { return policy } if r.fallBackResolver != nil { - return r.fallBackResolver.getCommandPolicy(ctx, cmdName) + return r.fallBackResolver.GetCommandPolicy(ctx, cmdName) } return nil } -func (r *resolver) setFallbackResolver(fallbackResolver CommandInfoResolver) { +func (r *internalCommandInfoResolver) SetFallbackResolver(fallbackResolver CommandInfoResolver) { r.fallBackResolver = fallbackResolver } diff --git a/osscluster.go b/osscluster.go index 1ec9c471c8..6300e0abbe 100644 --- a/osscluster.go +++ b/osscluster.go @@ -2215,7 +2215,7 @@ func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmdName string) } func (c *ClusterClient) NewDynamicResolver() CommandInfoResolver { - return &resolver{ + return &internalCommandInfoResolver{ resolve: c.extractCommandInfo, } } diff --git a/osscluster_router.go b/osscluster_router.go index 342f904c03..0a0a37307e 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -21,7 +21,7 @@ type slotResult struct { // routeAndRun routes a command to the appropriate cluster nodes and executes it func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { - policy := c.cmdInfoResolver.getCommandPolicy(ctx, cmd.Name()) + policy := c.cmdInfoResolver.GetCommandPolicy(ctx, cmd.Name()) if policy == nil { return c.executeDefault(ctx, cmd, node) } From 1ad51cd944ef43740082598c50835a69a9d984ef Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Fri, 17 Oct 2025 15:01:20 +0300 Subject: [PATCH 42/62] removed the interface, slight refactor --- command_policy_resolver.go | 51 ++++++++++++++++++++------------------ osscluster.go | 16 +++++++----- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/command_policy_resolver.go b/command_policy_resolver.go index 733de92b06..3cd1a8449c 100644 --- a/command_policy_resolver.go +++ b/command_policy_resolver.go @@ -117,37 +117,40 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ }, } -type CommandInfoResolver interface { - GetCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy - SetFallbackResolver(fallback CommandInfoResolver) +type CommandInfoResolver struct { + resolve func(ctx context.Context, cmdName string) *routing.CommandPolicy + fallBackResolver *CommandInfoResolver } -type internalCommandInfoResolver struct { - resolve func(ctx context.Context, cmdName string) *routing.CommandPolicy - fallBackResolver CommandInfoResolver +func NewCommandInfoResolver(resolver func(ctx context.Context, cmdName string) *routing.CommandPolicy) *CommandInfoResolver { + return &CommandInfoResolver{ + resolve: resolver, + } } -func NewDefaultCommandPolicyResolver() *internalCommandInfoResolver { - return &internalCommandInfoResolver{ - resolve: func(ctx context.Context, cmdName string) *routing.CommandPolicy { - module := "core" - command := cmdName - cmdParts := strings.Split(cmdName, ".") - if len(cmdParts) == 2 { - module = cmdParts[0] - command = cmdParts[1] - } +func NewDefaultCommandPolicyResolver() *CommandInfoResolver { + return NewCommandInfoResolver(func(ctx context.Context, cmdName string) *routing.CommandPolicy { + module := "core" + command := cmdName + cmdParts := strings.Split(cmdName, ".") + if len(cmdParts) == 2 { + module = cmdParts[0] + command = cmdParts[1] + } - if policy, ok := defaultPolicies[module][command]; ok { - return policy - } + if policy, ok := defaultPolicies[module][command]; ok { + return policy + } - return nil - }, - } + return nil + }) } -func (r *internalCommandInfoResolver) GetCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy { +func (r *CommandInfoResolver) GetCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy { + if r.resolve == nil { + return nil + } + policy := r.resolve(ctx, cmdName) if policy != nil { return policy @@ -160,6 +163,6 @@ func (r *internalCommandInfoResolver) GetCommandPolicy(ctx context.Context, cmdN return nil } -func (r *internalCommandInfoResolver) SetFallbackResolver(fallbackResolver CommandInfoResolver) { +func (r *CommandInfoResolver) SetFallbackResolver(fallbackResolver *CommandInfoResolver) { r.fallBackResolver = fallbackResolver } diff --git a/osscluster.go b/osscluster.go index 6300e0abbe..587f2173ee 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1017,7 +1017,7 @@ type ClusterClient struct { nodes *clusterNodes state *clusterStateHolder cmdsInfoCache *cmdsInfoCache - cmdInfoResolver CommandInfoResolver + cmdInfoResolver *CommandInfoResolver cmdable hooksMixin } @@ -1036,7 +1036,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.state = newClusterStateHolder(c.loadState) - c.cmdInfoResolver = c.NewDynamicResolver() + dynamicResolver := c.NewDynamicResolver() + dynamicResolver.SetFallbackResolver(NewDefaultCommandPolicyResolver()) + c.SetCommandInfoResolver(dynamicResolver) c.cmdable = c.Process c.initHooks(hooks{ @@ -2197,11 +2199,11 @@ func (c *ClusterClient) context(ctx context.Context) context.Context { return context.Background() } -func (c *ClusterClient) GetResolver() CommandInfoResolver { +func (c *ClusterClient) GetResolver() *CommandInfoResolver { return c.cmdInfoResolver } -func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver CommandInfoResolver) { +func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver *CommandInfoResolver) { c.cmdInfoResolver = cmdInfoResolver } @@ -2214,8 +2216,10 @@ func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmdName string) return nil } -func (c *ClusterClient) NewDynamicResolver() CommandInfoResolver { - return &internalCommandInfoResolver{ +// NewDynamicResolver returns a CommandInfoResolver +// that uses the underlying cmdInfo cache to resolve the policies +func (c *ClusterClient) NewDynamicResolver() *CommandInfoResolver { + return &CommandInfoResolver{ resolve: c.extractCommandInfo, } } From 30c0c6390f6aebbad08435b628df5602f6c112d6 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Fri, 17 Oct 2025 16:29:02 +0300 Subject: [PATCH 43/62] change func signature from cmdName to cmder --- command_policy_resolver.go | 16 ++++++++-------- osscluster.go | 8 ++++---- osscluster_router.go | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/command_policy_resolver.go b/command_policy_resolver.go index 3cd1a8449c..0bc18c22b9 100644 --- a/command_policy_resolver.go +++ b/command_policy_resolver.go @@ -118,21 +118,21 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ } type CommandInfoResolver struct { - resolve func(ctx context.Context, cmdName string) *routing.CommandPolicy + resolve func(ctx context.Context, cmd Cmder) *routing.CommandPolicy fallBackResolver *CommandInfoResolver } -func NewCommandInfoResolver(resolver func(ctx context.Context, cmdName string) *routing.CommandPolicy) *CommandInfoResolver { +func NewCommandInfoResolver(resolver func(ctx context.Context, cmd Cmder) *routing.CommandPolicy) *CommandInfoResolver { return &CommandInfoResolver{ resolve: resolver, } } func NewDefaultCommandPolicyResolver() *CommandInfoResolver { - return NewCommandInfoResolver(func(ctx context.Context, cmdName string) *routing.CommandPolicy { + return NewCommandInfoResolver(func(ctx context.Context, cmd Cmder) *routing.CommandPolicy { module := "core" - command := cmdName - cmdParts := strings.Split(cmdName, ".") + command := cmd.Name() + cmdParts := strings.Split(command, ".") if len(cmdParts) == 2 { module = cmdParts[0] command = cmdParts[1] @@ -146,18 +146,18 @@ func NewDefaultCommandPolicyResolver() *CommandInfoResolver { }) } -func (r *CommandInfoResolver) GetCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy { +func (r *CommandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { if r.resolve == nil { return nil } - policy := r.resolve(ctx, cmdName) + policy := r.resolve(ctx, cmd) if policy != nil { return policy } if r.fallBackResolver != nil { - return r.fallBackResolver.GetCommandPolicy(ctx, cmdName) + return r.fallBackResolver.GetCommandPolicy(ctx, cmd) } return nil diff --git a/osscluster.go b/osscluster.go index 587f2173ee..422f9d824c 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1425,7 +1425,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { for _, cmd := range cmds { - policy := c.extractCommandInfo(ctx, cmd.Name()) + policy := c.extractCommandInfo(ctx, cmd) if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), @@ -1442,7 +1442,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd } for _, cmd := range cmds { - policy := c.extractCommandInfo(ctx, cmd.Name()) + policy := c.extractCommandInfo(ctx, cmd) if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), @@ -2208,8 +2208,8 @@ func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver *CommandInfoResol } // extractCommandInfo retrieves the routing policy for a command -func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmdName string) *routing.CommandPolicy { - if cmdInfo := c.cmdInfo(ctx, cmdName); cmdInfo != nil && cmdInfo.CommandPolicy != nil { +func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmd Cmder) *routing.CommandPolicy { + if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.CommandPolicy != nil { return cmdInfo.CommandPolicy } diff --git a/osscluster_router.go b/osscluster_router.go index 0a0a37307e..bbd3be8f46 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -21,7 +21,7 @@ type slotResult struct { // routeAndRun routes a command to the appropriate cluster nodes and executes it func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { - policy := c.cmdInfoResolver.GetCommandPolicy(ctx, cmd.Name()) + policy := c.cmdInfoResolver.GetCommandPolicy(ctx, cmd) if policy == nil { return c.executeDefault(ctx, cmd, node) } From 5a510cfb1e29064456b81485dc1a563c63d32654 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Mon, 20 Oct 2025 10:32:52 +0300 Subject: [PATCH 44/62] added nil safety assertions --- command_policy_resolver.go | 24 +++++++++++++----------- osscluster.go | 22 ++++++++++++++-------- osscluster_router.go | 6 +++++- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/command_policy_resolver.go b/command_policy_resolver.go index 0bc18c22b9..fdf5196d76 100644 --- a/command_policy_resolver.go +++ b/command_policy_resolver.go @@ -117,18 +117,20 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ }, } -type CommandInfoResolver struct { - resolve func(ctx context.Context, cmd Cmder) *routing.CommandPolicy - fallBackResolver *CommandInfoResolver +type CommandInfoResolveFunc func(ctx context.Context, cmd Cmder) *routing.CommandPolicy + +type commandInfoResolver struct { + resolveFunc CommandInfoResolveFunc + fallBackResolver *commandInfoResolver } -func NewCommandInfoResolver(resolver func(ctx context.Context, cmd Cmder) *routing.CommandPolicy) *CommandInfoResolver { - return &CommandInfoResolver{ - resolve: resolver, +func NewCommandInfoResolver(resolveFunc CommandInfoResolveFunc) *commandInfoResolver { + return &commandInfoResolver{ + resolveFunc: resolveFunc, } } -func NewDefaultCommandPolicyResolver() *CommandInfoResolver { +func NewDefaultCommandPolicyResolver() *commandInfoResolver { return NewCommandInfoResolver(func(ctx context.Context, cmd Cmder) *routing.CommandPolicy { module := "core" command := cmd.Name() @@ -146,12 +148,12 @@ func NewDefaultCommandPolicyResolver() *CommandInfoResolver { }) } -func (r *CommandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { - if r.resolve == nil { +func (r *commandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { + if r.resolveFunc == nil { return nil } - policy := r.resolve(ctx, cmd) + policy := r.resolveFunc(ctx, cmd) if policy != nil { return policy } @@ -163,6 +165,6 @@ func (r *CommandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) * return nil } -func (r *CommandInfoResolver) SetFallbackResolver(fallbackResolver *CommandInfoResolver) { +func (r *commandInfoResolver) SetFallbackResolver(fallbackResolver *commandInfoResolver) { r.fallBackResolver = fallbackResolver } diff --git a/osscluster.go b/osscluster.go index 422f9d824c..308f1c0c66 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1017,7 +1017,7 @@ type ClusterClient struct { nodes *clusterNodes state *clusterStateHolder cmdsInfoCache *cmdsInfoCache - cmdInfoResolver *CommandInfoResolver + cmdInfoResolver *commandInfoResolver cmdable hooksMixin } @@ -1425,7 +1425,10 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { for _, cmd := range cmds { - policy := c.extractCommandInfo(ctx, cmd) + var policy *routing.CommandPolicy + if c.cmdInfoResolver != nil { + policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd) + } if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), @@ -1442,7 +1445,10 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd } for _, cmd := range cmds { - policy := c.extractCommandInfo(ctx, cmd) + var policy *routing.CommandPolicy + if c.cmdInfoResolver != nil { + policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd) + } if policy != nil && !policy.CanBeUsedInPipeline() { return fmt.Errorf( "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), @@ -2199,11 +2205,11 @@ func (c *ClusterClient) context(ctx context.Context) context.Context { return context.Background() } -func (c *ClusterClient) GetResolver() *CommandInfoResolver { +func (c *ClusterClient) GetResolver() *commandInfoResolver { return c.cmdInfoResolver } -func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver *CommandInfoResolver) { +func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver *commandInfoResolver) { c.cmdInfoResolver = cmdInfoResolver } @@ -2218,9 +2224,9 @@ func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmd Cmder) *rout // NewDynamicResolver returns a CommandInfoResolver // that uses the underlying cmdInfo cache to resolve the policies -func (c *ClusterClient) NewDynamicResolver() *CommandInfoResolver { - return &CommandInfoResolver{ - resolve: c.extractCommandInfo, +func (c *ClusterClient) NewDynamicResolver() *commandInfoResolver { + return &commandInfoResolver{ + resolveFunc: c.extractCommandInfo, } } diff --git a/osscluster_router.go b/osscluster_router.go index bbd3be8f46..438db5f7c2 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -21,7 +21,11 @@ type slotResult struct { // routeAndRun routes a command to the appropriate cluster nodes and executes it func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { - policy := c.cmdInfoResolver.GetCommandPolicy(ctx, cmd) + var policy *routing.CommandPolicy + if c.cmdInfoResolver != nil { + policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd) + } + if policy == nil { return c.executeDefault(ctx, cmd, node) } From 14dde5c7a06be58778652460d8654eaae9582c59 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Tue, 21 Oct 2025 10:10:57 +0300 Subject: [PATCH 45/62] few small refactors --- command.go | 14 +++++--------- internal/routing/aggregator.go | 4 ++-- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/command.go b/command.go index ce693004cb..ee716442de 100644 --- a/command.go +++ b/command.go @@ -4369,18 +4369,14 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { return err } - // Handle tips that don't have a colon (like "nondeterministic_output") - if !strings.Contains(tip, ":") { - rawTips[tip] = "" - continue - } - - // Handle normal key:value tips k, v, ok := strings.Cut(tip, ":") if !ok { - return fmt.Errorf("redis: unexpected tip %q in COMMAND reply", tip) + // Handle tips that don't have a colon (like "nondeterministic_output") + rawTips[tip] = "" + } else { + // Handle normal key:value tips + rawTips[k] = v } - rawTips[k] = v } cmdInfo.CommandPolicy = parseCommandPolicies(rawTips) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 5022e8a707..49d5ce600c 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -477,7 +477,7 @@ func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { val, e := toBool(result) if e != nil { a.err.CompareAndSwap(nil, e) - return nil + return e } if val { @@ -561,7 +561,7 @@ func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { val, e := toBool(result) if e != nil { a.err.CompareAndSwap(nil, e) - return nil + return e } if val { From b72becc53bae7603dcf537bd59c84347d8c66eea Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Mon, 27 Oct 2025 09:19:48 +0200 Subject: [PATCH 46/62] added read only policies --- command.go | 3 +++ command_policy_resolver.go | 39 ++++++++++++++++++++++++++++++++++ internal/routing/policy.go | 10 +++++++++ osscluster.go | 5 ++--- osscluster_router.go | 43 ++++++++++++++++++++++++++++++++++---- 5 files changed, 93 insertions(+), 7 deletions(-) diff --git a/command.go b/command.go index ee716442de..5f0af64d9e 100644 --- a/command.go +++ b/command.go @@ -4363,6 +4363,9 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { } rawTips := make(map[string]string, tipsLen) + if cmdInfo.ReadOnly { + rawTips[routing.ReadOnlyCMD] = "" + } for f := 0; f < tipsLen; f++ { tip, err := rd.ReadString() if err != nil { diff --git a/command_policy_resolver.go b/command_policy_resolver.go index fdf5196d76..da8c6d314c 100644 --- a/command_policy_resolver.go +++ b/command_policy_resolver.go @@ -21,10 +21,16 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ "search": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "aggregate": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "dictadd": { Request: routing.ReqDefault, @@ -33,6 +39,9 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ "dictdump": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "dictdel": { Request: routing.ReqDefault, @@ -41,10 +50,16 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ "suglen": { Request: routing.ReqDefault, Response: routing.RespDefaultHashSlot, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "cursor": { Request: routing.ReqSpecial, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "sugadd": { Request: routing.ReqDefault, @@ -53,6 +68,9 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ "sugget": { Request: routing.ReqDefault, Response: routing.RespDefaultHashSlot, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "sugdel": { Request: routing.ReqDefault, @@ -61,14 +79,23 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ "spellcheck": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "explain": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "explaincli": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "aliasadd": { Request: routing.ReqDefault, @@ -85,14 +112,23 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ "info": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "tagvals": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "syndump": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "synupdate": { Request: routing.ReqDefault, @@ -101,6 +137,9 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ "profile": { Request: routing.ReqDefault, Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, }, "alter": { Request: routing.ReqDefault, diff --git a/internal/routing/policy.go b/internal/routing/policy.go index a76dfaf19b..f40eb15f84 100644 --- a/internal/routing/policy.go +++ b/internal/routing/policy.go @@ -19,6 +19,10 @@ const ( ReqSpecial ) +const ( + ReadOnlyCMD string = "readonly" +) + func (p RequestPolicy) String() string { switch p { case ReqDefault: @@ -133,3 +137,9 @@ type CommandPolicy struct { func (p *CommandPolicy) CanBeUsedInPipeline() bool { return p.Request != ReqAllNodes && p.Request != ReqAllShards && p.Request != ReqMultiShard } + +func (p *CommandPolicy) IsReadOnly() bool { + _, readOnly := p.Tips[ReadOnlyCMD] + fmt.Println(readOnly) + return readOnly +} diff --git a/osscluster.go b/osscluster.go index 308f1c0c66..3a45fc06ad 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1036,9 +1036,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.state = newClusterStateHolder(c.loadState) - dynamicResolver := c.NewDynamicResolver() - dynamicResolver.SetFallbackResolver(NewDefaultCommandPolicyResolver()) - c.SetCommandInfoResolver(dynamicResolver) + c.SetCommandInfoResolver(NewDefaultCommandPolicyResolver()) c.cmdable = c.Process c.initHooks(hooks{ @@ -2109,6 +2107,7 @@ func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { if info == nil { internal.Logger.Printf(cmdInfoCtx, "info for cmd=%s not found", name) } + return info } diff --git a/osscluster_router.go b/osscluster_router.go index 438db5f7c2..437d0f527a 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -27,7 +27,7 @@ func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *cluste } if policy == nil { - return c.executeDefault(ctx, cmd, node) + return c.executeDefault(ctx, cmd, policy, node) } switch policy.Request { case routing.ReqAllNodes: @@ -39,16 +39,25 @@ func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *cluste case routing.ReqSpecial: return c.executeSpecialCommand(ctx, cmd, policy, node) default: - return c.executeDefault(ctx, cmd, node) + return c.executeDefault(ctx, cmd, policy, node) } } // executeDefault handles standard command routing based on keys -func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, node *clusterNode) error { +func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy, node *clusterNode) error { if c.hasKeys(cmd) { // execute on key based shard return node.Client.Process(ctx, cmd) } + if policy != nil { + + fmt.Println(policy.Tips) + if c.readOnlyEnabled() && policy.IsReadOnly() { + fmt.Println("will execute on arbitrary node") + return c.executeOnArbitraryNode(ctx, cmd) + } + } + return c.executeOnArbitraryShard(ctx, cmd) } @@ -61,6 +70,15 @@ func (c *ClusterClient) executeOnArbitraryShard(ctx context.Context, cmd Cmder) return node.Client.Process(ctx, cmd) } +// executeOnArbitraryNode routes command to an arbitrary node +func (c *ClusterClient) executeOnArbitraryNode(ctx context.Context, cmd Cmder) error { + node := c.pickArbitraryNode(ctx) + if node == nil { + return errClusterNoNodes + } + return node.Client.Process(ctx, cmd) +} + // executeOnAllNodes executes command on all nodes (masters and replicas) func (c *ClusterClient) executeOnAllNodes(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { state, err := c.state.Get(ctx) @@ -252,7 +270,7 @@ func (c *ClusterClient) executeSpecialCommand(ctx context.Context, cmd Cmder, po case "ft.cursor": return c.executeCursorCommand(ctx, cmd) default: - return c.executeDefault(ctx, cmd, node) + return c.executeDefault(ctx, cmd, policy, node) } } @@ -479,12 +497,29 @@ func (c *ClusterClient) pickArbitraryShard(ctx context.Context) *clusterNode { return state.Masters[idx] } +// pickArbitraryNode selects a master or slave shard using the configured ShardPicker +func (c *ClusterClient) pickArbitraryNode(ctx context.Context) *clusterNode { + state, err := c.state.Get(ctx) + if err != nil || len(state.Masters) == 0 { + return nil + } + + allNodes := append(state.Masters, state.Slaves...) + + idx := c.opt.ShardPicker.Next(len(allNodes)) + return allNodes[idx] +} + // hasKeys checks if a command operates on keys func (c *ClusterClient) hasKeys(cmd Cmder) bool { firstKeyPos := cmdFirstKeyPos(cmd) return firstKeyPos > 0 } +func (c *ClusterClient) readOnlyEnabled() bool { + return c.opt.ReadOnly +} + // setCommandValue sets the aggregated value on a command using the enum-based approach func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { // If value is nil, it might mean ExtractCommandValue couldn't extract the value From 731505bacda7722c38ba78b698b2b16c97f3ff5d Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Wed, 29 Oct 2025 10:48:15 +0200 Subject: [PATCH 47/62] removed leftover prints --- internal/routing/policy.go | 1 - osscluster_router.go | 3 --- osscluster_test.go | 2 -- 3 files changed, 6 deletions(-) diff --git a/internal/routing/policy.go b/internal/routing/policy.go index f40eb15f84..7f784b5061 100644 --- a/internal/routing/policy.go +++ b/internal/routing/policy.go @@ -140,6 +140,5 @@ func (p *CommandPolicy) CanBeUsedInPipeline() bool { func (p *CommandPolicy) IsReadOnly() bool { _, readOnly := p.Tips[ReadOnlyCMD] - fmt.Println(readOnly) return readOnly } diff --git a/osscluster_router.go b/osscluster_router.go index 437d0f527a..5443857b5d 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -50,10 +50,7 @@ func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, policy *r return node.Client.Process(ctx, cmd) } if policy != nil { - - fmt.Println(policy.Tips) if c.readOnlyEnabled() && policy.IsReadOnly() { - fmt.Println("will execute on arbitrary node") return c.executeOnArbitraryNode(ctx, cmd) } } diff --git a/osscluster_test.go b/osscluster_test.go index 62425463a5..ef261b14be 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net" - "reflect" "slices" "strconv" "strings" @@ -1602,7 +1601,6 @@ var _ = Describe("ClusterClient timeout", func() { return nil }) Expect(err).To(HaveOccurred()) - fmt.Println("qko greshki male", reflect.TypeOf(err).String(), reflect.TypeOf(err).Kind().String()) Expect(err.(net.Error).Timeout()).To(BeTrue()) }) From 888f791ab83c5cfbe5f752972feaadb1b471c8f3 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 30 Oct 2025 10:44:00 +0200 Subject: [PATCH 48/62] Rebased to master, resolved comnflicts --- command.go | 10 +++++++-- osscluster.go | 59 ++++++--------------------------------------------- 2 files changed, 14 insertions(+), 55 deletions(-) diff --git a/command.go b/command.go index 5f0af64d9e..c4d4118d1b 100644 --- a/command.go +++ b/command.go @@ -66,7 +66,6 @@ var keylessCommands = map[string]struct{}{ "unsubscribe": {}, "unwatch": {}, } -type CmdType = routing.CmdType // CmdTyper interface for getting command type type CmdTyper interface { @@ -151,7 +150,6 @@ const ( CmdTypeTSTimestampValue CmdTypeTSTimestampValueSlice ) ->>>>>>> b6633bf9 (centralize cluster command routing in osscluster_router.go and refactor osscluster.go (#6)) type ( CmdTypeXAutoClaimValue struct { @@ -6993,6 +6991,14 @@ func (cmd *VectorScoreSliceCmd) readReply(rd *proto.Reader) error { return nil } + +func (cmd *VectorScoreSliceCmd) Clone() Cmder { + return &VectorScoreSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + func (cmd *MonitorCmd) Clone() Cmder { // MonitorCmd cannot be safely cloned due to channels and goroutines // Return a new MonitorCmd with the same channel diff --git a/osscluster.go b/osscluster.go index 3a45fc06ad..6ddd4ca632 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1432,7 +1432,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), ) } - slot := c.cmdSlot(ctx, cmd) + slot := c.cmdSlot(cmd, -1) node, err := c.slotReadOnlyNode(state, slot) if err != nil { return err @@ -1452,7 +1452,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), ) } - slot := c.cmdSlot(ctx, cmd) + slot := c.cmdSlot(cmd, -1) node, err := state.slotMasterNode(slot) if err != nil { return err @@ -1557,53 +1557,6 @@ func (c *ClusterClient) pipelineReadCmds( return nil } -// Legacy functions needed for transaction pipeline processing -func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmds []Cmder) error { - state, err := c.state.Get(ctx) - if err != nil { - return err - } - - preferredRandomSlot := -1 - if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { - for _, cmd := range cmds { - slot := c.cmdSlot(cmd, preferredRandomSlot) - if preferredRandomSlot == -1 { - preferredRandomSlot = slot - } - node, err := c.slotReadOnlyNode(state, slot) - if err != nil { - return err - } - cmdsMap.Add(node, cmd) - } - return nil - } - - for _, cmd := range cmds { - slot := c.cmdSlot(cmd, preferredRandomSlot) - if preferredRandomSlot == -1 { - preferredRandomSlot = slot - } - node, err := state.slotMasterNode(slot) - if err != nil { - return err - } - cmdsMap.Add(node, cmd) - } - return nil -} - -func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool { - for _, cmd := range cmds { - cmdInfo := c.cmdInfo(ctx, cmd.Name()) - if cmdInfo == nil || !cmdInfo.ReadOnly { - return false - } - } - return true -} - func (c *ClusterClient) checkMovedErr( ctx context.Context, cmd Cmder, err error, failedCmds *cmdsMap, ) bool { @@ -1661,7 +1614,7 @@ func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) err return err } - keyedCmdsBySlot := c.slottedKeyedCommands(cmds) + keyedCmdsBySlot := c.slottedKeyedCommands(ctx, cmds) slot := -1 switch len(keyedCmdsBySlot) { case 0: @@ -1715,7 +1668,7 @@ func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) err // slottedKeyedCommands returns a map of slot to commands taking into account // only commands that have keys. -func (c *ClusterClient) slottedKeyedCommands(cmds []Cmder) map[int][]Cmder { +func (c *ClusterClient) slottedKeyedCommands(ctx context.Context, cmds []Cmder) map[int][]Cmder { cmdsSlots := map[int][]Cmder{} preferredRandomSlot := -1 @@ -2111,13 +2064,13 @@ func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { return info } -func (c *ClusterClient) cmdSlot(cmd Cmder, preferredRandomSlot int) int { +func (c *ClusterClient) cmdSlot(cmd Cmder, prefferedSlot int) int { args := cmd.Args() if args[0] == "cluster" && (args[1] == "getkeysinslot" || args[1] == "countkeysinslot") { return args[2].(int) } - return cmdSlot(cmd, cmdFirstKeyPos(cmd), preferredRandomSlot) + return cmdSlot(cmd, cmdFirstKeyPos(cmd), prefferedSlot) } func cmdSlot(cmd Cmder, pos int, preferredRandomSlot int) int { From cd74db0874ff3ca95b978efe6710d92e56e6ac42 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 30 Oct 2025 10:47:35 +0200 Subject: [PATCH 49/62] fixed lint --- osscluster.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/osscluster.go b/osscluster.go index 6ddd4ca632..542b1c08e1 100644 --- a/osscluster.go +++ b/osscluster.go @@ -2182,15 +2182,6 @@ func (c *ClusterClient) NewDynamicResolver() *commandInfoResolver { } } -func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { - for _, n := range nodes { - if n == node { - return nodes - } - } - return append(nodes, node) -} - func appendIfNotExist[T comparable](vals []T, newVal T) []T { for _, v := range vals { if v == newVal { From 68f7af8cb582885e0934bab0928832a03289cded Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 30 Oct 2025 13:19:12 +0200 Subject: [PATCH 50/62] updated gha --- .github/actions/run-tests/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index 0d6db09b31..1066e79002 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -49,4 +49,4 @@ runs: RE_CLUSTER: "false" run: | make test.ci - shell: bash + shell: bash \ No newline at end of file From 5c447f925276b82b74c8c90f56b3436181b113bd Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 30 Oct 2025 14:06:49 +0200 Subject: [PATCH 51/62] fixed tests, minor consistency refactor --- command.go | 16 ++++++++-------- osscluster_test.go | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/command.go b/command.go index c4d4118d1b..6a85d0b61d 100644 --- a/command.go +++ b/command.go @@ -1156,35 +1156,35 @@ func (cmd *StringCmd) Bool() (bool, error) { if cmd.err != nil { return false, cmd.err } - return strconv.ParseBool(cmd.Val()) + return strconv.ParseBool(cmd.val) } func (cmd *StringCmd) Int() (int, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.Atoi(cmd.Val()) + return strconv.Atoi(cmd.val) } func (cmd *StringCmd) Int64() (int64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseInt(cmd.Val(), 10, 64) + return strconv.ParseInt(cmd.val, 10, 64) } func (cmd *StringCmd) Uint64() (uint64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseUint(cmd.Val(), 10, 64) + return strconv.ParseUint(cmd.val, 10, 64) } func (cmd *StringCmd) Float32() (float32, error) { if cmd.err != nil { return 0, cmd.err } - f, err := strconv.ParseFloat(cmd.Val(), 32) + f, err := strconv.ParseFloat(cmd.val, 32) if err != nil { return 0, err } @@ -1195,14 +1195,14 @@ func (cmd *StringCmd) Float64() (float64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseFloat(cmd.Val(), 64) + return strconv.ParseFloat(cmd.val, 64) } func (cmd *StringCmd) Time() (time.Time, error) { if cmd.err != nil { return time.Time{}, cmd.err } - return time.Parse(time.RFC3339Nano, cmd.Val()) + return time.Parse(time.RFC3339Nano, cmd.val) } func (cmd *StringCmd) Scan(val interface{}) error { @@ -1381,7 +1381,7 @@ func (cmd *StringSliceCmd) String() string { } func (cmd *StringSliceCmd) ScanSlice(container interface{}) error { - return proto.ScanSlice(cmd.Val(), container) + return proto.ScanSlice(cmd.val, container) } func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { diff --git a/osscluster_test.go b/osscluster_test.go index ef261b14be..0e7dd4d01d 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -515,8 +515,8 @@ var _ = Describe("ClusterClient", func() { }) It("should work with missing keys", func() { - pipe.Set(ctx, "A", "A_value", 0) - pipe.Set(ctx, "C", "C_value", 0) + pipe.Set(ctx, "A{s}", "A_value", 0) + pipe.Set(ctx, "C{s}", "C_value", 0) _, err := pipe.Exec(ctx) Expect(err).NotTo(HaveOccurred()) From 86c73a00563d37c8eac5d0349ea00039fc628a31 Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Thu, 30 Oct 2025 14:27:21 +0200 Subject: [PATCH 52/62] preallocated simple errors --- osscluster.go | 12 ++++++++---- osscluster_router.go | 19 ++++++++++++++----- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/osscluster.go b/osscluster.go index 542b1c08e1..525bf8a490 100644 --- a/osscluster.go +++ b/osscluster.go @@ -3,6 +3,7 @@ package redis import ( "context" "crypto/tls" + "errors" "fmt" "math" "net" @@ -29,7 +30,11 @@ const ( minLatencyMeasurementInterval = 10 * time.Second ) -var errClusterNoNodes = fmt.Errorf("redis: cluster has no nodes") +var ( + errClusterNoNodes = errors.New("redis: cluster has no nodes") + errNoWatchKeys = errors.New("redis: Watch requires at least one key") + errWatchCrosslot = errors.New("redis: Watch requires all keys to be in the same slot") +) // ClusterOptions are used to configure a cluster client and should be // passed to NewClusterClient. @@ -1838,14 +1843,13 @@ func (c *ClusterClient) cmdsMoved( func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { if len(keys) == 0 { - return fmt.Errorf("redis: Watch requires at least one key") + return errNoWatchKeys } slot := hashtag.Slot(keys[0]) for _, key := range keys[1:] { if hashtag.Slot(key) != slot { - err := fmt.Errorf("redis: Watch requires all keys to be in the same slot") - return err + return errWatchCrosslot } } diff --git a/osscluster_router.go b/osscluster_router.go index 5443857b5d..23d89a40a2 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -2,6 +2,7 @@ package redis import ( "context" + "errors" "fmt" "reflect" "strings" @@ -12,6 +13,14 @@ import ( "github.com/redis/go-redis/v9/internal/routing" ) +var ( + errInvalidCmdPointer = errors.New("redis: invalid command pointer") + errNoCmdsToAggregate = errors.New("redis: no commands to aggregate") + errNoResToAggregate = errors.New("redis: no results to aggregate") + errInvalidCursorCmdArgsCount = errors.New("redis: FT.CURSOR command requires at least 3 arguments") + errInvalidCursorIdType = errors.New("redis: invalid cursor ID type") +) + // slotResult represents the result of executing a command on a specific slot type slotResult struct { cmd Cmder @@ -275,12 +284,12 @@ func (c *ClusterClient) executeSpecialCommand(ctx context.Context, cmd Cmder, po func (c *ClusterClient) executeCursorCommand(ctx context.Context, cmd Cmder) error { args := cmd.Args() if len(args) < 4 { - return fmt.Errorf("redis: FT.CURSOR command requires at least 3 arguments") + return errInvalidCursorCmdArgsCount } cursorID, ok := args[3].(string) if !ok { - return fmt.Errorf("redis: invalid cursor ID type") + return errInvalidCursorIdType } // Route based on cursor ID to maintain stickiness @@ -394,7 +403,7 @@ func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder // aggregateKeyedValues aggregates individual key-value pairs while preserving key order func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string]routing.AggregatorResErr, keyOrder []string, policy *routing.CommandPolicy) error { if len(keyedResults) == 0 { - return fmt.Errorf("redis: no results to aggregate") + return errNoResToAggregate } aggregator := c.createAggregator(policy, cmd, true) @@ -419,7 +428,7 @@ func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string] // aggregateResponses aggregates multiple shard responses func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *routing.CommandPolicy) error { if len(cmds) == 0 { - return fmt.Errorf("redis: no commands to aggregate") + return errNoCmdsToAggregate } if len(cmds) == 1 { @@ -958,7 +967,7 @@ func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { func (c *ClusterClient) setCommandValueReflection(cmd Cmder, value interface{}) error { cmdValue := reflect.ValueOf(cmd) if cmdValue.Kind() != reflect.Ptr || cmdValue.IsNil() { - return fmt.Errorf("redis: invalid command pointer") + return errInvalidCmdPointer } setValMethod := cmdValue.MethodByName("SetVal") From 71262ecc0f9eb47c2b59b2684f695102b214796f Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Fri, 31 Oct 2025 13:05:39 +0200 Subject: [PATCH 53/62] changed numeric aggregators to use float64 --- go.mod | 2 ++ go.sum | 5 ++++ internal/routing/aggregator.go | 47 +++++++++++++++++++++++++--------- internal/util/atomic_max.go | 21 +++++++-------- internal/util/atomic_min.go | 19 +++++++------- 5 files changed, 63 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index 3bbb8ac4d8..0d3144f695 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,8 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f ) +require go.uber.org/atomic v1.11.0 + retract ( v9.15.1 // This version is used to retract v9.15.0 v9.15.0 // This version was accidentally released. It is identical to 9.15.0-beta.2 diff --git a/go.sum b/go.sum index 4db68f6d4f..ab06e043de 100644 --- a/go.sum +++ b/go.sum @@ -4,5 +4,10 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go index 49d5ce600c..9bba778033 100644 --- a/internal/routing/aggregator.go +++ b/internal/routing/aggregator.go @@ -5,9 +5,11 @@ import ( "fmt" "math" "sync" + "sync/atomic" "github.com/redis/go-redis/v9/internal/util" + uberAtomic "go.uber.org/atomic" ) var ( @@ -214,7 +216,7 @@ func (a *OneSucceededAggregator) Result() (interface{}, error) { // AggSumAggregator sums numeric replies from all shards. type AggSumAggregator struct { err atomic.Value - res int64 + res uberAtomic.Float64 } func (a *AggSumAggregator) Add(result interface{}, err error) error { @@ -223,12 +225,12 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error { } if result != nil { - val, err := toInt64(result) + val, err := toFloat64(result) if err != nil { a.err.CompareAndSwap(nil, err) return err } - atomic.AddInt64(&a.res, val) + a.res.Add(val) } return nil @@ -277,7 +279,7 @@ func (a *AggSumAggregator) BatchSlice(results []AggregatorResErr) error { } func (a *AggSumAggregator) Result() (interface{}, error) { - res, err := atomic.LoadInt64(&a.res), a.err.Load() + res, err := a.res.Load(), a.err.Load() if err != nil { return nil, err.(error) } @@ -297,13 +299,13 @@ func (a *AggMinAggregator) Add(result interface{}, err error) error { return nil } - intVal, e := toInt64(result) + floatVal, e := toFloat64(result) if e != nil { a.err.CompareAndSwap(nil, err) return nil } - a.res.Value(intVal) + a.res.Value(floatVal) return nil } @@ -337,7 +339,7 @@ func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) } func (a *AggMinAggregator) BatchSlice(results []AggregatorResErr) error { - min := int64(math.MaxInt64) + min := float64(math.MaxFloat64) for _, res := range results { if res.Err != nil { @@ -345,14 +347,14 @@ func (a *AggMinAggregator) BatchSlice(results []AggregatorResErr) error { return nil } - resInt, err := toInt64(res.Result) + floatVal, err := toFloat64(res.Result) if err != nil { _ = a.Add(nil, res.Err) return nil } - if resInt < min { - min = resInt + if floatVal < min { + min = floatVal } } @@ -385,13 +387,13 @@ func (a *AggMaxAggregator) Add(result interface{}, err error) error { return nil } - intVal, e := toInt64(result) + floatVal, e := toFloat64(result) if e != nil { a.err.CompareAndSwap(nil, err) return nil } - a.res.Value(intVal) + a.res.Value(floatVal) return nil } @@ -650,6 +652,27 @@ func toInt64(val interface{}) (int64, error) { } } +func toFloat64(val interface{}) (float64, error) { + if val == nil { + return 0, nil + } + + switch v := val.(type) { + case float64: + return v, nil + case int: + return float64(v), nil + case int32: + return float64(v), nil + case int64: + return float64(v), nil + case float32: + return float64(v), nil + default: + return 0, fmt.Errorf("cannot convert %T to float64", val) + } +} + func toBool(val interface{}) (bool, error) { if val == nil { return false, nil diff --git a/internal/util/atomic_max.go b/internal/util/atomic_max.go index ccee0e4c88..6c621ba850 100644 --- a/internal/util/atomic_max.go +++ b/internal/util/atomic_max.go @@ -3,14 +3,15 @@ ISC License Modified by htemelski-redis -Removed the treshold, adapted it to work with int64 +Removed the treshold, adapted it to work with float64 */ package util import ( "math" - "sync/atomic" + + "go.uber.org/atomic" ) // AtomicMax is a thread-safe max container @@ -22,7 +23,7 @@ import ( type AtomicMax struct { // value is current max - value atomic.Int64 + value atomic.Float64 // whether [AtomicMax.Value] has been invoked // with value equal or greater to threshold hasValue atomic.Bool @@ -32,7 +33,7 @@ type AtomicMax struct { // - if threshold is not used, AtomicMax is initialization-free func NewAtomicMax() (atomicMax *AtomicMax) { m := AtomicMax{} - m.value.Store(math.MinInt64) + m.value.Store((-math.MaxFloat64)) return &m } @@ -44,14 +45,14 @@ func NewAtomicMax() (atomicMax *AtomicMax) { // - upon return, Max and Max1 are guaranteed to reflect the invocation // - the return order of concurrent Value invocations is not guaranteed // - Thread-safe -func (m *AtomicMax) Value(value int64) (isNewMax bool) { - // math.MinInt64 as max case +func (m *AtomicMax) Value(value float64) (isNewMax bool) { + // -math.MaxFloat64 as max case var hasValue0 = m.hasValue.Load() - if value == math.MinInt64 { + if value == (-math.MaxFloat64) { if !hasValue0 { isNewMax = m.hasValue.CompareAndSwap(false, true) } - return // math.MinInt64 as max: isNewMax true for first 0 writer + return // -math.MaxFloat64 as max: isNewMax true for first 0 writer } // check against present value @@ -82,7 +83,7 @@ func (m *AtomicMax) Value(value int64) (isNewMax bool) { // - hasValue true indicates that value reflects a Value invocation // - hasValue false: value is zero-value // - Thread-safe -func (m *AtomicMax) Max() (value int64, hasValue bool) { +func (m *AtomicMax) Max() (value float64, hasValue bool) { if hasValue = m.hasValue.Load(); !hasValue { return } @@ -93,4 +94,4 @@ func (m *AtomicMax) Max() (value int64, hasValue bool) { // Max1 returns current maximum whether zero-value or set by Value // - threshold is ignored // - Thread-safe -func (m *AtomicMax) Max1() (value int64) { return m.value.Load() } +func (m *AtomicMax) Max1() (value float64) { return m.value.Load() } diff --git a/internal/util/atomic_min.go b/internal/util/atomic_min.go index 962d2a8070..e33d29cc21 100644 --- a/internal/util/atomic_min.go +++ b/internal/util/atomic_min.go @@ -10,7 +10,8 @@ Adapted from the modified atomic_max, but with inverted logic import ( "math" - "sync/atomic" + + "go.uber.org/atomic" ) // AtomicMin is a thread-safe Min container @@ -21,7 +22,7 @@ import ( type AtomicMin struct { // value is current Min - value atomic.Int64 + value atomic.Float64 // whether [AtomicMin.Value] has been invoked // with value equal or greater to threshold hasValue atomic.Bool @@ -31,7 +32,7 @@ type AtomicMin struct { // - if threshold is not used, AtomicMin is initialization-free func NewAtomicMin() (atomicMin *AtomicMin) { m := AtomicMin{} - m.value.Store(math.MaxInt64) + m.value.Store(math.MaxFloat64) return &m } @@ -43,14 +44,14 @@ func NewAtomicMin() (atomicMin *AtomicMin) { // - upon return, Min and Min1 are guaranteed to reflect the invocation // - the return order of concurrent Value invocations is not guaranteed // - Thread-safe -func (m *AtomicMin) Value(value int64) (isNewMin bool) { - // math.MaxInt64 as Min case +func (m *AtomicMin) Value(value float64) (isNewMin bool) { + // math.MaxFloat64 as Min case var hasValue0 = m.hasValue.Load() - if value == math.MaxInt64 { + if value == math.MaxFloat64 { if !hasValue0 { isNewMin = m.hasValue.CompareAndSwap(false, true) } - return // math.MaxInt64 as Min: isNewMin true for first 0 writer + return // math.MaxFloat64 as Min: isNewMin true for first 0 writer } // check against present value @@ -81,7 +82,7 @@ func (m *AtomicMin) Value(value int64) (isNewMin bool) { // - hasValue true indicates that value reflects a Value invocation // - hasValue false: value is zero-value // - Thread-safe -func (m *AtomicMin) Min() (value int64, hasValue bool) { +func (m *AtomicMin) Min() (value float64, hasValue bool) { if hasValue = m.hasValue.Load(); !hasValue { return } @@ -92,4 +93,4 @@ func (m *AtomicMin) Min() (value int64, hasValue bool) { // Min1 returns current Minimum whether zero-value or set by Value // - threshold is ignored // - Thread-safe -func (m *AtomicMin) Min1() (value int64) { return m.value.Load() } +func (m *AtomicMin) Min1() (value float64) { return m.value.Load() } From 79fd0cf6ded599aaccab619c72f3f056f29f36bf Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Fri, 31 Oct 2025 13:34:37 +0200 Subject: [PATCH 54/62] speculative test fix --- example/del-keys-without-ttl/go.mod | 2 +- example/del-keys-without-ttl/go.sum | 4 ++-- example/disable-maintnotifications/go.mod | 1 + example/disable-maintnotifications/go.sum | 8 ++++++++ example/hll/go.mod | 1 + example/hll/go.sum | 5 +++++ example/hset-struct/go.mod | 1 + example/hset-struct/go.sum | 4 ++++ example/lua-scripting/go.mod | 1 + example/lua-scripting/go.sum | 5 +++++ example/otel/go.mod | 1 + example/otel/go.sum | 2 ++ example/pubsub/go.mod | 1 + example/pubsub/go.sum | 5 +++++ example/redis-bloom/go.mod | 1 + example/redis-bloom/go.sum | 5 +++++ example/scan-struct/go.mod | 1 + example/scan-struct/go.sum | 4 ++++ extra/rediscensus/go.mod | 2 +- extra/rediscensus/go.sum | 6 ++++++ extra/rediscmd/go.mod | 2 +- extra/rediscmd/go.sum | 5 +++++ extra/redisotel/go.mod | 2 +- extra/redisotel/go.sum | 2 ++ extra/redisprometheus/go.mod | 3 ++- extra/redisprometheus/go.sum | 8 ++++++++ osscluster_router.go | 8 ++++++++ 27 files changed, 83 insertions(+), 7 deletions(-) diff --git a/example/del-keys-without-ttl/go.mod b/example/del-keys-without-ttl/go.mod index d808c723d0..d283d58422 100644 --- a/example/del-keys-without-ttl/go.mod +++ b/example/del-keys-without-ttl/go.mod @@ -12,6 +12,6 @@ require ( require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - go.uber.org/atomic v1.10.0 // indirect + go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.9.0 // indirect ) diff --git a/example/del-keys-without-ttl/go.sum b/example/del-keys-without-ttl/go.sum index 96beed56e1..c9432ef5c7 100644 --- a/example/del-keys-without-ttl/go.sum +++ b/example/del-keys-without-ttl/go.sum @@ -9,8 +9,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= -go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= diff --git a/example/disable-maintnotifications/go.mod b/example/disable-maintnotifications/go.mod index e342e2abc8..5c68b726e4 100644 --- a/example/disable-maintnotifications/go.mod +++ b/example/disable-maintnotifications/go.mod @@ -9,4 +9,5 @@ require github.com/redis/go-redis/v9 v9.7.0 require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + go.uber.org/atomic v1.11.0 // indirect ) diff --git a/example/disable-maintnotifications/go.sum b/example/disable-maintnotifications/go.sum index 4db68f6d4f..4ac0b36ede 100644 --- a/example/disable-maintnotifications/go.sum +++ b/example/disable-maintnotifications/go.sum @@ -4,5 +4,13 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/example/hll/go.mod b/example/hll/go.mod index 54178bf25c..54c03c01e7 100644 --- a/example/hll/go.mod +++ b/example/hll/go.mod @@ -9,4 +9,5 @@ require github.com/redis/go-redis/v9 v9.16.0 require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + go.uber.org/atomic v1.11.0 // indirect ) diff --git a/example/hll/go.sum b/example/hll/go.sum index d64ea0303f..09cfe27956 100644 --- a/example/hll/go.sum +++ b/example/hll/go.sum @@ -2,5 +2,10 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/example/hset-struct/go.mod b/example/hset-struct/go.mod index 81c7c3dd3b..e6511cb36b 100644 --- a/example/hset-struct/go.mod +++ b/example/hset-struct/go.mod @@ -12,4 +12,5 @@ require ( require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + go.uber.org/atomic v1.11.0 // indirect ) diff --git a/example/hset-struct/go.sum b/example/hset-struct/go.sum index 5496d29e58..22b30118a7 100644 --- a/example/hset-struct/go.sum +++ b/example/hset-struct/go.sum @@ -6,3 +6,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/example/lua-scripting/go.mod b/example/lua-scripting/go.mod index 02023f7bbd..e4248fceed 100644 --- a/example/lua-scripting/go.mod +++ b/example/lua-scripting/go.mod @@ -9,4 +9,5 @@ require github.com/redis/go-redis/v9 v9.16.0 require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + go.uber.org/atomic v1.11.0 // indirect ) diff --git a/example/lua-scripting/go.sum b/example/lua-scripting/go.sum index d64ea0303f..09cfe27956 100644 --- a/example/lua-scripting/go.sum +++ b/example/lua-scripting/go.sum @@ -2,5 +2,10 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/example/otel/go.mod b/example/otel/go.mod index 3883f54610..ac53f1dbaf 100644 --- a/example/otel/go.mod +++ b/example/otel/go.mod @@ -36,6 +36,7 @@ require ( go.opentelemetry.io/otel/sdk/metric v1.21.0 // indirect go.opentelemetry.io/otel/trace v1.22.0 // indirect go.opentelemetry.io/proto/otlp v1.0.0 // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/net v0.36.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/text v0.22.0 // indirect diff --git a/example/otel/go.sum b/example/otel/go.sum index fa94c15b6f..207bf9c326 100644 --- a/example/otel/go.sum +++ b/example/otel/go.sum @@ -51,6 +51,8 @@ go.opentelemetry.io/otel/trace v1.22.0 h1:Hg6pPujv0XG9QaVbGOBVHunyuLcCC3jN7WEhPx go.opentelemetry.io/otel/trace v1.22.0/go.mod h1:RbbHXVqKES9QhzZq/fE5UnOSILqRt40a21sPw2He1xo= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= diff --git a/example/pubsub/go.mod b/example/pubsub/go.mod index 731a92839d..9fabdb991e 100644 --- a/example/pubsub/go.mod +++ b/example/pubsub/go.mod @@ -9,4 +9,5 @@ require github.com/redis/go-redis/v9 v9.11.0 require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + go.uber.org/atomic v1.11.0 // indirect ) diff --git a/example/pubsub/go.sum b/example/pubsub/go.sum index d64ea0303f..09cfe27956 100644 --- a/example/pubsub/go.sum +++ b/example/pubsub/go.sum @@ -2,5 +2,10 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/example/redis-bloom/go.mod b/example/redis-bloom/go.mod index f25b4b2488..423097069e 100644 --- a/example/redis-bloom/go.mod +++ b/example/redis-bloom/go.mod @@ -9,4 +9,5 @@ require github.com/redis/go-redis/v9 v9.16.0 require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + go.uber.org/atomic v1.11.0 // indirect ) diff --git a/example/redis-bloom/go.sum b/example/redis-bloom/go.sum index d64ea0303f..09cfe27956 100644 --- a/example/redis-bloom/go.sum +++ b/example/redis-bloom/go.sum @@ -2,5 +2,10 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/example/scan-struct/go.mod b/example/scan-struct/go.mod index 81c7c3dd3b..e6511cb36b 100644 --- a/example/scan-struct/go.mod +++ b/example/scan-struct/go.mod @@ -12,4 +12,5 @@ require ( require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + go.uber.org/atomic v1.11.0 // indirect ) diff --git a/example/scan-struct/go.sum b/example/scan-struct/go.sum index 5496d29e58..22b30118a7 100644 --- a/example/scan-struct/go.sum +++ b/example/scan-struct/go.sum @@ -6,3 +6,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/extra/rediscensus/go.mod b/extra/rediscensus/go.mod index d4272c977b..bc77b1f663 100644 --- a/extra/rediscensus/go.mod +++ b/extra/rediscensus/go.mod @@ -16,10 +16,10 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + go.uber.org/atomic v1.11.0 // indirect ) retract ( v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead. v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead. ) - diff --git a/extra/rediscensus/go.sum b/extra/rediscensus/go.sum index ab3a8984fa..0967fb718a 100644 --- a/extra/rediscensus/go.sum +++ b/extra/rediscensus/go.sum @@ -8,6 +8,7 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= @@ -36,6 +37,7 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -43,9 +45,12 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -95,6 +100,7 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/extra/rediscmd/go.mod b/extra/rediscmd/go.mod index d8c03b6af4..e7760811a4 100644 --- a/extra/rediscmd/go.mod +++ b/extra/rediscmd/go.mod @@ -13,10 +13,10 @@ require ( require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + go.uber.org/atomic v1.11.0 // indirect ) retract ( v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead. v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead. ) - diff --git a/extra/rediscmd/go.sum b/extra/rediscmd/go.sum index 4db68f6d4f..ab06e043de 100644 --- a/extra/rediscmd/go.sum +++ b/extra/rediscmd/go.sum @@ -4,5 +4,10 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/extra/redisotel/go.mod b/extra/redisotel/go.mod index 23cec11a2d..1c3c3104c5 100644 --- a/extra/redisotel/go.mod +++ b/extra/redisotel/go.mod @@ -20,6 +20,7 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/sys v0.16.0 // indirect ) @@ -27,4 +28,3 @@ retract ( v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead. v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead. ) - diff --git a/extra/redisotel/go.sum b/extra/redisotel/go.sum index 4b832c80f3..f95fc1c984 100644 --- a/extra/redisotel/go.sum +++ b/extra/redisotel/go.sum @@ -21,6 +21,8 @@ go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc= go.opentelemetry.io/otel/trace v1.22.0 h1:Hg6pPujv0XG9QaVbGOBVHunyuLcCC3jN7WEhPx83XD0= go.opentelemetry.io/otel/trace v1.22.0/go.mod h1:RbbHXVqKES9QhzZq/fE5UnOSILqRt40a21sPw2He1xo= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/extra/redisprometheus/go.mod b/extra/redisprometheus/go.mod index fd4e2d93ee..88af3781de 100644 --- a/extra/redisprometheus/go.mod +++ b/extra/redisprometheus/go.mod @@ -18,12 +18,13 @@ require ( github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/common v0.39.0 // indirect github.com/prometheus/procfs v0.9.0 // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/sys v0.4.0 // indirect google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) retract ( v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead. v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead. ) - diff --git a/extra/redisprometheus/go.sum b/extra/redisprometheus/go.sum index 7093016eec..ac4cc4e90d 100644 --- a/extra/redisprometheus/go.sum +++ b/extra/redisprometheus/go.sum @@ -4,6 +4,7 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -15,6 +16,7 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/prometheus/client_golang v1.14.0 h1:nJdhIvne2eSX/XRAFV9PcvFFRbrjbcTUj0VP62TMhnw= github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y= github.com/prometheus/client_model v0.3.0 h1:UBgGFHqYdG/TPFD1B1ogZywDqEkwp3fBMvqdiQ7Xew4= @@ -23,6 +25,9 @@ github.com/prometheus/common v0.39.0 h1:oOyhkDq05hPZKItWVBkJ6g6AtGxi+fy7F4JvUV8u github.com/prometheus/common v0.39.0/go.mod h1:6XBZ7lYdLCbkAVhwRsWTZn+IN5AB9F/NXd5w0BbEX0Y= github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJfhI= github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -31,3 +36,6 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/osscluster_router.go b/osscluster_router.go index 23d89a40a2..1e14db2645 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -558,6 +558,8 @@ func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { if c, ok := cmd.(*IntCmd); ok { if v, ok := value.(int64); ok { c.SetVal(v) + } else if v, ok := value.(float64); ok { + c.SetVal(int64(v)) } } case CmdTypeBool: @@ -582,6 +584,12 @@ func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { if c, ok := cmd.(*IntSliceCmd); ok { if v, ok := value.([]int64); ok { c.SetVal(v) + } else if v, ok := value.([]float64); ok { + els := len(v) + intSlc := make([]int, els) + for i := range v { + intSlc[i] = int(v[i]) + } } } case CmdTypeFloatSlice: From 513140688567f7436d3ed85af11948b138b06b4b Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Wed, 5 Nov 2025 13:01:11 +0200 Subject: [PATCH 55/62] Update command.go Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> --- command.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/command.go b/command.go index 8f632eafef..6c1ba4539d 100644 --- a/command.go +++ b/command.go @@ -4508,25 +4508,24 @@ func parseCommandPolicies(commandInfoTips map[string]string) *routing.CommandPol req := routing.ReqDefault resp := routing.RespAllSucceeded - if commandInfoTips != nil { - if v, ok := commandInfoTips[requestPolicy]; ok { + + tips := make(map[string]string, len(commandInfoTips)) + for k, v := range commandInfoTips { + if k == requestPolicy { if p, err := routing.ParseRequestPolicy(v); err == nil { req = p } + continue } - if v, ok := commandInfoTips[responsePolicy]; ok { + if k == responsePolicy { if p, err := routing.ParseResponsePolicy(v); err == nil { resp = p } - } - } - tips := make(map[string]string, len(commandInfoTips)) - for k, v := range commandInfoTips { - if k == requestPolicy || k == responsePolicy { continue } tips[k] = v } + return &routing.CommandPolicy{Request: req, Response: resp, Tips: tips} } From 340dcbe03371095f52658a1c834d09b1bc1430af Mon Sep 17 00:00:00 2001 From: Hristo Temelski Date: Wed, 5 Nov 2025 13:01:35 +0200 Subject: [PATCH 56/62] Update main_test.go Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> --- main_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/main_test.go b/main_test.go index 4cd9bc6719..9d8efe3d98 100644 --- a/main_test.go +++ b/main_test.go @@ -107,7 +107,6 @@ var _ = BeforeSuite(func() { if RedisVersion < 7.0 || RedisVersion > 9 { panic("incorrect or not supported redis version") - } redisPort = redisStackPort From dbe75097c6ba242167b8afb1af76ea015e56f0f1 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Thu, 6 Nov 2025 12:55:18 +0200 Subject: [PATCH 57/62] Add static shard picker --- internal/routing/shard_picker.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/internal/routing/shard_picker.go b/internal/routing/shard_picker.go index e29d526b0b..8e6228dd20 100644 --- a/internal/routing/shard_picker.go +++ b/internal/routing/shard_picker.go @@ -11,6 +11,22 @@ type ShardPicker interface { Next(total int) int // returns an index in [0,total) } +// StaticShardPicker always returns the same shard index. +type StaticShardPicker struct { + index int +} + +func NewStaticShardPicker(index int) *StaticShardPicker { + return &StaticShardPicker{index: index} +} + +func (p *StaticShardPicker) Next(total int) int { + if total == 0 || p.index >= total { + return 0 + } + return p.index +} + /*─────────────────────────────── Round-robin (default) ────────────────────────────────*/ From 5544624da0e3e1c089587745efefb2eaebd542f8 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Thu, 6 Nov 2025 15:19:18 +0200 Subject: [PATCH 58/62] Fix nil value handling in command aggregation --- osscluster_router.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/osscluster_router.go b/osscluster_router.go index 1e14db2645..90a3f00a9e 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -531,16 +531,16 @@ func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { // If value is nil, it might mean ExtractCommandValue couldn't extract the value // but the command might have executed successfully. In this case, don't set an error. if value == nil { - // Check if the original command has an error - if not, the nil value is not an error - if cmd.Err() == nil { - // Command executed successfully but value extraction failed - // This is common for complex commands like CLUSTER SLOTS - // The command already has its result set correctly, so just return - return nil - } - // If the command does have an error, set Nil error - cmd.SetErr(Nil) - return Nil + // ExtractCommandValue returned nil - this means the command type is not supported + // in the aggregation flow. This is a programming error, not a runtime error. + if cmd.Err() != nil { + // Command already has an error, preserve it + return cmd.Err() + } + // Command executed successfully but we can't extract/set the aggregated value + // This indicates the command type needs to be added to ExtractCommandValue + return fmt.Errorf("redis: cannot aggregate command %s: unsupported command type %d", + cmd.Name(), cmd.GetCmdType()) } switch cmd.GetCmdType() { From c235f6c8c07b848b58b662eb45e71daba324332e Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Sun, 9 Nov 2025 09:42:24 +0200 Subject: [PATCH 59/62] Modify the Clone method to return a shallow copy --- json.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/json.go b/json.go index 44b09d4754..781cc46869 100644 --- a/json.go +++ b/json.go @@ -317,12 +317,7 @@ func (cmd *IntPointerSliceCmd) Clone() Cmder { var val []*int64 if cmd.val != nil { val = make([]*int64, len(cmd.val)) - for i, ptr := range cmd.val { - if ptr != nil { - newVal := *ptr - val[i] = &newVal - } - } + copy(val, cmd.val) } return &IntPointerSliceCmd{ baseCmd: cmd.cloneBaseCmd(), From ceeb2c80411c748858a3dfe77e5dde0f66403a6b Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Sun, 9 Nov 2025 10:16:42 +0200 Subject: [PATCH 60/62] Add clone method to digest command --- command.go | 7 +++++++ example/digest-optimistic-locking/go.mod | 1 + example/digest-optimistic-locking/go.sum | 5 +++++ 3 files changed, 13 insertions(+) diff --git a/command.go b/command.go index c2b7d95870..53cd03d0a0 100644 --- a/command.go +++ b/command.go @@ -926,6 +926,13 @@ func (cmd *DigestCmd) String() string { return cmdString(cmd, cmd.val) } +func (cmd *DigestCmd) Clone() Cmder { + return &DigestCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + func (cmd *DigestCmd) readReply(rd *proto.Reader) (err error) { // Redis DIGEST command returns a hex string (e.g., "a1b2c3d4e5f67890") // We parse it as a uint64 xxh3 hash value diff --git a/example/digest-optimistic-locking/go.mod b/example/digest-optimistic-locking/go.mod index d27d92020a..9dcf4d5bce 100644 --- a/example/digest-optimistic-locking/go.mod +++ b/example/digest-optimistic-locking/go.mod @@ -13,4 +13,5 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect + go.uber.org/atomic v1.11.0 // indirect ) diff --git a/example/digest-optimistic-locking/go.sum b/example/digest-optimistic-locking/go.sum index 1efe9a309b..d4bb2c2467 100644 --- a/example/digest-optimistic-locking/go.sum +++ b/example/digest-optimistic-locking/go.sum @@ -2,10 +2,15 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= From 93226d0122c1264eab7c5328b42acd641b0c48bf Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Sun, 9 Nov 2025 16:51:43 +0200 Subject: [PATCH 61/62] Optimize keyless command routing to respect ShardPicker policy --- osscluster.go | 10 ++++++++++ osscluster_router.go | 26 +------------------------- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/osscluster.go b/osscluster.go index 525bf8a490..db54c66036 100644 --- a/osscluster.go +++ b/osscluster.go @@ -2098,6 +2098,16 @@ func (c *ClusterClient) cmdNode( return nil, err } + // For keyless commands (slot == -1), use ShardPicker to select a shard + // This respects the user's configured ShardPicker policy + if slot == -1 { + if len(state.Masters) == 0 { + return nil, errClusterNoNodes + } + idx := c.opt.ShardPicker.Next(len(state.Masters)) + return state.Masters[idx], nil + } + if c.opt.ReadOnly { cmdInfo := c.cmdInfo(ctx, cmdName) if cmdInfo != nil && cmdInfo.ReadOnly { diff --git a/osscluster_router.go b/osscluster_router.go index 90a3f00a9e..a7c05f9fc4 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -54,25 +54,12 @@ func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *cluste // executeDefault handles standard command routing based on keys func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy, node *clusterNode) error { - if c.hasKeys(cmd) { - // execute on key based shard - return node.Client.Process(ctx, cmd) - } - if policy != nil { + if policy != nil && !c.hasKeys(cmd) { if c.readOnlyEnabled() && policy.IsReadOnly() { return c.executeOnArbitraryNode(ctx, cmd) } } - return c.executeOnArbitraryShard(ctx, cmd) -} - -// executeOnArbitraryShard routes command to an arbitrary shard -func (c *ClusterClient) executeOnArbitraryShard(ctx context.Context, cmd Cmder) error { - node := c.pickArbitraryShard(ctx) - if node == nil { - return errClusterNoNodes - } return node.Client.Process(ctx, cmd) } @@ -492,17 +479,6 @@ func (c *ClusterClient) finishAggregation(cmd Cmder, aggregator routing.Response return c.setCommandValue(cmd, finalValue) } -// pickArbitraryShard selects a master shard using the configured ShardPicker -func (c *ClusterClient) pickArbitraryShard(ctx context.Context) *clusterNode { - state, err := c.state.Get(ctx) - if err != nil || len(state.Masters) == 0 { - return nil - } - - idx := c.opt.ShardPicker.Next(len(state.Masters)) - return state.Masters[idx] -} - // pickArbitraryNode selects a master or slave shard using the configured ShardPicker func (c *ClusterClient) pickArbitraryNode(ctx context.Context) *clusterNode { state, err := c.state.Get(ctx) From aea0f1886d1c57f4d2d806ad5bdedbc1640fd3e3 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Mon, 10 Nov 2025 11:40:30 +0200 Subject: [PATCH 62/62] Remove MGET references --- osscluster.go | 14 ++--- osscluster_router.go | 39 ++---------- osscluster_test.go | 142 +++++++++++++++++++++---------------------- 3 files changed, 82 insertions(+), 113 deletions(-) diff --git a/osscluster.go b/osscluster.go index db54c66036..bab587efc6 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1676,15 +1676,15 @@ func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) err func (c *ClusterClient) slottedKeyedCommands(ctx context.Context, cmds []Cmder) map[int][]Cmder { cmdsSlots := map[int][]Cmder{} - preferredRandomSlot := -1 + prefferedRandomSlot := -1 for _, cmd := range cmds { if cmdFirstKeyPos(cmd) == 0 { continue } - slot := c.cmdSlot(cmd, preferredRandomSlot) - if preferredRandomSlot == -1 { - preferredRandomSlot = slot + slot := c.cmdSlot(cmd, prefferedRandomSlot) + if prefferedRandomSlot == -1 { + prefferedRandomSlot = slot } cmdsSlots[slot] = append(cmdsSlots[slot], cmd) @@ -2077,10 +2077,10 @@ func (c *ClusterClient) cmdSlot(cmd Cmder, prefferedSlot int) int { return cmdSlot(cmd, cmdFirstKeyPos(cmd), prefferedSlot) } -func cmdSlot(cmd Cmder, pos int, preferredRandomSlot int) int { +func cmdSlot(cmd Cmder, pos int, prefferedRandomSlot int) int { if pos == 0 { - if preferredRandomSlot != -1 { - return preferredRandomSlot + if prefferedRandomSlot != -1 { + return prefferedRandomSlot } return hashtag.RandomSlot() } diff --git a/osscluster_router.go b/osscluster_router.go index a7c05f9fc4..a3d606b5d7 100644 --- a/osscluster_router.go +++ b/osscluster_router.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "reflect" - "strings" "sync" "time" @@ -352,34 +351,10 @@ func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder firstErr = result.err } if result.cmd != nil && result.err == nil { - // For MGET, extract individual values from the array result - if strings.ToLower(cmd.Name()) == "mget" { - if sliceCmd, ok := result.cmd.(*SliceCmd); ok { - values := sliceCmd.Val() - err := sliceCmd.Err() - if len(values) == len(result.keys) { - for i, key := range result.keys { - keyedResults[key] = routing.AggregatorResErr{Result: values[i], Err: err} - } - } else { - // Fallback: map all keys to the entire result - for _, key := range result.keys { - keyedResults[key] = routing.AggregatorResErr{Result: values, Err: err} - } - } - } else { - // Fallback for non-SliceCmd results - value, err := ExtractCommandValue(result.cmd) - for _, key := range result.keys { - keyedResults[key] = routing.AggregatorResErr{Result: value, Err: err} - } - } - } else { - // For other commands, map each key to the entire result - value, err := ExtractCommandValue(result.cmd) - for _, key := range result.keys { - keyedResults[key] = routing.AggregatorResErr{Result: value, Err: err} - } + // For other commands, map each key to the entire result + value, err := ExtractCommandValue(result.cmd) + for _, key := range result.keys { + keyedResults[key] = routing.AggregatorResErr{Result: value, Err: err} } } } @@ -450,12 +425,6 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout // createAggregator creates the appropriate response aggregator func (c *ClusterClient) createAggregator(policy *routing.CommandPolicy, cmd Cmder, isKeyed bool) routing.ResponseAggregator { - cmdName := strings.ToLower(cmd.Name()) - // For MGET without policy, use keyed aggregator - if cmdName == "mget" { - return routing.NewDefaultAggregator(true) - } - if policy != nil { return routing.NewResponseAggregator(policy.Response, cmd.Name()) } diff --git a/osscluster_test.go b/osscluster_test.go index 0e7dd4d01d..d3ece5f001 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -2262,79 +2262,79 @@ var _ = Describe("Command Tips tests", func() { Expect(err).NotTo(HaveOccurred()) Expect(len(masterNodes)).To(BeNumerically(">", 1)) - // MGET command aggregation across multiple keys on different shards - verify all_succeeded policy with keyed aggregation - testData := map[string]string{ - "mget_test_key_1111": "value1", - "mget_test_key_2222": "value2", - "mget_test_key_3333": "value3", - "mget_test_key_4444": "value4", - "mget_test_key_5555": "value5", - } - - keyLocations := make(map[string]string) - for key, value := range testData { - - result := client.Set(ctx, key, value, 0) - Expect(result.Err()).NotTo(HaveOccurred()) - - for _, node := range masterNodes { - getResult := node.client.Get(ctx, key) - if getResult.Err() == nil && getResult.Val() == value { - keyLocations[key] = node.addr - break - } - } - } - - shardsUsed := make(map[string]bool) - for _, shardAddr := range keyLocations { - shardsUsed[shardAddr] = true - } - Expect(len(shardsUsed)).To(BeNumerically(">", 1)) - - keys := make([]string, 0, len(testData)) - expectedValues := make([]interface{}, 0, len(testData)) - - for key, value := range testData { - keys = append(keys, key) - expectedValues = append(expectedValues, value) - } - - mgetResult := client.MGet(ctx, keys...) - Expect(mgetResult.Err()).NotTo(HaveOccurred()) - - actualValues := mgetResult.Val() - Expect(len(actualValues)).To(Equal(len(keys))) - Expect(actualValues).To(ConsistOf(expectedValues)) - - // Verify all values are correctly aggregated - for i, key := range keys { - expectedValue := testData[key] - actualValue := actualValues[i] - Expect(actualValue).To(Equal(expectedValue)) - } + // // MGET command aggregation across multiple keys on different shards - verify all_succeeded policy with keyed aggregation + // testData := map[string]string{ + // "mget_test_key_1111": "value1", + // "mget_test_key_2222": "value2", + // "mget_test_key_3333": "value3", + // "mget_test_key_4444": "value4", + // "mget_test_key_5555": "value5", + // } + + // keyLocations := make(map[string]string) + // for key, value := range testData { + + // result := client.Set(ctx, key, value, 0) + // Expect(result.Err()).NotTo(HaveOccurred()) + + // for _, node := range masterNodes { + // getResult := node.client.Get(ctx, key) + // if getResult.Err() == nil && getResult.Val() == value { + // keyLocations[key] = node.addr + // break + // } + // } + // } + + // shardsUsed := make(map[string]bool) + // for _, shardAddr := range keyLocations { + // shardsUsed[shardAddr] = true + // } + // Expect(len(shardsUsed)).To(BeNumerically(">", 1)) + + // keys := make([]string, 0, len(testData)) + // expectedValues := make([]interface{}, 0, len(testData)) + + // for key, value := range testData { + // keys = append(keys, key) + // expectedValues = append(expectedValues, value) + // } + + // mgetResult := client.MGet(ctx, keys...) + // Expect(mgetResult.Err()).NotTo(HaveOccurred()) + + // actualValues := mgetResult.Val() + // Expect(len(actualValues)).To(Equal(len(keys))) + // Expect(actualValues).To(ConsistOf(expectedValues)) + + // // Verify all values are correctly aggregated + // for i, key := range keys { + // expectedValue := testData[key] + // actualValue := actualValues[i] + // Expect(actualValue).To(Equal(expectedValue)) + // } // DEL command aggregation across multiple keys on different shards - delResult := client.Del(ctx, keys...) - Expect(delResult.Err()).NotTo(HaveOccurred()) - - deletedCount := delResult.Val() - Expect(deletedCount).To(Equal(int64(len(keys)))) - - // Verify keys are actually deleted from their respective shards - for key, shardAddr := range keyLocations { - var targetNode *masterNode - for i := range masterNodes { - if masterNodes[i].addr == shardAddr { - targetNode = &masterNodes[i] - break - } - } - Expect(targetNode).NotTo(BeNil()) - - getResult := targetNode.client.Get(ctx, key) - Expect(getResult.Err()).To(HaveOccurred()) - } + // delResult := client.Del(ctx, keys...) + // Expect(delResult.Err()).NotTo(HaveOccurred()) + + // deletedCount := delResult.Val() + // Expect(deletedCount).To(Equal(int64(len(keys)))) + + // // Verify keys are actually deleted from their respective shards + // for key, shardAddr := range keyLocations { + // var targetNode *masterNode + // for i := range masterNodes { + // if masterNodes[i].addr == shardAddr { + // targetNode = &masterNodes[i] + // break + // } + // } + // Expect(targetNode).NotTo(BeNil()) + + // getResult := targetNode.client.Get(ctx, key) + // Expect(getResult.Err()).To(HaveOccurred()) + // } // EXISTS command aggregation across multiple keys existsTestData := map[string]string{