diff --git a/typesense/src/traits/field_type.rs b/typesense/src/traits/field_type.rs index 9185828..da2b52d 100644 --- a/typesense/src/traits/field_type.rs +++ b/typesense/src/traits/field_type.rs @@ -1,3 +1,4 @@ +use crate::traits::Document; use std::collections::{BTreeMap, HashMap}; /// Type for a field. Currently it is a wrapping to a `String` but it could be extended to a enum pub type FieldType = String; @@ -8,6 +9,19 @@ pub trait ToTypesenseField { /// Static function that should implement the types of the typesense documents. fn to_typesense_type() -> &'static str; } +/// Generic implementation for any type that is also a Typesense document. +impl ToTypesenseField for T { + fn to_typesense_type() -> &'static str { + "object" + } +} + +/// Generic implementation for a Vec of any type that is also a Typesense document. +impl ToTypesenseField for Vec { + fn to_typesense_type() -> &'static str { + "object[]" + } +} /// macro used internally to add implementations of ToTypesenseField for several rust types. #[macro_export] diff --git a/typesense/tests/client/derive_integration_test.rs b/typesense/tests/client/derive_integration_test.rs new file mode 100644 index 0000000..89178f9 --- /dev/null +++ b/typesense/tests/client/derive_integration_test.rs @@ -0,0 +1,723 @@ +use serde::{Deserialize, Serialize}; +use typesense::Typesense; +use typesense::models::Field; +use typesense::models::SearchParameters; +use typesense::prelude::*; + +use crate::{get_client, new_id}; + +/// A nested struct for deep nesting test. +#[derive(Typesense, Serialize, Deserialize, Debug, PartialEq, Clone)] +struct ExtraDetails { + model_year: i32, + #[typesense(facet)] + color: String, +} + +#[derive(Typesense, Serialize, Deserialize, Debug, PartialEq, Clone)] +struct SupplierInfo { + name: String, + contact: String, +} + +#[derive(Typesense, Serialize, Deserialize, Debug, PartialEq, Clone)] +struct Part { + part_id: String, + #[typesense(flatten)] + supplier: SupplierInfo, +} + +/// A nested struct that will be flattened into the parent. +#[derive(Typesense, Serialize, Deserialize, Debug, PartialEq, Clone)] +struct ProductDetails { + #[typesense(facet)] + part_number: String, + #[typesense(sort = false)] + weight_kg: f32, + #[typesense(skip)] + desc: String, + #[typesense(flatten)] + extra_details: ExtraDetails, +} +/// A nested struct that will be flattened and renamed. +#[derive(Typesense, Serialize, Deserialize, Debug, PartialEq, Clone)] +struct Logistics { + warehouse_code: String, + shipping_class: String, +} + +/// A nested struct that will be indexed as a single "object". +#[derive(Typesense, Serialize, Deserialize, Debug, PartialEq, Clone)] +struct Manufacturer { + name: String, + city: String, +} + +/// The main struct that uses every feature of the derive macro. +#[derive(Typesense, Serialize, Deserialize, Debug, PartialEq, Clone)] +#[typesense( + collection_name = "mega_products", + default_sorting_field = "price", + enable_nested_fields = true, + token_separators = ["-", "/"], + symbols_to_index = ["+"] +)] +struct MegaProduct { + id: String, + + #[typesense(infix, stem)] + title: String, + + #[typesense(rename = "product_name")] + #[serde(rename = "product_name")] + official_name: String, + + #[typesense(facet)] + brand: String, + + #[typesense(sort)] + price: f32, + + #[typesense(range_index)] + review_score: f32, + + #[typesense(index = false, store = false)] + internal_sku: Option, + + #[typesense(type = "geopoint")] + location: (f32, f32), + + #[typesense(num_dim = 4, vec_dist = "cosine")] + embedding: Vec, + + #[typesense(flatten, skip)] + details: ProductDetails, + + #[typesense(flatten, rename = "logistics_data")] + #[serde(rename = "logistics_data")] + logistics: Logistics, + + manufacturer: Manufacturer, + + #[typesense(flatten)] + parts: Vec, + + tags: Option>, + + #[typesense(rename = "primary_address.city")] + #[serde(rename = "primary_address.city")] + primary_city: String, + + #[typesense(locale = "vi")] + locale: String, + + #[typesense(optional)] + qty: i32, +} + +async fn logic_test_derive_macro_with_generic_client_lifecycle() { + let client = get_client(); + let collection_name = new_id("mega_products_test"); + + // Create Collection using the schema from the derive macro + let schema = MegaProduct::collection_schema(); + let mut schema_for_creation = schema.clone(); + schema_for_creation.name = collection_name.clone(); // Use the unique name + + let create_res = client.collections().create(schema_for_creation).await; + assert!( + create_res.is_ok(), + "Failed to create collection: {:?}", + create_res.err() + ); + + // Verify the schema on the server with targeted assertions + let retrieved_schema = client + .collection_schemaless(&collection_name) + .retrieve() + .await + .unwrap(); + + // Create a map of the actual fields for easy lookup. + let actual_fields_map: std::collections::HashMap = retrieved_schema + .fields + .into_iter() + .map(|f| (f.name.clone(), f)) + .collect(); + + // Iterate through our *expected* fields and assert only the attributes we set. + for expected_field in schema.fields { + let field_name = &expected_field.name; + // The 'id' field is a special primary key and not listed in the schema's "fields" array. + if field_name == "id" { + continue; + } + let actual_field = actual_fields_map.get(field_name).unwrap_or_else(|| { + panic!( + "Field '{}' expected but not found in retrieved schema", + field_name + ) + }); + + // Perform targeted checks based on the attributes set in MegaProduct struct + match field_name.as_str() { + "title" => { + assert_eq!( + actual_field.infix, + Some(true), + "Field 'title' should have infix: true" + ); + assert_eq!( + actual_field.stem, + Some(true), + "Field 'title' should have stem: true" + ); + } + "product_name" => { + // This is the renamed `official_name` + assert_eq!( + actual_field.name, "product_name", + "Field 'official_name' should be renamed to 'product_name'" + ); + } + "brand" => { + assert_eq!( + actual_field.facet, + Some(true), + "Field 'brand' should have facet: true" + ); + } + "price" => { + assert_eq!( + actual_field.sort, + Some(true), + "Field 'price' should have sort: true" + ); + } + "review_score" => { + assert_eq!( + actual_field.range_index, + Some(true), + "Field 'review_score' should have range_index: true" + ); + } + "internal_sku" => { + assert_eq!( + actual_field.index, + Some(false), + "Field 'internal_sku' should have index: false" + ); + assert_eq!( + actual_field.store, + Some(false), + "Field 'internal_sku' should have store: false" + ); + } + "location" => { + assert_eq!( + actual_field.r#type, "geopoint", + "Field 'location' should have type: 'geopoint'" + ); + } + "embedding" => { + assert_eq!( + actual_field.num_dim, + Some(4), + "Field 'embedding' should have num_dim: 4" + ); + assert_eq!( + actual_field.vec_dist.as_deref(), + Some("cosine"), + "Field 'embedding' should have vec_dist: 'cosine'" + ); + } + "manufacturer" => { + assert_eq!( + actual_field.r#type, "object", + "Field 'manufacturer' should have type: 'object'" + ); + } + "tags" => { + assert_eq!( + actual_field.optional, + Some(true), + "Field 'tags' should be optional" + ); + assert_eq!( + actual_field.r#type, "string[]", + "Field 'tags' should have type 'string[]'" + ); + } + "details" => { + assert!(false, "Parent field 'details' should have been skipped") + } + "details.part_number" => { + assert_eq!( + actual_field.facet, + Some(true), + "Flattened field 'details.part_number' should have facet: true" + ); + } + "details.weight_kg" => { + assert_eq!( + actual_field.sort, + Some(false), + "Flattened field 'details.weight_kg' should have sort: false" + ); + } + "details.extra_details" => { + assert_eq!( + actual_field.r#type, "object", + "Field 'details.extra_details' should have type: 'object'" + ); + } + "details.extra_details.model_year" => { + assert_eq!( + actual_field.r#type, "int32", + "Field 'details.extra_details.model_year' should have type: 'int32'" + ); + } + "details.extra_details.color" => { + assert_eq!( + actual_field.r#type, "string", + "Field 'details.extra_details.color' should have type: 'string'" + ); + assert_eq!( + actual_field.facet, + Some(true), + "Field 'details.extra_details.color' should have facet: true" + ); + } + "details.desc" => { + assert!( + false, + "Flattened field 'details.desc' should have been skipped" + ); + } + "logistics_data" => { + assert_eq!( + actual_field.r#type, "object", + "Renamed field 'logistics_data' should have type: 'object'" + ) + } + "logistics_data.warehouse_code" => { + assert_eq!(actual_field.name, "logistics_data.warehouse_code"); + } + "logistics_data.shipping_class" => { + assert_eq!(actual_field.name, "logistics_data.shipping_class"); + } + + "primary_address.city" => { + assert_eq!(actual_field.r#type, "string") + } + "parts" => { + assert_eq!( + actual_field.r#type, "object[]", + "Field 'parts' should have type 'object[]'" + ); + } + "parts.part_id" => { + assert_eq!( + actual_field.r#type, "string[]", + "Field 'parts.part_id' should have type 'string[]'" + ); + } + "parts.supplier" => { + assert_eq!( + actual_field.r#type, "object[]", + "Field 'parts.supplier' should have type 'object[]'" + ); + } + "parts.supplier.name" => { + assert_eq!( + actual_field.r#type, "string[]", + "Field 'parts.supplier.name' should have type 'string[]'" + ); + } + "parts.supplier.contact" => { + assert_eq!( + actual_field.r#type, "string[]", + "Field 'parts.supplier.contact' should have type 'string[]'" + ); + } + "locale" => { + assert_eq!( + actual_field.locale, + Some("vi".to_owned()), + "Field 'locale' should have locale of 'vi'" + ); + } + "qty" => { + assert_eq!( + actual_field.optional, + Some(true), + "Field 'qty' should have been optional" + ); + } + _ => { + // If we add a new field to MegaProduct, this panic will remind us to add a check for it. + panic!( + "Unhandled field '{}' in test assertion. Please add a check.", + field_name + ); + } + } + } + + // Create Documents using the strongly-typed client + let typed_collection = client.collection_named::(&collection_name); + let documents_client = typed_collection.documents(); + + let mut product1 = MegaProduct { + id: "product-1".to_owned(), + title: "Durable Steel Wrench".to_owned(), + official_name: "The Wrenchmaster 3000+".to_owned(), + brand: "MegaTools".to_owned(), + price: 29.99, + review_score: 4.8, + internal_sku: Some("INTERNAL-123".to_owned()), + location: (34.05, -118.24), + embedding: vec![0.1, 0.2, 0.3, 0.4], + details: ProductDetails { + part_number: "MT-WM-3000".to_owned(), + weight_kg: 1.5, + desc: "A high-quality wrench for all your needs.".to_owned(), + extra_details: ExtraDetails { + model_year: 2023, + color: "Red".to_string(), + }, + }, + logistics: Logistics { + warehouse_code: "WH-US-WEST-05".to_owned(), + shipping_class: "GROUND_FREIGHT".to_owned(), + }, + manufacturer: Manufacturer { + name: "MegaTools Inc.".to_owned(), + city: "Toolsville".to_owned(), + }, + parts: vec![ + Part { + part_id: "p-01".to_string(), + supplier: SupplierInfo { + name: "Supplier A".to_string(), + contact: "contact@supplier-a.com".to_string(), + }, + }, + Part { + part_id: "p-02".to_string(), + supplier: SupplierInfo { + name: "Supplier B".to_string(), + contact: "contact@supplier-b.com".to_string(), + }, + }, + ], + tags: Some(vec!["steel".to_owned(), "heavy-duty".to_owned()]), + primary_city: "City".to_owned(), + locale: "Xin chào!".to_owned(), + qty: 123, + }; + + let create_res = documents_client.create(&product1, None).await; + assert!( + create_res.is_ok(), + "Failed to create typed document: {:?}", + create_res.err() + ); + // we set store: false for internal_sku so it should not be present in the response + product1.internal_sku = None; + assert_eq!(create_res.unwrap(), product1); + + // Retrieve Document and verify deserialization + let retrieve_res = typed_collection.document("product-1").retrieve().await; + assert!(retrieve_res.is_ok(), "Failed to retrieve typed document"); + assert_eq!(retrieve_res.unwrap(), product1); + + // Search and Filter (Testing attributes) + // A. Search a normal field + let search_res1: Result< + typesense::models::SearchResult, + typesense::Error, + > = documents_client + .search(SearchParameters { + q: Some("Wrench".to_owned()), + query_by: Some("title".to_owned()), + ..Default::default() + }) + .await; + assert_eq!(search_res1.unwrap().found, Some(1)); + + // B. Search a renamed field + let search_res2 = documents_client + .search(SearchParameters { + q: Some("Wrenchmaster".to_owned()), + query_by: Some("product_name".to_owned()), + ..Default::default() + }) + .await; + assert_eq!(search_res2.unwrap().found, Some(1)); + + // C. Filter by a facet + let search_params3 = SearchParameters { + q: Some("*".to_owned()), + query_by: Some("title".to_owned()), + filter_by: Some("brand:='MegaTools'".to_owned()), + ..Default::default() + }; + let search_res3 = documents_client.search(search_params3).await; + assert_eq!(search_res3.unwrap().found, Some(1)); + + // D. Filter by a range_index + let search_params4 = SearchParameters { + q: Some("*".to_owned()), + query_by: Some("title".to_owned()), + filter_by: Some("review_score:>4.5".to_owned()), + ..Default::default() + }; + let search_res4 = documents_client.search(search_params4).await; + assert_eq!(search_res4.unwrap().found, Some(1)); + + // E. Search a flattened field + let search_params5 = SearchParameters { + q: Some("MT-WM-3000".to_owned()), + query_by: Some("details.part_number".to_owned()), + ..Default::default() + }; + let search_res5 = documents_client.search(search_params5).await; + assert_eq!(search_res5.unwrap().found, Some(1)); + + // F. Filter by a deep nested field + let search_params_deep = SearchParameters { + q: Some("*".to_owned()), + query_by: Some("title".to_owned()), + filter_by: Some("details.extra_details.color:='Red'".to_owned()), + ..Default::default() + }; + let search_res_deep = documents_client.search(search_params_deep).await; + assert_eq!( + search_res_deep.unwrap().found, + Some(1), + "Should find by deep nested field" + ); + + let search_params6 = SearchParameters { + q: Some("WH-US-WEST-05".to_owned()), + query_by: Some("logistics_data.warehouse_code".to_owned()), + ..Default::default() + }; + let search_res6 = documents_client.search(search_params6).await; + assert_eq!( + search_res6.unwrap().found, + Some(1), + "Should find by flattened field with a custom prefix" + ); + + // G. Search a field in a nested object array + let search_params7 = SearchParameters { + q: Some("p-01".to_owned()), + query_by: Some("parts.part_id".to_owned()), + ..Default::default() + }; + let search_res7 = documents_client.search(search_params7).await; + assert_eq!( + search_res7.unwrap().found, + Some(1), + "Should find by field in nested object array" + ); + + // H. Search a field in a flattened nested object array + let search_params8 = SearchParameters { + q: Some("Supplier A".to_owned()), + query_by: Some("parts.supplier.name".to_owned()), + ..Default::default() + }; + let search_res8 = documents_client.search(search_params8).await; + assert_eq!( + search_res8.unwrap().found, + Some(1), + "Should find by field in flattened nested object array" + ); + + // Update Document (with a partial struct) + let update_payload = MegaProductPartial { + price: Some(25.99), + tags: Some(Some(vec!["steel".to_owned(), "sale".to_owned()])), + ..Default::default() + }; + + let update_res = typed_collection + .document("product-1") + .update(&update_payload, None) + .await; + assert!(update_res.is_ok(), "Failed to update document"); + + // Retrieve again and check updated fields + let updated_product = typed_collection + .document("product-1") + .retrieve() + .await + .unwrap(); + assert_eq!(updated_product.price, 25.99); + assert_eq!( + updated_product.tags, + Some(vec!["steel".to_owned(), "sale".to_owned()]) + ); + assert_eq!(updated_product.title, product1.title); // Unchanged field + + // Delete Document + let delete_res = typed_collection.document("product-1").delete().await; + assert!(delete_res.is_ok(), "Failed to delete document"); + // Returned document should be the state before deletion + assert_eq!(delete_res.unwrap().id, "product-1"); + + // Verify Deletion + let retrieve_after_delete = typed_collection.document("product-1").retrieve().await; + assert!( + retrieve_after_delete.is_err(), + "Document should not exist after deletion" + ); +} + +// Indexing nested objects via flattening test + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +struct ManualProductDetails { + part_number: String, + weight_kg: f32, +} + +#[derive(Typesense, Serialize, Deserialize, Debug, PartialEq, Clone)] +#[typesense( + collection_name = "manual_flat_products", + // IMPORTANT: Nested fields are disabled for this strategy. + enable_nested_fields = false +)] +struct ManualFlattenedProduct { + id: String, + title: String, + + // This field is part of the Rust struct and will be in the JSON document, + // but it will NOT be part of the Typesense schema. + #[typesense(skip)] + details: ManualProductDetails, + + // These fields represent the flattened data in the Typesense schema. + // Both `typesense(rename)` and `serde(rename)` are used to achieve the desired structure. + #[typesense(rename = "details.part_number")] + #[serde(rename = "details.part_number")] + details_part_number: String, + + #[typesense(rename = "details.weight_kg")] + #[serde(rename = "details.weight_kg")] + details_weight_kg: f32, +} + +async fn logic_test_manual_flattening_lifecycle() { + let client = get_client(); + let collection_name = new_id("manual_flat_test"); + + // 1. Create collection from the schema derived from `ManualFlattenedProduct` + let mut schema = ManualFlattenedProduct::collection_schema(); + schema.name = collection_name.clone(); + + // Verify the generated schema is correct *before* creating it + let schema_fields: Vec<_> = schema.fields.iter().map(|f| f.name.as_str()).collect(); + assert!( + !schema_fields.contains(&"details"), + "Schema should not contain the skipped 'details' field" + ); + assert!( + schema_fields.contains(&"details.part_number"), + "Schema must contain the renamed 'details.part_number' field" + ); + + let create_res = client.collections().create(schema).await; + assert!( + create_res.is_ok(), + "Failed to create collection: {:?}", + create_res.err() + ); + + let typed_collection = client.collection_named::(&collection_name); + + // 2. Create the document. Note how we populate all fields of the Rust struct. + let product = ManualFlattenedProduct { + id: "manual-1".to_owned(), + title: "Portable Generator".to_owned(), + details: ManualProductDetails { + part_number: "PG-123".to_owned(), + weight_kg: 25.5, + }, + details_part_number: "PG-123".to_owned(), + details_weight_kg: 25.5, + }; + + let create_res = typed_collection.documents().create(&product, None).await; + assert!( + create_res.is_ok(), + "Failed to create document with manual flattening" + ); + + // The created document in the response should be equal to our input struct. + assert_eq!(create_res.unwrap(), product); + + // 3. Retrieve and verify the document. + let retrieved_product = typed_collection + .document("manual-1") + .retrieve() + .await + .unwrap(); + assert_eq!(retrieved_product, product); + // We can access the nested struct for display purposes, even though it wasn't indexed. + assert_eq!(retrieved_product.details.part_number, "PG-123"); + + // 4. Search using the flattened (and indexed) field. + let search_res_indexed = typed_collection + .documents() + .search(SearchParameters { + q: Some("PG-123".to_owned()), + query_by: Some("details.part_number".to_owned()), + ..Default::default() + }) + .await + .unwrap(); + assert_eq!( + search_res_indexed.found, + Some(1), + "Should find document by indexed flattened field" + ); +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tokio_test { + use super::*; + + #[tokio::test] + async fn test_derive_macro_with_generic_client_lifecycle() { + logic_test_derive_macro_with_generic_client_lifecycle().await; + } + + #[tokio::test] + async fn test_manual_flattening_lifecycle() { + logic_test_manual_flattening_lifecycle().await; + } +} + +#[cfg(all(test, target_arch = "wasm32"))] +mod wasm_test { + use super::*; + use wasm_bindgen_test::wasm_bindgen_test; + + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + + #[wasm_bindgen_test] + async fn test_derive_macro_with_generic_client_lifecycle() { + console_error_panic_hook::set_once(); + logic_test_derive_macro_with_generic_client_lifecycle().await; + } + + #[wasm_bindgen_test] + async fn test_manual_flattening_lifecycle() { + console_error_panic_hook::set_once(); + logic_test_manual_flattening_lifecycle().await; + } +} diff --git a/typesense/tests/client/documents_test.rs b/typesense/tests/client/documents_test.rs index 77cf73d..a98cce1 100644 --- a/typesense/tests/client/documents_test.rs +++ b/typesense/tests/client/documents_test.rs @@ -255,7 +255,7 @@ async fn run_test_generic_document_lifecycle() { // --- 6. Update a single document with a partial payload --- let partial_update_struct = BookPartial { publication_year: Some(1966), - in_stock: Some(Some(false)), + in_stock: Some(None), ..Default::default() }; let index_params = DocumentIndexParameters { @@ -269,7 +269,10 @@ async fn run_test_generic_document_lifecycle() { // The returned document should be the full, updated Book struct assert_eq!(updated_book.publication_year, 1966); - assert_eq!(updated_book.in_stock, Some(false)); + assert_eq!( + updated_book.in_stock, None, + "The updated 'in_stock' must be null" + ); assert_eq!(updated_book.title, book_1.title); // Other fields are preserved // --- 7. Bulk Update (via `documents().update()`) --- diff --git a/typesense/tests/client/mod.rs b/typesense/tests/client/mod.rs index a8a645e..41882a9 100644 --- a/typesense/tests/client/mod.rs +++ b/typesense/tests/client/mod.rs @@ -2,6 +2,7 @@ mod aliases_test; mod client_test; mod collections_test; mod conversation_models_test; +mod derive_integration_test; mod documents_test; mod keys_test; mod multi_search_test; @@ -10,9 +11,8 @@ mod presets_test; mod stemming_dictionaries_test; mod stopwords_test; -use reqwest_retry::policies::ExponentialBackoff; use std::time::Duration; -use typesense::Client; +use typesense::{Client, ExponentialBackoff}; use web_time::{SystemTime, UNIX_EPOCH}; /// Helper function to create a new client for all tests in this suite. @@ -21,7 +21,7 @@ pub fn get_client() -> Client { .nodes(vec!["http://localhost:8108"]) .api_key("xyz") .healthcheck_interval(Duration::from_secs(5)) - .retry_policy(ExponentialBackoff::builder().build_with_max_retries(3)) + .retry_policy(ExponentialBackoff::builder().build_with_max_retries(0)) .connection_timeout(Duration::from_secs(3)) .build() .expect("Failed to create Typesense client") diff --git a/typesense/tests/derive/collection.rs b/typesense/tests/derive/collection.rs index 71dcd91..9cc6821 100644 --- a/typesense/tests/derive/collection.rs +++ b/typesense/tests/derive/collection.rs @@ -1,8 +1,29 @@ +use std::collections::BTreeMap; +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; -use typesense::{Typesense, prelude::Document}; +use serde_json::json; +use typesense::Typesense; +use typesense::prelude::*; + +// Test 1: Basic Schema Generation (keeping the old test to ensure backward compatibility) +#[allow(dead_code)] +#[derive(Typesense, Serialize, Deserialize)] +#[typesense( + collection_name = "companies", + default_sorting_field = "num_employees", + enable_nested_fields = true +)] +struct Company { + company_name: String, + num_employees: i32, + #[typesense(facet)] + country: String, + keywords: Option>, +} #[test] -fn derived_document_generates_schema() { +fn derived_document_generates_basic_schema() { let schema = Company::collection_schema(); let expected = serde_json::json!( @@ -36,17 +57,261 @@ fn derived_document_generates_schema() { assert_eq!(serde_json::to_value(&schema).unwrap(), expected) } +// Test 2: All Field-Level and Collection-Level Attributes + +type GeoPoint = (f32, f32); + #[allow(dead_code)] #[derive(Typesense, Serialize, Deserialize)] #[typesense( - collection_name = "companies", - default_sorting_field = "num_employees", - enable_nested_fields = true + collection_name = "kitchen_sink_products", + default_sorting_field = "renamed_price", + token_separators = ["-", "/"], + symbols_to_index = ["+"] )] -struct Company { - company_name: String, - num_employees: i32, +struct KitchenSinkProduct { + // Basic types and rename + #[typesense(rename = "product_name")] + name: String, + #[typesense(sort = false, rename = "renamed_price")] + price: f32, + + // Booleans for index, store, stem, infix, range_index + #[typesense(index = false, store = false)] + internal_id: u64, + #[typesense(stem = true, infix = true)] + description: String, + #[typesense(range_index = true)] + review_score: f32, + + // Facet and explicit optional + #[typesense(facet = true, optional = true)] + brand: String, + + // Locale and type override + #[typesense(locale = "ja")] + description_jp: String, + #[typesense(type = "geopoint")] + location: GeoPoint, + + // Vector search attributes + #[typesense(num_dim = 256, vec_dist = "cosine")] + image_embedding: Vec, + + // Auto type + #[typesense(type = "auto")] + misc_data: String, + + hash_map: HashMap, + btree_map: BTreeMap, + + hash_map_vec: Vec>, + btree_map_vec: Vec>, +} + +#[test] +fn derived_document_handles_all_attributes() { + let schema = KitchenSinkProduct::collection_schema(); + + let expected = json!({ + "name": "kitchen_sink_products", + "fields": [ + { "name": "product_name", "type": "string" }, + { "name": "renamed_price", "type": "float", "sort": false }, + { "name": "internal_id", "type": "int64", "index": false, "store": false }, + { "name": "description", "type": "string", "stem": true, "infix": true }, + { "name": "review_score", "type": "float", "range_index": true }, + { "name": "brand", "type": "string", "facet": true, "optional": true }, + { "name": "description_jp", "type": "string", "locale": "ja" }, + { "name": "location", "type": "geopoint" }, + { "name": "image_embedding", "type": "float[]", "num_dim": 256, "vec_dist": "cosine" }, + { "name": "misc_data", "type": "auto" }, + + { "name": "hash_map", "type": "object" }, + { "name": "btree_map", "type": "object" }, + + { "name": "hash_map_vec", "type": "object[]" }, + { "name": "btree_map_vec", "type": "object[]" }, + ], + "default_sorting_field": "renamed_price", + "token_separators": ["-", "/"], + "symbols_to_index": ["+"] + }); + + assert_eq!(serde_json::to_value(&schema).unwrap(), expected); +} + +// Test 3: Nested Objects and Flattening + +#[derive(Typesense, Serialize, Deserialize)] +struct Address { + line_1: String, + number: i32, + optional_field: Option, + #[typesense(skip)] + city: String, +} + +#[derive(Typesense, Serialize, Deserialize)] +struct Profile { + #[typesense(facet, sort)] + name: String, + email: Option, +} + +#[derive(Typesense, Serialize, Deserialize)] +struct AddressData { + primary_city: String, + work_zips: Vec, +} + +#[derive(Typesense, Serialize, Deserialize)] +struct NestedStruct { + name: String, + #[typesense(flatten)] + address: AddressData, +} + +#[allow(dead_code)] +#[derive(Typesense, Serialize, Deserialize)] +#[typesense(collection_name = "nested_users", enable_nested_fields = true)] +struct User { + // --- Indexing as an object --- + primary_address: Address, + work_addresses: Vec
, + optional_profile: Option, + + // --- Sub-fields indexing --- + #[typesense(flatten)] + profile: Profile, + #[typesense(flatten)] + previous_addresses: Vec
, + #[typesense(flatten, skip)] + sub_fields_only: Profile, + + #[typesense(flatten, skip)] + nested_struct: NestedStruct, + #[typesense(flatten)] + nested_struct_vec: Vec, + + // --- Manually flattened object --- + #[typesense(skip)] + data: AddressData, + #[typesense(rename = "primary_address.city")] + primary_city: String, + #[typesense(rename = "work_addresses.zip", type = "string[]")] + work_zips: Vec, +} + +#[test] +fn derived_document_handles_nested_and_flattened_fields() { + let schema = User::collection_schema(); + + let expected = json!({ + "name": "nested_users", + "enable_nested_fields": true, + "fields": [ + // --- Object Indexing --- + { "name": "primary_address", "type": "object" }, + { "name": "work_addresses", "type": "object[]" }, + { "name": "optional_profile", "type": "object", "optional": true }, + + // --- Sub-fields indexing --- + { "name": "profile", "type": "object" }, + { "name": "profile.name", "type": "string", "facet": true, "sort": true}, + { "name": "profile.email", "type": "string", "optional": true }, + + { "name": "previous_addresses", "type": "object[]" }, + { "name": "previous_addresses.line_1", "type": "string[]" }, + { "name": "previous_addresses.number", "type": "int32[]" }, + { "name": "previous_addresses.optional_field", "type": "string[]", "optional": true}, + // { "name": "previous_addresses.city", "type": "string[]" }, correctly skipped + + { "name": "sub_fields_only.name", "type": "string", "facet": true, "sort": true}, + { "name": "sub_fields_only.email", "type": "string", "optional": true }, + + { "name": "nested_struct.name", "type": "string"}, + { "name": "nested_struct.address", "type": "object" }, + { "name": "nested_struct.address.primary_city", "type": "string" }, + { "name": "nested_struct.address.work_zips", "type": "string[]" }, + + { "name": "nested_struct_vec", "type": "object[]"}, + { "name": "nested_struct_vec.name", "type": "string[]"}, + { "name": "nested_struct_vec.address", "type": "object[]" }, + { "name": "nested_struct_vec.address.primary_city", "type": "string[]" }, + { "name": "nested_struct_vec.address.work_zips", "type": "string[]" }, + + // --- Manually flattened object --- + // correctly skipped `data` + { "name": "primary_address.city", "type": "string" }, + { "name": "work_addresses.zip", "type": "string[]" } + ] + }); + + assert_eq!(serde_json::to_value(schema).unwrap(), expected); +} + +// Test 4: All Boolean Shorthand Attributes + +#[allow(dead_code)] +#[derive(Typesense, Serialize, Deserialize)] +#[typesense(collection_name = "shorthand_products")] +struct ShorthandProduct { + // Shorthand for facet = true #[typesense(facet)] - country: String, - keywords: Option>, + brand: String, + + // Shorthand for sort = true + #[typesense(sort)] + name: String, + + // Shorthand for index = true + #[typesense(index)] + category: String, + + // Shorthand for store = true + #[typesense(store)] + description: String, + + // Shorthand for infix = true + #[typesense(infix)] + tags: String, + + // Shorthand for stem = true + #[typesense(stem)] + title: String, + + // Shorthand for range_index = true + #[typesense(range_index)] + price: f32, + + // Shorthand for optional = true, overriding the non-Option type + #[typesense(optional)] + variant: String, + + // This field is for internal Rust logic only and should NOT be in the schema + #[typesense(skip)] + internal_metadata: String, +} + +#[test] +fn derived_document_handles_boolean_shorthand() { + let schema = ShorthandProduct::collection_schema(); + + let expected = json!({ + "name": "shorthand_products", + "fields": [ + { "name": "brand", "type": "string", "facet": true }, + { "name": "name", "type": "string", "sort": true }, + { "name": "category", "type": "string", "index": true }, + { "name": "description", "type": "string", "store": true }, + { "name": "tags", "type": "string", "infix": true }, + { "name": "title", "type": "string", "stem": true }, + { "name": "price", "type": "float", "range_index": true }, + { "name": "variant", "type": "string", "optional": true } + // `internal_metadata` is correctly omitted from the fields array + ] + }); + + assert_eq!(serde_json::to_value(schema).unwrap(), expected); } diff --git a/typesense/tests/derive/ui/duplicate_attribute.rs b/typesense/tests/derive/ui/duplicate_attribute.rs new file mode 100644 index 0000000..34e3cba --- /dev/null +++ b/typesense/tests/derive/ui/duplicate_attribute.rs @@ -0,0 +1,11 @@ +use serde::{Deserialize, Serialize}; +use typesense::Typesense; +#[derive(Typesense, Serialize, Deserialize)] +struct Company { + company_name: String, + num_employees: i32, + #[typesense(facet, sort, facet)] + country_code: String, +} + +fn main() {} diff --git a/typesense/tests/derive/ui/duplicate_attribute.stderr b/typesense/tests/derive/ui/duplicate_attribute.stderr new file mode 100644 index 0000000..bee8a8c --- /dev/null +++ b/typesense/tests/derive/ui/duplicate_attribute.stderr @@ -0,0 +1,5 @@ +error: Attribute `facet` is duplicated + --> tests/derive/ui/duplicate_attribute.rs:7:30 + | +7 | #[typesense(facet, sort, facet)] + | ^^^^^ diff --git a/typesense/tests/derive/ui/duplicated_attribute.rs b/typesense/tests/derive/ui/duplicate_derive_attribute.rs similarity index 100% rename from typesense/tests/derive/ui/duplicated_attribute.rs rename to typesense/tests/derive/ui/duplicate_derive_attribute.rs diff --git a/typesense/tests/derive/ui/duplicated_attribute.stderr b/typesense/tests/derive/ui/duplicate_derive_attribute.stderr similarity index 53% rename from typesense/tests/derive/ui/duplicated_attribute.stderr rename to typesense/tests/derive/ui/duplicate_derive_attribute.stderr index 454650a..a5c9126 100644 --- a/typesense/tests/derive/ui/duplicated_attribute.stderr +++ b/typesense/tests/derive/ui/duplicate_derive_attribute.stderr @@ -1,5 +1,5 @@ -error: #[typesense(facet)] repeated more than one time. - --> $DIR/duplicated_attribute.rs:7:5 +error: #[typesense(...)] is repeated more than one time. + --> tests/derive/ui/duplicate_derive_attribute.rs:7:5 | 7 | / #[typesense(facet)] 8 | | #[typesense(facet)] diff --git a/typesense/tests/derive/ui/flag_only_attribute.rs b/typesense/tests/derive/ui/flag_only_attribute.rs new file mode 100644 index 0000000..bc5c509 --- /dev/null +++ b/typesense/tests/derive/ui/flag_only_attribute.rs @@ -0,0 +1,11 @@ +use serde::{Deserialize, Serialize}; +use typesense::Typesense; +#[derive(Typesense, Serialize, Deserialize)] +struct Company { + company_name: String, + num_employees: i32, + #[typesense(skip = true)] + country_code: String, +} + +fn main() {} diff --git a/typesense/tests/derive/ui/flag_only_attribute.stderr b/typesense/tests/derive/ui/flag_only_attribute.stderr new file mode 100644 index 0000000..cc2321a --- /dev/null +++ b/typesense/tests/derive/ui/flag_only_attribute.stderr @@ -0,0 +1,5 @@ +error: `skip` is a flag and does not take a value. Use `#[typesense(skip)]` + --> tests/derive/ui/flag_only_attribute.rs:7:17 + | +7 | #[typesense(skip = true)] + | ^^^^ diff --git a/typesense/tests/derive/ui/unknown_attribute.stderr b/typesense/tests/derive/ui/unknown_attribute.stderr index 385b665..0d5ba83 100644 --- a/typesense/tests/derive/ui/unknown_attribute.stderr +++ b/typesense/tests/derive/ui/unknown_attribute.stderr @@ -1,5 +1,5 @@ -error: Unexpected token facets. Did you mean `facet`? - --> $DIR/unknown_attribute.rs:8:17 +error: Unexpected field attribute "facets" + --> tests/derive/ui/unknown_attribute.rs:8:17 | 8 | #[typesense(facets)] | ^^^^^^ diff --git a/typesense_derive/src/field_attributes.rs b/typesense_derive/src/field_attributes.rs new file mode 100644 index 0000000..8cdfcdb --- /dev/null +++ b/typesense_derive/src/field_attributes.rs @@ -0,0 +1,358 @@ +use crate::{bool_literal, get_inner_type, i32_literal, skip_eq, string_literal, ty_inner_type}; +use proc_macro2::TokenTree; +use quote::quote; +use syn::{Attribute, Field}; + +#[derive(Default)] +pub(crate) struct FieldAttributes { + type_override: Option, + facet: Option, + index: Option, + locale: Option, + sort: Option, + infix: Option, + num_dim: Option, + optional: Option, + store: Option, + stem: Option, + range_index: Option, + vec_dist: Option, + flatten: bool, + pub(crate) rename: Option, + skip: bool, +} + +// This function will parse #[typesense(...)] on a FIELD +pub(crate) fn extract_field_attrs(field: &Field) -> syn::Result { + let attrs = &field.attrs; + let mut res = FieldAttributes::default(); + + // Find the single #[typesense] attribute, erroring if there are more than one. + let all_ts_attrs: Vec<&Attribute> = attrs + .iter() + .filter(|a| a.path.get_ident().is_some_and(|i| i == "typesense")) + .collect(); + + // Check for duplicates and create an error if found + if all_ts_attrs.len() > 1 { + return Err(syn::Error::new_spanned( + field, + "#[typesense(...)] is repeated more than one time.", + )); + } + + // Get the single attribute, or return default if none exist + let attr = if let Some(a) = all_ts_attrs.first() { + *a + } else { + return Ok(res); // No typesense attribute, return default + }; + + if let Some(TokenTree::Group(g)) = attr.tokens.clone().into_iter().next() { + let mut tt_iter = g.stream().into_iter().peekable(); + while let Some(tt) = tt_iter.next() { + if let TokenTree::Ident(i) = tt { + let is_shorthand = + tt_iter.peek().is_none() || tt_iter.peek().unwrap().to_string() == ","; + let ident_str = i.to_string(); + + match ident_str.as_str() { + // --- Boolean flags that support shorthand and key-value --- + "facet" | "sort" | "index" | "store" | "infix" | "stem" | "range_index" + | "optional" => { + let value = if is_shorthand { + true + } else { + skip_eq(&i, &mut tt_iter)?; + bool_literal(&mut tt_iter)? + }; + + // Set the correct field on the result struct, checking for duplicates + match ident_str.as_str() { + "facet" => { + if res.facet.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `facet` is duplicated", + )); + } + res.facet = Some(value); + } + "sort" => { + if res.sort.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `sort` is duplicated", + )); + } + res.sort = Some(value); + } + "index" => { + if res.index.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `index` is duplicated", + )); + } + res.index = Some(value); + } + "store" => { + if res.store.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `store` is duplicated", + )); + } + res.store = Some(value); + } + "infix" => { + if res.infix.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `infix` is duplicated", + )); + } + res.infix = Some(value); + } + "stem" => { + if res.stem.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `stem` is duplicated", + )); + } + res.stem = Some(value); + } + "range_index" => { + if res.range_index.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `range_index` is duplicated", + )); + } + res.range_index = Some(value); + } + "optional" => { + if res.optional.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `optional` is duplicated", + )); + } + res.optional = Some(value); + } + _ => unreachable!(), + } + } + // --- Flags that are ONLY shorthand --- + "flatten" | "skip" => { + if !is_shorthand { + return Err(syn::Error::new( + i.span(), + format!( + "`{}` is a flag and does not take a value. Use `#[typesense({})]`", + ident_str, ident_str + ), + )); + } + match ident_str.as_str() { + "flatten" => { + if res.flatten { + return Err(syn::Error::new_spanned( + &i, + "Attribute `flatten` is duplicated", + )); + } + res.flatten = true; + } + "skip" => { + if res.skip { + return Err(syn::Error::new_spanned( + &i, + "Attribute `skip` is duplicated", + )); + } + res.skip = true; + } + _ => unreachable!(), + } + } + + // --- Key-value only attributes --- + "rename" => { + skip_eq(&i, &mut tt_iter)?; + if res.rename.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `rename` is duplicated", + )); + } + res.rename = Some(string_literal(&mut tt_iter)?); + } + "locale" => { + skip_eq(&i, &mut tt_iter)?; + if res.locale.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `locale` is duplicated", + )); + } + res.locale = Some(string_literal(&mut tt_iter)?); + } + "vec_dist" => { + skip_eq(&i, &mut tt_iter)?; + if res.vec_dist.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `vec_dist` is duplicated", + )); + } + res.vec_dist = Some(string_literal(&mut tt_iter)?); + } + "type" => { + skip_eq(&i, &mut tt_iter)?; + if res.type_override.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `type` is duplicated", + )); + } + res.type_override = Some(string_literal(&mut tt_iter)?); + } + "num_dim" => { + skip_eq(&i, &mut tt_iter)?; + if res.num_dim.is_some() { + return Err(syn::Error::new_spanned( + &i, + "Attribute `num_dim` is duplicated", + )); + } + res.num_dim = Some(i32_literal(&mut tt_iter)?); + } + // --- Error for unknown attributes --- + v => { + return Err(syn::Error::new( + i.span(), + format!("Unexpected field attribute \"{}\"", v), + )); + } + } + }; + + if let Some(TokenTree::Punct(p)) = tt_iter.peek() + && p.as_char() == ',' + { + tt_iter.next(); // Consume the comma + } + } + } + + Ok(res) +} + +fn build_regular_field(field: &Field, field_attrs: &FieldAttributes) -> proc_macro2::TokenStream { + let (ty, is_option_type) = if let Some(inner_ty) = ty_inner_type(&field.ty, "Option") { + (inner_ty, true) + } else { + (&field.ty, false) + }; + + let name_ident = &field.ident; + let field_name = if let Some(rename) = &field_attrs.rename { + quote! { #rename.to_string() } + } else { + quote! { std::string::String::from(stringify!(#name_ident)) } + }; + + let typesense_field_type = if let Some(override_str) = &field_attrs.type_override { + quote! { #override_str.to_owned() } + } else { + quote! { <#ty as typesense::prelude::ToTypesenseField>::to_typesense_type().to_owned() } + }; + + let optional = field_attrs + .optional + .or(if is_option_type { Some(true) } else { None }) + .map(|v| quote!(.optional(#v))); + let facet = field_attrs.facet.map(|v| quote!(.facet(#v))); + let index = field_attrs.index.map(|v| quote!(.index(#v))); + let store = field_attrs.store.map(|v| quote!(.store(#v))); + let sort = field_attrs.sort.map(|v| quote!(.sort(#v))); + let infix = field_attrs.infix.map(|v| quote!(.infix(#v))); + let stem = field_attrs.stem.map(|v| quote!(.stem(#v))); + let range_index = field_attrs.range_index.map(|v| quote!(.range_index(#v))); + let locale = field_attrs.locale.as_ref().map(|v| quote!(.locale(#v))); + let vec_dist = field_attrs.vec_dist.as_ref().map(|v| quote!(.vec_dist(#v))); + let num_dim = field_attrs.num_dim.map(|v| quote!(.num_dim(#v))); + + quote! { + vec![ + typesense::models::Field::builder().name(#field_name).r#type(#typesense_field_type) + #optional #facet #index #store #sort #infix #stem #range_index #locale #vec_dist #num_dim + .build() + ] + } +} + +/// Processes a single struct field. +/// Returns a TokenStream which evaluates to a `Vec`. +pub(crate) fn process_field(field: &Field) -> syn::Result { + let field_attrs = extract_field_attrs(field)?; + + if field_attrs.flatten { + // Determine the prefix: use the rename value if it exists, otherwise use the field's name. + let prefix = if let Some(rename_prefix) = &field_attrs.rename { + quote! { #rename_prefix } + } else { + let name_ident = &field.ident; + quote! { stringify!(#name_ident) } + }; + + let inner_type = get_inner_type(&field.ty); + let is_vec = ty_inner_type(&field.ty, "Vec").is_some() + || ty_inner_type(&field.ty, "Option") + .is_some_and(|t| ty_inner_type(t, "Vec").is_some()); + + let flattened_fields = quote! { + <#inner_type as typesense::prelude::Document>::collection_schema().fields + .into_iter() + .map(|mut f| { + // Use the dynamically determined prefix here + f.name = format!("{}.{}", #prefix, f.name); + if #is_vec && !f.r#type.ends_with("[]") { + f.r#type.push_str("[]"); + } + f + }) + .collect::>() + }; + + if field_attrs.skip { + // `#[typesense(flatten, skip)]` -> Only flattened fields + return Ok(quote! { + { + #flattened_fields + } + }); + } + + // `#[typesense(flatten)]` -> Flattened fields + object field + let regular_field = build_regular_field(field, &field_attrs); + + Ok(quote! { + { + let mut fields = #regular_field; + fields.extend(#flattened_fields); + fields + } + }) + } else { + // --- REGULAR FIELD LOGIC --- + if field_attrs.skip { + return Ok(quote! { + vec![] + }); + } + Ok(build_regular_field(field, &field_attrs)) + } +} diff --git a/typesense_derive/src/helpers.rs b/typesense_derive/src/helpers.rs new file mode 100644 index 0000000..337fc11 --- /dev/null +++ b/typesense_derive/src/helpers.rs @@ -0,0 +1,142 @@ +use proc_macro2::{Ident, TokenTree}; +use syn::spanned::Spanned; + +// Helper to parse a boolean literal +pub(crate) fn bool_literal(tt_iter: &mut impl Iterator) -> syn::Result { + match tt_iter.next() { + Some(TokenTree::Ident(i)) => { + if i == "true" { + Ok(true) + } else if i == "false" { + Ok(false) + } else { + Err(syn::Error::new_spanned( + i, + "Expected a boolean `true` or `false`", + )) + } + } + tt => Err(syn::Error::new(tt.span(), "Expected a boolean literal")), + } +} + +// Helper to parse an integer literal +pub(crate) fn i32_literal(tt_iter: &mut impl Iterator) -> syn::Result { + match tt_iter.next() { + Some(TokenTree::Literal(l)) => { + let lit = syn::Lit::new(l); + if let syn::Lit::Int(i) = lit { + i.base10_parse::() + } else { + Err(syn::Error::new_spanned( + lit, + "it must be equal to an integer literal", + )) + } + } + tt => Err(syn::Error::new(tt.span(), "Expected an integer literal")), + } +} + +pub(crate) fn string_literal(tt_iter: &mut impl Iterator) -> syn::Result { + match tt_iter.next() { + Some(TokenTree::Literal(l)) => { + let lit = syn::Lit::new(l); + if let syn::Lit::Str(s) = lit { + Ok(s.value()) + } else { + Err(syn::Error::new_spanned( + lit, + "it must be equal to a literal string", + )) + } + } + Some(TokenTree::Ident(i)) => Err(syn::Error::new( + i.span(), + format!("Expected string literal, did you mean \"{i}\"?"), + )), + tt => Err(syn::Error::new(tt.span(), "Expected string literal")), + } +} + +// Helper function to parse a bracketed list of string literals +pub(crate) fn string_list( + tt_iter: &mut impl Iterator, +) -> syn::Result> { + let group = match tt_iter.next() { + Some(TokenTree::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Bracket => g, + Some(tt) => { + return Err(syn::Error::new_spanned( + tt, + "Expected a list in brackets `[]`", + )); + } + None => { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + "Expected a list in brackets `[]`", + )); + } + }; + + let mut result = Vec::new(); + let mut inner_iter = group.stream().into_iter().peekable(); + + while let Some(tt) = inner_iter.next() { + if let TokenTree::Literal(l) = tt { + let lit = syn::Lit::new(l); + if let syn::Lit::Str(s) = lit { + result.push(s.value()); + } else { + return Err(syn::Error::new_spanned(lit, "Expected a string literal")); + } + } else { + return Err(syn::Error::new_spanned(tt, "Expected a string literal")); + } + + // Check for a trailing comma + if let Some(TokenTree::Punct(p)) = inner_iter.peek() + && p.as_char() == ',' + { + inner_iter.next(); // Consume the comma + } + } + + Ok(result) +} + +pub(crate) fn skip_eq(i: &Ident, tt_iter: &mut impl Iterator) -> syn::Result<()> { + match tt_iter.next() { + Some(TokenTree::Punct(p)) if p.as_char() == '=' => Ok(()), + Some(tt) => Err(syn::Error::new_spanned( + &tt, + format!("Unexpected \"{tt}\", expected equal sign \"=\""), + )), + None => Err(syn::Error::new_spanned(i, "expected: equal sign \"=\"")), + } +} + +// Get the inner type for a given wrappper +pub(crate) fn ty_inner_type<'a>(ty: &'a syn::Type, wrapper: &'static str) -> Option<&'a syn::Type> { + if let syn::Type::Path(p) = ty + && p.path.segments.len() == 1 + && p.path.segments[0].ident == wrapper + && let syn::PathArguments::AngleBracketed(ref inner_ty) = p.path.segments[0].arguments + && inner_ty.args.len() == 1 + { + // len is 1 so this should not fail + let inner_ty = inner_ty.args.first().unwrap(); + if let syn::GenericArgument::Type(t) = inner_ty { + return Some(t); + } + } + None +} + +/// Helper to get the inner-most type from nested Option/Vec wrappers. +pub(crate) fn get_inner_type(mut ty: &syn::Type) -> &syn::Type { + while let Some(inner) = ty_inner_type(ty, "Option").or_else(|| ty_inner_type(ty, "Vec")) { + ty = inner; + } + ty +} diff --git a/typesense_derive/src/lib.rs b/typesense_derive/src/lib.rs index 3248e62..e62a65f 100644 --- a/typesense_derive/src/lib.rs +++ b/typesense_derive/src/lib.rs @@ -1,7 +1,13 @@ +mod field_attributes; +mod helpers; + +use field_attributes::{extract_field_attrs, process_field}; +use helpers::*; + use proc_macro::TokenStream; use proc_macro2::{Ident, TokenTree}; use quote::{ToTokens, quote}; -use syn::{Attribute, Field, ItemStruct, spanned::Spanned}; +use syn::{Attribute, ItemStruct, spanned::Spanned}; #[proc_macro_derive(Typesense, attributes(typesense))] pub fn typesense_collection_derive(input: TokenStream) -> TokenStream { @@ -40,25 +46,39 @@ fn impl_typesense_collection(item: ItemStruct) -> syn::Result { collection_name, default_sorting_field, enable_nested_fields, + symbols_to_index, + token_separators, } = extract_attrs(attrs)?; let collection_name = collection_name.unwrap_or_else(|| ident.to_string().to_lowercase()); - if let Some(ref sorting_field) = default_sorting_field - && !fields.iter().any(|field| - // At this point we are sure that this field is a named field. - field.ident.as_ref().unwrap() == sorting_field) - { - return Err(syn::Error::new_spanned( - item_ts, - format!( - "defined default_sorting_field = \"{sorting_field}\" does not match with any field." - ), - )); + if let Some(ref sorting_field) = default_sorting_field { + let field_names_and_renames = fields + .iter() + .map(|field| { + extract_field_attrs(field).map(|attrs| { + attrs + .rename + .unwrap_or_else(|| field.ident.as_ref().unwrap().to_string()) + }) + }) + .collect::>>()?; + + if !field_names_and_renames + .iter() + .any(|name| name == sorting_field) + { + return Err(syn::Error::new_spanned( + item_ts, + format!( + "defined default_sorting_field = \"{sorting_field}\" does not match with any field." + ), + )); + } } let typesense_fields = fields .iter() - .map(to_typesense_field_type) + .map(process_field) .collect::>>()?; let default_sorting_field = if let Some(v) = default_sorting_field { @@ -77,6 +97,23 @@ fn impl_typesense_collection(item: ItemStruct) -> syn::Result { proc_macro2::TokenStream::new() }; + let symbols_to_index = if let Some(v) = symbols_to_index { + quote! { + let builder = builder.symbols_to_index(vec![#(#v.to_owned()),*]); + } + } else { + proc_macro2::TokenStream::new() + }; + + let token_separators = if let Some(v) = token_separators { + quote! { + let builder = builder.token_separators(vec![#(#v.to_owned()),*]); + } + } else { + proc_macro2::TokenStream::new() + }; + + // Create Partial struct for document update let optional_fields = fields.iter().filter_map(|f| { let ident = f.ident.as_ref()?; if ident == "id" { @@ -106,12 +143,16 @@ fn impl_typesense_collection(item: ItemStruct) -> syn::Result { fn collection_schema() -> ::typesense::models::CollectionSchema { let name = Self::COLLECTION_NAME.to_owned(); - let fields = vec![#(#typesense_fields,)*]; + + let mut fields = Vec::new(); + #(fields.extend(#typesense_fields);)* let builder = ::typesense::models::CollectionSchema::builder().name(name).fields(fields); #default_sorting_field #enable_nested_fields + #token_separators + #symbols_to_index builder.build() } @@ -120,23 +161,6 @@ fn impl_typesense_collection(item: ItemStruct) -> syn::Result { Ok(generated_code.into()) } -// Get the inner type for a given wrapper -fn ty_inner_type<'a>(ty: &'a syn::Type, wrapper: &'static str) -> Option<&'a syn::Type> { - if let syn::Type::Path(p) = ty - && p.path.segments.len() == 1 - && p.path.segments[0].ident == wrapper - && let syn::PathArguments::AngleBracketed(ref inner_ty) = p.path.segments[0].arguments - && inner_ty.args.len() == 1 - { - // len is 1 so this should not fail - let inner_ty = inner_ty.args.first().unwrap(); - if let syn::GenericArgument::Type(t) = inner_ty { - return Some(t); - } - } - None -} - // Add a bound `T: ToTypesenseField` to every type parameter T. fn add_trait_bounds(mut generics: syn::Generics) -> syn::Generics { for param in &mut generics.params { @@ -153,39 +177,9 @@ fn add_trait_bounds(mut generics: syn::Generics) -> syn::Generics { struct Attrs { collection_name: Option, default_sorting_field: Option, + symbols_to_index: Option>, enable_nested_fields: Option, -} - -fn skip_eq(i: Ident, tt_iter: &mut impl Iterator) -> syn::Result<()> { - match tt_iter.next() { - Some(TokenTree::Punct(p)) if p.as_char() == '=' => Ok(()), - Some(tt) => Err(syn::Error::new_spanned( - &tt, - format!("Unexpected \"{tt}\", expected equal sign \"=\""), - )), - None => Err(syn::Error::new_spanned(i, "expected: equal sign \"=\"")), - } -} - -fn string_literal(tt_iter: &mut impl Iterator) -> syn::Result { - match tt_iter.next() { - Some(TokenTree::Literal(l)) => { - let lit = syn::Lit::new(l); - if let syn::Lit::Str(s) = lit { - Ok(s.value()) - } else { - Err(syn::Error::new_spanned( - lit, - "it must be equal to a literal string", - )) - } - } - Some(TokenTree::Ident(i)) => Err(syn::Error::new( - i.span(), - format!("Expected string literal, did you mean \"{i}\"?"), - )), - tt => Err(syn::Error::new(tt.span(), "Expected string literal")), - } + token_separators: Option>, } fn extract_attrs(attrs: Vec) -> syn::Result { @@ -205,15 +199,15 @@ fn extract_attrs(attrs: Vec) -> syn::Result { if let TokenTree::Ident(i) = tt { match &i.to_string() as &str { "collection_name" => { - skip_eq(i, &mut tt_iter)?; + skip_eq(&i, &mut tt_iter)?; res.collection_name = Some(string_literal(&mut tt_iter)?); } "default_sorting_field" => { - skip_eq(i, &mut tt_iter)?; + skip_eq(&i, &mut tt_iter)?; res.default_sorting_field = Some(string_literal(&mut tt_iter)?); } "enable_nested_fields" => { - skip_eq(i, &mut tt_iter)?; + skip_eq(&i, &mut tt_iter)?; let val = match tt_iter.next() { Some(TokenTree::Ident(i)) => &i.to_string() == "true", tt => { @@ -225,6 +219,14 @@ fn extract_attrs(attrs: Vec) -> syn::Result { }; res.enable_nested_fields = Some(val); } + "symbols_to_index" => { + skip_eq(&i, &mut tt_iter)?; + res.symbols_to_index = Some(string_list(&mut tt_iter)?); + } + "token_separators" => { + skip_eq(&i, &mut tt_iter)?; + res.token_separators = Some(string_list(&mut tt_iter)?); + } v => { return Err(syn::Error::new(i.span(), format!("Unexpected \"{v}\""))); } @@ -244,79 +246,3 @@ fn extract_attrs(attrs: Vec) -> syn::Result { Ok(res) } - -/// Convert a given field in a typesense field type. -fn to_typesense_field_type(field: &Field) -> syn::Result { - let name = &field.ident; - - let facet = { - let facet_vec = field - .attrs - .iter() - .filter_map(|attr| { - if attr.path.segments.len() == 1 - && attr.path.segments[0].ident == "typesense" - && let Some(proc_macro2::TokenTree::Group(g)) = - attr.tokens.clone().into_iter().next() - { - let mut tokens = g.stream().into_iter(); - match tokens.next() { - Some(proc_macro2::TokenTree::Ident(ref i)) => { - if i != "facet" { - return Some(Err(syn::Error::new_spanned( - i, - format!("Unexpected token {i}. Did you mean `facet`?"), - ))); - } - } - Some(ref tt) => { - return Some(Err(syn::Error::new_spanned( - tt, - format!("Unexpected token {tt}. Did you mean `facet`?"), - ))); - } - None => { - return Some(Err(syn::Error::new_spanned(attr, "expected `facet`"))); - } - } - - if let Some(ref tt) = tokens.next() { - return Some(Err(syn::Error::new_spanned( - tt, - "Unexpected token. Expected )", - ))); - } - return Some(Ok(())); - } - None - }) - .collect::>>()?; - let facet_count = facet_vec.len(); - if facet_count == 1 { - quote!(Some(true)) - } else if facet_count == 0 { - quote!(None) - } else { - return Err(syn::Error::new_spanned( - field, - "#[typesense(facet)] repeated more than one time.", - )); - } - }; - - let (ty, optional) = if let Some(inner_ty) = ty_inner_type(&field.ty, "Option") { - (inner_ty, quote!(Some(true))) - } else { - (&field.ty, quote!(None)) - }; - let typesense_field_type = quote!( - <#ty as ::typesense::prelude::ToTypesenseField>::to_typesense_type().to_owned() - ); - - Ok(quote! { - ::typesense::models::Field::builder().name(std::string::String::from(stringify!(#name))).r#type(#typesense_field_type) - .maybe_optional(#optional) - .maybe_facet(#facet) - .build() - }) -}