routes.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # @description:
  2. # @author: licanglong
  3. # @date: 2025/11/20 14:22
  4. import json
  5. import re
  6. import dirtyjson
  7. import openai
  8. from app.client.VectorStoreClient import vector_store_client
  9. from app.constants.vector_store import VectorStoreCollection
  10. from app.core import BizException, CTX
  11. from app.models.Result import SysResult
  12. from app.models.dto import FinalDecisionResult, RiskEvidenceResult, SimilarIdentificationResult
  13. from app.prompt import person_consumption_prompt, external_evidence_search_prompt, similar_identification_prompt
  14. from app.routes.risk import risk_router
  15. from app.service.llm_client import llm_call
  16. @risk_router.post('/decide')
  17. async def risk_decide(invoice_data: dict):
  18. """
  19. 发票风险裁决
  20. :return:
  21. """
  22. vector = vector_store_client.embedding.encode(f"""
  23. 特定业务类型:{invoice_data['tdywlx'] or ''}
  24. 购买方名称: {invoice_data['gmfmc'] or ''}
  25. 货物名称:{invoice_data['hwmc'] or ''}
  26. 规格型号:{invoice_data['ggxh'] or ''}
  27. 开票人:{invoice_data['kpr'] or ''}
  28. """)
  29. rules = await vector_store_client.client.query_points(
  30. collection_name=VectorStoreCollection.RULE_EMBED_STORE,
  31. query=vector.tolist(),
  32. limit=5,
  33. score_threshold=0.5
  34. )
  35. cases = await vector_store_client.client.query_points(
  36. collection_name=VectorStoreCollection.CASE_EMBED_STORE,
  37. query=vector.tolist(),
  38. limit=5,
  39. score_threshold=0.5
  40. )
  41. merchants = await vector_store_client.client.query_points(
  42. collection_name=VectorStoreCollection.MERCHANTS_EMBED_STORE,
  43. query=vector.tolist(),
  44. limit=5,
  45. score_threshold=0.5
  46. )
  47. edges = await vector_store_client.client.query_points(
  48. collection_name=VectorStoreCollection.EDGES_EMBED_STORE,
  49. query=vector.tolist(),
  50. limit=5,
  51. score_threshold=0.5
  52. )
  53. input_data = {
  54. "invoice_context": invoice_data,
  55. "rules": [hit.payload for hit in rules.points],
  56. "cases": [hit.payload for hit in cases.points],
  57. "industry": [hit.payload for hit in merchants.points],
  58. "signals": [hit.payload for hit in edges.points]
  59. }
  60. final_user_prompt = person_consumption_prompt.get_person_consumption_user_prompt(
  61. json.dumps(input_data, ensure_ascii=False))
  62. final_user_prompt = final_user_prompt.replace("{{input_data_desc}}", "")
  63. client = openai.AsyncOpenAI(
  64. api_key=CTX.ENV.getprop("llm.qwen.api_key", raise_error=True),
  65. base_url=CTX.ENV.getprop("llm.qwen.base_url", raise_error=True),
  66. )
  67. completion = await client.chat.completions.create(
  68. model="qwen-plus",
  69. messages=[{'role': 'system', 'content': person_consumption_prompt.system_prompt},
  70. {'role': 'user', 'content': final_user_prompt}]
  71. )
  72. if not completion.choices:
  73. raise BizException("LLM响应异常")
  74. generate_content = completion.choices[0].message.content
  75. decision_result: FinalDecisionResult = FinalDecisionResult.model_validate(dirtyjson.loads(generate_content))
  76. return SysResult.success(data=decision_result)
  77. @risk_router.post('/evidence')
  78. async def evidence_replenish(invoice_data: dict):
  79. input_data = {
  80. "invoice_context": invoice_data
  81. }
  82. final_external_evidence_user_prompt = external_evidence_search_prompt.get_external_evidence_user_prompt(
  83. json.dumps(input_data, ensure_ascii=False))
  84. tools = [
  85. {
  86. "type": "function",
  87. "function": {
  88. "name": "ali_search_tool",
  89. "description": "当需要从互联网获取额外信息时使用",
  90. "parameters": {
  91. "type": "object",
  92. "properties": {
  93. "keyword": {
  94. "type": "string",
  95. "description": "搜索关键词,如果需要限定搜索源可以在结尾加上 <+ 平台名称 >,例如: 如何判断一个企业的经营范围? + 税务局"
  96. }
  97. },
  98. "required": ["keyword"]
  99. }
  100. }
  101. }
  102. ]
  103. generate_content = await llm_call(tools=tools, messages=[
  104. {'role': 'system', 'content': external_evidence_search_prompt.external_evidence_system_prompt},
  105. {'role': 'user', 'content': final_external_evidence_user_prompt}])
  106. evidence_result: RiskEvidenceResult = RiskEvidenceResult.model_validate(dirtyjson.loads(generate_content))
  107. return SysResult.success(data=evidence_result)
  108. @risk_router.post('/similar')
  109. async def similar_identification(invoice_data: dict):
  110. hwmc = invoice_data["hwmc"]
  111. hw_type = None
  112. type_match = re.match(r'\*([^*]+)\*', hwmc)
  113. if type_match:
  114. hw_type = type_match[0]
  115. info = re.sub(r'\*([^*]+)\*', "", hwmc)
  116. if not hw_type or not info:
  117. return SysResult.fail(msg="货物信息不符合规范")
  118. system_prompt = similar_identification_prompt.similar_identification_system_prompt
  119. user_prompt = similar_identification_prompt.get_similar_identification_user_prompt(info, hw_type)
  120. client = openai.AsyncOpenAI(
  121. api_key=CTX.ENV.getprop("llm.qwen.api_key", raise_error=True),
  122. base_url=CTX.ENV.getprop("llm.qwen.base_url", raise_error=True),
  123. )
  124. tools = [
  125. {
  126. "type": "function",
  127. "function": {
  128. "name": "ali_search_tool",
  129. "description": "当需要从互联网获取额外信息时使用",
  130. "parameters": {
  131. "type": "object",
  132. "properties": {
  133. "keyword": {
  134. "type": "string",
  135. "description": "搜索关键词,如果需要限定搜索源可以在结尾加上 <+ 平台名称 >,例如: 行业指标 + 税务局"
  136. }
  137. },
  138. "required": ["keyword"]
  139. }
  140. }
  141. }
  142. ]
  143. generate_content = await llm_call(tools=tools, messages=[
  144. {'role': 'system', 'content': system_prompt},
  145. {'role': 'user', 'content': user_prompt}])
  146. identification_result: SimilarIdentificationResult = SimilarIdentificationResult.model_validate(
  147. dirtyjson.loads(generate_content))
  148. return SysResult.success(data=identification_result)