diff --git a/api/analyzers/analyzer.py b/api/analyzers/analyzer.py index 73d2661..036857f 100644 --- a/api/analyzers/analyzer.py +++ b/api/analyzers/analyzer.py @@ -143,3 +143,32 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ pass + @abstractmethod + def add_file_imports(self, file: File) -> None: + """ + Add import statements to the file. + + Args: + file (File): The file to add imports to. + """ + + pass + + @abstractmethod + def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]: + """ + Resolve an import statement to entities. + + Args: + files (dict[Path, File]): All files in the project. + lsp (SyncLanguageServer): The language server. + file_path (Path): The path to the file containing the import. + path (Path): The path to the project root. + import_node (Node): The import statement node. + + Returns: + list[Entity]: List of resolved entities. + """ + + pass + diff --git a/api/analyzers/java/analyzer.py b/api/analyzers/java/analyzer.py index 4ae01d5..199ffdd 100644 --- a/api/analyzers/java/analyzer.py +++ b/api/analyzers/java/analyzer.py @@ -132,3 +132,19 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ return self.resolve_method(files, lsp, file_path, path, symbol) else: raise ValueError(f"Unknown key {key}") + + def add_file_imports(self, file: File) -> None: + """ + Extract and add import statements from the file. + Java imports are not yet implemented. + """ + # TODO: Implement Java import tracking + pass + + def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]: + """ + Resolve an import statement to the entities it imports. + Java imports are not yet implemented. + """ + # TODO: Implement Java import resolution + return [] diff --git a/api/analyzers/python/analyzer.py b/api/analyzers/python/analyzer.py index 80bba88..7fd5b36 100644 --- a/api/analyzers/python/analyzer.py +++ b/api/analyzers/python/analyzer.py @@ -121,3 +121,95 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ return self.resolve_method(files, lsp, file_path, path, symbol) else: raise ValueError(f"Unknown key {key}") + + def add_file_imports(self, file: File) -> None: + """ + Extract and add import statements from the file. + + Supports: + - import module + - import module as alias + - from module import name + - from module import name1, name2 + - from module import name as alias + """ + try: + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Query for both import types + import_query = self.language.query(""" + (import_statement) @import + (import_from_statement) @import_from + """) + + captures = import_query.captures(file.tree.root_node) + + # Add all import statement nodes to the file + if 'import' in captures: + for import_node in captures['import']: + file.add_import(import_node) + + if 'import_from' in captures: + for import_node in captures['import_from']: + file.add_import(import_node) + except Exception as e: + logger.debug(f"Failed to extract imports from {file.path}: {e}") + + def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]: + """ + Resolve an import statement to the entities it imports. + """ + res = [] + + try: + if import_node.type == 'import_statement': + # Handle "import module" or "import module as alias" + # Find all dotted_name and aliased_import nodes + for child in import_node.children: + if child.type == 'dotted_name': + # Try to resolve the module/name + identifier = child.children[0] if child.child_count > 0 else child + resolved = self.resolve_type(files, lsp, file_path, path, identifier) + res.extend(resolved) + elif child.type == 'aliased_import': + # Get the actual name from aliased import (before 'as') + if child.child_count > 0: + actual_name = child.children[0] + if actual_name.type == 'dotted_name' and actual_name.child_count > 0: + identifier = actual_name.children[0] + else: + identifier = actual_name + resolved = self.resolve_type(files, lsp, file_path, path, identifier) + res.extend(resolved) + + elif import_node.type == 'import_from_statement': + # Handle "from module import name1, name2" + # Find the 'import' keyword to know where imported names start + import_keyword_found = False + for child in import_node.children: + if child.type == 'import': + import_keyword_found = True + continue + + # After 'import' keyword, dotted_name nodes are the imported names + if import_keyword_found and child.type == 'dotted_name': + # Try to resolve the imported name + identifier = child.children[0] if child.child_count > 0 else child + resolved = self.resolve_type(files, lsp, file_path, path, identifier) + res.extend(resolved) + elif import_keyword_found and child.type == 'aliased_import': + # Handle "from module import name as alias" + if child.child_count > 0: + actual_name = child.children[0] + if actual_name.type == 'dotted_name' and actual_name.child_count > 0: + identifier = actual_name.children[0] + else: + identifier = actual_name + resolved = self.resolve_type(files, lsp, file_path, path, identifier) + res.extend(resolved) + + except Exception as e: + logger.debug(f"Failed to resolve import: {e}") + + return res diff --git a/api/analyzers/source_analyzer.py b/api/analyzers/source_analyzer.py index 12502ab..9ec3572 100644 --- a/api/analyzers/source_analyzer.py +++ b/api/analyzers/source_analyzer.py @@ -112,6 +112,10 @@ def first_pass(self, path: Path, files: list[Path], ignore: list[str], graph: Gr # Walk thought the AST graph.add_file(file) self.create_hierarchy(file, analyzer, graph) + + # Extract import statements + if not analyzer.is_dependency(str(file_path)): + analyzer.add_file_imports(file) def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: """ @@ -141,6 +145,8 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: for i, file_path in enumerate(files): file = self.files[file_path] logging.info(f'Processing file ({i + 1}/{files_len}): {file_path}') + + # Resolve entity symbols for _, entity in file.entities.items(): entity.resolved_symbol(lambda key, symbol: analyzers[file_path.suffix].resolve_symbol(self.files, lsps[file_path.suffix], file_path, path, key, symbol)) for key, symbols in entity.resolved_symbols.items(): @@ -157,6 +163,13 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: graph.connect_entities("RETURNS", entity.id, symbol.id) elif key == "parameters": graph.connect_entities("PARAMETERS", entity.id, symbol.id) + + # Resolve file imports + for import_node in file.imports: + resolved_entities = analyzers[file_path.suffix].resolve_import(self.files, lsps[file_path.suffix], file_path, path, import_node) + for resolved_entity in resolved_entities: + file.add_resolved_import(resolved_entity) + graph.connect_entities("IMPORTS", file.id, resolved_entity.id) def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None: self.first_pass(path, files, [], graph) diff --git a/api/entities/file.py b/api/entities/file.py index c59e2b6..a893734 100644 --- a/api/entities/file.py +++ b/api/entities/file.py @@ -21,10 +21,30 @@ def __init__(self, path: Path, tree: Tree) -> None: self.path = path self.tree = tree self.entities: dict[Node, Entity] = {} + self.imports: list[Node] = [] + self.resolved_imports: set[Entity] = set() def add_entity(self, entity: Entity): entity.parent = self self.entities[entity.node] = entity + + def add_import(self, import_node: Node): + """ + Add an import statement node to track. + + Args: + import_node (Node): The import statement node. + """ + self.imports.append(import_node) + + def add_resolved_import(self, resolved_entity: Entity): + """ + Add a resolved import entity. + + Args: + resolved_entity (Entity): The resolved entity that is imported. + """ + self.resolved_imports.add(resolved_entity) def __str__(self) -> str: return f"path: {self.path}" diff --git a/tests/source_files/py_imports/module_a.py b/tests/source_files/py_imports/module_a.py new file mode 100644 index 0000000..b632304 --- /dev/null +++ b/tests/source_files/py_imports/module_a.py @@ -0,0 +1,12 @@ +"""Module A with a class definition.""" + +class ClassA: + """A simple class in module A.""" + + def method_a(self): + """A method in ClassA.""" + return "Method A" + +def function_a(): + """A function in module A.""" + return "Function A" diff --git a/tests/source_files/py_imports/module_b.py b/tests/source_files/py_imports/module_b.py new file mode 100644 index 0000000..c0c1c30 --- /dev/null +++ b/tests/source_files/py_imports/module_b.py @@ -0,0 +1,11 @@ +"""Module B that imports from module A.""" + +from module_a import ClassA, function_a + +class ClassB(ClassA): + """A class that extends ClassA.""" + + def method_b(self): + """A method in ClassB.""" + result = function_a() + return f"Method B: {result}" diff --git a/tests/test_py_imports.py b/tests/test_py_imports.py new file mode 100644 index 0000000..8e86603 --- /dev/null +++ b/tests/test_py_imports.py @@ -0,0 +1,67 @@ +import os +import unittest +from pathlib import Path + +from api import SourceAnalyzer, File, Graph + + +class Test_PY_Imports(unittest.TestCase): + def test_import_tracking(self): + """Test that Python imports are tracked correctly.""" + # Get test file path + current_dir = os.path.dirname(os.path.abspath(__file__)) + test_path = os.path.join(current_dir, 'source_files', 'py_imports') + + # Create graph and analyze + g = Graph("py_imports_test") + analyzer = SourceAnalyzer() + + try: + analyzer.analyze_local_folder(test_path, g) + + # Verify files were created + module_a = g.get_file('', 'module_a.py', '.py') + self.assertIsNotNone(module_a, "module_a.py should be in the graph") + + module_b = g.get_file('', 'module_b.py', '.py') + self.assertIsNotNone(module_b, "module_b.py should be in the graph") + + # Verify classes were created + class_a = g.get_class_by_name('ClassA') + self.assertIsNotNone(class_a, "ClassA should be in the graph") + + class_b = g.get_class_by_name('ClassB') + self.assertIsNotNone(class_b, "ClassB should be in the graph") + + # Verify function was created + func_a = g.get_function_by_name('function_a') + self.assertIsNotNone(func_a, "function_a should be in the graph") + + # Test: module_b should have IMPORTS relationship to ClassA + # Query to check if module_b imports ClassA + query = """ + MATCH (f:File {name: 'module_b.py'})-[:IMPORTS]->(c:Class {name: 'ClassA'}) + RETURN c + """ + result = g._query(query, {}) + self.assertGreater(len(result.result_set), 0, + "module_b.py should import ClassA") + + # Test: module_b should have IMPORTS relationship to function_a + query = """ + MATCH (f:File {name: 'module_b.py'})-[:IMPORTS]->(fn:Function {name: 'function_a'}) + RETURN fn + """ + result = g._query(query, {}) + self.assertGreater(len(result.result_set), 0, + "module_b.py should import function_a") + + print("✓ Import tracking test passed") + + finally: + # Cleanup: delete the test graph + g.delete() + + +if __name__ == '__main__': + unittest.main()