11module ModelingToolkitUnitfulExt
22
3- __precompile__ (false )
4-
53using ModelingToolkit
64using Unitful
7- using Symbolics: Symbolic, value, issym, isadd, ismul, ispow, arguments, operation, iscall, getmetadata
5+ using Symbolics: Symbolic, value
86using SciMLBase
9- using RecursiveArrayTools
10- using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump
117
128# Import necessary types and functions from ModelingToolkit
13- import ModelingToolkit: ValidationError, Connection, instream, JumpType, VariableUnit,
14- get_systems, Conditional, Comparison, Differential,
15- Integral, Num, check_units
9+ import ModelingToolkit: ValidationError, _get_unittype, get_unit, screen_unit,
10+ equivalent, _is_dimension_error, convert_units, check_units
1611
1712const MT = ModelingToolkit
1813
19- # Method extension for Unitful unit detection
20- # This adds a method for the specific case where we have a Unitful unit
21- function MT. __get_scalar_unit_type (v)
22- u = MT. __get_literal_unit (v)
23- if u isa MT. DQ. AbstractQuantity
24- return Val (:DynamicQuantities )
25- elseif u isa Unitful. Unitlike
26- return Val (:Unitful )
27- end
28- return nothing
14+ # Add Unitful-specific unit type detection
15+ function MT. _get_unittype (u:: Unitful.Unitlike )
16+ return Val (:Unitful )
2917end
3018
3119# Base operations for mixing Symbolic and Unitful
32- Base.:* (x:: Union{Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x * y
33- Base.:/ (x:: Union{Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x / y
20+ Base.:* (x:: Union{MT.Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x * y
21+ Base.:/ (x:: Union{MT.Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x / y
22+
23+ # Unitful-specific get_unit method
24+ function MT. get_unit (x:: Unitful.Quantity )
25+ return screen_unit (Unitful. unit (x))
26+ end
3427
35- """
36- Throw exception on invalid unit types, otherwise return argument.
37- """
38- function screen_unit (result)
39- result isa Unitful. Unitlike ||
40- throw (ValidationError (" Unit must be a subtype of Unitful.Unitlike, not $(typeof (result)) ." ))
28+ # Unitful-specific screen_unit method
29+ function MT. screen_unit (result:: Unitful.Unitlike )
4130 result isa Unitful. ScalarUnits ||
4231 throw (ValidationError (" Non-scalar units such as $result are not supported. Use a scalar unit instead." ))
4332 result == Unitful. u " °" &&
4433 throw (ValidationError (" Degrees are not supported. Use radians instead." ))
45- result
46- end
47-
48- """
49- Test unit equivalence.
50-
51- Example of implemented behavior:
52-
53- ```julia
54- using ModelingToolkit, Unitful
55- MT = ModelingToolkit
56- @parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"]
57- @test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes
58- @test !MT.equivalent(u"m", u"cm") # Units must be same magnitude
59- @test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents
60- ```
61- """
62- equivalent (x, y) = isequal (1 * x, 1 * y)
63- const unitless = Unitful. unit (1 )
64-
65- """
66- Find the unit of a symbolic item.
67- """
68- get_unit (x:: Real ) = unitless
69- get_unit (x:: Unitful.Quantity ) = screen_unit (Unitful. unit (x))
70- get_unit (x:: AbstractArray ) = map (get_unit, x)
71- get_unit (x:: Num ) = get_unit (value (x))
72- function get_unit (x:: Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata} )
73- get_literal_unit (x)
74- end
75- get_unit (op:: Differential , args) = get_unit (args[1 ]) / get_unit (op. x)
76- get_unit (op:: typeof (getindex), args) = get_unit (args[1 ])
77- get_unit (x:: SciMLBase.NullParameters ) = unitless
78- get_unit (op:: typeof (instream), args) = get_unit (args[1 ])
79-
80- get_literal_unit (x) = screen_unit (getmetadata (x, VariableUnit, unitless))
81-
82- function get_unit (op, args) # Fallback
83- result = op (1 .* get_unit .(args)... )
84- try
85- Unitful. unit (result)
86- catch
87- throw (ValidationError (" Unable to get unit for operation $op with arguments $args ." ))
88- end
89- end
90-
91- function get_unit (op:: Integral , args)
92- unit = 1
93- if op. domain. variables isa Vector
94- for u in op. domain. variables
95- unit *= get_unit (u)
96- end
97- else
98- unit *= get_unit (op. domain. variables)
99- end
100- return get_unit (args[1 ]) * unit
34+ return result
10135end
10236
103- function get_unit (op:: Conditional , args)
104- terms = get_unit .(args)
105- terms[1 ] == unitless ||
106- throw (ValidationError (" , in $op , [$(terms[1 ]) ] is not dimensionless." ))
107- equivalent (terms[2 ], terms[3 ]) ||
108- throw (ValidationError (" , in $op , units [$(terms[2 ]) ] and [$(terms[3 ]) ] do not match." ))
109- return terms[2 ]
37+ # Unitful-specific equivalence check
38+ function MT. equivalent (x:: Unitful.Unitlike , y:: Unitful.Unitlike )
39+ return isequal (1 * x, 1 * y)
11040end
11141
112- function get_unit (op:: typeof (Symbolics. _mapreduce), args)
113- if args[2 ] == +
114- get_unit (args[3 ])
115- else
116- throw (ValidationError (" Unsupported array operation $op " ))
117- end
118- end
119-
120- function get_unit (op:: Comparison , args)
121- terms = get_unit .(args)
122- equivalent (terms[1 ], terms[2 ]) ||
123- throw (ValidationError (" , in comparison $op , units [$(terms[1 ]) ] and [$(terms[2 ]) ] do not match." ))
124- return unitless
125- end
42+ # Mixed equivalence checks
43+ MT. equivalent (x:: Unitful.Unitlike , y) = isequal (1 * x, y)
44+ MT. equivalent (x, y:: Unitful.Unitlike ) = isequal (x, 1 * y)
12645
127- function get_unit (x:: Symbolic )
128- if issym (x)
129- get_literal_unit (x)
130- elseif isadd (x)
131- terms = get_unit .(arguments (x))
132- firstunit = terms[1 ]
133- for other in terms[2 : end ]
134- termlist = join (map (repr, terms), " , " )
135- equivalent (other, firstunit) ||
136- throw (ValidationError (" , in sum $x , units [$termlist ] do not match." ))
137- end
138- return firstunit
139- elseif ispow (x)
140- pargs = arguments (x)
141- base, expon = get_unit .(pargs)
142- @assert expon isa Unitful. DimensionlessUnits
143- if base == unitless
144- unitless
145- else
146- pargs[2 ] isa Number ? base^ pargs[2 ] : (1 * base)^ pargs[2 ]
147- end
148- elseif iscall (x)
149- op = operation (x)
150- if issym (op) || (iscall (op) && iscall (operation (op))) # Dependent variables, not function calls
151- return screen_unit (getmetadata (x, VariableUnit, unitless)) # Like x(t) or x[i]
152- elseif iscall (op) && ! iscall (operation (op))
153- gp = getmetadata (x, Symbolics. GetindexParent, nothing ) # Like x[1](t)
154- return screen_unit (getmetadata (gp, VariableUnit, unitless))
155- end # Actual function calls:
156- args = arguments (x)
157- return get_unit (op, args)
158- else # This function should only be reached by Terms, for which `iscall` is true
159- throw (ArgumentError (" Unsupported value $x ." ))
160- end
161- end
46+ # The safe_get_unit function stays in the main package and already handles DQ.DimensionError
47+ # We just need to make sure it can handle Unitful.DimensionError too
48+ # This will be handled by the main function's MethodError catch
16249
163- """
164- Get unit of term, returning nothing & showing warning instead of throwing errors.
165- """
166- function safe_get_unit (term, info)
167- side = nothing
168- try
169- side = get_unit (term)
170- catch err
171- if err isa Unitful. DimensionError
172- @warn (" $info : $(err. x) and $(err. y) are not dimensionally compatible." )
173- elseif err isa ValidationError
174- @warn (info* err. message)
175- elseif err isa MethodError
176- @warn (" $info : no method matching $(err. f) for arguments $(typeof .(err. args)) ." )
177- else
178- rethrow ()
179- end
180- end
181- side
182- end
183-
184- function _validate (terms:: Vector , labels:: Vector{String} ; info:: String = " " )
185- valid = true
186- first_unit = nothing
187- first_label = nothing
188- for (term, label) in zip (terms, labels)
189- equnit = safe_get_unit (term, info * label)
190- if equnit === nothing
191- valid = false
192- elseif ! isequal (term, 0 )
193- if first_unit === nothing
194- first_unit = equnit
195- first_label = label
196- elseif ! equivalent (first_unit, equnit)
197- valid = false
198- @warn (" $info : units [$(first_unit) ] for $(first_label) and [$(equnit) ] for $(label) do not match." )
199- end
200- end
201- end
202- valid
203- end
204-
205- function _validate (conn:: Connection ; info:: String = " " )
206- valid = true
207- syss = get_systems (conn)
208- sys = first (syss)
209- unks = MT. unknowns (sys)
210- for i in 2 : length (syss)
211- s = syss[i]
212- _unks = MT. unknowns (s)
213- if length (unks) != length (_unks)
214- valid = false
215- @warn (" $info : connected systems $(MT. nameof (sys)) and $(MT. nameof (s)) have $(length (unks)) and $(length (_unks)) unknowns, cannot connect." )
216- continue
217- end
218- for (i, x) in enumerate (unks)
219- j = findfirst (isequal (x), _unks)
220- if j == nothing
221- valid = false
222- @warn (" $info : connected systems $(MT. nameof (sys)) and $(MT. nameof (s)) do not have the same unknowns." )
223- else
224- aunit = safe_get_unit (x, info * string (MT. nameof (sys)) * " #$i " )
225- bunit = safe_get_unit (_unks[j], info * string (MT. nameof (s)) * " #$j " )
226- if ! equivalent (aunit, bunit)
227- valid = false
228- @warn (" $info : connected system unknowns $x and $(_unks[j]) have mismatched units." )
229- end
230- end
231- end
232- end
233- valid
234- end
235-
236- function validate (jump:: Union{VariableRateJump, ConstantRateJump} , t:: Symbolic ; info:: String = " " )
237- newinfo = replace (info, " eq." => " jump" )
238- _validate ([jump. rate, 1 / t], [" rate" , " 1/t" ], info = newinfo) && # Assuming the rate is per time units
239- validate (jump. affect!, info = newinfo)
240- end
241-
242- function validate (jump:: MassActionJump , t:: Symbolic ; info:: String = " " )
243- left_symbols = [x[1 ] for x in jump. reactant_stoch] # vector of pairs of symbol,int -> vector symbols
244- net_symbols = [x[1 ] for x in jump. net_stoch]
245- all_symbols = vcat (left_symbols, net_symbols)
246- allgood = _validate (all_symbols, string .(all_symbols); info)
247- n = sum (x -> x[2 ], jump. reactant_stoch, init = 0 )
248- base_unitful = all_symbols[1 ] # all same, get first
249- allgood && _validate ([jump. scaled_rates, 1 / (t * base_unitful^ n)],
250- [" scaled_rates" , " 1/(t*reactants^$n ))" ]; info)
251- end
252-
253- function validate (jumps:: Vector{JumpType} , t:: Symbolic )
254- labels = [" in Mass Action Jumps," , " in Constant Rate Jumps," , " in Variable Rate Jumps," ]
255- majs = filter (x -> x isa MassActionJump, jumps)
256- crjs = filter (x -> x isa ConstantRateJump, jumps)
257- vrjs = filter (x -> x isa VariableRateJump, jumps)
258- splitjumps = [majs, crjs, vrjs]
259- all ([validate (js, t; info) for (js, info) in zip (splitjumps, labels)])
260- end
261-
262- function validate (eq:: MT.Equation ; info:: String = " " )
263- if typeof (eq. lhs) == Connection
264- _validate (eq. rhs; info)
265- else
266- _validate ([eq. lhs, eq. rhs], [" left" , " right" ]; info)
267- end
268- end
269-
270- function validate (eq:: MT.Equation , term:: Union{Symbolic, Unitful.Quantity, Num} ; info:: String = " " )
271- _validate ([eq. lhs, eq. rhs, term], [" left" , " right" , " noise" ]; info)
272- end
273-
274- function validate (eq:: MT.Equation , terms:: Vector ; info:: String = " " )
275- _validate (vcat ([eq. lhs, eq. rhs], terms),
276- vcat ([" left" , " right" ], " noise #" .* string .(1 : length (terms))); info)
277- end
278-
279- """
280- Returns true iff units of equations are valid.
281- """
282- function validate (eqs:: Vector ; info:: String = " " )
283- all ([validate (eqs[idx], info = info * " in eq. #$idx " ) for idx in 1 : length (eqs)])
284- end
285-
286- function validate (eqs:: Vector , noise:: Vector ; info:: String = " " )
287- all ([validate (eqs[idx], noise[idx], info = info * " in eq. #$idx " )
288- for idx in 1 : length (eqs)])
289- end
290-
291- function validate (eqs:: Vector , noise:: Matrix ; info:: String = " " )
292- all ([validate (eqs[idx], noise[idx, :], info = info * " in eq. #$idx " )
293- for idx in 1 : length (eqs)])
294- end
295-
296- function validate (eqs:: Vector , term:: Symbolic ; info:: String = " " )
297- all ([validate (eqs[idx], term, info = info * " in eq. #$idx " ) for idx in 1 : length (eqs)])
298- end
299-
300- validate (term:: Symbolic ) = safe_get_unit (term, " " ) != = nothing
301-
302- """
303- Throws error if units of equations are invalid.
304- """
305- function check_units (:: Val{:Unitful} , eqs... )
306- validate (eqs... ) ||
307- throw (ValidationError (" Some equations had invalid units. See warnings for details." ))
308- end
50+ # Unitful-specific dimension error detection for model parsing
51+ MT. _is_dimension_error (e:: Unitful.DimensionError ) = true
30952
310- # Model parsing functions for Unitful
311- function convert_units (varunits:: Unitful.FreeUnits , value)
53+ # Unitful-specific convert_units methods for model parsing
54+ function MT . convert_units (varunits:: Unitful.FreeUnits , value)
31255 Unitful. ustrip (varunits, value)
31356end
31457
315- convert_units (:: Unitful.FreeUnits , value:: MT.NoValue ) = MT. NO_VALUE
58+ MT . convert_units (:: Unitful.FreeUnits , value:: MT.NoValue ) = MT. NO_VALUE
31659
317- function convert_units (varunits:: Unitful.FreeUnits , value:: AbstractArray{T} ) where {T}
60+ function MT . convert_units (varunits:: Unitful.FreeUnits , value:: AbstractArray{T} ) where {T}
31861 Unitful. ustrip .(varunits, value)
31962end
32063
321- convert_units (:: Unitful.FreeUnits , value:: Num ) = value
64+ MT . convert_units (:: Unitful.FreeUnits , value:: MT. Num ) = value
32265
323- # Extend model parsing error handling to include Unitful.DimensionError
324- MT. _is_dimension_error (e:: Unitful.DimensionError ) = true
66+ # Unitful-specific check_units method
67+ function MT. check_units (:: Val{:Unitful} , eqs... )
68+ # Use the main package's validate function
69+ MT. validate (eqs... ) ||
70+ throw (ValidationError (" Some equations had invalid units. See warnings for details." ))
71+ end
32572
32673# Define Unitful time variables (moved from main module)
32774const t_unitful = let
32875 MT. only (MT. @independent_variables t [unit = Unitful. u " s" ])
32976end
33077const D_unitful = MT. Differential (t_unitful)
33178
332- # Create a UnitfulUnitCheck module for backward compatibility
333- module UnitfulUnitCheck
334- using .. ModelingToolkitUnitfulExt
335- # Re-export all functions from the extension for backward compatibility
336- const equivalent = ModelingToolkitUnitfulExt . equivalent
337- const unitless = ModelingToolkitUnitfulExt . unitless
338- const get_unit = ModelingToolkitUnitfulExt . get_unit
339- const get_literal_unit = ModelingToolkitUnitfulExt . get_literal_unit
340- const safe_get_unit = ModelingToolkitUnitfulExt . safe_get_unit
341- const validate = ModelingToolkitUnitfulExt . validate
342- const screen_unit = ModelingToolkitUnitfulExt . screen_unit
343- end
79+ # For backward compatibility - provide UnitfulUnitCheck module interface
80+ # Extensions can access all the main package functions through MT
81+ const UnitfulUnitCheck = (
82+ equivalent = MT . equivalent,
83+ unitless = Unitful . unit ( 1 ),
84+ get_unit = MT . get_unit,
85+ get_literal_unit = MT . get_literal_unit,
86+ safe_get_unit = MT . safe_get_unit,
87+ validate = MT . validate,
88+ screen_unit = MT . screen_unit,
89+ ValidationError = ValidationError
90+ )
34491
34592end # module
0 commit comments