|
| 1 | +import re |
1 | 2 | from collections.abc import Mapping |
2 | 3 | from functools import lru_cache |
3 | | -from typing import Any, NamedTuple, Sequence, Tuple, Union |
| 4 | +from inspect import signature |
| 5 | +from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union |
4 | 6 | from warnings import warn |
5 | 7 |
|
6 | 8 | from . import _array_module as xp |
7 | 9 | from ._array_module import _UndefinedStub |
| 10 | +from .stubs import name_to_func |
8 | 11 | from .typing import DataType, ScalarType |
9 | 12 |
|
10 | 13 | __all__ = [ |
@@ -242,67 +245,31 @@ def result_type(*dtypes: DataType): |
242 | 245 | return result |
243 | 246 |
|
244 | 247 |
|
245 | | -func_in_dtypes = { |
246 | | - # elementwise |
247 | | - "abs": numeric_dtypes, |
248 | | - "acos": float_dtypes, |
249 | | - "acosh": float_dtypes, |
250 | | - "add": numeric_dtypes, |
251 | | - "asin": float_dtypes, |
252 | | - "asinh": float_dtypes, |
253 | | - "atan": float_dtypes, |
254 | | - "atan2": float_dtypes, |
255 | | - "atanh": float_dtypes, |
256 | | - "bitwise_and": bool_and_all_int_dtypes, |
257 | | - "bitwise_invert": bool_and_all_int_dtypes, |
258 | | - "bitwise_left_shift": all_int_dtypes, |
259 | | - "bitwise_or": bool_and_all_int_dtypes, |
260 | | - "bitwise_right_shift": all_int_dtypes, |
261 | | - "bitwise_xor": bool_and_all_int_dtypes, |
262 | | - "ceil": numeric_dtypes, |
263 | | - "cos": float_dtypes, |
264 | | - "cosh": float_dtypes, |
265 | | - "divide": float_dtypes, |
266 | | - "equal": all_dtypes, |
267 | | - "exp": float_dtypes, |
268 | | - "expm1": float_dtypes, |
269 | | - "floor": numeric_dtypes, |
270 | | - "floor_divide": numeric_dtypes, |
271 | | - "greater": numeric_dtypes, |
272 | | - "greater_equal": numeric_dtypes, |
273 | | - "isfinite": numeric_dtypes, |
274 | | - "isinf": numeric_dtypes, |
275 | | - "isnan": numeric_dtypes, |
276 | | - "less": numeric_dtypes, |
277 | | - "less_equal": numeric_dtypes, |
278 | | - "log": float_dtypes, |
279 | | - "logaddexp": float_dtypes, |
280 | | - "log10": float_dtypes, |
281 | | - "log1p": float_dtypes, |
282 | | - "log2": float_dtypes, |
283 | | - "logical_and": (xp.bool,), |
284 | | - "logical_not": (xp.bool,), |
285 | | - "logical_or": (xp.bool,), |
286 | | - "logical_xor": (xp.bool,), |
287 | | - "multiply": numeric_dtypes, |
288 | | - "negative": numeric_dtypes, |
289 | | - "not_equal": all_dtypes, |
290 | | - "positive": numeric_dtypes, |
291 | | - "pow": numeric_dtypes, |
292 | | - "remainder": numeric_dtypes, |
293 | | - "round": numeric_dtypes, |
294 | | - "sign": numeric_dtypes, |
295 | | - "sin": float_dtypes, |
296 | | - "sinh": float_dtypes, |
297 | | - "sqrt": float_dtypes, |
298 | | - "square": numeric_dtypes, |
299 | | - "subtract": numeric_dtypes, |
300 | | - "tan": float_dtypes, |
301 | | - "tanh": float_dtypes, |
302 | | - "trunc": numeric_dtypes, |
303 | | - # searching |
304 | | - "where": all_dtypes, |
| 248 | +r_alias = re.compile("[aA]lias") |
| 249 | +r_in_dtypes = re.compile("x1?: array\n.+have an? (.+) data type.") |
| 250 | +r_int_note = re.compile( |
| 251 | + "If one or both of the input arrays have integer data types, " |
| 252 | + "the result is implementation-dependent" |
| 253 | +) |
| 254 | +category_to_dtypes = { |
| 255 | + "boolean": (xp.bool,), |
| 256 | + "integer": all_int_dtypes, |
| 257 | + "floating-point": float_dtypes, |
| 258 | + "numeric": numeric_dtypes, |
| 259 | + "integer or boolean": bool_and_all_int_dtypes, |
305 | 260 | } |
| 261 | +func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {} |
| 262 | +for name, func in name_to_func.items(): |
| 263 | + if m := r_in_dtypes.search(func.__doc__): |
| 264 | + dtype_category = m.group(1) |
| 265 | + if dtype_category == "numeric" and r_int_note.search(func.__doc__): |
| 266 | + dtype_category = "floating-point" |
| 267 | + dtypes = category_to_dtypes[dtype_category] |
| 268 | + func_in_dtypes[name] = dtypes |
| 269 | + elif any("x" in name for name in signature(func).parameters.keys()): |
| 270 | + func_in_dtypes[name] = all_dtypes |
| 271 | +# See https://github.com/data-apis/array-api/pull/413 |
| 272 | +func_in_dtypes["expm1"] = float_dtypes |
306 | 273 |
|
307 | 274 |
|
308 | 275 | func_returns_bool = { |
@@ -365,6 +332,8 @@ def result_type(*dtypes: DataType): |
365 | 332 | "trunc": False, |
366 | 333 | # searching |
367 | 334 | "where": False, |
| 335 | + # linalg |
| 336 | + "matmul": False, |
368 | 337 | } |
369 | 338 |
|
370 | 339 |
|
@@ -408,7 +377,7 @@ def result_type(*dtypes: DataType): |
408 | 377 | "__gt__": "greater", |
409 | 378 | "__le__": "less_equal", |
410 | 379 | "__lt__": "less", |
411 | | - # '__matmul__': 'matmul', # TODO: support matmul |
| 380 | + "__matmul__": "matmul", |
412 | 381 | "__mod__": "remainder", |
413 | 382 | "__mul__": "multiply", |
414 | 383 | "__ne__": "not_equal", |
@@ -440,6 +409,14 @@ def result_type(*dtypes: DataType): |
440 | 409 | func_returns_bool[iop] = func_returns_bool[op] |
441 | 410 |
|
442 | 411 |
|
| 412 | +func_in_dtypes["__bool__"] = (xp.bool,) |
| 413 | +func_in_dtypes["__int__"] = all_int_dtypes |
| 414 | +func_in_dtypes["__index__"] = all_int_dtypes |
| 415 | +func_in_dtypes["__float__"] = float_dtypes |
| 416 | +func_in_dtypes["from_dlpack"] = numeric_dtypes |
| 417 | +func_in_dtypes["__dlpack__"] = numeric_dtypes |
| 418 | + |
| 419 | + |
443 | 420 | @lru_cache |
444 | 421 | def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: |
445 | 422 | f_types = [] |
|
0 commit comments