typeutils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # @description:
  2. # @author: licanglong
  3. # @date: 2025/6/9 11:39
  4. from dataclasses import is_dataclass, fields, MISSING
  5. from datetime import datetime
  6. from decimal import Decimal
  7. from typing import Type, TypeVar, get_args, get_origin, Union, List, Tuple, Dict, Any, Set, Iterable, FrozenSet
  8. T = TypeVar('T')
  9. def as_dataclass(cls: Type[T], data, ignore_case: bool = True) -> T:
  10. """
  11. 将 dict / list / Dict[str, dict] 转换为 dataclass 实例,支持嵌套和类型转换。
  12. 特性:
  13. - 支持 dataclass、List、Tuple、Set、FrozenSet、Dict、Union 类型。
  14. - 支持嵌套 dataclass、datetime、Decimal 转换。
  15. - 支持通过 `as_dataclass.register_type_converter` 注册自定义类型转换器。
  16. - ignore_case=True 时,支持字典 key 不区分大小写。
  17. 参数:
  18. cls: 目标 dataclass 类型
  19. data: 输入数据 (dict, list 等)
  20. ignore_case: 是否忽略 dict key 大小写
  21. 返回:
  22. cls 类型实例
  23. """
  24. # dataclass 字段缓存,提高性能
  25. _FIELD_CACHE = {}
  26. # 自定义类型转换注册表
  27. type_registry = {}
  28. def register_type_converter(from_type, to_type, converter_func):
  29. type_registry[(to_type, from_type)] = converter_func
  30. as_dataclass.register_type_converter = register_type_converter
  31. def _get_fields(c):
  32. if c not in _FIELD_CACHE:
  33. _FIELD_CACHE[c] = fields(c)
  34. return _FIELD_CACHE[c]
  35. def _convert_value(field_type, value, path="root", ignore_case_local: bool = True):
  36. if value is None:
  37. return None
  38. origin = get_origin(field_type) or getattr(field_type, "__origin__", None)
  39. args = get_args(field_type)
  40. # Union / Optional
  41. if origin is Union:
  42. non_none_args = [arg for arg in args if arg is not type(None)]
  43. # 优先尝试 dataclass
  44. for typ in non_none_args:
  45. if is_dataclass(typ):
  46. try:
  47. return _as_dataclass(typ, value, path, ignore_case_local)
  48. except Exception:
  49. continue
  50. # 尝试其他类型
  51. for typ in non_none_args:
  52. try:
  53. return _convert_value(typ, value, path, ignore_case_local)
  54. except Exception:
  55. continue
  56. return value
  57. # dataclass
  58. if is_dataclass(field_type):
  59. if not isinstance(value, dict):
  60. raise TypeError(f"{path}: Expected dict for dataclass {field_type}, got {type(value)}")
  61. return _as_dataclass(field_type, value, path, ignore_case_local)
  62. # 自定义类型转换器
  63. converter = type_registry.get((field_type, type(value)))
  64. if converter:
  65. return converter(value)
  66. # List
  67. if origin in (list, List):
  68. item_type = args[0] if args else Any
  69. if not isinstance(value, list):
  70. raise TypeError(f"{path}: Expected list, got {type(value)}")
  71. return [_convert_value(item_type, v, f"{path}[]", ignore_case_local) for v in value]
  72. # Tuple
  73. if origin in (tuple, Tuple):
  74. item_type = args[0] if args else Any
  75. if not isinstance(value, (list, tuple)):
  76. raise TypeError(f"{path}: Expected tuple/list, got {type(value)}")
  77. return tuple(_convert_value(item_type, v, f"{path}[]", ignore_case_local) for v in value)
  78. # Set / FrozenSet
  79. if origin in (set, Set):
  80. item_type = args[0] if args else Any
  81. if not isinstance(value, Iterable) or isinstance(value, str):
  82. raise TypeError(f"{path}: Expected iterable for set, got {type(value)}")
  83. return {_convert_value(item_type, v, f"{path}[]", ignore_case_local) for v in value}
  84. if origin in (frozenset, FrozenSet):
  85. item_type = args[0] if args else Any
  86. if not isinstance(value, Iterable) or isinstance(value, str):
  87. raise TypeError(f"{path}: Expected iterable for frozenset, got {type(value)}")
  88. return frozenset(_convert_value(item_type, v, f"{path}[]", ignore_case_local) for v in value)
  89. # Dict
  90. if origin in (dict, Dict):
  91. key_type, val_type = args if args else (Any, Any)
  92. if not isinstance(value, dict):
  93. raise TypeError(f"{path}: Expected dict, got {type(value)}")
  94. # 修复:在嵌套 Dict 时继续传递 ignore_case_local,保证内部 dataclass 解析依然大小写不敏感
  95. return {k: _convert_value(val_type, v, f"{path}[{k}]", ignore_case_local) for k, v in value.items()}
  96. # datetime
  97. if field_type is datetime and isinstance(value, str):
  98. return datetime.fromisoformat(value)
  99. # Decimal
  100. if field_type is Decimal:
  101. return Decimal(str(value))
  102. return value
  103. def _as_dataclass(cls_inner, data_inner, path="root", ignore_case_local: bool = True):
  104. if not is_dataclass(cls_inner):
  105. raise TypeError(f"{cls_inner} must be a dataclass")
  106. if not isinstance(data_inner, dict):
  107. raise TypeError(f"{path}: Expected dict, got {type(data_inner)}")
  108. data_map = {k.lower(): v for k, v in data_inner.items()} if ignore_case_local else data_inner
  109. kwargs = {}
  110. for field in _get_fields(cls_inner):
  111. name = field.name
  112. field_type = field.type
  113. key = name.lower() if ignore_case_local else name
  114. value = data_map.get(key, MISSING)
  115. if value is MISSING:
  116. if field.default is not MISSING:
  117. continue
  118. elif field.default_factory is not MISSING:
  119. kwargs[name] = field.default_factory()
  120. continue
  121. else:
  122. kwargs[name] = None
  123. continue
  124. if value is None:
  125. if field.default_factory is not MISSING:
  126. kwargs[name] = field.default_factory()
  127. else:
  128. kwargs[name] = None
  129. continue
  130. kwargs[name] = _convert_value(field_type, value, f"{path}.{name}", ignore_case_local)
  131. return cls_inner(**kwargs)
  132. # 默认简单类型转换
  133. register_type_converter(str, int, lambda v: int(v))
  134. register_type_converter(int, str, lambda v: str(v))
  135. register_type_converter(float, int, lambda v: round(v))
  136. # 新增:支持 List / Dict / Set 等容器类型作为顶层类型
  137. root_origin = get_origin(cls)
  138. if root_origin in (list, List, tuple, Tuple, set, Set, frozenset, FrozenSet, dict, Dict, Union):
  139. # 传递 ignore_case 参数,确保顶层容器内 dataclass 也能无视 key 大小写
  140. return _convert_value(cls, data, "root", ignore_case)
  141. return _as_dataclass(cls, data, "root", ignore_case)