Files
Esmart/jimeng.py
2025-09-28 15:42:56 +08:00

218 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import sys
import os
import base64
import datetime
import hashlib
import hmac
import time
import requests
class VolcEngineAPIClient:
"""火山引擎API客户端类"""
def __init__(self, access_key: str, secret_key: str, service: str = 'cv',
region: str = 'cn-north-1', endpoint: str = 'https://visual.volcengineapi.com'):
"""
初始化API客户端
Args:
access_key: 访问密钥
secret_key: 秘密密钥
service: 服务名称,默认为'cv'
region: 区域,默认为'cn-north-1'
endpoint: API端点默认为'https://visual.volcengineapi.com'
"""
self.access_key = access_key
self.secret_key = secret_key
self.service = service
self.region = region
self.endpoint = endpoint
self.host = 'visual.volcengineapi.com'
self.method = 'POST'
@staticmethod
def _sign(key, msg: str) -> bytes:
"""HMAC签名方法"""
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
def _get_signature_key(self, key: str, date_stamp: str, region_name: str, service_name: str) -> bytes:
"""生成签名密钥"""
k_date = self._sign(key.encode('utf-8'), date_stamp)
k_region = self._sign(k_date, region_name)
k_service = self._sign(k_region, service_name)
k_signing = self._sign(k_service, 'request')
return k_signing
@staticmethod
def _format_query(parameters: dict) -> str:
"""格式化查询参数"""
request_parameters_init = ''
for key in sorted(parameters):
request_parameters_init += key + '=' + parameters[key] + '&'
request_parameters = request_parameters_init[:-1]
return request_parameters
def _generate_signature_v4(self, req_query: str, req_body: str):
"""生成V4签名"""
if self.access_key is None or self.secret_key is None:
raise ValueError('No access key is available.')
t = datetime.datetime.utcnow()
current_date = t.strftime('%Y%m%dT%H%M%SZ')
datestamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope
canonical_uri = '/'
canonical_querystring = req_query
signed_headers = 'content-type;host;x-content-sha256;x-date'
payload_hash = hashlib.sha256(req_body.encode('utf-8')).hexdigest()
content_type = 'application/json'
canonical_headers = 'content-type:' + content_type + '\n' + 'host:' + self.host + \
'\n' + 'x-content-sha256:' + payload_hash + \
'\n' + 'x-date:' + current_date + '\n'
canonical_request = self.method + '\n' + canonical_uri + '\n' + canonical_querystring + \
'\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash
algorithm = 'HMAC-SHA256'
credential_scope = datestamp + '/' + self.region + '/' + self.service + '/' + 'request'
string_to_sign = algorithm + '\n' + current_date + '\n' + credential_scope + '\n' + hashlib.sha256(
canonical_request.encode('utf-8')).hexdigest()
signing_key = self._get_signature_key(self.secret_key, datestamp, self.region, self.service)
signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest()
authorization_header = algorithm + ' ' + 'Credential=' + self.access_key + '/' + \
credential_scope + ', ' + 'SignedHeaders=' + \
signed_headers + ', ' + 'Signature=' + signature
headers = {
'X-Date': current_date,
'Authorization': authorization_header,
'X-Content-Sha256': payload_hash,
'Content-Type': content_type
}
return headers, current_date
def send_request(self, query_params: dict, body_params: dict) -> dict:
"""
发送API请求
Args:
query_params: 查询参数
body_params: 请求体参数
Returns:
dict: 响应对象
"""
formatted_query = self._format_query(query_params)
formatted_body = json.dumps(body_params)
headers, current_date = self._generate_signature_v4(formatted_query, formatted_body)
request_url = self.endpoint + '?' + formatted_query
try:
response = requests.post(request_url, headers=headers, data=formatted_body)
if response.status_code != 200:
raise Exception(f'HTTP error! status: {response.status_code}, response: {response.text}')
return response.json()
except Exception as err:
print(f'error occurred: {err}')
raise
def cv_process(self, req_key: str, **kwargs) -> dict:
"""
CV处理接口
Args:
req_key: 请求键
**kwargs: 其他参数
Returns:
dict: 响应对象
"""
query_params = {
'Action': 'CVSync2AsyncSubmitTask',
'Version': '2022-08-31',
}
body_params = {
"req_key": req_key,
**kwargs
}
return self.send_request(query_params, body_params)["data"]
def get_image_url(self, req_key, task_id: str) -> dict:
query_params = {
'Action': 'CVSync2AsyncGetResult',
'Version': '2022-08-31',
}
req_json = {"return_url": True}
body_params = {
"req_key": req_key,
"task_id": task_id,
"req_json": json.dumps(req_json, ensure_ascii=False),
}
return self.send_request(query_params, body_params)
def get_image(self, req_key, task_id: str) -> bytes:
image_info = self.get_image_url(req_key, task_id)
image_urls = image_info["data"]["image_urls"]
if image_urls:
image_url = image_urls[0]
response = requests.get(image_url)
return response.content
else:
return b""
def _run(self, prompt: str, blocking_times: int = 5, wait_time: int = 5,model: str = "jimeng_t2i_v31"):
try:
task_info = self.cv_process(model,
prompt=prompt,
seed=-1,
width=936,
height=1664,
use_pre_llm=False
)
task_id = task_info["task_id"]
for i in range(blocking_times):
time.sleep(wait_time)
image_info = self.get_image_url(model, task_id)
image_urls = image_info["data"]["image_urls"]
if image_urls:
image_url = image_urls[0]
return image_url
else:
print("生成时间过长")
return ""
except Exception as e:
print(f"请求失败: {e}")
return ""
def run(self, prompt: str,return_type: str = "content", try_times: int = 2):
for _ in range(try_times):
url = self._run(prompt,model="jimeng_t2i_v31")
if url:
if return_type == "content":
resp = requests.get(url)
if resp.status_code != 200:
raise Exception(f'HTTP error! status: {resp.status_code}, response: {resp.text}')
return resp.content
else:
return url
else:
print("生成失败")
return None