diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py new file mode 100644 index 0000000..1c697b9 --- /dev/null +++ b/tests/test_async_connection.py @@ -0,0 +1,313 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import asyncio + + +class TestCase(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + """ + Setup connection parameters. + """ + self.user = test_env.get_test_user() + self.password = test_env.get_test_password() + self.dsn = test_env.get_localhost_connect_string() + + # --- Connection Tests --- + # async def test_async_connection_success(self): + # try: + # # Use the helper function to create the connection + # await test_env.create_async_connection(use_wallet=True) + + # # Check if connected + # is_connected = await select_ai.async_is_connected() + # self.assertTrue(is_connected, "Connection to DB failed") + + # finally: + # if is_connected: + # await select_ai.async_disconnect() + # # Ensure connection is now closed + # is_still_connected = await select_ai.async_is_connected() + # self.assertFalse(is_still_connected, "Connection should be closed after close()") + + + async def test_async_connection_without_wallet(self): + try: + # Use helper to create async connection without wallet + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + + # Check if connected + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed without wallet") + finally: + if is_connected: + await select_ai.async_disconnect() + # Ensure connection is now closed + is_still_connected = await select_ai.async_is_connected() + self.assertFalse(is_still_connected, "Connection should be closed after close()") + + + async def test_async_is_connected_bool(self): + # connection version is a string + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + is_connected = await select_ai.async_is_connected() + self.assertIsInstance(is_connected, bool) + await select_ai.async_disconnect() + + async def test_async_conn_failure_wrong_password(self): + with self.assertRaises(oracledb.DatabaseError): + await select_ai.async_connect(user=self.user, password="wrong_pass", dsn=self.dsn) + + + async def test_async_connect_bad_string(self): + # connection to database with bad connect string + with self.assertRaises(TypeError) as cm: + await select_ai.async_connect("not a valid connect string!!") + + self.assertIn( + "missing 2 required positional arguments", + str(cm.exception) + ) + + + async def test_async_connect_bad_dsn(self): + # Expecting a standard DatabaseError for bad DSN + dsn = 'invalid_dsn' + with self.assertRaises(oracledb.DatabaseError) as context: + await select_ai.async_connect(user=self.user, password=self.password, dsn=dsn) + + self.assertIn("DPY-4027", str(context.exception)) + + + async def test_asyn_connect_bad_password(self): + # connection to database with bad password + with self.assertRaises(oracledb.DatabaseError) as cm: + await select_ai.async_connect(user=self.user, password=test_env.get_test_password() + "X", dsn=self.dsn) + + # Validate that the error contains ORA-01017 + self.assertIn("ORA-01017", str(cm.exception)) + + + # --- Query Tests --- + async def test_async_query_execution(self): + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + async with select_ai.async_cursor() as cr: + await cr.execute("SELECT 1 FROM DUAL") + result = await cr.fetchone() + self.assertEqual(result[0], 1) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_query_with_parameters(self): + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + async with select_ai.async_cursor() as cr: + await cr.execute("SELECT :val FROM dual", val=42) + result = await cr.fetchone() + self.assertEqual(result[0], 42) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_fetchall(self): + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + async with select_ai.async_cursor() as cursor: + await cursor.execute("SELECT level FROM dual CONNECT BY level <= 5") + results = await cursor.fetchall() + self.assertEqual(len(results), 5) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_invalid_query(self): + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + async with select_ai.async_cursor() as cursor: + with self.assertRaises(oracledb.DatabaseError): + await cursor.execute("SELECT * FROM non_existent_table") + + # Disconnect + await select_ai.async_disconnect() + + + # --- Transaction Tests --- + async def test_async_commit_and_rollback(self): + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + async with select_ai.async_cursor() as cursor: + # Create the table only if it doesn't exist + await cursor.execute(""" + begin + execute immediate 'create table test_cr_tab (id int)'; + exception + when others then + if sqlcode != -955 then + raise; + end if; + end; + """) + await cursor.execute("commit") + + # Clean up any leftover data + await cursor.execute("truncate table test_cr_tab") + + # Test rollback + await cursor.execute("insert into test_cr_tab values (1)") + await cursor.execute("rollback") + + await cursor.execute("select count(*) from test_cr_tab") + (count,) = await cursor.fetchone() + print (count) + self.assertEqual(count, 0, "Rollback should undo the insert.") + + # Disconnect + await select_ai.async_disconnect() + + + # --- Lifecycle Tests --- + async def test_async_connection_close(self): + # Create connection + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + + # Close connection + await select_ai.async_disconnect() + + # Attempt to use cursor after disconnect should raise InterfaceError + with self.assertRaises(oracledb.InterfaceError): + async with select_ai.async_cursor() as cr: + await cr.execute("SELECT 1 FROM DUAL") + + + async def test_async_connection_reclose(self): + # Create connection + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + + # First disconnect + await select_ai.async_disconnect() + # Second disconnect should not raise + await select_ai.async_disconnect() + + # Assert connection is closed + is_connected = await select_ai.async_is_connected() + self.assertFalse(is_connected, "Connection should be closed after repeated disconnects") + + + async def test_async_dbms_output(self): + # Create connection + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + + test_string = "Testing DBMS_OUTPUT package" + + async with select_ai.async_cursor() as cursor: + await cursor.callproc("dbms_output.enable") + await cursor.callproc("dbms_output.put_line", [test_string]) + string_var = cursor.var(str) + number_var = cursor.var(int) + await cursor.callproc("dbms_output.get_line", (string_var, number_var)) + self.assertEqual(string_var.getvalue(), test_string) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_connection_instance(self): + # Create connection + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + + async with select_ai.async_cursor() as cursor: + await cursor.execute( + """ + select upper(sys_context('userenv', 'instance_name')) + from dual + """ + ) + (instance_name,) = await cursor.fetchone() + self.assertIsInstance(instance_name, str, "Expected service_name to be a string") + + # Disconnect + await select_ai.async_disconnect() + + async def test_async_max_open_cursors(self): + # test getting max_open_cursors + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "select value from V$PARAMETER where name='open_cursors'" + ) + (max_open_cursors,) = await cursor.fetchone() + + self.assertEqual(1000, int(max_open_cursors)) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_get_service_name(self): + # test getting service_name + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "select sys_context('userenv', 'service_name') from dual" + ) + (service_name,) = await cursor.fetchone() + + # Verify service_name is a string + self.assertIsInstance(service_name, str, "Expected service_name to be a string") + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_user_and_table(self): + # Create connection + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + + test_username = "TEST_USER1" + test_password = self.password + + # Drop user if exists and create new one + async with select_ai.async_cursor() as admin_cursor: + try: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + + await admin_cursor.execute(f"CREATE USER {test_username} IDENTIFIED BY {test_password}") + await admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + await admin_cursor.execute("commit") + + # Connect as test user + await test_env.create_async_connection(user=test_username, password=test_password, dsn=self.dsn, use_wallet=False) + + async with select_ai.async_cursor() as test_cursor: + await test_cursor.execute("CREATE TABLE test_table (id INT)") + await test_cursor.execute("INSERT INTO test_table (id) VALUES (100)") + await test_cursor.execute("commit") + + await test_cursor.execute("SELECT id FROM test_table") + result = await test_cursor.fetchone() + self.assertEqual(result[0], 100) + + # Disconnect + await select_ai.async_disconnect() + + # Clean up user + await test_env.create_async_connection(dsn=self.dsn, use_wallet=False) + async with select_ai.async_cursor() as admin_cursor: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + await admin_cursor.execute("commit") + + # Disconnect + await select_ai.async_disconnect() + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_async_create_cred.py b/tests/test_async_create_cred.py new file mode 100644 index 0000000..f33cb70 --- /dev/null +++ b/tests/test_async_create_cred.py @@ -0,0 +1,576 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import asyncio + + +class TestAsyncCreateCredential(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + """ + Setup connection parameters. + """ + self.user = test_env.get_test_user() + self.password = test_env.get_test_password() + self.dsn = test_env.get_localhost_connect_string() + + # Get Native cred secrets + self.user_ocid = test_env.get_user_ocid() + self.tenancy_ocid = test_env.get_tenancy_ocid() + self.private_key = test_env.get_private_key() + self.fingerprint = test_env.get_fingerprint() + + # Get basic cred secrets + self.cred_username = test_env.get_cred_username() + self.cred_password = test_env.get_cred_password() + + + def get_native_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + user_ocid = self.user_ocid, + tenancy_ocid = self.tenancy_ocid, + private_key = self.private_key, + fingerprint = self.fingerprint + ) + + + def get_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + username = self.cred_username, + password = self.cred_password + ) + + + async def drop_async_credential_cursor(self, cursor, cred_name='GENAI_CRED'): + await cursor.callproc( + "DBMS_CLOUD.DROP_CREDENTIAL", + keyword_parameters={ + "credential_name": cred_name + }, + ) + + + async def test_async_create_cred(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + try: + await select_ai.async_create_credential(credential=credential, replace=False) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_twice(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + # First creation + try: + await select_ai.async_create_credential(credential=credential) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + with self.assertRaises(oracledb.DatabaseError) as cm: + await select_ai.async_create_credential(credential=credential) + + + # Verify specific error code/message + self.assertIn( + 'ORA-20022', + str(cm.exception), + "Expected ORA-20022 error when creating credential without replace" + ) + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_same_data_multiple_times(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + for _ in range(5): + await select_ai.async_create_credential(credential=credential, replace=True) + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_rtrue(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_twice_rtrue(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + # First creation + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Assert passed if no exception raised + self.assertTrue(True, "Credential creation and replacement passed without exception.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_twice_rtrue_rfalse(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + # First creation + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + with self.assertRaises(oracledb.DatabaseError) as cm: + await select_ai.async_create_credential(credential=credential) + + + # Verify specific error code/message + self.assertIn( + 'ORA-20022', + str(cm.exception), + "Expected ORA-20022 error when creating credential without replace" + ) + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_twice_rfalse_rtrue(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + # First creation + try: + await select_ai.async_create_credential(credential=credential) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Assert passed if no exception raised + self.assertTrue(True, "Credential creation and replacement passed without exception.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_native_cred(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get native cred secrets + credential = self.get_native_cred_param('GENAI_CRED') + + try: + await select_ai.async_create_credential(credential=credential, replace=False) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_native_cred_twice(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get native cred secrets + credential = self.get_native_cred_param('GENAI_CRED') + + # First creation + try: + await select_ai.async_create_credential(credential=credential) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + with self.assertRaises(oracledb.DatabaseError) as cm: + await select_ai.async_create_credential(credential=credential) + + # Verify specific error code/message + self.assertIn( + 'ORA-20022', + str(cm.exception), + "Expected ORA-20022 error when creating credential without replace" + ) + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_native_cred_rtrue(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get native cred secrets + credential = self.get_native_cred_param('GENAI_CRED') + + # First creation + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Assert passed if no exception raised + self.assertTrue(True, "Credential creation and replacement passed without exception.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_native_cred_twice_rtrue(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get native cred secrets + credential = self.get_native_cred_param('GENAI_CRED') + + # Create credential + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_empty_name(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param() + + # Verify create credential + with self.assertRaises(Exception) as cm: + await select_ai.async_create_credential(credential=credential) + self.assertIn("ORA-20010: Missing credential name", str(cm.exception)) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_credential_empty_dict(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = dict() + + # Verify create credential + with self.assertRaises(oracledb.DatabaseError) as cm: + await select_ai.async_create_credential(credential=credential) + + self.assertIn( + "PLS-00306: wrong number or types of arguments in call to 'CREATE_CREDENTIAL'", + str(cm.exception) + ) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_invalid_username(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = dict( + credential_name = 'GENAI_CRED', + username = 'invalid_username', + password = self.cred_password + ) + + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_invalid_password(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = dict( + credential_name = 'GENAI_CRED', + username = self.cred_username, + password = 'invalid_pwd' + ) + + try: + await select_ai.async_create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_db_unavailable(self): + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + with self.assertRaisesRegex( + select_ai.errors.DatabaseNotConnectedError, + r"Not connected to the Database" + ): + await select_ai.async_create_credential(credential=credential, replace=False) + + + async def test_async_create_cred_local_user(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + test_username = "TEST_USER1" + test_password = self.password + + # Drop user if exists and create new one + async with select_ai.async_cursor() as admin_cursor: + try: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + + await admin_cursor.execute(f"CREATE USER {test_username} IDENTIFIED BY {test_password}") + await admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + await admin_cursor.execute(f"grant execute on dbms_cloud to {test_username}") + + # Connect as test user + await test_env.create_async_connection( + user=test_username, + password=test_password, + dsn=self.dsn, + use_wallet=False + ) + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED_USER1') + + try: + await select_ai.async_create_credential(credential=credential, replace=False) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + async with select_ai.async_cursor() as cursor: + await self.drop_async_credential_cursor(cursor, 'GENAI_CRED_USER1') + + # Disconnect + await select_ai.async_disconnect() + + # Clean up user + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + async with select_ai.async_cursor() as admin_cursor: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + + + # Negative Tests + async def test_async_create_cred_special_characters(self): + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED!@#') + + with self.assertRaisesRegex(oracledb.DatabaseError, r"ORA-20010: Invalid credential name"): + await select_ai.async_create_credential(credential=credential, replace=False) + + # Disconnect + await select_ai.async_disconnect() + + + async def test_async_create_cred_long_name(self): + long_name = "GENAI_CRED" + "_" + "a" * (128 - len('GENAI_CRED')) + + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param(long_name) + + with self.assertRaisesRegex( + oracledb.DatabaseError, + r"ORA-20008: Credential name length \(129\) exceeds maximum length \(128\)" + ): + await select_ai.async_create_credential(credential=credential, replace=False) + + # Disconnect + await select_ai.async_disconnect() + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_async_create_vector_index.py b/tests/test_async_create_vector_index.py new file mode 100644 index 0000000..83b964e --- /dev/null +++ b/tests/test_async_create_vector_index.py @@ -0,0 +1,409 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import asyncio + + +class TestAsyncCreateVectorIndex(unittest.IsolatedAsyncioTestCase): + def get_native_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + user_ocid = self.__class__.user_ocid, + tenancy_ocid = self.__class__.tenancy_ocid, + private_key = self.__class__.private_key, + fingerprint = self.__class__.fingerprint + ) + + + def get_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + username = self.__class__.cred_username, + password = self.__class__.cred_password + ) + + + async def create_async_credential(self, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + # Get credential secret + genai_credential = self.get_native_cred_param(genai_cred) + objstore_credential = self.get_cred_param(objstore_cred) + + # Create GenAI Credential + try: + await select_ai.async_create_credential(credential=genai_credential, replace=True) + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + # Create ObjStore Credential + try: + await select_ai.async_create_credential(credential=objstore_credential, replace=True) + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + + async def create_async_profile(self, profile_name="vector_ai_profile"): + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=self.__class__.oci_compartment_id, + oci_apiformat="GENERIC" + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider + ) + self.async_profile = await select_ai.AsyncProfile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True + ) + + + async def delete_async_profile(self): + try: + await self.async_profile.delete() + except Exception as e: + raise AssertionError(f"profile.delete() raised {e} unexpectedly.") + + + async def delete_async_credential(self, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + # Create GenAI Credential + try: + await select_ai.async_delete_credential(genai_cred, force=True) + except Exception as e: + raise AssertionError(f"delete_credential() raised {e} unexpectedly.") + + # Create ObjStore Credential + try: + await select_ai.async_delete_credential(objstore_cred, force=True) + except Exception as e: + raise AssertionError(f"delete_credential() raised {e} unexpectedly.") + + + @classmethod + def setUpClass(cls): + """ + Get Env Variabels + """ + # Assign password from test_env + cls.user = test_env.get_test_user() + cls.password = test_env.get_test_password() + cls.dsn = test_env.get_localhost_connect_string() + + # Get Native cred secrets + cls.user_ocid = test_env.get_user_ocid() + cls.tenancy_ocid = test_env.get_tenancy_ocid() + cls.private_key = test_env.get_private_key() + cls.fingerprint = test_env.get_fingerprint() + + # Get basic cred secrets + cls.cred_username = test_env.get_cred_username() + cls.cred_password = test_env.get_cred_password() + + # Get OCI Provider + cls.oci_compartment_id = test_env.get_compartment_id() + cls.embedding_location = test_env.get_embedding_location() + + + async def asyncSetUp(self): + self.embedding_location = self.__class__.embedding_location + self.dsn = self.__class__.dsn + self.objstore_cred = "OBJSTORE_CRED" + + """ + Create Credential, Profile for all tests. + """ + # Create async connection + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + assert is_connected, "Connection to DB failed" + + # Create Credential + await self.create_async_credential() + # Create Profile + await self.create_async_profile() + + # Specify objects to create an embedding for. + # The objects reside in ObjectStore and the vector database is Oracle + self.vector_index_attributes = select_ai.OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name=self.objstore_cred + ) + + # Create vector index object + self.async_vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="Test vector index", + profile=self.async_profile + ) + + + async def asyncTearDown(self): + # Delete Vector Index + async_vector_index = select_ai.AsyncVectorIndex(index_name="test_vector_index") + await async_vector_index.delete(force=True) + + # Delete Profile + await self.delete_async_profile() + + # Delete Credentials + await self.delete_async_credential() + + # Disconnect from DB + try: + await select_ai.async_disconnect() + except Exception as e: + print(f"Warning: disconnect failed ({e})") + + + async def test_async_create_vector_index_success(self): + try: + await self.async_vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Verify list + async_vector_index = select_ai.AsyncVectorIndex() + indexes = [i.index_name async for i in async_vector_index.list()] + self.assertIn("TEST_VECTOR_INDEX", indexes) + + + async def test_async_create_vector_index_success_replace_false(self): + try: + await self.async_vector_index.create(replace=False) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Verify list + async_vector_index = select_ai.AsyncVectorIndex() + indexes = [i.index_name async for i in async_vector_index.list()] + self.assertIn("TEST_VECTOR_INDEX", indexes) + + + async def test_async_create_vector_index_empty_description(self): + # Create vector index object + async_vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="", + profile=self.async_profile + ) + try: + await async_vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Verify list + async_vector_index = select_ai.AsyncVectorIndex() + indexes = [i.index_name async for i in async_vector_index.list()] + self.assertIn("TEST_VECTOR_INDEX", indexes) + + + async def test_async_create_vector_index_replace_true(self): + # First creation + try: + await self.async_vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Second creation + try: + await self.async_vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + + async def test_async_create_vector_index_replace_false(self): + # First creation should succeed + try: + await self.async_vector_index.create(replace=False) + except Exception as e: + self.fail(f"Create vector index failed unexpectedly with exception: {e}") + + # Second creation should raise ORA-20048 + with self.assertRaises(oracledb.DatabaseError) as cm: + await self.async_vector_index.create(replace=False) + + # Verify the error code/message + self.assertIn("ORA-20048", str(cm.exception)) + self.assertIn("already exists", str(cm.exception)) + + + async def test_async_create_vector_index_minimal_attributes(self): + # Create vector index object + async_vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + profile=self.async_profile + ) + + try: + await async_vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + + async def test_async_create_vector_index_recreate_after_delete(self): + try: + await self.async_vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Delete Vector Index + async_vector_index = select_ai.AsyncVectorIndex(index_name="test_vector_index") + await async_vector_index.delete(force=True) + + try: + self.async_vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + + # Negative Case + async def test_async_create_vector_index_invalid_credential(self): + vector_index_attributes = select_ai.OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="invalidObjStore_cred" + ) + + # Create vector index object + async_vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.async_profile + ) + + with self.assertRaises(oracledb.DatabaseError): + await async_vector_index.create(replace=True) + + + async def test_async_create_vector_index_invalid_location(self): + vector_index_attributes = select_ai.OracleVectorIndexAttributes( + location="invalid_location", + object_storage_credential_name=self.objstore_cred + ) + + # Create vector index object + async_vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.async_profile + ) + + with self.assertRaises(oracledb.DatabaseError): + await async_vector_index.create(replace=True) + + + async def test_async_create_vector_index_missing_attributes(self): + with self.assertRaises(AttributeError): + await select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=None, + profile=self.async_profile + ).create() + + + async def test_async_create_vector_index_invalid_attributes_type(self): + with self.assertRaises(TypeError): + await select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes="invalid_attributes", # invalid type + profile=self.async_profile + ).create() + + + async def test_async_create_vector_index_invalid_name_type(self): + with self.assertRaises(oracledb.DatabaseError) as cm: + await select_ai.AsyncVectorIndex( + index_name=12345, # invalid type (int instead of str) + attributes=self.vector_index_attributes, + profile=self.async_profile + ).create() + + # Verify error + self.assertIn("ORA-20048", str(cm.exception)) + self.assertIn("Invalid vector index name", str(cm.exception)) + + + async def test_async_create_vector_index_empty_name(self): + with self.assertRaises(oracledb.DatabaseError) as cm: + await select_ai.AsyncVectorIndex( + index_name="", + attributes=self.vector_index_attributes, + profile=self.async_profile + ).create() + + # Verify the error code/message + self.assertIn("ORA-20048", str(cm.exception)) + self.assertIn("Missing vector index name", str(cm.exception)) + + + async def test_async_create_vector_index_invalid_profile(self): + # Create vector index object + async_vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="Test vector index", + profile="invalid_profile" + ) + with self.assertRaises(ValueError): + await async_vector_index.create() + + + async def test_async_create_vector_index_none_attributes(self): + async_vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=None, + description="Test vector index", + profile="invalid_profile" + ) + with self.assertRaises(TypeError): + await async_vector_index.create() + + + # Boundary Cases + async def test_async_create_vector_index_long_name(self): + long_name = "X" * 150 # > Oracle identifier length + async_vector_index = select_ai.AsyncVectorIndex( + index_name=long_name, + attributes=self.vector_index_attributes, + profile=self.async_profile + ) + with self.assertRaises(oracledb.DatabaseError): + await async_vector_index.create() + + + async def test_async_create_vector_index_long_description(self): + long_desc = "D" * 5000 # deliberately too long + + # Create vector index object + async_vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description=long_desc, + profile=self.async_profile + ) + + # Expect DatabaseError due to description length + with self.assertRaises(oracledb.DatabaseError) as cm: + await async_vector_index.create(replace=True) + + # Verify Oracle error details + self.assertIn("ORA-20045", str(cm.exception)) + self.assertIn("description is too long", str(cm.exception)) + + + async def test_async_create_vector_index_multiple_recreates(self): + for _ in range(10): + await self.async_vector_index.create(replace=True) + + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_async_drop_cred.py b/tests/test_async_drop_cred.py new file mode 100644 index 0000000..52ab494 --- /dev/null +++ b/tests/test_async_drop_cred.py @@ -0,0 +1,203 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import asyncio + + +class TestAsyncDropCredential(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + """ + Async setup for test connection parameters and connection creation. + """ + self.user = test_env.get_test_user() + self.password = test_env.get_test_password() + self.dsn = test_env.get_localhost_connect_string() + + # Get basic cred secrets + self.cred_username = test_env.get_cred_username() + self.cred_password = test_env.get_cred_password() + + # Create async connection + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + + is_connected = await select_ai.async_is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + + + async def asyncTearDown(self): + # Disconnect after each test + await select_ai.async_disconnect() + + + def get_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + username = self.cred_username, + password = self.cred_password + ) + + async def create_test_credential(self, cred_name="GENAI_CRED"): + """ + Async helper to create a test credential. + """ + # Get credential secret + credential = self.get_cred_param(cred_name) + + try: + await select_ai.async_create_credential( + credential=credential, + replace=False + ) + except Exception as e: + self.fail(f"async_create_credential() raised {e} unexpectedly.") + + + async def create_local_user(self, test_username="TEST_USER1"): + """ + Async helper to drop and create a local test user with required grants. + """ + test_password = self.password + + # Drop user if exists and create new one + async with select_ai.async_cursor() as admin_cursor: + try: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + + await admin_cursor.execute(f"CREATE USER {test_username} IDENTIFIED BY {test_password}") + await admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + await admin_cursor.execute(f"grant execute on dbms_cloud to {test_username}") + + + async def test_async_delete_cred_success(self): + # Create credential + await self.create_test_credential() + + try: + await select_ai.async_delete_credential("GENAI_CRED", force=True) + except Exception as e: + self.fail(f"delete_credential() raised {e} unexpectedly.") + + + async def test_async_delete_cred_twice_force_true(self): + # Create credential + await self.create_test_credential() + + # First delete should succeed + await select_ai.async_delete_credential("GENAI_CRED", force=True) + + # Second delete should also succeed (no exception, since force=True) + await select_ai.async_delete_credential("GENAI_CRED", force=True) + + + async def test_async_delete_cred_twice_force_false(self): + # Create credential + await self.create_test_credential() + + # First delete should succeed + await select_ai.async_delete_credential("GENAI_CRED", force=False) + + # Second delete should raise DatabaseError since credential is already deleted + with self.assertRaises(oracledb.DatabaseError): + await select_ai.async_delete_credential("GENAI_CRED", force=False) + + + async def test_async_delete_nonexistent_cred_default(self): + with self.assertRaises(oracledb.DatabaseError): + await select_ai.async_delete_credential("nonexistent_cred") + + + async def test_async_delete_nonexistent_cred_without_force(self): + with self.assertRaises(oracledb.DatabaseError): + await select_ai.async_delete_credential("nonexistent_cred", force=False) + + + async def test_async_delete_nonexistent_cred_with_force(self): + # Should not raise error when force=True + try: + await select_ai.async_delete_credential("nonexistent_cred", force=True) + except Exception as e: + self.fail(f"delete_credential(force=True) raised {e} unexpectedly.") + + + async def test_async_delete_cred_local_user(self): + test_username = "TEST_USER1" + + await self.create_local_user(test_username) + + # Connect as test user + await test_env.create_async_connection( + user=test_username, + password=self.password, + dsn=self.dsn, + use_wallet=False + ) + + # Get credential secret + credential = self.get_cred_param("GENAI_CRED_USER1") + + try: + await select_ai.async_delete_credential("GENAI_CRED_USER1", force=True) + except Exception as e: + self.fail(f"delete_credential() raised {e} unexpectedly.") + + # Disconnect + await select_ai.async_disconnect() + + # Clean up user + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + async with select_ai.async_cursor() as admin_cursor: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + + + async def test_async_invalid_cred_name(self): + with self.assertRaisesRegex(oracledb.DatabaseError, + r"ORA-20010: Invalid credential name"): + await select_ai.async_delete_credential("invalid!@#", force=True) + + + async def test_async_delete_cred_not_connected(self): + await select_ai.async_disconnect() + with self.assertRaises(select_ai.errors.DatabaseNotConnectedError): + await select_ai.async_delete_credential("GENAI_CRED", force=True) + + + async def test_async_credential_name_too_long(self): + long_name = "GENAI_CRED_" + "a" * 120 + with self.assertRaisesRegex(oracledb.DatabaseError, + r"ORA-20008: Credential name length .* exceeds maximum length"): + await select_ai.async_delete_credential(long_name, force=True) + + + async def test_async_delete_cred_case_sensitive(self): + # Create credential + await self.create_test_credential("GENAI_CRED") + + # Try deleting with lower case + try: + await select_ai.async_delete_credential(credential_name="genai_cred") + except Exception as e: + self.fail(f"async_delete_credential raised {e} unexpectedly for lowercase name") + + + async def test_async_delete_cred_empty_name(self): + # Empty string → ORA-20010: Missing credential name + with self.assertRaisesRegex(oracledb.DatabaseError, + r"ORA-20010: Missing credential name"): + await select_ai.async_delete_credential(credential_name="", force=True) + + # None → should also end up with ORA-20010 + with self.assertRaisesRegex(oracledb.DatabaseError, + r"ORA-20010: Missing credential name"): + await select_ai.async_delete_credential(credential_name=None, force=True) + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_async_enable_provider.py b/tests/test_async_enable_provider.py new file mode 100644 index 0000000..323aa8a --- /dev/null +++ b/tests/test_async_enable_provider.py @@ -0,0 +1,102 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import asyncio + + +class TestAsyncEnableProvider(unittest.IsolatedAsyncioTestCase): + @classmethod + async def create_async_local_user(cls, test_username="TEST_USER1"): + """ + Helper to drop and create a local test user with required grants. + """ + test_password = cls.password + + # Drop user if exists and create new one + async with select_ai.async_cursor() as admin_cursor: + try: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + + await admin_cursor.execute(f"CREATE USER {test_username} IDENTIFIED BY {test_password}") + await admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + await admin_cursor.execute(f"grant execute on dbms_cloud to {test_username}") + + + @classmethod + async def _asyncSetUpClass(cls): + """ + Create DB users once before all tests. + """ + # Assign password from test_env so create_local_user can use it + cls.user = test_env.get_test_user() + cls.password = test_env.get_test_password() + cls.dsn = test_env.get_localhost_connect_string() + + # Create async connection + await test_env.create_async_connection( + dsn=cls.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + assert is_connected, "Connection to DB failed" + + + cls.db_users = list() + # Create multiple DB users (DB_USER1 ... DB_USER5) + for i in range(1, 6): + user = f"DB_USER{i}" + await cls.create_async_local_user(user) + cls.db_users.append(user) + + @classmethod + def setUpClass(cls): + """ + Sync wrapper that runs async setup once for all tests. + """ + asyncio.run(cls._asyncSetUpClass()) + + + @classmethod + async def _asyncTearDownClass(cls): + """ + Drop DB users after all tests finish. + """ + async with select_ai.async_cursor() as admin_cursor: + for user in cls.db_users: + try: + await admin_cursor.execute(f"DROP USER {user} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if already dropped + + # Disconnect from DB + try: + await select_ai.async_disconnect() + except Exception as e: + print(f"Warning: disconnect failed ({e})") + + + @classmethod + def tearDownClass(cls): + asyncio.run(cls._asyncTearDownClass()) + + + async def asyncSetUp(self): + self.provider_endpoint = "*.openai.azure.com" + + + async def test_async_enable_provider_success(self): + # Enabling provider with valid users and endpoint should succeed. + try: + await select_ai.async_enable_provider( + users=self.__class__.db_users, + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"enable_provider() raised {e} unexpectedly.") + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_async_list_vector_index.py b/tests/test_async_list_vector_index.py new file mode 100644 index 0000000..dfba62b --- /dev/null +++ b/tests/test_async_list_vector_index.py @@ -0,0 +1,379 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import asyncio +import re + + +class TestAsyncListVectorIndex(unittest.IsolatedAsyncioTestCase): + def get_native_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + user_ocid = self.__class__.user_ocid, + tenancy_ocid = self.__class__.tenancy_ocid, + private_key = self.__class__.private_key, + fingerprint = self.__class__.fingerprint + ) + + + def get_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + username = self.__class__.cred_username, + password = self.__class__.cred_password + ) + + + async def create_async_credential(self, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + # Get credential secret + genai_credential = self.get_native_cred_param(genai_cred) + objstore_credential = self.get_cred_param(objstore_cred) + + # Create GenAI Credential + try: + await select_ai.async_create_credential(credential=genai_credential, replace=True) + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + # Create ObjStore Credential + try: + await select_ai.async_create_credential(credential=objstore_credential, replace=True) + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + + async def create_async_profile(self, profile_name="vector_ai_profile"): + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=self.__class__.oci_compartment_id, + oci_apiformat="GENERIC" + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider + ) + self.async_profile = await select_ai.AsyncProfile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True + ) + + + async def delete_async_profile(self): + try: + await self.async_profile.delete() + except Exception as e: + raise AssertionError(f"profile.delete() raised {e} unexpectedly.") + + + async def delete_async_credential(self, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + # Create GenAI Credential + try: + await select_ai.async_delete_credential(genai_cred, force=True) + except Exception as e: + raise AssertionError(f"delete_credential() raised {e} unexpectedly.") + + # Create ObjStore Credential + try: + await select_ai.async_delete_credential(objstore_cred, force=True) + except Exception as e: + raise AssertionError(f"delete_credential() raised {e} unexpectedly.") + + + async def create_async_vector_index(self, index_name): + # Specify objects to create an embedding for. + # The objects reside in ObjectStore and the vector database is Oracle + vector_index_attributes = select_ai.OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name=self.objstore_cred + ) + + # Create vector index object + async_vector_index = select_ai.AsyncVectorIndex( + index_name=index_name, + attributes=vector_index_attributes, + description="Test vector index", + profile=self.async_profile + ) + + # Create vector index + await async_vector_index.create(replace=True) + + + @classmethod + def setUpClass(cls): + """ + Get Env Variabels + """ + # Assign password from test_env + cls.user = test_env.get_test_user() + cls.password = test_env.get_test_password() + cls.dsn = test_env.get_localhost_connect_string() + + # Get Native cred secrets + cls.user_ocid = test_env.get_user_ocid() + cls.tenancy_ocid = test_env.get_tenancy_ocid() + cls.private_key = test_env.get_private_key() + cls.fingerprint = test_env.get_fingerprint() + + # Get basic cred secrets + cls.cred_username = test_env.get_cred_username() + cls.cred_password = test_env.get_cred_password() + + # Get OCI Provider + cls.oci_compartment_id = test_env.get_compartment_id() + cls.embedding_location = test_env.get_embedding_location() + + + async def asyncSetUp(self): + self.embedding_location = self.__class__.embedding_location + self.dsn = self.__class__.dsn + self.objstore_cred = "OBJSTORE_CRED" + + """ + Create Credential, Profile for all tests. + """ + # Create async connection + await test_env.create_async_connection( + dsn=self.dsn, use_wallet=False + ) + is_connected = await select_ai.async_is_connected() + assert is_connected, "Connection to DB failed" + + # Create Credential + await self.create_async_credential() + # Create Profile + await self.create_async_profile() + + # Create some vector indexes + self.objstore_cred = "OBJSTORE_CRED" + + self.indexes = [f"test_vector_index{i}" for i in range(1, 6)] + \ + [f"test_vecidx{i}" for i in range(1, 3)] + + for idx in self.indexes: + try: + await self.create_async_vector_index(index_name=idx) + except Exception: + pass + + self.async_vector_index = select_ai.AsyncVectorIndex() + + + async def asyncTearDown(self): + # Clean up test indexes and close connection. + for idx in self.indexes: + try: + # Delete Vector Index + async_vector_index = select_ai.AsyncVectorIndex(index_name_pattern=idx) + await async_vector_index.delete(force=True) + except Exception: + pass + + # Delete Profile + await self.delete_async_profile() + + # Delete Credentials + await self.delete_async_credential() + + # Disconnect from DB + try: + await select_ai.async_disconnect() + except Exception as e: + print(f"Warning: disconnect failed ({e})") + + + # ---------- Positive Test Cases ---------- + async def test_async_list_matching_names(self): + expected_index_names = [f"test_vector_index{i}".upper() for i in range(1, 6)] + \ + [f"test_vecidx{i}".upper() for i in range(1, 3)] + + actual_indexes = [idx.index_name async for idx in self.async_vector_index.list(index_name_pattern=".*")] + + # Verify count of indexes + self.assertEqual( + len(actual_indexes), + len(expected_index_names), + f"Expected {len(expected_index_names)} indexes, got {len(actual_indexes)}" + ) + + # Verify same set of names, ignoring order + self.assertEqual( + sorted(actual_indexes), + sorted(expected_index_names), + f"Expected names {sorted(expected_index_names)}, got {sorted(actual_indexes)}" + ) + + + async def test_async_list_matching_profile_name(self): + expected_profile = "vector_ai_profile" + async for index in self.async_vector_index.list(index_name_pattern=".*"): + # Verify profile name + self.assertEqual( + index.profile.profile_name, + expected_profile, + f"Profile mismatch for {index.index_name}: expected {expected_profile}, got {index.profile.profile_name}" + ) + + + async def test_async_list_matching_credential_name(self): + expected_credential = "OBJSTORE_CRED" + async for index in self.async_vector_index.list(index_name_pattern=".*"): + # Verify object store credential + self.assertEqual( + index.attributes.object_storage_credential_name, + expected_credential, + f"Credential mismatch for {index.index_name}: expected {expected_credential}, got {index.attributes.object_storage_credential_name}" + ) + + + async def test_async_list_matching_description(self): + expected_description = "Test vector index" + async for index in self.async_vector_index.list(index_name_pattern=".*"): + # Verify description + self.assertEqual( + index.description, + expected_description, + f"Description mismatch for {index.index_name}: expected {expected_description}, got {index.description}" + ) + + + async def test_async_list_exact_match(self): + indexes = [idx async for idx in self.async_vector_index.list(index_name_pattern="^test_vector_index1$")] + self.assertEqual(indexes[0].index_name, "TEST_VECTOR_INDEX1") + + + async def test_async_list_multiple_matches(self): + actual_indexes = [] + async for index in self.async_vector_index.list(index_name_pattern="^test_vector_index"): + actual_indexes.append(index) + + # Verify count + expected_count = 5 + self.assertEqual( + len(list(actual_indexes)), + expected_count, + f"Expected {expected_count} indexes, got {len(list(actual_indexes))}" + ) + + # Verify each index name + for i, index in enumerate(actual_indexes, start=1): + expected_index_name = f"TEST_VECTOR_INDEX{i}" + self.assertEqual( + index.index_name, + expected_index_name, + f"Index name mismatch: expected {expected_index_name}, got {index.index_name}" + ) + + + async def test_async_list_case_sensitive_pattern(self): + indexes = [idx async for idx in self.async_vector_index.list(index_name_pattern="^TEST_VECTOR_INDEX?")] + self.assertTrue(any(idx.index_name == "TEST_VECTOR_INDEX2" for idx in indexes)) + + + async def test_async_list_case_insensitive_pattern(self): + indexes = [] + async for index in self.async_vector_index.list(index_name_pattern="(?i)^TEST"): + indexes.append(index) + + # for index in indexes: + # print(index.index_name) + self.assertTrue(any(idx.index_name == "TEST_VECTOR_INDEX1" for idx in indexes)) + + + async def test_async_list_complex_regex_or_operator(self): + indexes = [] + async for index in self.async_vector_index.list(index_name_pattern="^(test_vector_index|test_vecidx)"): + indexes.append(index) + + names = [idx.index_name for idx in indexes] + self.assertIn("TEST_VECTOR_INDEX1", names) + self.assertIn("TEST_VECIDX1", names) + + # Invalid Index + self.assertNotIn("INVALID_VECIDX1", names) + + + # ----- Negative Cases ----- + async def test_async_list_non_matching_pattern(self): + indexes = [] + async for index in self.async_vector_index.list(index_name_pattern="^xyz"): + indexes.append(index) + + self.assertEqual(len(list(indexes)), 0) + + + async def test_async_list_invalid_regex_pattern(self): + with self.assertRaises(oracledb.DatabaseError) as cm: + _ = [idx async for idx in self.async_vector_index.list("[unclosed")] + + # Optional: verify the error code/message + self.assertIn( + "ORA-12726", + str(cm.exception), + f"Expected ORA-12726 error, got {cm.exception}" + ) + + + async def test_async_list_invalid_type_pattern(self): + with self.assertRaises(TypeError): + _ = [idx async for idx in self.async_vector_index.list(123)] + + + async def test_async_list_invalid_type_pattern(self): + # Invalid type -> expect empty list + indexes = [idx async for idx in self.async_vector_index.list(123)] + self.assertEqual( + len(indexes), 0, + f"Expected 0 indexes for invalid type pattern, got {len(indexes)}" + ) + + # ----- Edge Cases ----- + async def test_async_list_none_pattern_match(self): + # None should usually mean "match all" + indexes = [idx async for idx in self.async_vector_index.list(None)] + self.assertNotEqual(len(indexes), len(self.indexes)) + + + async def test_async_list_empty_string_pattern_matches(self): + # Empty string should typically return all (but verify) + indexes = [idx async for idx in self.async_vector_index.list("")] + self.assertNotEqual(len(indexes), len(self.indexes)) + + async def test_async_list_whitespace_pattern(self): + indexes = [idx async for idx in self.async_vector_index.list(" ")] + self.assertEqual(len(indexes), 0) + + + async def test_async_list_numeric_pattern(self): + indexes = [idx async for idx in self.async_vector_index.list("test123")] + self.assertEqual( + len(indexes), 0, + f"Expected no indexes to match 'test123', but got {len(indexes)}" + ) + + + async def test_async_list_special_characters_in_pattern(self): + indexes = [idx async for idx in self.async_vector_index.list("test_vector_index1$")] + self.assertEqual(len(indexes), 1) + + + async def test_async_list_long_pattern_no_match(self): + pattern = "^" + "a" * 1000 + "$" + with self.assertRaises(oracledb.DatabaseError) as cm: + _ = [idx async for idx in self.async_vector_index.list(pattern)] + self.assertIn( + "ORA-12733", str(cm.exception), + f"Expected ORA-12733 error, got {cm.exception}" + ) + + + async def test_async_list_case_insensitive_match(self): + indexes = [idx async for idx in self.async_vector_index.list("^TEST")] + self.assertEqual(len(indexes), 7) + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..079cb93 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,307 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import time + + +class TestCase(unittest.TestCase): + + def setUp(self): + """ + Setup connection parameters. + """ + self.user = test_env.get_test_user() + self.password = test_env.get_test_password() + self.dsn = test_env.get_localhost_connect_string() + + test_env.create_connection() + with select_ai.cursor() as cursor: + cursor.execute(""" + begin + execute immediate 'create table testtemptable (intcol int)'; + exception + when others then + if sqlcode != -955 then + raise; + end if; + end; + """) + cursor.execute("commit") + + + # --- Connection Tests --- + def test_connection_success(self): + try: + # Use the helper function to create the connection + test_env.create_connection(use_wallet=True) + is_connected = select_ai.is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + finally: + if is_connected: + select_ai.disconnect() + self.assertFalse(select_ai.is_connected(), "Connection to DB failed") + + + def test_connection_without_wallet(self): + try: + # Use the helper function to create the connection + test_env.create_connection(dsn=self.dsn, use_wallet=False) + is_connected = select_ai.is_connected() + self.assertTrue(is_connected, "Connection to DB failed") + finally: + if is_connected: + select_ai.disconnect() + self.assertFalse(select_ai.is_connected(), "Connection to DB failed") + + + def test_is_connected_bool(self): + # connection version is a string + test_env.create_connection() + self.assertIsInstance(select_ai.is_connected(), bool) + select_ai.disconnect() + + + def test_conn_failure_wrong_password(self): + with self.assertRaises(oracledb.DatabaseError): + select_ai.connect(user=self.user, password="wrong_pass", dsn=self.dsn) + + + def test_connect_bad_string(self): + # connection to database with bad connect string + with self.assertRaises(TypeError) as cm: + select_ai.connect("not a valid connect string!!") + + self.assertIn( + "missing 2 required positional arguments", + str(cm.exception) + ) + + + def test_connect_bad_dsn(self): + # Expecting a standard DatabaseError for bad DSN + dsn = 'invalid_dsn' + with self.assertRaises(oracledb.DatabaseError) as context: + select_ai.connect(user=self.user, password=self.password, dsn=dsn) + + self.assertIn("DPY-4027", str(context.exception)) + + + def test_connect_bad_password(self): + # connection to database with bad password + with self.assertRaises(oracledb.DatabaseError) as cm: + select_ai.connect(user=self.user, password=test_env.get_test_password() + "X", dsn=self.dsn) + + # Validate that the error contains ORA-01017 + self.assertIn("ORA-01017", str(cm.exception)) + + + # --- Query Tests --- + def test_query_execution(self): + test_env.create_connection() + with select_ai.cursor() as cr: + cr.execute("SELECT 1 FROM DUAL") + result = cr.fetchone() + self.assertEqual(result[0], 1) + + # Disconnect + select_ai.disconnect() + + def test_query_with_parameters(self): + test_env.create_connection() + with select_ai.cursor() as cr: + cr.execute("SELECT :val FROM dual", val=42) + result = cr.fetchone() + self.assertEqual(result[0], 42) + + # Disconnect + select_ai.disconnect() + + def test_fetchall(self): + test_env.create_connection() + with select_ai.cursor() as cursor: + cursor.execute("SELECT level FROM dual CONNECT BY level <= 5") + results = cursor.fetchall() + self.assertEqual(len(results), 5) + + # Disconnect + select_ai.disconnect() + + def test_invalid_query(self): + test_env.create_connection() + with select_ai.cursor() as cursor: + with self.assertRaises(oracledb.DatabaseError): + cursor.execute("SELECT * FROM non_existent_table") + + # Disconnect + select_ai.disconnect() + + + # --- Transaction Tests --- + def test_commit_and_rollback(self): + test_env.create_connection() + with select_ai.cursor() as cursor: + # Create the table only if it doesn't exist + cursor.execute(""" + begin + execute immediate 'create table test_cr_tab (id int)'; + exception + when others then + if sqlcode != -955 then + raise; + end if; + end; + """) + cursor.execute("commit") + + # Clean up any leftover data + cursor.execute("truncate table test_cr_tab") + + # Test rollback + cursor.execute("insert into test_cr_tab values (1)") + cursor.execute("rollback") + + cursor.execute("select count(*) from test_cr_tab") + (count,) = cursor.fetchone() + self.assertEqual(count, 0, "Rollback should undo the insert.") + + + # --- Lifecycle Tests --- + def test_connection_close(self): + # Create connection + test_env.create_connection() + + # Close connection + select_ai.disconnect() + + # Attempt to use cursor after disconnect should raise InterfaceError + with self.assertRaises(oracledb.InterfaceError): + with select_ai.cursor() as cr: + cr.execute("SELECT 1 FROM DUAL") + + + def test_connection_reclose(self): + # Create connection + test_env.create_connection() + + # First disconnect + select_ai.disconnect() + # Second disconnect should not raise + select_ai.disconnect() + + # Assert connection is closed + is_connected = select_ai.is_connected() + self.assertFalse(is_connected, "Connection should be closed after repeated disconnects") + + + def test_dbms_output(self): + # test dbms_output package + test_env.create_connection() + test_string = "Testing DBMS_OUTPUT package" + + with select_ai.cursor() as cursor: + cursor.callproc("dbms_output.enable") + cursor.callproc("dbms_output.put_line", [test_string]) + string_var = cursor.var(str) + number_var = cursor.var(int) + cursor.callproc("dbms_output.get_line", (string_var, number_var)) + self.assertEqual(string_var.getvalue(), test_string) + + # Disconnect + select_ai.disconnect() + + + def test_connection_instance(self): + # test connection instance name + test_env.create_connection() + + with select_ai.cursor() as cursor: + cursor.execute( + """ + select upper(sys_context('userenv', 'instance_name')) + from dual + """ + ) + (instance_name,) = cursor.fetchone() + self.assertIsInstance(instance_name, str, "Expected service_name to be a string") + + # Disconnect + select_ai.disconnect() + + + def test_max_open_cursors(self): + # test getting max_open_cursors + test_env.create_connection() + + with select_ai.cursor() as cursor: + cursor.execute( + "select value from V$PARAMETER where name='open_cursors'" + ) + (max_open_cursors,) = cursor.fetchone() + + self.assertEqual(1000, int(max_open_cursors)) + + # Disconnect + select_ai.disconnect() + + + def test_get_service_name(self): + # test getting service_name + test_env.create_connection() + + with select_ai.cursor() as cursor: + cursor.execute( + "select sys_context('userenv', 'service_name') from dual" + ) + (service_name,) = cursor.fetchone() + + # Verify service_name is a string + self.assertIsInstance(service_name, str, "Expected service_name to be a string") + + # Disconnect + select_ai.disconnect() + + + def test_create_user_and_table(self): + test_env.create_connection() + + test_username = "TEST_USER1" + test_password = self.password + + # Drop user if exists and create new one + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + + admin_cursor.execute(f"CREATE USER {test_username} IDENTIFIED BY {test_password}") + admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + + + # Connect as test user + test_env.create_connection(user=test_username, password=test_password) + + with select_ai.cursor() as test_cursor: + test_cursor.execute("CREATE TABLE test_table (id INT)") + test_cursor.execute("INSERT INTO test_table (id) VALUES (100)") + # test_user_conn.commit() + + test_cursor.execute("SELECT id FROM test_table") + result = test_cursor.fetchone() + self.assertEqual(result[0], 100) + + # Disconnect + select_ai.disconnect() + + # Clean up user + test_env.create_connection() + with select_ai.cursor() as admin_cursor: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + # admin_conn.commit() + + +if __name__ == "__main__": + # unittest.main() + test_env.run_test_cases() diff --git a/tests/test_create_cred.py b/tests/test_create_cred.py new file mode 100644 index 0000000..befc236 --- /dev/null +++ b/tests/test_create_cred.py @@ -0,0 +1,466 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import time + + +class TestCreateCredential(unittest.TestCase): + def setUp(self): + """ + Setup connection parameters. + """ + self.user = test_env.get_test_user() + self.password = test_env.get_test_password() + self.dsn = test_env.get_localhost_connect_string() + + # Get Native cred secrets + self.user_ocid = test_env.get_user_ocid() + self.tenancy_ocid = test_env.get_tenancy_ocid() + self.private_key = test_env.get_private_key() + self.fingerprint = test_env.get_fingerprint() + + # Get basic cred secrets + self.cred_username = test_env.get_cred_username() + self.cred_password = test_env.get_cred_password() + + + def tearDown(self): + # Disconnect after each test + select_ai.disconnect() + + + def get_native_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + user_ocid = self.user_ocid, + tenancy_ocid = self.tenancy_ocid, + private_key = self.private_key, + fingerprint = self.fingerprint + ) + + + def get_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + username = self.cred_username, + password = self.cred_password + ) + + + def drop_credential_cursor(self, cursor, cred_name='GENAI_CRED'): + cursor.callproc( + "DBMS_CLOUD.DROP_CREDENTIAL", + keyword_parameters={ + "credential_name": cred_name + }, + ) + + + def test_create_cred(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + try: + select_ai.create_credential(credential=credential, replace=False) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_cred_twice(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + # First creation + try: + select_ai.create_credential(credential=credential) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + with self.assertRaises(oracledb.DatabaseError) as cm: + select_ai.create_credential(credential=credential) + + + # Verify specific error code/message + self.assertIn( + 'ORA-20022', + str(cm.exception), + "Expected ORA-20022 error when creating credential without replace" + ) + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_cred_same_data_multiple_times(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + for _ in range(5): + select_ai.create_credential(credential=credential, replace=True) + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_cred_rtrue(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_cred_twice_rtrue(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + # First creation + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Assert passed if no exception raised + self.assertTrue(True, "Credential creation and replacement passed without exception.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_cred_twice_rtrue_rfalse(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + # First creation + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + with self.assertRaises(oracledb.DatabaseError) as cm: + select_ai.create_credential(credential=credential) + + + # Verify specific error code/message + self.assertIn( + 'ORA-20022', + str(cm.exception), + "Expected ORA-20022 error when creating credential without replace" + ) + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_cred_twice_rfalse_rtrue(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + # First creation + try: + select_ai.create_credential(credential=credential) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Assert passed if no exception raised + self.assertTrue(True, "Credential creation and replacement passed without exception.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_native_cred(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get native cred secrets + credential = self.get_native_cred_param('GENAI_CRED') + + try: + select_ai.create_credential(credential=credential, replace=False) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_native_cred_twice(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get native cred secrets + credential = self.get_native_cred_param('GENAI_CRED') + + # First creation + try: + select_ai.create_credential(credential=credential) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + with self.assertRaises(oracledb.DatabaseError) as cm: + select_ai.create_credential(credential=credential) + + # Verify specific error code/message + self.assertIn( + 'ORA-20022', + str(cm.exception), + "Expected ORA-20022 error when creating credential without replace" + ) + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_native_cred_rtrue(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get native cred secrets + credential = self.get_native_cred_param('GENAI_CRED') + + # First creation + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Second creation + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Assert passed if no exception raised + self.assertTrue(True, "Credential creation and replacement passed without exception.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_native_cred_twice_rtrue(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get native cred secrets + credential = self.get_native_cred_param('GENAI_CRED') + + # Create credential + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_cred_empty_name(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param() + + # Verify create credential + with self.assertRaises(Exception) as cm: + select_ai.create_credential(credential=credential) + self.assertIn("ORA-20010: Missing credential name", str(cm.exception)) + + + def test_create_cred_empty_dict(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = dict() + + # Verify create credential + with self.assertRaises(oracledb.DatabaseError) as cm: + select_ai.create_credential(credential=credential) + + self.assertIn( + "PLS-00306: wrong number or types of arguments in call to 'CREATE_CREDENTIAL'", + str(cm.exception) + ) + + + def test_create_cred_invalid_username(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = dict( + credential_name = 'GENAI_CRED', + username = 'invalid_username', + password = self.cred_password + ) + + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_cred_invalid_password(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = dict( + credential_name = 'GENAI_CRED', + username = self.cred_username, + password = 'invalid_pwd' + ) + + try: + select_ai.create_credential(credential=credential, replace=True) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + + def test_create_cred_db_unavailable(self): + # Get credential secret + credential = self.get_cred_param('GENAI_CRED') + + with self.assertRaisesRegex(oracledb.InterfaceError, + r"DPY-1001: not connected to database"): + select_ai.create_credential(credential=credential, replace=False) + + + def test_create_cred_local_user(self): + test_env.create_connection() + + test_username = "TEST_USER1" + test_password = self.password + + # Drop user if exists and create new one + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + + admin_cursor.execute(f"CREATE USER {test_username} IDENTIFIED BY {test_password}") + admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + admin_cursor.execute(f"grant execute on dbms_cloud to {test_username}") + + # Connect as test user + test_env.create_connection(user=test_username, password=test_password) + # Get credential secret + credential = self.get_cred_param('GENAI_CRED_USER1') + + try: + select_ai.create_credential(credential=credential, replace=False) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + # Drop Credential + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor, 'GENAI_CRED_USER1') + + # Disconnect + select_ai.disconnect() + + # Clean up user + test_env.create_connection() + with select_ai.cursor() as admin_cursor: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + + + # Negative Tests + def test_create_cred_special_characters(self): + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param('GENAI_CRED!@#') + + with self.assertRaisesRegex(oracledb.DatabaseError, r"ORA-20010: Invalid credential name"): + select_ai.create_credential(credential=credential, replace=False) + + + def test_create_cred_long_name(self): + long_name = "GENAI_CRED" + "_" + "a" * (128 - len('GENAI_CRED')) + + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + # Get credential secret + credential = self.get_cred_param(long_name) + + with self.assertRaisesRegex( + oracledb.DatabaseError, + r"ORA-20008: Credential name length \(129\) exceeds maximum length \(128\)" + ): + select_ai.create_credential(credential=credential, replace=False) + + + +if __name__ == "__main__": + test_env.run_test_cases() diff --git a/tests/test_create_vector_index.py b/tests/test_create_vector_index.py new file mode 100644 index 0000000..730af8f --- /dev/null +++ b/tests/test_create_vector_index.py @@ -0,0 +1,413 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import time + + +class TestCreateVectorIndex(unittest.TestCase): + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + user_ocid = cls.user_ocid, + tenancy_ocid = cls.tenancy_ocid, + private_key = cls.private_key, + fingerprint = cls.fingerprint + ) + + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + username = cls.cred_username, + password = cls.cred_password + ) + + + @classmethod + def create_credential(cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + # Get credential secret + genai_credential = cls.get_native_cred_param(genai_cred) + objstore_credential = cls.get_cred_param(objstore_cred) + + # Create GenAI Credential + try: + select_ai.create_credential(credential=genai_credential, replace=True) + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + # Create ObjStore Credential + try: + select_ai.create_credential(credential=objstore_credential, replace=True) + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + + @classmethod + def create_profile(cls, profile_name="vector_ai_profile"): + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=cls.oci_compartment_id, + oci_apiformat="GENERIC" + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider + ) + cls.profile = select_ai.Profile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True + ) + + + @classmethod + def delete_profile(cls): + try: + cls.profile.delete() + except Exception as e: + raise AssertionError(f"profile.delete() raised {e} unexpectedly.") + + + @classmethod + def delete_credential(cls): + try: + select_ai.delete_credential("GENAI_CRED", force=True) + except Exception as e: + self.fail(f"delete_credential() raised {e} unexpectedly.") + + try: + select_ai.delete_credential("OBJSTORE_CRED", force=True) + except Exception as e: + self.fail(f"delete_credential() raised {e} unexpectedly.") + + + @classmethod + def setUpClass(cls): + """ + Create Credential, Profile once before all tests. + """ + # Assign password from test_env + cls.user = test_env.get_test_user() + cls.password = test_env.get_test_password() + cls.dsn = test_env.get_localhost_connect_string() + + # Create connection + # test_env.create_connection() + test_env.create_connection( + dsn=cls.dsn, use_wallet=False + ) + assert select_ai.is_connected(), "Connection to DB failed" + + # Get Native cred secrets + cls.user_ocid = test_env.get_user_ocid() + cls.tenancy_ocid = test_env.get_tenancy_ocid() + cls.private_key = test_env.get_private_key() + cls.fingerprint = test_env.get_fingerprint() + + # Get basic cred secrets + cls.cred_username = test_env.get_cred_username() + cls.cred_password = test_env.get_cred_password() + + # Get OCI Provider + cls.oci_compartment_id = test_env.get_compartment_id() + cls.embedding_location = test_env.get_embedding_location() + + # Create Credential + cls.create_credential() + # Create Profile + cls.create_profile() + + + @classmethod + def tearDownClass(cls): + # Delete Profile + cls.delete_profile() + + # Delete Credential + cls.delete_credential() + + # Disconnect from DB + try: + select_ai.disconnect() + except Exception as e: + print(f"Warning: disconnect failed ({e})") + + + def setUp(self): + self.embedding_location = self.__class__.embedding_location + self.profile = self.__class__.profile + self.dsn = self.__class__.dsn + self.objstore_cred = "OBJSTORE_CRED" + + # Specify objects to create an embedding for. + # The objects reside in ObjectStore and the vector database is Oracle + self.vector_index_attributes = select_ai.OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name=self.objstore_cred + ) + + # Create vector index object + self.vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + + + def tearDown(self): + # Delete Vector Index + vector_index = select_ai.VectorIndex(index_name="test_vector_index") + vector_index.delete(force=True) + + + def test_create_vector_index_success(self): + try: + self.vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Verify list + vector_index = select_ai.VectorIndex() + indexes = [i.index_name for i in vector_index.list()] + self.assertIn("TEST_VECTOR_INDEX", indexes) + + + def test_create_vector_index_success_replace_false(self): + try: + self.vector_index.create(replace=False) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Verify list + vector_index = select_ai.VectorIndex() + indexes = [i.index_name for i in vector_index.list()] + self.assertIn("TEST_VECTOR_INDEX", indexes) + + + def test_create_vector_index_empty_description(self): + # Create vector index object + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="", + profile=self.profile + ) + try: + vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Verify list + vector_index = select_ai.VectorIndex() + indexes = [i.index_name for i in vector_index.list()] + self.assertIn("TEST_VECTOR_INDEX", indexes) + + + def test_create_vector_index_replace_true(self): + # First creation + try: + self.vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Second creation + try: + self.vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + + def test_create_vector_index_replace_false(self): + # First creation should succeed + try: + self.vector_index.create(replace=False) + except Exception as e: + self.fail(f"Create vector index failed unexpectedly with exception: {e}") + + # Second creation should raise ORA-20048 + with self.assertRaises(oracledb.DatabaseError) as cm: + self.vector_index.create(replace=False) + + # Verify the error code/message + self.assertIn("ORA-20048", str(cm.exception)) + self.assertIn("already exists", str(cm.exception)) + + + def test_create_vector_index_minimal_attributes(self): + # Create vector index object + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + profile=self.profile + ) + + try: + vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + + def test_create_vector_index_recreate_after_delete(self): + try: + self.vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + # Delete Vector Index + vector_index = select_ai.VectorIndex(index_name="test_vector_index") + vector_index.delete(force=True) + + try: + self.vector_index.create(replace=True) + except Exception as e: + self.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + + # Negative Case + def test_create_vector_index_invalid_credential(self): + vector_index_attributes = select_ai.OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="invalidObjStore_cred" + ) + + # Create vector index object + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + + with self.assertRaises(oracledb.DatabaseError): + vector_index.create(replace=True) + + def test_create_vector_index_invalid_location(self): + vector_index_attributes = select_ai.OracleVectorIndexAttributes( + location="invalid_location", + object_storage_credential_name=self.objstore_cred + ) + + # Create vector index object + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + + with self.assertRaises(oracledb.DatabaseError): + vector_index.create(replace=True) + + + def test_create_vector_index_missing_attributes(self): + with self.assertRaises(AttributeError): + select_ai.VectorIndex( + index_name="test_vector_index", + attributes=None, + profile=self.profile + ).create() + + + def test_create_vector_index_invalid_attributes_type(self): + with self.assertRaises(TypeError): + select_ai.VectorIndex( + index_name="test_vector_index", + attributes="invalid_attributes", # invalid type + profile=self.profile + ).create() + + + def test_create_vector_index_invalid_name_type(self): + with self.assertRaises(oracledb.DatabaseError) as cm: + select_ai.VectorIndex( + index_name=12345, # invalid type (int instead of str) + attributes=self.vector_index_attributes, + profile=self.profile + ).create() + + # Verify error + self.assertIn("ORA-20048", str(cm.exception)) + self.assertIn("Invalid vector index name", str(cm.exception)) + + + def test_create_vector_index_empty_name(self): + with self.assertRaises(oracledb.DatabaseError) as cm: + select_ai.VectorIndex( + index_name="", + attributes=self.vector_index_attributes, + profile=self.profile + ).create() + + # Verify the error code/message + self.assertIn("ORA-20048", str(cm.exception)) + self.assertIn("Missing vector index name", str(cm.exception)) + + + def test_create_vector_index_invalid_profile(self): + # Create vector index object + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="Test vector index", + profile="invalid_profile" + ) + with self.assertRaises(ValueError): + vector_index.create() + + + def test_create_vector_index_none_attributes(self): + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=None, + description="Test vector index", + profile="invalid_profile" + ) + with self.assertRaises(TypeError): + vector_index.create() + + + # Boundary Cases + def test_create_vector_index_long_name(self): + long_name = "X" * 150 # > Oracle identifier length + vector_index = select_ai.VectorIndex( + index_name=long_name, + attributes=self.vector_index_attributes, + profile=self.profile + ) + with self.assertRaises(oracledb.DatabaseError): + vector_index.create() + + + def test_create_vector_index_long_description(self): + long_desc = "D" * 5000 # deliberately too long + + # Create vector index object + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description=long_desc, + profile=self.profile + ) + + # Expect DatabaseError due to description length + with self.assertRaises(oracledb.DatabaseError) as cm: + vector_index.create(replace=True) + + # Verify Oracle error details + self.assertIn("ORA-20045", str(cm.exception)) + self.assertIn("description is too long", str(cm.exception)) + + + def test_create_vector_index_multiple_recreates(self): + for _ in range(10): + self.vector_index.create(replace=True) + + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_disable_provider.py b/tests/test_disable_provider.py new file mode 100644 index 0000000..df068d0 --- /dev/null +++ b/tests/test_disable_provider.py @@ -0,0 +1,248 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import time + + +class TestDisableProvider(unittest.TestCase): + @classmethod + def create_local_user(cls, test_username="TEST_USER1"): + """ + Helper to drop and create a local test user with required grants. + """ + test_password = cls.password + + # Drop user if exists and create new one + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + + admin_cursor.execute(f"CREATE USER {test_username} IDENTIFIED BY {test_password}") + admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + admin_cursor.execute(f"grant execute on dbms_cloud to {test_username}") + + + @classmethod + def setUpClass(cls): + """ + Create DB users once before all tests. + """ + # Create connection + test_env.create_connection() + assert select_ai.is_connected(), "Connection to DB failed" + + # Assign password from test_env so create_local_user can use it + cls.user = test_env.get_test_user() + cls.password = test_env.get_test_password() + cls.dsn = test_env.get_localhost_connect_string() + + + cls.db_users = list() + # Create multiple DB users (DB_USER1 ... DB_USER5) + for i in range(1, 6): + user = f"DB_USER{i}" + cls.create_local_user(user) + cls.db_users.append(user) + + # Create Additional user + cls.create_local_user("DB_USER6") + + + @classmethod + def tearDownClass(cls): + """ + Drop DB users after all tests finish. + """ + cls.db_users.append("DB_USER6") + with select_ai.cursor() as admin_cursor: + for user in cls.db_users: + try: + admin_cursor.execute(f"DROP USER {user} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if already dropped + + # Disconnect from DB + try: + select_ai.disconnect() + except Exception as e: + print(f"Warning: disconnect failed ({e})") + + + def setUp(self): + self.provider_endpoint = "*.openai.azure.com" + self.db_users = self.__class__.db_users + + # Enabling provider with valid users + try: + select_ai.enable_provider( + users=self.db_users, + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"enable_provider() raised {e} unexpectedly.") + + + def test_disable_provider_success(self): + # Disabling provider for valid users and endpoint should succeed. + try: + select_ai.disable_provider( + users=self.db_users, + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"disable_provider() raised {e} unexpectedly.") + + + def test_disable_provider_nonexistent_user(self): + # Disabling provider with non-existent users should raise DatabaseError. + db_users = ["DB_USER1", "TEST_USER2"] + + with self.assertRaises(oracledb.DatabaseError): + select_ai.disable_provider( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + + + def test_disable_provider_nonexistent_users(self): + # Disabling provider with non-existent users should raise DatabaseError. + with self.assertRaises(oracledb.DatabaseError): + select_ai.disable_provider( + users=["INVALID_USER1", "INVALID_USER2"], + provider_endpoint=self.provider_endpoint + ) + + + def test_disable_provider_invalid_users_type_int(self): + # Disabling provider with int as users should raise TypeError/ValueError. + with self.assertRaises((TypeError, ValueError)): + select_ai.disable_provider( + users=123, + provider_endpoint=self.provider_endpoint + ) + + + def test_disable_provider_invalid_users_type_string(self): + # Disabling provider with string instead of list should raise TypeError/ValueError. + with self.assertRaises((TypeError, ValueError)): + select_ai.disable_provider( + users="DB_USER1", + provider_endpoint=self.provider_endpoint + ) + + + def test_disable_provider_invalid_users_type_none(self): + # Disabling provider with None as users should raise TypeError/ValueError. + with self.assertRaises((TypeError, ValueError)): + select_ai.disable_provider( + users=None, + provider_endpoint=self.provider_endpoint + ) + + + def test_disable_provider_missing_endpoint(self): + # None endpoint should raise ValueError. + with self.assertRaises(ValueError): + select_ai.disable_provider( + users=self.db_users, + provider_endpoint=None + ) + + + def test_disable_provider_invalid_endpoint(self): + # Disabling provider with invalid endpoint should raise DatabaseError. + with self.assertRaises(oracledb.DatabaseError): + select_ai.disable_provider( + users=self.db_users, + provider_endpoint="invalid.endpoint" + ) + + + def test_disable_provider_with_empty_users(self): + # Disabling provider with empty users list should succeed without error. + try: + select_ai.disable_provider( + users=[], + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"disable_provider() raised {e} unexpectedly with empty users list.") + + + def test_disable_provider_duplicate_users(self): + # Disabling provider with duplicate users should raise ORA-01927 + with self.assertRaises(oracledb.DatabaseError) as cm: + select_ai.disable_provider( + users=[self.db_users[0], self.db_users[0]], + provider_endpoint=self.provider_endpoint + ) + + # Verify error code is ORA-01927 + self.assertIn("ORA-01927", str(cm.exception)) + + + def test_disable_provider_case_insensitive_username(self): + # Lowercase username should work if DB is case-insensitive + try: + select_ai.disable_provider( + users=[self.db_users[0].lower()], + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"disable_provider() raised {e} unexpectedly with lowercase username.") + + + def test_disable_provider_username_with_whitespace(self): + # Leading/trailing whitespace should be ignored and still succeed + db_users = [f" {self.db_users[0]} "] + try: + select_ai.disable_provider( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"disable_provider() raised {e} unexpectedly with whitespace in username.") + + + def test_disable_provider_valid_custom_endpoint(self): + # Enabling provider with a custom endpoint should raise ORA-24244 + with self.assertRaisesRegex( + oracledb.DatabaseError, + r"ORA-24244: invalid host or port for access control list \(ACL\) assignment" + ): + select_ai.disable_provider( + users=self.db_users, + provider_endpoint="https://custom.openai.azure.com" + ) + + + def test_disable_provider_non_granted_user(self): + # Disabling provider for a user who was never granted access should raise DatabaseError. + non_granted_user = "DB_USER6" + + with self.assertRaises(oracledb.DatabaseError) as cm: + select_ai.disable_provider( + users=[non_granted_user], + provider_endpoint=self.provider_endpoint + ) + + # Optionally check the specific Oracle error code + self.assertIn("ORA-01927", str(cm.exception)) + + + + def test_disable_provider_large_user_list(self): + db_users = [f"DB_USER_{i}" for i in range(1000)] + with self.assertRaises(oracledb.DatabaseError): + select_ai.disable_provider( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_drop_cred.py b/tests/test_drop_cred.py new file mode 100644 index 0000000..5010020 --- /dev/null +++ b/tests/test_drop_cred.py @@ -0,0 +1,187 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import time + + +class TestDropCredential(unittest.TestCase): + def setUp(self): + """ + Setup connection parameters. + """ + self.user = test_env.get_test_user() + self.password = test_env.get_test_password() + self.dsn = test_env.get_localhost_connect_string() + + # Get basic cred secrets + self.cred_username = test_env.get_cred_username() + self.cred_password = test_env.get_cred_password() + + # Create connection + test_env.create_connection() + self.assertTrue(select_ai.is_connected(), "Connection to DB failed") + + def tearDown(self): + # Disconnect after each test + select_ai.disconnect() + + + def get_cred_param(self, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + username = self.cred_username, + password = self.cred_password + ) + + def create_test_credential(self, cred_name="GENAI_CRED"): + """ + Helper to create a test credential. + """ + # Get credential secret + credential = self.get_cred_param(cred_name) + + try: + select_ai.create_credential(credential=credential, replace=False) + except Exception as e: + self.fail(f"create_credential() raised {e} unexpectedly.") + + + def create_local_user(self, test_username="TEST_USER1"): + """ + Helper to drop and create a local test user with required grants. + """ + test_password = self.password + + # Drop user if exists and create new one + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + + admin_cursor.execute(f"CREATE USER {test_username} IDENTIFIED BY {test_password}") + admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + admin_cursor.execute(f"grant execute on dbms_cloud to {test_username}") + + + def test_delete_cred_success(self): + # Create credential + self.create_test_credential() + + try: + select_ai.delete_credential("GENAI_CRED", force=True) + except Exception as e: + self.fail(f"delete_credential() raised {e} unexpectedly.") + + + def test_delete_cred_twice_force_true(self): + # Create credential + self.create_test_credential() + + # First delete should succeed + select_ai.delete_credential("GENAI_CRED", force=True) + + # Second delete should also succeed (no exception, since force=True) + select_ai.delete_credential("GENAI_CRED", force=True) + + + def test_delete_cred_twice_force_false(self): + # Create credential + self.create_test_credential() + + # First delete should succeed + select_ai.delete_credential("GENAI_CRED", force=False) + + # Second delete should raise DatabaseError since credential is already deleted + with self.assertRaises(oracledb.DatabaseError): + select_ai.delete_credential("GENAI_CRED", force=False) + + + def test_delete_nonexistent_cred_default(self): + with self.assertRaises(oracledb.DatabaseError): + select_ai.delete_credential("nonexistent_cred") + + + def test_delete_nonexistent_cred_without_force(self): + with self.assertRaises(oracledb.DatabaseError): + select_ai.delete_credential("nonexistent_cred", force=False) + + + def test_delete_nonexistent_cred_with_force(self): + # Should not raise error when force=True + try: + select_ai.delete_credential("nonexistent_cred", force=True) + except Exception as e: + self.fail(f"delete_credential(force=True) raised {e} unexpectedly.") + + + def test_delete_cred_local_user(self): + test_username = "TEST_USER1" + + self.create_local_user(test_username) + + # Connect as test user + test_env.create_connection(user=test_username, password=self.password) + # Get credential secret + credential = self.get_cred_param("GENAI_CRED_USER1") + + try: + select_ai.delete_credential("GENAI_CRED_USER1", force=True) + except Exception as e: + self.fail(f"delete_credential() raised {e} unexpectedly.") + + # Disconnect + select_ai.disconnect() + + # Clean up user + test_env.create_connection() + with select_ai.cursor() as admin_cursor: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + + + def test_invalid_cred_name(self): + with self.assertRaisesRegex(oracledb.DatabaseError, + r"ORA-20010: Invalid credential name"): + select_ai.delete_credential("invalid!@#", force=True) + + + def test_delete_cred_not_connected(self): + select_ai.disconnect() + with self.assertRaises(select_ai.errors.DatabaseNotConnectedError): + select_ai.delete_credential("GENAI_CRED", force=True) + + + def test_credential_name_too_long(self): + long_name = "GENAI_CRED_" + "a" * 120 + with self.assertRaisesRegex(oracledb.DatabaseError, + r"ORA-20008: Credential name length .* exceeds maximum length"): + select_ai.delete_credential(long_name, force=True) + + + def test_delete_cred_case_sensitive(self): + # Create credential + self.create_test_credential("GENAI_CRED") + + # Try deleting with lower case + try: + select_ai.delete_credential(credential_name="genai_cred") + except Exception as e: + self.fail(f"async_delete_credential raised {e} unexpectedly for lowercase name") + + + def test_delete_cred_empty_name(self): + # Empty string → ORA-20010: Missing credential name + with self.assertRaisesRegex(oracledb.DatabaseError, + r"ORA-20010: Missing credential name"): + select_ai.delete_credential(credential_name="", force=True) + + # None → should also end up with ORA-20010 + with self.assertRaisesRegex(oracledb.DatabaseError, + r"ORA-20010: Missing credential name"): + select_ai.delete_credential(credential_name=None, force=True) + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_enable_provider.py b/tests/test_enable_provider.py new file mode 100644 index 0000000..95d4dcc --- /dev/null +++ b/tests/test_enable_provider.py @@ -0,0 +1,219 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import time + + + +class TestEnableProvider(unittest.TestCase): + @classmethod + def create_local_user(cls, test_username="TEST_USER1"): + """ + Helper to drop and create a local test user with required grants. + """ + test_password = cls.password + + # Drop user if exists and create new one + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + + admin_cursor.execute(f"CREATE USER {test_username} IDENTIFIED BY {test_password}") + admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + admin_cursor.execute(f"grant execute on dbms_cloud to {test_username}") + + + @classmethod + def setUpClass(cls): + """ + Create DB users once before all tests. + """ + # Create connection + test_env.create_connection() + assert select_ai.is_connected(), "Connection to DB failed" + + # Assign password from test_env so create_local_user can use it + cls.user = test_env.get_test_user() + cls.password = test_env.get_test_password() + cls.dsn = test_env.get_localhost_connect_string() + + + cls.db_users = list() + # Create multiple DB users (DB_USER1 ... DB_USER5) + for i in range(1, 6): + user = f"DB_USER{i}" + cls.create_local_user(user) + cls.db_users.append(user) + + + @classmethod + def tearDownClass(cls): + """ + Drop DB users after all tests finish. + """ + with select_ai.cursor() as admin_cursor: + for user in cls.db_users: + try: + admin_cursor.execute(f"DROP USER {user} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if already dropped + + # Disconnect from DB + try: + select_ai.disconnect() + except Exception as e: + print(f"Warning: disconnect failed ({e})") + + + def setUp(self): + self.provider_endpoint = "*.openai.azure.com" + self.db_users = self.__class__.db_users + + + def test_enable_provider_success(self): + # Enabling provider with valid users and endpoint should succeed. + try: + select_ai.enable_provider( + users=self.db_users, + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"enable_provider() raised {e} unexpectedly.") + + + def test_enable_provider_nonexistent_username(self): + # Enabling provider with nonexistent users should raise DatabaseError + db_users = ["DB_USER1", "TEST_USER2"] + + with self.assertRaisesRegex( + oracledb.DatabaseError, + r"ORA-01917: user or role 'TEST_USER2' does not exist" + ): + select_ai.enable_provider( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + + + def test_enable_provider_nonexistent_usernames(self): + # Enabling provider with nonexistent users should raise DatabaseError + db_users = ["TEST_USER1", "TEST_USER2"] + + with self.assertRaisesRegex( + oracledb.DatabaseError, + r"ORA-01917: user or role 'TEST_USER1' does not exist" + ): + select_ai.enable_provider( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + + + def test_enable_provider_empty_users(self): + # Empty users list should raise ValueError. + try: + select_ai.enable_provider( + users=[], + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"enable_provider() raised {e} unexpectedly with empty users.") + + + def test_enable_provider_invalid_users_type(self): + # Passing users as non-list should raise TypeError. + with self.assertRaises(TypeError): + select_ai.enable_provider( + users="DB_USER1", # not a list + provider_endpoint=self.provider_endpoint + ) + + + def test_enable_provider_invalid_users_type_int(self): + # Passing users as non-list should raise TypeError. + with self.assertRaises(TypeError): + select_ai.enable_provider( + users=2, # not a list + provider_endpoint=self.provider_endpoint + ) + + + def test_enable_provider_missing_endpoint(self): + # None endpoint should raise ValueError. + with self.assertRaises(ValueError): + select_ai.enable_provider( + users=self.db_users, + provider_endpoint=None + ) + + + def test_enable_provider_invalid_endpoint(self): + # DatabaseError from underlying call should propagate. + with self.assertRaises(ValueError): + select_ai.enable_provider( + users=self.db_users, + provider_endpoint="invalid.endpoint" + ) + + + def test_enable_provider_duplicate_users(self): + # Duplicate users should not cause failure + try: + select_ai.enable_provider( + users=[self.db_users[0], self.db_users[0]], + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"enable_provider() raised {e} unexpectedly with duplicate users.") + + + def test_enable_provider_case_insensitive_username(self): + # Lowercase username should work if DB is case-insensitive + try: + select_ai.enable_provider( + users=[self.db_users[0].lower()], + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"enable_provider() raised {e} unexpectedly with lowercase username.") + + + def test_enable_provider_username_with_whitespace(self): + # Leading/trailing whitespace should be ignored and still succeed + db_users = [f" {self.db_users[0]} "] + try: + select_ai.enable_provider( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + except Exception as e: + self.fail(f"enable_provider() raised {e} unexpectedly with whitespace in username.") + + + def test_enable_provider_large_user_list(self): + db_users = [f"DB_USER_{i}" for i in range(1000)] + with self.assertRaises(oracledb.DatabaseError): + select_ai.enable_provider( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + + + def test_enable_provider_valid_custom_endpoint(self): + # Enabling provider with a custom endpoint should raise ORA-24244 + with self.assertRaisesRegex( + oracledb.DatabaseError, + r"ORA-24244: invalid host or port for access control list \(ACL\) assignment" + ): + select_ai.enable_provider( + users=self.db_users, + provider_endpoint="https://custom.openai.azure.com" + ) + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file diff --git a/tests/test_env.py b/tests/test_env.py new file mode 100644 index 0000000..ab35646 --- /dev/null +++ b/tests/test_env.py @@ -0,0 +1,275 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2020, 2025, Oracle and/or its affiliates. +# +# This software is dual-licensed to you under the Universal Permissive License +# (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl and Apache License +# 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose +# either license. +# +# If you elect to accept the software under the Apache License, Version 2.0, +# the following applies: +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- + +import importlib +import os +import secrets +import sys +import string +import unittest +from threading import get_ident +from typing import Dict, Hashable + +import select_ai +import oracledb + +# __conn__: Dict[Hashable, oracledb.Connection] = {} +# __async_conn__: Dict[Hashable, oracledb.AsyncConnection] = {} + +DEFAULT_MAIN_USER = "admin" +DEFAULT_CONNECT_STRING = "localhost:1531/GCC59E2CF7A6F5F_ADWP" +DEFAULT_PROXY_USER = "selectai_testuser" + + +# dictionary containing all parameters; these are acquired as needed by the +# methods below (which should be used instead of consulting this dictionary +# directly) and then stored so that a value is not requested more than once +PARAMETERS = {} + + +def _initialize(): + """ + Performs initialization of the select_ai environment. + Ensures that OracleDB mode is set and required plugins are imported. + """ + if PARAMETERS.get("INITIALIZED"): + return + + # Initialize Oracle client if needed + if not get_is_thin() and oracledb.is_thin_mode(): + oracledb.init_oracle_client() + oracledb.defaults.thick_mode_dsn_passthrough = False + + # Load select_ai plugins from environment variable + plugin_names = os.environ.get("SAI_TEST_PLUGINS") + if plugin_names: + for name in plugin_names.split(","): + module_name = f"oracledb.plugins.{name.strip()}" + print(f"Importing module: {module_name}") + importlib.import_module(module_name) + + PARAMETERS["INITIALIZED"] = True + + +def run_test_cases(): + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) + + +def get_value(name, label, default_value=None, password=False): + """Retrieve a value from PARAMETERS or environment.""" + if name in PARAMETERS: + return PARAMETERS[name] + + env_name = "SAI_TEST_" + name + value = os.environ.get(env_name) + + if not value: + value = default_value + + PARAMETERS[name] = value + return value + + +def get_client_version(): + name = "CLIENT_VERSION" + value = PARAMETERS.get(name) + if value is None: + _initialize() + value = oracledb.clientversion()[:2] + PARAMETERS[name] = value + return value + + +def get_connection_args(): + """Get and return connection parameters""" + return { + "user": get_test_user(), + "password": get_test_password(), + "dsn": get_connect_string(), + "wallet_location": get_wallet_location(), + "wallet_password": get_wallet_password(), + } + + +def create_connection(use_wallet=True, **kwargs): + """Create a synchronous connection.""" + conn_args = get_connection_args() + + connect_kwargs = { + "user": kwargs.get("user", conn_args["user"]), + "password": kwargs.get("password", conn_args["password"]), + "dsn": kwargs.get("dsn", conn_args["dsn"]), + } + + if use_wallet: + connect_kwargs.update({ + "config_dir": kwargs.get("wallet_location", conn_args["wallet_location"]), + "wallet_location": kwargs.get("wallet_location", conn_args["wallet_location"]), + "wallet_password": kwargs.get("wallet_password", conn_args["wallet_password"]) + }) + + select_ai.connect(**connect_kwargs) + + +async def create_async_connection(use_wallet=True, **kwargs): + """Create an asynchronous connection.""" + conn_args = get_connection_args() + + connect_kwargs = { + "user": kwargs.get("user", conn_args["user"]), + "password": kwargs.get("password", conn_args["password"]), + "dsn": kwargs.get("dsn", conn_args["dsn"]), + } + + if use_wallet: + connect_kwargs.update({ + "config_dir": kwargs.get("wallet_location", conn_args["wallet_location"]), + "wallet_location": kwargs.get("wallet_location", conn_args["wallet_location"]), + "wallet_password": kwargs.get("wallet_password", conn_args["wallet_password"]) + }) + + print(connect_kwargs) + await select_ai.async_connect(**connect_kwargs) + + +def get_connect_string(): + return get_value( + "CONNECT_STRING", "Connect String", DEFAULT_CONNECT_STRING + ) + +def get_localhost_connect_string(): + return "localhost:1531/GCC59E2CF7A6F5F_ADWP" + + +def get_is_thin(): + driver_mode = get_value("DRIVER_MODE", "Driver mode (thin|thick)", "thin") + return driver_mode == "thin" + + +def get_test_password(): + return get_value( + "PASSWORD", f"Password for {get_test_user()}", password=True + ) + + +def get_test_user(): + return get_value("USER", "Test User Name", DEFAULT_MAIN_USER) + +def get_proxy_user(): + return get_value("PROXY_USER", "Proxy User Name", DEFAULT_PROXY_USER) + + +def get_wallet_location(): + return get_value("WALLET_LOCATION", "Wallet Location") + + +def get_cred_username(): + return get_value("CRED_USERNAME", "OCI credential username") + + +def get_cred_password(): + return get_value("CRED_PASSWORD", "OCI credential password") + + +def get_user_ocid(): + return get_value("USER_OCID", "user ocid") + + +def get_tenancy_ocid(): + return get_value("TENANCY_OCID", "tenancy ocid") + + +def get_private_key(): + return get_value("PRIVATE_KEY", "private key") + + +def get_fingerprint(): + return get_value("FINGERPRINT", "fingerprint") + + +def get_wallet_password(): + return get_value("WALLET_PASSWORD", "Wallet Password", password=True) + + +def get_compartment_id(provider="OCI"): + return get_value(f"{provider}_COMPARTMENT_ID", "Compartment ID") + + +def get_embedding_location(): + return get_value("EMBEDDING_LOCATION", "Vector Embedding Location") + + +def get_random_string(length=10): + return "".join(secrets.choice(string.ascii_letters) for i in range(length)) + + +def has_client_version(major_version, minor_version=0): + if get_is_thin(): + return True + return get_client_version() >= (major_version, minor_version) + + +def has_server_version(major_version, minor_version=0): + return get_server_version() >= (major_version, minor_version) + + +def run_sql_script(conn, script_name, **kwargs): + statement_parts = [] + cursor = conn.cursor() + replace_values = [("&" + k + ".", v) for k, v in kwargs.items()] + [ + ("&" + k, v) for k, v in kwargs.items() + ] + script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) + file_name = os.path.join(script_dir, "sql", script_name + ".sql") + for line in open(file_name): + if line.strip() == "/": + statement = "".join(statement_parts).strip() + if statement: + for search_value, replace_value in replace_values: + statement = statement.replace(search_value, replace_value) + try: + cursor.execute(statement) + except: + print("Failed to execute SQL:", statement) + raise + statement_parts = [] + else: + statement_parts.append(line) + cursor.execute( + """ + select name, type, line, position, text + from dba_errors + where owner = upper(:owner) + order by name, type, line, position + """, + owner=get_test_user(), + ) + prev_name = prev_obj_type = None + for name, obj_type, line_num, position, text in cursor: + if name != prev_name or obj_type != prev_obj_type: + print("%s (%s)" % (name, obj_type)) + prev_name = name + prev_obj_type = obj_type + print(" %s/%s %s" % (line_num, position, text)) + diff --git a/tests/test_list_vector_index.py b/tests/test_list_vector_index.py new file mode 100644 index 0000000..205336b --- /dev/null +++ b/tests/test_list_vector_index.py @@ -0,0 +1,368 @@ +import unittest +import select_ai +import test_env +import oracledb +import os +import re + + +class TestListVectorIndex(unittest.TestCase): + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + user_ocid = cls.user_ocid, + tenancy_ocid = cls.tenancy_ocid, + private_key = cls.private_key, + fingerprint = cls.fingerprint + ) + + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + return dict( + credential_name = cred_name, + username = cls.cred_username, + password = cls.cred_password + ) + + + @classmethod + def create_credential(cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + # Get credential secret + genai_credential = cls.get_native_cred_param(genai_cred) + objstore_credential = cls.get_cred_param(objstore_cred) + + # Create GenAI Credential + try: + select_ai.create_credential(credential=genai_credential, replace=True) + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + # Create ObjStore Credential + try: + select_ai.create_credential(credential=objstore_credential, replace=True) + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + + @classmethod + def create_profile(cls, profile_name="vector_ai_profile"): + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=cls.oci_compartment_id, + oci_apiformat="GENERIC" + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider + ) + cls.profile = select_ai.Profile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True + ) + + + @classmethod + def delete_profile(cls): + try: + cls.profile.delete() + except Exception as e: + raise AssertionError(f"profile.delete() raised {e} unexpectedly.") + + + @classmethod + def delete_credential(cls): + try: + select_ai.delete_credential("GENAI_CRED", force=True) + except Exception as e: + self.fail(f"delete_credential() raised {e} unexpectedly.") + + try: + select_ai.delete_credential("OBJSTORE_CRED", force=True) + except Exception as e: + self.fail(f"delete_credential() raised {e} unexpectedly.") + + + @classmethod + def create_vector_index(cls, index_name): + # Specify objects to create an embedding for. + # The objects reside in ObjectStore and the vector database is Oracle + vector_index_attributes = select_ai.OracleVectorIndexAttributes( + location=cls.embedding_location, + object_storage_credential_name=cls.objstore_cred + ) + + # Create vector index object + vector_index = select_ai.VectorIndex( + index_name=index_name, + attributes=vector_index_attributes, + description="Test vector index", + profile=cls.profile + ) + + # Create vector index + vector_index.create(replace=True) + + + @classmethod + def setUpClass(cls): + """ + Create Credential, Profile once before all tests. + """ + # Assign password from test_env + cls.user = test_env.get_test_user() + cls.password = test_env.get_test_password() + cls.dsn = test_env.get_localhost_connect_string() + + # Create connection + # test_env.create_connection() + test_env.create_connection( + dsn=cls.dsn, use_wallet=False + ) + assert select_ai.is_connected(), "Connection to DB failed" + + # Get Native cred secrets + cls.user_ocid = test_env.get_user_ocid() + cls.tenancy_ocid = test_env.get_tenancy_ocid() + cls.private_key = test_env.get_private_key() + cls.fingerprint = test_env.get_fingerprint() + + # Get basic cred secrets + cls.cred_username = test_env.get_cred_username() + cls.cred_password = test_env.get_cred_password() + + # Get OCI Provider + cls.oci_compartment_id = test_env.get_compartment_id() + cls.embedding_location = test_env.get_embedding_location() + + # Create Credential + cls.create_credential() + # Create Profile + cls.create_profile() + + cls.objstore_cred = "OBJSTORE_CRED" + + # Create some vector indexes + cls.indexes = [f"test_vector_index{i}" for i in range(1, 6)] + \ + [f"test_vecidx{i}" for i in range(1, 3)] + for idx in cls.indexes: + try: + cls.create_vector_index(index_name=idx) + except Exception: + pass + + + @classmethod + def tearDownClass(cls): + # Clean up test indexes and close connection. + for idx in cls.indexes: + try: + # Delete Vector Index + vector_index = select_ai.VectorIndex(index_name_pattern=idx) + vector_index.delete(force=True) + except Exception: + pass + + # Delete Profile + cls.delete_profile() + + # Delete Credential + cls.delete_credential() + + # Disconnect from DB + try: + select_ai.disconnect() + except Exception as e: + print(f"Warning: disconnect failed ({e})") + + + def setUp(self): + self.indexes = self.__class__.indexes + self.vector_index = select_ai.VectorIndex() + + + # ---------- Positive Test Cases ---------- + def test_list_matching_names(self): + expected_index_names = [f"test_vector_index{i}".upper() for i in range(1, 6)] + \ + [f"test_vecidx{i}".upper() for i in range(1, 3)] + + actual_indexes = list(self.vector_index.list(index_name_pattern=".*")) + + # Verify count of indexes + self.assertEqual( + len(actual_indexes), + len(expected_index_names), + f"Expected {len(expected_index_names)} indexes, got {len(actual_indexes)}" + ) + + # Verify each index name + for index, expected_name in zip(actual_indexes, expected_index_names): + self.assertEqual( + index.index_name, + expected_name, + f"Index name mismatch: expected {expected_name}, got {index.index_name}" + ) + + + def test_list_matching_profile_name(self): + expected_profile = "vector_ai_profile" + for index in self.vector_index.list(index_name_pattern=".*"): + # Verify profile name + self.assertEqual( + index.profile.profile_name, + expected_profile, + f"Profile mismatch for {index.index_name}: expected {expected_profile}, got {index.profile.profile_name}" + ) + + + def test_list_matching_credential_name(self): + expected_credential = "OBJSTORE_CRED" + for index in self.vector_index.list(index_name_pattern=".*"): + # Verify object store credential + self.assertEqual( + index.attributes.object_storage_credential_name, + expected_credential, + f"Credential mismatch for {index.index_name}: expected {expected_credential}, got {index.attributes.object_storage_credential_name}" + ) + + + def test_list_matching_description(self): + expected_description = "Test vector index" + for index in self.vector_index.list(index_name_pattern=".*"): + # Verify description + self.assertEqual( + index.description, + expected_description, + f"Description mismatch for {index.index_name}: expected {expected_description}, got {index.description}" + ) + + + def test_list_exact_match(self): + indexes = self.vector_index.list(index_name_pattern="^test_vector_index1$") + self.assertEqual(list(indexes)[0].index_name, "TEST_VECTOR_INDEX1") + + + def test_list_multiple_matches(self): + actual_indexes = list(self.vector_index.list(index_name_pattern="^test_vector_index")) + # Verify count + expected_count = 5 + self.assertEqual( + len(list(actual_indexes)), + expected_count, + f"Expected {expected_count} indexes, got {len(list(actual_indexes))}" + ) + + + # Verify each index name + for i, index in enumerate(actual_indexes, start=1): + expected_index_name = f"TEST_VECTOR_INDEX{i}" + self.assertEqual( + index.index_name, + expected_index_name, + f"Index name mismatch: expected {expected_index_name}, got {index.index_name}" + ) + + + def test_list_case_sensitive_pattern(self): + indexes = self.vector_index.list("^TEST_VECTOR_INDEX?") + self.assertTrue(any(idx.index_name == "TEST_VECTOR_INDEX2" for idx in indexes)) + + + def test_list_case_insensitive_pattern(self): + indexes = self.vector_index.list("(?i)^TEST") + + # for index in indexes: + # print(index.index_name) + self.assertTrue(any(idx.index_name == "TEST_VECTOR_INDEX1" for idx in indexes)) + + + def test_list_complex_regex_or_operator(self): + indexes = self.vector_index.list("^(test_vector_index|test_vecidx)") + names = [idx.index_name for idx in indexes] + self.assertIn("TEST_VECTOR_INDEX1", names) + self.assertIn("TEST_VECIDX1", names) + + # Invalid Index + self.assertNotIn("INVALID_VECIDX1", names) + + + # ----- Negative Cases ----- + def test_list_non_matching_pattern(self): + indexes = self.vector_index.list(index_name_pattern="^xyz") + self.assertEqual(len(list(indexes)), 0) + + + def test_list_invalid_regex_pattern(self): + # Expect Oracle to raise an error for invalid regex + with self.assertRaises(oracledb.DatabaseError) as cm: + list(self.vector_index.list("[unclosed")) + + # Optional: verify the error code/message + self.assertIn("ORA-12726", str(cm.exception), + f"Expected ORA-12726 error, got {cm.exception}") + + + # def test_list_invalid_type_pattern(self): + # with self.assertRaises(TypeError): + # self.vector_index.list(123) + + + def test_list_invalid_type_pattern(self): + indexes = list(self.vector_index.list(123)) + + # Should just return an empty list (no matches) + self.assertEqual( + len(indexes), 0, + f"Expected 0 indexes for invalid type pattern, got {len(indexes)}" + ) + + # ----- Edge Cases ----- + def test_list_none_pattern_match(self): + indexes = self.vector_index.list(None) + self.assertNotEqual(len(list(indexes)), len(self.indexes)) + + def test_list_empty_string_pattern_matches(self): + indexes = self.vector_index.list("") + self.assertNotEqual(len(list(indexes)), len(self.indexes)) + + + def test_list_whitespace_pattern(self): + indexes = self.vector_index.list(" ") + self.assertEqual(len(list(indexes)), 0) + + + def test_list_numeric_pattern(self): + indexes = list(self.vector_index.list("test123")) + + # Expect no matches, so indexes should be empty + self.assertEqual( + len(indexes), + 0, + f"Expected no indexes to match 'test123', but got {len(indexes)}" + ) + + def test_list_special_characters_in_pattern(self): + indexes = self.vector_index.list("test_vector_index1$") + self.assertEqual(len(list(indexes)), 1) + + def test_list_long_pattern_no_match(self): + pattern = "^" + "a" * 1000 + "$" + + # Expect Oracle to raise regex-too-long error + with self.assertRaises(oracledb.DatabaseError) as cm: + list(self.vector_index.list(pattern)) + + # Optional: check correct Oracle error code + self.assertIn("ORA-12733", str(cm.exception), + f"Expected ORA-12733 error, got {cm.exception}") + + + def test_list_case_insensitive_match(self): + indexes = self.vector_index.list("^TEST") + self.assertEqual(len(list(indexes)), 7) + + +if __name__ == "__main__": + test_env.run_test_cases() \ No newline at end of file