diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 38d66fce19..cc6c9ae914 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -153,7 +153,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(validate.In(c.catalog, raw)); err != nil { return nil, err } - rvs := rangeVars(raw.Stmt) + scopedRVs := rangeVarsWithScope(raw.Stmt) refs, errs := findParameters(raw.Stmt) if len(errs) > 0 { if failfast { @@ -173,7 +173,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } - params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) + params, err := c.resolveCatalogRefs(qc, scopedRVs, refs, namedParams, embeds) if err := check(err); err != nil { return nil, err } diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index 8199addd33..d570fdd709 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -24,6 +24,8 @@ type paramRef struct { rv *ast.RangeVar ref *ast.ParamRef name string // Named parameter support + + cteName *string // Current CTE name, nil if not inside a CTE. } type paramSearch struct { @@ -36,6 +38,8 @@ type paramSearch struct { // XXX: Gross state hack for limit limitCount ast.Node limitOffset ast.Node + + cteName *string // Current CTE name, nil if not inside a CTE. } type limitCount struct { @@ -55,6 +59,10 @@ func (l *limitOffset) Pos() int { func (p paramSearch) Visit(node ast.Node) astutils.Visitor { switch n := node.(type) { + case *ast.CommonTableExpr: + p.cteName = n.Ctename + return p + case *ast.A_Expr: p.parent = node @@ -87,7 +95,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, cteName: p.cteName}) p.seen[ref.Location] = struct{}{} } for _, item := range s.ValuesLists.Items { @@ -104,7 +112,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, cteName: p.cteName}) p.seen[ref.Location] = struct{}{} } } @@ -125,7 +133,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { if !ok { continue } - *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv}) + *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, cteName: p.cteName}) } p.seen[ref.Location] = struct{}{} } @@ -186,7 +194,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } if set { - *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar}) + *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar, cteName: p.cteName}) p.seen[n.Location] = struct{}{} } return nil diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index b0a15e6ac4..d58bf8a515 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -22,6 +22,10 @@ func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) { return nil, err } + return convertColumnsToCatalog(cols), nil +} + +func convertColumnsToCatalog(cols []*Column) []*catalog.Column { catCols := make([]*catalog.Column, 0, len(cols)) for _, col := range cols { catCols = append(catCols, &catalog.Column{ @@ -35,7 +39,7 @@ func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) { Length: col.Length, }) } - return catCols, nil + return catCols } func hasStarRef(cf *ast.ColumnRef) bool { diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 681d291122..e4ae525aa7 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -132,16 +132,36 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, }, nil } -func rangeVars(root ast.Node) []*ast.RangeVar { - var vars []*ast.RangeVar - find := astutils.VisitorFunc(func(node ast.Node) { - switch n := node.(type) { - case *ast.RangeVar: - vars = append(vars, n) - } - }) - astutils.Walk(find, root) - return vars +// scopedRangeVar associates a RangeVar with a scope. +type scopedRangeVar struct { + rv *ast.RangeVar + + cteName *string // Current CTE name, nil if not inside a CTE. +} + +// rangeVarsWithScope collects all RangeVars with their scope. +func rangeVarsWithScope(root ast.Node) []scopedRangeVar { + var rvs []scopedRangeVar + visitor := &rvSearch{rvs: &rvs, cteName: nil} + astutils.Walk(visitor, root) + return rvs +} + +// rvSearch finds all RangeVars and tracks their scope. +type rvSearch struct { + rvs *[]scopedRangeVar + + cteName *string // Current CTE name, nil if not inside a CTE. +} + +func (v *rvSearch) Visit(node ast.Node) astutils.Visitor { + switch n := node.(type) { + case *ast.CommonTableExpr: + return &rvSearch{rvs: v.rvs, cteName: n.Ctename} + case *ast.RangeVar: + *v.rvs = append(*v.rvs, scopedRangeVar{rv: n, cteName: v.cteName}) + } + return v } func uniqueParamRefs(in []paramRef, dollar bool) []paramRef { diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index b1fbb1990e..ea602a9fa7 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -21,36 +21,50 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, scopedRVs []scopedRangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { c := comp.catalog + scopeMap := make(map[*string][]*ast.TableName) + outerAliasMap := map[string]*ast.TableName{} aliasMap := map[string]*ast.TableName{} + tableNameMap := map[string]*ast.TableName{} // TODO: Deprecate defaultTable var defaultTable *ast.TableName var tables []*ast.TableName - typeMap := map[string]map[string]map[string]*catalog.Column{} - indexTable := func(table catalog.Table) error { - tables = append(tables, table.Rel) + + indexTableWithColumns := func(rel *ast.TableName, cols []*catalog.Column) error { + tables = append(tables, rel) + tableNameMap[rel.Name] = rel if defaultTable == nil { - defaultTable = table.Rel + defaultTable = rel } - schema := table.Rel.Schema + schema := rel.Schema if schema == "" { schema = c.DefaultSchema } if _, exists := typeMap[schema]; !exists { typeMap[schema] = map[string]map[string]*catalog.Column{} } - typeMap[schema][table.Rel.Name] = map[string]*catalog.Column{} - for _, c := range table.Columns { - cc := c - typeMap[schema][table.Rel.Name][c.Name] = cc + typeMap[schema][rel.Name] = map[string]*catalog.Column{} + for _, c := range cols { + typeMap[schema][rel.Name][c.Name] = c } return nil } - for _, rv := range rvs { + indexTable := func(table catalog.Table) error { + return indexTableWithColumns(table.Rel, table.Columns) + } + + indexCTE := func(cte *Table) error { + catalogCols := convertColumnsToCatalog(cte.Columns) + return indexTableWithColumns(cte.Rel, catalogCols) + } + + for _, srv := range scopedRVs { + rv := srv.rv + scope := srv.cteName if rv.Relname == nil { continue } @@ -58,6 +72,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if err != nil { return nil, err } + + scopeMap[scope] = append(scopeMap[scope], fqn) + if scope == nil && rv.Alias != nil { + outerAliasMap[*rv.Alias.Aliasname] = fqn + } + if _, found := aliasMap[fqn.Name]; found { continue } @@ -67,9 +87,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } // If the table name doesn't exist, first check if it's a CTE - if _, qcerr := qc.GetTable(fqn); qcerr != nil { + cte, qcerr := qc.GetTable(fqn) + if qcerr != nil { return nil, err } + if err := indexCTE(cte); err != nil { + return nil, err + } + if rv.Alias != nil { + aliasMap[*rv.Alias.Aliasname] = fqn + } continue } err = indexTable(table) @@ -89,7 +116,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } - if alias, ok := aliasMap[embed.Table.Name]; ok { + if alias, ok := outerAliasMap[embed.Table.Name]; ok { embed.Table = alias continue } @@ -195,24 +222,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, panic("too many field items: " + strconv.Itoa(len(items))) } - search := tables + search := scopeMap[ref.cteName] if alias != "" { if original, ok := aliasMap[alias]; ok { search = []*ast.TableName{original} + } else if tableName, ok := tableNameMap[alias]; ok { + search = []*ast.TableName{tableName} } else { - var located bool - for _, fqn := range tables { - if fqn.Name == alias { - located = true - search = []*ast.TableName{fqn} - } - } - if !located { - return nil, &sqlerr.Error{ - Code: "42703", - Message: fmt.Sprintf("table alias %q does not exist", alias), - Location: node.Location, - } + return nil, &sqlerr.Error{ + Code: "42703", + Message: fmt.Sprintf("table alias %q does not exist", alias), + Location: node.Location, } } } @@ -573,12 +593,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if alias != "" { if original, ok := aliasMap[alias]; ok { search = []*ast.TableName{original} - } else { - for _, fqn := range tables { - if fqn.Name == alias { - search = []*ast.TableName{fqn} - } - } + } else if tableName, ok := tableNameMap[alias]; ok { + search = []*ast.TableName{tableName} } } diff --git a/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..36d0da90d3 --- /dev/null +++ b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/models.go @@ -0,0 +1,22 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "github.com/google/uuid" +) + +type Task struct { + ID uuid.UUID + WorkspaceID uuid.NullUUID + OwnerID uuid.UUID + Name string +} + +type Workspace struct { + ID uuid.UUID + OwnerID uuid.UUID + Name string +} diff --git a/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..7fcd84e086 --- /dev/null +++ b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,155 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" + + "github.com/google/uuid" +) + +const getWorkspacesJoinTasks = `-- name: GetWorkspacesJoinTasks :many +WITH wtask AS ( + SELECT + workspaces.id, workspaces.owner_id, workspaces.name, + tasks.id IS NOT NULL::boolean AS has_task + FROM workspaces + LEFT JOIN tasks ON tasks.workspace_id = workspaces.id +) +SELECT id, owner_id, name, has_task +FROM wtask +ORDER BY CASE WHEN owner_id = $1 THEN 0 ELSE 1 END +` + +type GetWorkspacesJoinTasksRow struct { + ID uuid.UUID + OwnerID uuid.UUID + Name string + HasTask bool +} + +func (q *Queries) GetWorkspacesJoinTasks(ctx context.Context, ownerID uuid.UUID) ([]GetWorkspacesJoinTasksRow, error) { + rows, err := q.db.QueryContext(ctx, getWorkspacesJoinTasks, ownerID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetWorkspacesJoinTasksRow + for rows.Next() { + var i GetWorkspacesJoinTasksRow + if err := rows.Scan( + &i.ID, + &i.OwnerID, + &i.Name, + &i.HasTask, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getWorkspacesJoinTasksRenameColumn = `-- name: GetWorkspacesJoinTasksRenameColumn :many +WITH wtask AS ( + SELECT + workspaces.owner_id AS w_owner_id, + workspaces.id, workspaces.owner_id, workspaces.name, + tasks.id IS NOT NULL::boolean AS has_task + FROM workspaces + LEFT JOIN tasks ON tasks.workspace_id = workspaces.id +) +SELECT w_owner_id, id, owner_id, name, has_task +FROM wtask +ORDER BY CASE WHEN w_owner_id = $1 THEN 0 ELSE 1 END +` + +type GetWorkspacesJoinTasksRenameColumnRow struct { + WOwnerID uuid.UUID + ID uuid.UUID + OwnerID uuid.UUID + Name string + HasTask bool +} + +func (q *Queries) GetWorkspacesJoinTasksRenameColumn(ctx context.Context, ownerID uuid.UUID) ([]GetWorkspacesJoinTasksRenameColumnRow, error) { + rows, err := q.db.QueryContext(ctx, getWorkspacesJoinTasksRenameColumn, ownerID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetWorkspacesJoinTasksRenameColumnRow + for rows.Next() { + var i GetWorkspacesJoinTasksRenameColumnRow + if err := rows.Scan( + &i.WOwnerID, + &i.ID, + &i.OwnerID, + &i.Name, + &i.HasTask, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getWorkspacesSubQueryTasks = `-- name: GetWorkspacesSubQueryTasks :many +WITH wfiltered AS ( + SELECT workspaces.id, workspaces.owner_id, workspaces.name + FROM workspaces + WHERE EXISTS ( + SELECT 1 + FROM tasks + WHERE tasks.workspace_id = workspaces.id + ) +) +SELECT id, owner_id, name +FROM wfiltered +ORDER BY CASE WHEN owner_id = $1 THEN 0 ELSE 1 END +` + +type GetWorkspacesSubQueryTasksRow struct { + ID uuid.UUID + OwnerID uuid.UUID + Name string +} + +func (q *Queries) GetWorkspacesSubQueryTasks(ctx context.Context, ownerID uuid.UUID) ([]GetWorkspacesSubQueryTasksRow, error) { + rows, err := q.db.QueryContext(ctx, getWorkspacesSubQueryTasks, ownerID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetWorkspacesSubQueryTasksRow + for rows.Next() { + var i GetWorkspacesSubQueryTasksRow + if err := rows.Scan(&i.ID, &i.OwnerID, &i.Name); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/query.sql b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..cfb5c43fb5 --- /dev/null +++ b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/query.sql @@ -0,0 +1,38 @@ +-- name: GetWorkspacesJoinTasks :many +WITH wtask AS ( + SELECT + workspaces.*, + tasks.id IS NOT NULL::boolean AS has_task + FROM workspaces + LEFT JOIN tasks ON tasks.workspace_id = workspaces.id +) +SELECT * +FROM wtask +ORDER BY CASE WHEN owner_id = @owner_id THEN 0 ELSE 1 END; + +-- name: GetWorkspacesJoinTasksRenameColumn :many +WITH wtask AS ( + SELECT + workspaces.owner_id AS w_owner_id, + workspaces.*, + tasks.id IS NOT NULL::boolean AS has_task + FROM workspaces + LEFT JOIN tasks ON tasks.workspace_id = workspaces.id +) +SELECT * +FROM wtask +ORDER BY CASE WHEN w_owner_id = @owner_id THEN 0 ELSE 1 END; + +-- name: GetWorkspacesSubQueryTasks :many +WITH wfiltered AS ( + SELECT workspaces.* + FROM workspaces + WHERE EXISTS ( + SELECT 1 + FROM tasks + WHERE tasks.workspace_id = workspaces.id + ) +) +SELECT * +FROM wfiltered +ORDER BY CASE WHEN owner_id = @owner_id THEN 0 ELSE 1 END; diff --git a/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..9c9b0cc44c --- /dev/null +++ b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/schema.sql @@ -0,0 +1,12 @@ +CREATE TABLE workspaces ( + id uuid NOT NULL, + owner_id uuid NOT NULL, + name text NOT NULL +); + +CREATE TABLE tasks ( + id uuid NOT NULL, + workspace_id uuid, + owner_id uuid NOT NULL, + name text NOT NULL +); diff --git a/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/sqlc.yaml b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/sqlc.yaml new file mode 100644 index 0000000000..6f830a0c9a --- /dev/null +++ b/internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/sqlc.yaml @@ -0,0 +1,9 @@ +version: "2" +sql: + - schema: "./schema.sql" + queries: "./" + engine: "postgresql" + gen: + go: + package: "querytest" + out: "./go" diff --git a/internal/endtoend/testdata/cte_left_join/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/cte_left_join/postgresql/pgx/go/query.sql.go index 564b33b190..4c087d16df 100644 --- a/internal/endtoend/testdata/cte_left_join/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/cte_left_join/postgresql/pgx/go/query.sql.go @@ -7,8 +7,6 @@ package querytest import ( "context" - - "github.com/jackc/pgx/v5/pgtype" ) const badQuery = `-- name: BadQuery :exec @@ -28,7 +26,7 @@ FROM WHERE c1.name = $1 ` -func (q *Queries) BadQuery(ctx context.Context, dollar_1 pgtype.Text) error { - _, err := q.db.Exec(ctx, badQuery, dollar_1) +func (q *Queries) BadQuery(ctx context.Context, name string) error { + _, err := q.db.Exec(ctx, badQuery, name) return err } diff --git a/internal/endtoend/testdata/cte_update/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/cte_update/postgresql/pgx/go/query.sql.go index 61ba601b90..8e8ff6238b 100644 --- a/internal/endtoend/testdata/cte_update/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/cte_update/postgresql/pgx/go/query.sql.go @@ -23,9 +23,9 @@ from updated_attribute ` type UpdateAttributeParams struct { - FilterValue pgtype.Bool - Value pgtype.Text - ID pgtype.Int8 + FilterValue bool + Value string + ID int64 } type UpdateAttributeRow struct {