Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 86 additions & 13 deletions src/gitingest/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import ssl
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import requests.exceptions
import tiktoken
Expand All @@ -23,6 +23,66 @@
(1_000, "k"),
]

# cache tiktoken encoding for performance
_TIKTOKEN_ENCODING: Any | None = None

def _get_tiktoken_encoding() -> Any:
"""Get cached tiktoken encoding, initializing only once."""
global _TIKTOKEN_ENCODING
if _TIKTOKEN_ENCODING is None:
_TIKTOKEN_ENCODING = tiktoken.get_encoding("o200k_base")
return _TIKTOKEN_ENCODING


def _estimate_tokens(text: str) -> int:
"""Estimate token count for a given text.

Parameters
----------
text : str
The text string for which the token count is to be estimated.

Returns
-------
int
The number of tokens, or 0 if an error occurs.

"""
if not text:
return 0
try:
encoding = _get_tiktoken_encoding()
return len(encoding.encode(text, disallowed_special=()))
except (ValueError, UnicodeEncodeError) as exc:
logger.warning("Failed to estimate token size", extra={"error": str(exc)})
return 0
except (requests.exceptions.RequestException, ssl.SSLError) as exc:
# if network errors, skip token count estimation instead of erroring out
logger.warning("Failed to download tiktoken model", extra={"error": str(exc)})
return 0


def _format_token_number(count: int) -> str:
"""Return a human-readable token-count string (e.g. 1.2k, 1.2M).

Parameters
----------
count : int
The token count to format.

Returns
-------
str
The formatted number of tokens as a string (e.g., ``"1.2k"``, ``"1.2M"``), or empty string if count is 0.

"""
if count == 0:
return ""
for threshold, suffix in _TOKEN_THRESHOLDS:
if count >= threshold:
return f"{count / threshold:.1f}{suffix}"
return str(count)


def format_node(node: FileSystemNode, query: IngestionQuery) -> tuple[str, str, str]:
"""Generate a summary, directory structure, and file contents for a given file system node.
Expand Down Expand Up @@ -51,9 +111,17 @@ def format_node(node: FileSystemNode, query: IngestionQuery) -> tuple[str, str,
summary += f"File: {node.name}\n"
summary += f"Lines: {len(node.content.splitlines()):,}\n"

content = _gather_file_contents(node)

tree = "Directory structure:\n" + _create_tree_structure(query, node=node)

content = _gather_file_contents(node)
# calculate total tokens for entire digest (tree + content) - what users download/copy
total_tokens = _estimate_tokens(tree + content)

# set root node token count to match the total exactly
node.token_count = total_tokens

tree = "Directory structure:\n" + _create_tree_structure(query, node=node)

token_estimate = _format_token_count(tree + content)
if token_estimate:
Expand Down Expand Up @@ -107,6 +175,7 @@ def _gather_file_contents(node: FileSystemNode) -> str:

This function recursively processes a directory node and gathers the contents of all files
under that node. It returns the concatenated content of all files as a single string.
Also calculates and aggregates token counts during traversal.

Parameters
----------
Expand All @@ -120,10 +189,17 @@ def _gather_file_contents(node: FileSystemNode) -> str:

"""
if node.type != FileSystemNodeType.DIRECTORY:
node.token_count = _estimate_tokens(node.content)
return node.content_string

# Recursively gather contents of all files under the current directory
return "\n".join(_gather_file_contents(child) for child in node.children)
# recursively gather contents and aggregate token counts
node.token_count = 0
contents = []
for child in node.children:
contents.append(_gather_file_contents(child))
node.token_count += child.token_count

return "\n".join(contents)


def _create_tree_structure(
Expand Down Expand Up @@ -169,6 +245,10 @@ def _create_tree_structure(
elif node.type == FileSystemNodeType.SYMLINK:
display_name += " -> " + readlink(node.path).name

if node.token_count > 0:
formatted_tokens = _format_token_number(node.token_count)
display_name += f" ({formatted_tokens} tokens)"

tree_str += f"{prefix}{current_prefix}{display_name}\n"

if node.type == FileSystemNodeType.DIRECTORY and node.children:
Expand All @@ -192,15 +272,8 @@ def _format_token_count(text: str) -> str | None:
The formatted number of tokens as a string (e.g., ``"1.2k"``, ``"1.2M"``), or ``None`` if an error occurs.

"""
try:
encoding = tiktoken.get_encoding("o200k_base") # gpt-4o, gpt-4o-mini
total_tokens = len(encoding.encode(text, disallowed_special=()))
except (ValueError, UnicodeEncodeError) as exc:
logger.warning("Failed to estimate token size", extra={"error": str(exc)})
return None
except (requests.exceptions.RequestException, ssl.SSLError) as exc:
# If network errors, skip token count estimation instead of erroring out
logger.warning("Failed to download tiktoken model", extra={"error": str(exc)})
total_tokens = _estimate_tokens(text)
if total_tokens == 0:
return None

for threshold, suffix in _TOKEN_THRESHOLDS:
Expand Down
1 change: 1 addition & 0 deletions src/gitingest/schemas/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class FileSystemNode: # pylint: disable=too-many-instance-attributes
file_count: int = 0
dir_count: int = 0
depth: int = 0
token_count: int = 0
children: list[FileSystemNode] = field(default_factory=list)

def sort_children(self) -> None:
Expand Down
Loading