import json import os import time import logging from typing import Dict, Any, Optional import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry from jimeng import VolcEngineAPIClient class AIAgent: """AI代理类,整合了各种AI功能""" def __init__(self, deepseek_api_key: str = "sk-c2f04b4d635b4df0aace0432f3c6d6f4", max_retries: int = 3, timeout: int = 30): """初始化AI助手 Args: deepseek_api_key: DeepSeek API密钥 max_retries: 最大重试次数 timeout: 请求超时时间(秒) """ self.deepseek_api_key = deepseek_api_key self.timeout = timeout # 配置日志 self._setup_logging() # 创建带重试机制的会话 self.session = self._create_session(max_retries) # 配置常量 self.BASE_REPORT_URL = "https://a.cmdp.cn/basiceg/v1/json/tpldev" self.CACHE_DIR = "cache" def _setup_logging(self): """配置日志""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) self.logger = logging.getLogger(__name__) @staticmethod def _create_session( max_retries: int) -> requests.Session: """创建带重试机制的会话 Args: max_retries: 最大重试次数 Returns: 配置好的会话对象 """ session = requests.Session() # 配置重试策略 retry_strategy = Retry( total=max_retries, status_forcelist=[429, 500, 502, 503, 504], method_whitelist=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"], backoff_factor=1 ) # 为HTTP和HTTPS适配器添加重试机制 adapter = HTTPAdapter(max_retries=retry_strategy) session.mount("http://", adapter) session.mount("https://", adapter) return session def _make_request(self, method: str, url: str, **kwargs) -> Optional[Dict[str, Any]]: """发送HTTP请求(带错误处理) Args: method: HTTP方法 url: 请求URL **kwargs: 其他请求参数 Returns: JSON响应或None """ try: # 设置默认超时 if 'timeout' not in kwargs: kwargs['timeout'] = self.timeout response = self.session.request(method, url, **kwargs) response.raise_for_status() self.logger.info(f"请求成功: {url}") return response.json() except requests.exceptions.RequestException as e: self.logger.error(f"请求失败: {url} - {str(e)}") return None except json.JSONDecodeError as e: self.logger.error(f"JSON解析失败: {url} - {str(e)}") return None def chat_deepseek(self, message: str, role: str = "", model: str = "deepseek-chat") -> str: """与DeepSeek聊天 Args: message: 用户消息 role: 系统角色 model: 使用的模型 Returns: AI回复内容 """ url = "https://api.deepseek.com/chat/completions" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.deepseek_api_key}" } messages = [] if role: messages.append({"role": "system", "content": role}) messages.append({"role": "user", "content": message}) body = { "model": model, "messages": messages, "stream": False } result = self._make_request("POST", url, json=body, headers=headers) if result and "choices" in result: return result['choices'][0]['message']['content'] self.logger.warning("DeepSeek API调用失败或返回空结果") return "" def generate_report_cache_data(self, year: str) -> Dict[str, Any]: """生成报告缓存数据 Args: year: 年份 Returns: 报告数据 """ self.logger.info(f"开始生成{year}年报告缓存数据") base_url = f"{self.BASE_REPORT_URL}/disclosure_assessment_report_data/{year}/" urls = { "rank_a": base_url + "Rank_A.json", "rank_summary": base_url + "Rank_summary.json", "stats": base_url + "stats.json" } # 并行获取数据 data_sources = {} for key, url in urls.items(): self.logger.info(f"获取{key}数据: {url}") data = self._make_request("GET", url) if data: data_sources[key] = data else: self.logger.error(f"获取{key}数据失败") return {} # 处理数据 rank_a_data = data_sources["rank_a"] rank_summary_data = data_sources["rank_summary"] stats_data = data_sources["stats"] # 构建数据结构 data = {} for item in rank_a_data: data[item["gsdm"]] = { "name": item["gsjc"], "num": item["num"], "ranks": {} } # 填充排名数据 for item in rank_summary_data: security_code = item["gsdm"] if security_code in data: year_key = item["kpnd"] rank = item["kpjg"] data[security_code]["ranks"][year_key] = rank # 添加统计信息 data["stats"] = stats_data # 保存到缓存 self._save_cache(year, data) self.logger.info(f"{year}年报告缓存数据生成完成") return data def _save_cache(self, year: str, data: Dict[str, Any]) -> bool: """保存数据到缓存 Args: year: 年份 data: 要缓存的数据 Returns: 是否保存成功 """ try: os.makedirs(self.CACHE_DIR, exist_ok=True) cache_file = os.path.join(self.CACHE_DIR, f"{year}.json") with open(cache_file, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) self.logger.info(f"缓存已保存: {cache_file}") return True except Exception as e: self.logger.error(f"保存缓存失败: {str(e)}") return False def _load_cache(self, year: str) -> Optional[Dict[str, Any]]: """从缓存加载数据 Args: year: 年份 Returns: 缓存数据或None """ cache_file = os.path.join(self.CACHE_DIR, f"{year}.json") if not os.path.exists(cache_file): return None try: with open(cache_file, "r", encoding="utf-8") as f: data = json.load(f) self.logger.info(f"缓存已加载: {cache_file}") return data except Exception as e: self.logger.error(f"加载缓存失败: {str(e)}") return None def get_report_data(self, security_code: str, year: str) -> Dict[str, Any]: """获取报告数据 Args: security_code: 证券代码 year: 年份 Returns: 报告数据 """ # 尝试从缓存加载 cached_data = self._load_cache(year) if cached_data is None: # 缓存不存在,生成新数据 cached_data = self.generate_report_cache_data(year) # 返回特定证券的数据 result = cached_data.get(security_code, {}) result["stats"] = cached_data.get("stats", {}) return result def get_image_prompt(self, message: str, security_code: str, year: str) -> str: """获取图像生成提示词 Args: message: 绘画主题 security_code: 证券代码 year: 年份 Returns: 图像生成提示词 """ self.logger.info(f"生成图像提示词: {message}") with open("cache/painting.prompt", "r", encoding="utf-8") as f: role = f.read() # 从缓存中获取特定证券的数据 security_data = self.get_report_data(security_code, year) security_name = security_data.get("name") rank = security_data.get("ranks",{}).get(year, "N/A") num = security_data.get("num", 0) stats = json.dumps(security_data.get("stats", {}), ensure_ascii=False) prompt_message = f"## 相关信息:证券名称是{security_name},{year}年排名是{rank},已经连续{num}年考评A级。\n{stats}\n\n## 绘画主题:{message}。" prompt = self.chat_deepseek(prompt_message, role) if prompt: self.logger.info("图像提示词生成成功") else: self.logger.warning("图像提示词生成失败") return prompt def get_report_template(self, security_code: str, year: str) -> str: """获取报告模板 Args: security_code: 证券代码 year: 年份 Returns: 报告模板内容 """ url = f"{self.BASE_REPORT_URL}/disclosure_assessment_report_prompt/{year}/{security_code}.txt" try: response = self.session.get(url, timeout=self.timeout) response.raise_for_status() self.logger.info(f"报告模板获取成功: {security_code} - {year}") return response.text except requests.exceptions.RequestException as e: self.logger.error(f"获取报告模板失败: {security_code} - {year} - {str(e)}") return "" def get_report(self, security_code: str, year: str) -> str: """获取报告 Args: security_code: 证券代码 year: 年份 Returns: 生成的报告内容 """ self.logger.info(f"开始生成报告: {security_code} - {year}") template = self.get_report_template(security_code, year) if not template: self.logger.warning("报告模板为空") return "" result = self.chat_deepseek(template, "") if result: self.logger.info("报告生成成功") else: self.logger.warning("报告生成失败") return result def test(): """测试函数""" # 创建AI助手实例,可配置重试次数和超时时间 assistant = AIAgent(max_retries=3, timeout=30) try: # 测试获取报告数据 print("测试获取报告数据...") data = assistant.get_report_data("000001", "2023") print(f"获取到的数据: {data}") # 测试AI聊天 print("\n测试AI聊天...") response = assistant.chat_deepseek("你好,介绍一下你自己", "你是一个有用的助手") print(f"AI回复: {response}") # 测试图像提示词生成 print("\n测试图像提示词生成...") prompt = assistant.get_image_prompt("春天的花园") print(f"图像提示词: {prompt}") # 测试报告生成 print("\n测试报告生成...") report = assistant.get_report("000001", "2023") print(f"报告内容: {report}") except Exception as e: print(f"测试过程中发生错误: {str(e)}") assistant.logger.error(f"主函数执行错误: {str(e)}") agent = AIAgent(max_retries=1, timeout=180)