218 lines
7.6 KiB
Python
218 lines
7.6 KiB
Python
import json
|
||
import sys
|
||
import os
|
||
import base64
|
||
import datetime
|
||
import hashlib
|
||
import hmac
|
||
import time
|
||
|
||
import requests
|
||
|
||
|
||
class VolcEngineAPIClient:
|
||
"""火山引擎API客户端类"""
|
||
|
||
def __init__(self, access_key: str, secret_key: str, service: str = 'cv',
|
||
region: str = 'cn-north-1', endpoint: str = 'https://visual.volcengineapi.com'):
|
||
"""
|
||
初始化API客户端
|
||
|
||
Args:
|
||
access_key: 访问密钥
|
||
secret_key: 秘密密钥
|
||
service: 服务名称,默认为'cv'
|
||
region: 区域,默认为'cn-north-1'
|
||
endpoint: API端点,默认为'https://visual.volcengineapi.com'
|
||
"""
|
||
self.access_key = access_key
|
||
self.secret_key = secret_key
|
||
self.service = service
|
||
self.region = region
|
||
self.endpoint = endpoint
|
||
self.host = 'visual.volcengineapi.com'
|
||
self.method = 'POST'
|
||
|
||
@staticmethod
|
||
def _sign(key, msg: str) -> bytes:
|
||
"""HMAC签名方法"""
|
||
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
|
||
|
||
def _get_signature_key(self, key: str, date_stamp: str, region_name: str, service_name: str) -> bytes:
|
||
"""生成签名密钥"""
|
||
k_date = self._sign(key.encode('utf-8'), date_stamp)
|
||
k_region = self._sign(k_date, region_name)
|
||
k_service = self._sign(k_region, service_name)
|
||
k_signing = self._sign(k_service, 'request')
|
||
return k_signing
|
||
|
||
@staticmethod
|
||
def _format_query(parameters: dict) -> str:
|
||
"""格式化查询参数"""
|
||
request_parameters_init = ''
|
||
for key in sorted(parameters):
|
||
request_parameters_init += key + '=' + parameters[key] + '&'
|
||
request_parameters = request_parameters_init[:-1]
|
||
return request_parameters
|
||
|
||
def _generate_signature_v4(self, req_query: str, req_body: str):
|
||
"""生成V4签名"""
|
||
if self.access_key is None or self.secret_key is None:
|
||
raise ValueError('No access key is available.')
|
||
|
||
t = datetime.datetime.utcnow()
|
||
current_date = t.strftime('%Y%m%dT%H%M%SZ')
|
||
datestamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope
|
||
|
||
canonical_uri = '/'
|
||
canonical_querystring = req_query
|
||
signed_headers = 'content-type;host;x-content-sha256;x-date'
|
||
payload_hash = hashlib.sha256(req_body.encode('utf-8')).hexdigest()
|
||
content_type = 'application/json'
|
||
|
||
canonical_headers = 'content-type:' + content_type + '\n' + 'host:' + self.host + \
|
||
'\n' + 'x-content-sha256:' + payload_hash + \
|
||
'\n' + 'x-date:' + current_date + '\n'
|
||
|
||
canonical_request = self.method + '\n' + canonical_uri + '\n' + canonical_querystring + \
|
||
'\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash
|
||
|
||
algorithm = 'HMAC-SHA256'
|
||
credential_scope = datestamp + '/' + self.region + '/' + self.service + '/' + 'request'
|
||
string_to_sign = algorithm + '\n' + current_date + '\n' + credential_scope + '\n' + hashlib.sha256(
|
||
canonical_request.encode('utf-8')).hexdigest()
|
||
|
||
signing_key = self._get_signature_key(self.secret_key, datestamp, self.region, self.service)
|
||
signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest()
|
||
|
||
authorization_header = algorithm + ' ' + 'Credential=' + self.access_key + '/' + \
|
||
credential_scope + ', ' + 'SignedHeaders=' + \
|
||
signed_headers + ', ' + 'Signature=' + signature
|
||
|
||
headers = {
|
||
'X-Date': current_date,
|
||
'Authorization': authorization_header,
|
||
'X-Content-Sha256': payload_hash,
|
||
'Content-Type': content_type
|
||
}
|
||
|
||
return headers, current_date
|
||
|
||
def send_request(self, query_params: dict, body_params: dict) -> dict:
|
||
"""
|
||
发送API请求
|
||
|
||
Args:
|
||
query_params: 查询参数
|
||
body_params: 请求体参数
|
||
|
||
Returns:
|
||
dict: 响应对象
|
||
"""
|
||
formatted_query = self._format_query(query_params)
|
||
formatted_body = json.dumps(body_params)
|
||
|
||
headers, current_date = self._generate_signature_v4(formatted_query, formatted_body)
|
||
|
||
request_url = self.endpoint + '?' + formatted_query
|
||
|
||
try:
|
||
response = requests.post(request_url, headers=headers, data=formatted_body)
|
||
if response.status_code != 200:
|
||
raise Exception(f'HTTP error! status: {response.status_code}, response: {response.text}')
|
||
return response.json()
|
||
|
||
except Exception as err:
|
||
print(f'error occurred: {err}')
|
||
raise
|
||
|
||
def cv_process(self, req_key: str, **kwargs) -> dict:
|
||
"""
|
||
CV处理接口
|
||
|
||
Args:
|
||
req_key: 请求键
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
dict: 响应对象
|
||
"""
|
||
query_params = {
|
||
'Action': 'CVSync2AsyncSubmitTask',
|
||
'Version': '2022-08-31',
|
||
}
|
||
|
||
body_params = {
|
||
"req_key": req_key,
|
||
**kwargs
|
||
}
|
||
|
||
return self.send_request(query_params, body_params)["data"]
|
||
|
||
def get_image_url(self, req_key, task_id: str) -> dict:
|
||
query_params = {
|
||
'Action': 'CVSync2AsyncGetResult',
|
||
'Version': '2022-08-31',
|
||
}
|
||
req_json = {"return_url": True}
|
||
body_params = {
|
||
"req_key": req_key,
|
||
"task_id": task_id,
|
||
"req_json": json.dumps(req_json, ensure_ascii=False),
|
||
}
|
||
|
||
return self.send_request(query_params, body_params)
|
||
|
||
def get_image(self, req_key, task_id: str) -> bytes:
|
||
image_info = self.get_image_url(req_key, task_id)
|
||
image_urls = image_info["data"]["image_urls"]
|
||
if image_urls:
|
||
image_url = image_urls[0]
|
||
response = requests.get(image_url)
|
||
return response.content
|
||
else:
|
||
return b""
|
||
|
||
def _run(self, prompt: str, blocking_times: int = 5, wait_time: int = 5,model: str = "jimeng_t2i_v31"):
|
||
try:
|
||
task_info = self.cv_process(model,
|
||
prompt=prompt,
|
||
seed=-1,
|
||
width=936,
|
||
height=1664,
|
||
use_pre_llm=False
|
||
)
|
||
|
||
task_id = task_info["task_id"]
|
||
|
||
for i in range(blocking_times):
|
||
time.sleep(wait_time)
|
||
image_info = self.get_image_url(model, task_id)
|
||
image_urls = image_info["data"]["image_urls"]
|
||
if image_urls:
|
||
image_url = image_urls[0]
|
||
return image_url
|
||
else:
|
||
print("生成时间过长")
|
||
return ""
|
||
except Exception as e:
|
||
print(f"请求失败: {e}")
|
||
return ""
|
||
|
||
def run(self, prompt: str,return_type: str = "content", try_times: int = 2):
|
||
for _ in range(try_times):
|
||
url = self._run(prompt,model="jimeng_t2i_v31")
|
||
if url:
|
||
if return_type == "content":
|
||
resp = requests.get(url)
|
||
if resp.status_code != 200:
|
||
raise Exception(f'HTTP error! status: {resp.status_code}, response: {resp.text}')
|
||
return resp.content
|
||
else:
|
||
return url
|
||
else:
|
||
print("生成失败")
|
||
return None
|
||
|
||
|