diff --git a/pkg/github/projects.go b/pkg/github/projects.go index 37c411ed7..6550f8046 100644 --- a/pkg/github/projects.go +++ b/pkg/github/projects.go @@ -69,13 +69,13 @@ func ListProjects(getClient GetClientFn, t translations.TranslationHelperFunc) ( var resp *github.Response var projects []*github.ProjectV2 - minimalProjects := []MinimalProject{} - var queryPtr *string + if queryStr != "" { queryPtr = &queryStr } + minimalProjects := []MinimalProject{} opts := &github.ListProjectsOptions{ ListProjectsPaginationOptions: github.ListProjectsPaginationOptions{PerPage: &perPage}, Query: queryPtr, @@ -237,27 +237,19 @@ func ListProjectFields(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(err.Error()), nil } - var url string - if ownerType == "org" { - url = fmt.Sprintf("orgs/%s/projectsV2/%d/fields", owner, projectNumber) - } else { - url = fmt.Sprintf("users/%s/projectsV2/%d/fields", owner, projectNumber) - } - projectFields := []projectV2Field{} - - opts := paginationOptions{PerPage: perPage} + var resp *github.Response + var projectFields []*github.ProjectV2Field - url, err = addOptions(url, opts) - if err != nil { - return nil, fmt.Errorf("failed to add options to request: %w", err) + opts := &github.ListProjectsOptions{ + ListProjectsPaginationOptions: github.ListProjectsPaginationOptions{PerPage: &perPage}, } - httpRequest, err := client.NewRequest("GET", url, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + if ownerType == "org" { + projectFields, resp, err = client.Projects.ListOrganizationProjectFields(ctx, owner, projectNumber, opts) + } else { + projectFields, resp, err = client.Projects.ListUserProjectFields(ctx, owner, projectNumber, opts) } - resp, err := client.Do(ctx, httpRequest, &projectFields) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list project fields", @@ -317,7 +309,7 @@ func GetProjectField(getClient GetClientFn, t translations.TranslationHelperFunc if err != nil { return mcp.NewToolResultError(err.Error()), nil } - fieldID, err := RequiredInt(req, "field_id") + fieldID, err := RequiredBigInt(req, "field_id") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -326,21 +318,15 @@ func GetProjectField(getClient GetClientFn, t translations.TranslationHelperFunc return mcp.NewToolResultError(err.Error()), nil } - var url string + var resp *github.Response + var projectField *github.ProjectV2Field + if ownerType == "org" { - url = fmt.Sprintf("orgs/%s/projectsV2/%d/fields/%d", owner, projectNumber, fieldID) + projectField, resp, err = client.Projects.GetOrganizationProjectField(ctx, owner, projectNumber, fieldID) } else { - url = fmt.Sprintf("users/%s/projectsV2/%d/fields/%d", owner, projectNumber, fieldID) - } - - projectField := projectV2Field{} - - httpRequest, err := client.NewRequest("GET", url, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + projectField, resp, err = client.Projects.GetUserProjectField(ctx, owner, projectNumber, fieldID) } - resp, err := client.Do(ctx, httpRequest, &projectField) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get project field", @@ -416,41 +402,37 @@ func ListProjectItems(getClient GetClientFn, t translations.TranslationHelperFun if err != nil { return mcp.NewToolResultError(err.Error()), nil } - fields, err := OptionalStringArrayParam(req, "fields") + fields, err := OptionalBigIntArrayParam(req, "fields") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - var url string - if ownerType == "org" { - url = fmt.Sprintf("orgs/%s/projectsV2/%d/items", owner, projectNumber) - } else { - url = fmt.Sprintf("users/%s/projectsV2/%d/items", owner, projectNumber) - } - projectItems := []projectV2Item{} + var resp *github.Response + var projectItems []*github.ProjectV2Item + var queryPtr *string - opts := listProjectItemsOptions{ - paginationOptions: paginationOptions{PerPage: perPage}, - filterQueryOptions: filterQueryOptions{Query: queryStr}, - fieldSelectionOptions: fieldSelectionOptions{Fields: fields}, + if queryStr != "" { + queryPtr = &queryStr } - url, err = addOptions(url, opts) - if err != nil { - return nil, fmt.Errorf("failed to add options to request: %w", err) + opts := &github.ListProjectItemsOptions{ + Fields: fields, + ListProjectsOptions: github.ListProjectsOptions{ + ListProjectsPaginationOptions: github.ListProjectsPaginationOptions{PerPage: &perPage}, + Query: queryPtr, + }, } - httpRequest, err := client.NewRequest("GET", url, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + if ownerType == "org" { + projectItems, resp, err = client.Projects.ListOrganizationProjectItems(ctx, owner, projectNumber, opts) + } else { + projectItems, resp, err = client.Projects.ListUserProjectItems(ctx, owner, projectNumber, opts) } - resp, err := client.Do(ctx, httpRequest, &projectItems) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, ProjectListFailedError, @@ -518,11 +500,11 @@ func GetProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - itemID, err := RequiredInt(req, "item_id") + itemID, err := RequiredBigInt(req, "item_id") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - fields, err := OptionalStringArrayParam(req, "fields") + fields, err := OptionalBigIntArrayParam(req, "fields") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -624,7 +606,7 @@ func AddProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - itemID, err := RequiredInt(req, "item_id") + itemID, err := RequiredBigInt(req, "item_id") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -642,24 +624,20 @@ func AddProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) return mcp.NewToolResultError(err.Error()), nil } - var projectsURL string - if ownerType == "org" { - projectsURL = fmt.Sprintf("orgs/%s/projectsV2/%d/items", owner, projectNumber) - } else { - projectsURL = fmt.Sprintf("users/%s/projectsV2/%d/items", owner, projectNumber) - } - - newItem := &newProjectItem{ - ID: int64(itemID), + newItem := &github.AddProjectItemOptions{ + ID: itemID, Type: toNewProjectType(itemType), } - httpRequest, err := client.NewRequest("POST", projectsURL, newItem) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + + var resp *github.Response + var addedItem *github.ProjectV2Item + + if ownerType == "org" { + addedItem, resp, err = client.Projects.AddOrganizationProjectItem(ctx, owner, projectNumber, newItem) + } else { + addedItem, resp, err = client.Projects.AddUserProjectItem(ctx, owner, projectNumber, newItem) } - addedItem := projectV2Item{} - resp, err := client.Do(ctx, httpRequest, &addedItem) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, ProjectAddFailedError, @@ -869,9 +847,10 @@ func DeleteProjectItem(getClient GetClientFn, t translations.TranslationHelperFu } } -type newProjectItem struct { - ID int64 `json:"id,omitempty"` - Type string `json:"type,omitempty"` +type fieldSelectionOptions struct { + // Specific list of field IDs to include in the response. If not provided, only the title field is included. + // Example: fields=102589,985201,169875 or fields[]=102589&fields[]=985201&fields[]=169875 + Fields []int64 `url:"fields,omitempty,comma"` } type updateProjectItemPayload struct { @@ -883,17 +862,6 @@ type updateProjectItem struct { Value any `json:"value"` } -type projectV2Field struct { - ID *int64 `json:"id,omitempty"` // The unique identifier for this field. - NodeID string `json:"node_id,omitempty"` // The GraphQL node ID for this field. - Name string `json:"name,omitempty"` // The display name of the field. - DataType string `json:"data_type,omitempty"` // The data type of the field (e.g., "text", "number", "date", "single_select", "multi_select"). - URL string `json:"url,omitempty"` // The API URL for this field. - Options []*any `json:"options,omitempty"` // Available options for single_select and multi_select fields. - CreatedAt *github.Timestamp `json:"created_at,omitempty"` // The time when this field was created. - UpdatedAt *github.Timestamp `json:"updated_at,omitempty"` // The time when this field was last updated. -} - type projectV2ItemFieldValue struct { ID *int64 `json:"id,omitempty"` // The unique identifier for this field. Name string `json:"name,omitempty"` // The display name of the field. @@ -931,26 +899,6 @@ type projectV2ItemContent struct { URL *string `json:"url,omitempty"` } -type paginationOptions struct { - PerPage int `url:"per_page,omitempty"` -} - -type filterQueryOptions struct { - Query string `url:"q,omitempty"` -} - -type fieldSelectionOptions struct { - // Specific list of field IDs to include in the response. If not provided, only the title field is included. - // Example: fields=102589,985201,169875 or fields[]=102589&fields[]=985201&fields[]=169875 - Fields []string `url:"fields,omitempty"` -} - -type listProjectItemsOptions struct { - paginationOptions - filterQueryOptions - fieldSelectionOptions -} - func toNewProjectType(projType string) string { switch strings.ToLower(projType) { case "issue": @@ -994,18 +942,28 @@ func addOptions(s string, opts any) (string, error) { return s, nil } - u, err := url.Parse(s) + origURL, err := url.Parse(s) if err != nil { return s, err } - qs, err := query.Values(opts) + origValues := origURL.Query() + + // Use the github.com/google/go-querystring library to parse the struct + newValues, err := query.Values(opts) if err != nil { return s, err } - u.RawQuery = qs.Encode() - return u.String(), nil + // Merge the values + for key, values := range newValues { + for _, value := range values { + origValues.Add(key, value) + } + } + + origURL.RawQuery = origValues.Encode() + return origURL.String(), nil } func ManageProjectItemsPrompt(t translations.TranslationHelperFunc) (tool mcp.Prompt, handler server.PromptHandlerFunc) { diff --git a/pkg/github/projects_test.go b/pkg/github/projects_test.go index 30a465ff4..ed198a97a 100644 --- a/pkg/github/projects_test.go +++ b/pkg/github/projects_test.go @@ -653,8 +653,8 @@ func Test_ListProjectItems(t *testing.T) { mock.EndpointPattern{Pattern: "/orgs/{org}/projectsV2/{project}/items", Method: http.MethodGet}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() - fieldParams := q["fields"] - if len(fieldParams) == 3 && fieldParams[0] == "123" && fieldParams[1] == "456" && fieldParams[2] == "789" { + fieldParams := q.Get("fields") + if fieldParams == "123,456,789" { w.WriteHeader(http.StatusOK) _, _ = w.Write(mock.MustMarshal(orgItems)) return @@ -852,8 +852,8 @@ func Test_GetProjectItem(t *testing.T) { mock.EndpointPattern{Pattern: "/orgs/{org}/projectsV2/{project}/items/{item_id}", Method: http.MethodGet}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() - fieldParams := q["fields"] - if len(fieldParams) == 2 && fieldParams[0] == "123" && fieldParams[1] == "456" { + fieldParams := q.Get("fields") + if fieldParams == "123,456" { w.WriteHeader(http.StatusOK) _, _ = w.Write(mock.MustMarshal(orgItem)) return diff --git a/pkg/github/server.go b/pkg/github/server.go index adff7359e..4db24504c 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "strconv" "github.com/google/go-github/v77/github" "github.com/mark3labs/mcp-go/mcp" @@ -99,6 +100,19 @@ func RequiredInt(r mcp.CallToolRequest, p string) (int, error) { return int(v), nil } +// RequiredBigInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func RequiredBigInt(r mcp.CallToolRequest, p string) (int64, error) { + v, err := RequiredParam[float64](r, p) + if err != nil { + return 0, err + } + return int64(v), nil +} + // OptionalParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request, if not, it returns its zero-value @@ -189,6 +203,52 @@ func OptionalStringArrayParam(r mcp.CallToolRequest, p string) ([]string, error) } } +func convertStringSliceToBigIntSlice(s []string) []int64 { + int64Slice := make([]int64, len(s)) + for i, str := range s { + int64Slice[i] = convertStringToBigInt(str, 0) + } + return int64Slice +} + +func convertStringToBigInt(s string, def int64) int64 { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return def + } + return v +} + +// OptionalBigIntArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, iterates the elements and checks each is a string +func OptionalBigIntArrayParam(r mcp.CallToolRequest, p string) ([]int64, error) { + // Check if the parameter is present in the request + if _, ok := r.GetArguments()[p]; !ok { + return []int64{}, nil + } + + switch v := r.GetArguments()[p].(type) { + case nil: + return []int64{}, nil + case []string: + return convertStringSliceToBigIntSlice(v), nil + case []any: + int64Slice := make([]int64, len(v)) + for i, v := range v { + s, ok := v.(string) + if !ok { + return []int64{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) + } + int64Slice[i] = convertStringToBigInt(s, 0) + } + return int64Slice, nil + default: + return []int64{}, fmt.Errorf("parameter %s could not be coerced to []int64, is %T", p, r.GetArguments()[p]) + } +} + // WithPagination adds REST API pagination parameters to a tool. // https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api func WithPagination() mcp.ToolOption {