295 lines
10 KiB
Python
295 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
主程序
|
||
"""
|
||
import warnings
|
||
|
||
warnings.filterwarnings(
|
||
action="ignore", category=UserWarning, module="volcenginesdkarkruntime.*"
|
||
)
|
||
|
||
from base64 import b64decode, b64encode
|
||
import json
|
||
from pathlib import Path
|
||
import sys
|
||
from time import sleep
|
||
from typing import Any, Dict, List
|
||
from uuid import uuid4
|
||
|
||
from volcenginesdkarkruntime import Ark
|
||
from volcenginesdkarkruntime.types.responses import (
|
||
ResponseOutputMessage,
|
||
ResponseOutputText,
|
||
)
|
||
|
||
sys.path.append(Path(__file__).parent.parent.as_posix())
|
||
|
||
from utils.request import Request
|
||
from create_draft import JianYingDraftGenerator
|
||
|
||
|
||
# 初始化火山引擎 Ark 客户端
|
||
ark_client = Ark(
|
||
base_url="https://ark.cn-beijing.volces.com/api/v3",
|
||
api_key="2c28ab07-888c-45be-84a2-fc4b2cb5f3f2",
|
||
) # 本人火山引擎账密
|
||
|
||
# 初始化请求客户端
|
||
request_client = Request()
|
||
|
||
|
||
def get_brand_words() -> List[str]:
|
||
"""
|
||
获取品牌词
|
||
:return: 品牌词
|
||
"""
|
||
try:
|
||
with open(
|
||
file=Path(__file__).parent / "brand_words.txt", mode="r", encoding="utf-8"
|
||
) as file: # Trae IDE 需要指定文件路径(和 PyCharm 不同)
|
||
brand_words = [line.strip() for line in file.readlines() if line.strip()]
|
||
assert brand_words, "品牌词为空"
|
||
return brand_words
|
||
except FileNotFoundError:
|
||
raise FileNotFoundError("未找到品牌词文件")
|
||
except Exception as exception:
|
||
raise exception
|
||
|
||
|
||
def generate_task_id() -> str:
|
||
"""
|
||
生成任务标识
|
||
:return: 任务ID
|
||
"""
|
||
return uuid4().hex.upper().replace("-", "")
|
||
|
||
|
||
def get_storyboard(brand_word: str) -> Dict[str, Any]:
|
||
"""
|
||
获取分镜脚本
|
||
:param brand_word: 品牌词
|
||
:return: 分镜脚本
|
||
"""
|
||
try:
|
||
with open(
|
||
file=Path(__file__).parent / "storyboard_prompt_template.txt",
|
||
mode="r",
|
||
encoding="utf-8",
|
||
) as file:
|
||
storyboard_prompt = file.read()
|
||
# 重构分镜脚本的提示词
|
||
storyboard_prompt = storyboard_prompt.replace("{{品牌词}}", brand_word)
|
||
except FileNotFoundError:
|
||
raise FileNotFoundError("未找到分镜脚本的提示词模板文件")
|
||
except Exception as exception:
|
||
raise exception
|
||
|
||
retries = 0 # 重试次数
|
||
while True:
|
||
try:
|
||
# 调用 doubao-seed 语言大模型,基于分镜脚本的提示词生成分镜脚本
|
||
response = ark_client.responses.create(
|
||
model="doubao-seed-2-0-pro-260215",
|
||
input=[
|
||
{
|
||
"role": "user",
|
||
"content": [{"type": "input_text", "text": storyboard_prompt}],
|
||
}
|
||
],
|
||
)
|
||
# 解析响应并反序列化以此作为分镜脚本
|
||
storyboard = json.loads(
|
||
[item for item in [item for item in response.output if isinstance(item, ResponseOutputMessage)][0].content if isinstance(item, ResponseOutputText)][0].text # type: ignore
|
||
)
|
||
return storyboard
|
||
except Exception as exception:
|
||
retries += 1
|
||
if retries > 2:
|
||
raise Exception(f"获取分镜脚本发生异常,{str(exception)}")
|
||
continue
|
||
|
||
|
||
def generate_frame(frame_prompt: str, frame_name: str) -> None:
|
||
"""
|
||
生成视频帧
|
||
:param frame_prompt: 视频帧提示词
|
||
:param frame_name: 视频帧名称
|
||
:return: None
|
||
"""
|
||
# 构建视频帧路径
|
||
frame_path = Path(__file__).parent / "frames" / frame_name
|
||
# 若视频帧已存在则直接返回
|
||
if frame_path.exists():
|
||
return
|
||
|
||
retries = 0 # 重试次数
|
||
while True:
|
||
try:
|
||
# 调用 doubao-seedream 图像大模型,基于视频帧提示词生成视频帧
|
||
#
|
||
response = ark_client.images.generate(
|
||
model="doubao-seedream-4-5-251128",
|
||
prompt=frame_prompt,
|
||
sequential_image_generation="disabled", # 关闭组图输出
|
||
size="1600x2848",
|
||
watermark=False, # 关闭水印
|
||
response_format="b64_json", # 图像数据为base64编码的JSON字符串
|
||
)
|
||
# 解析响应
|
||
frame_base64 = b64decode(response.data[0].b64_json)
|
||
# 本地保存视频帧
|
||
with open(file=frame_path, mode="wb") as file:
|
||
file.write(frame_base64)
|
||
return
|
||
except Exception as exception:
|
||
retries += 1
|
||
if retries > 2:
|
||
raise Exception(f"生成视频帧发生异常,{str(exception)}")
|
||
continue
|
||
|
||
|
||
def generate_video(
|
||
video_name: str,
|
||
first_frame_name: str,
|
||
last_frame_name: str,
|
||
shot_prompt: str,
|
||
video_duration: int,
|
||
) -> None:
|
||
"""
|
||
生成视频
|
||
:param video_name: 视频名称
|
||
:param first_frame_name: 首帧名称
|
||
:param last_frame_name: 尾帧名称
|
||
:param shot_prompt: 运镜提示词
|
||
:param video_duration: 视频时长
|
||
:return: None
|
||
"""
|
||
# 若视频已存在则直接返回
|
||
if (Path(__file__).parent / "videos" / video_name).exists():
|
||
return
|
||
|
||
# 构建首帧路径
|
||
first_frame_path = Path(__file__).parent / "frames" / first_frame_name
|
||
if not first_frame_path.exists():
|
||
raise RuntimeError(f"首帧 {first_frame_name} 不存在")
|
||
with open(file=first_frame_path, mode="rb") as file:
|
||
first_frame_base64 = b64encode(file.read()).decode("utf-8")
|
||
|
||
# 构建尾帧路径
|
||
last_frame_path = Path(__file__).parent / "frames" / last_frame_name
|
||
if not last_frame_path.exists():
|
||
raise RuntimeError(f"尾帧 {last_frame_name} 不存在")
|
||
with open(file=last_frame_path, mode="rb") as file:
|
||
last_frame_base64 = b64encode(file.read()).decode("utf-8")
|
||
|
||
retries = 0 # 重试次数
|
||
while True:
|
||
try:
|
||
# 调用 doubao-seedrance 视频大模型,基于首尾帧和运镜提示词生成视频
|
||
# 创建视频生成任务
|
||
create_response = ark_client.content_generation.tasks.create(
|
||
model="doubao-seedance-1-5-pro-251215",
|
||
content=[
|
||
{"type": "text", "text": shot_prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{first_frame_base64}"
|
||
},
|
||
"role": "first_frame",
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{last_frame_base64}"
|
||
},
|
||
"role": "last_frame",
|
||
},
|
||
],
|
||
ratio="9:16", # 视频的宽高比例
|
||
duration=video_duration, # 视频时长(doubao-seedance-1-5-pro-251215 有效范围为[4, 12])
|
||
watermark=False, # 关闭水印
|
||
)
|
||
|
||
# 轮询查询视频生成任务
|
||
while True:
|
||
# 查询视频生成任务
|
||
query_response = ark_client.content_generation.tasks.get(
|
||
task_id=create_response.id,
|
||
)
|
||
# 根据视频生成任务的状态匹配处理方法
|
||
match query_response.status:
|
||
case "succeeded":
|
||
video_url = query_response.content.video_url
|
||
# 下载视频
|
||
chunk_generator = request_client.download(
|
||
url=video_url,
|
||
stream_enabled=True, # 开启流式传输
|
||
)
|
||
# 本地保存视频
|
||
with open(
|
||
file=Path(__file__).parent / "videos" / video_name,
|
||
mode="wb",
|
||
) as file:
|
||
for chunk in chunk_generator:
|
||
file.write(chunk)
|
||
return
|
||
case "failed":
|
||
raise Exception(f"{query_response.error}")
|
||
case _:
|
||
sleep(5) # 避免频繁请求查询视频生成任务故显性等待
|
||
except Exception as exception:
|
||
retries += 1
|
||
if retries > 2:
|
||
raise Exception(f"生成视频发生异常,{str(exception)}")
|
||
continue
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 分镜视频时长列表
|
||
video_durations = [15, 4, 8, 8, 4]
|
||
|
||
# 遍历品牌词
|
||
for brand_word in get_brand_words():
|
||
# 生成任务标识
|
||
task_id = generate_task_id()
|
||
|
||
# 获取分镜脚本
|
||
storyboard = get_storyboard(
|
||
brand_word=brand_word,
|
||
)
|
||
|
||
# 遍历分镜
|
||
for i, shot in enumerate(storyboard["分镜脚本"], start=1):
|
||
# 构建分镜首帧名称
|
||
first_frame_name = f"{task_id}_{i:02d}_first.jpeg"
|
||
# 生成分镜首帧
|
||
generate_frame(
|
||
frame_prompt=shot["首帧提示词"], # 分镜首帧提示词
|
||
frame_name=first_frame_name,
|
||
)
|
||
|
||
# 构建分镜尾帧名称
|
||
last_frame_name = f"{task_id}_{i:02d}_last.jpeg"
|
||
# 生成分镜尾帧
|
||
generate_frame(
|
||
frame_prompt=shot["尾帧提示词"], # 分镜尾帧提示词
|
||
frame_name=last_frame_name,
|
||
)
|
||
|
||
# 生成视频
|
||
generate_video(
|
||
video_name=f"{task_id}_{i:02d}.mp4", # 分镜视频名称
|
||
video_duration=video_durations[i], # 分镜视频时长
|
||
first_frame_name=first_frame_name,
|
||
last_frame_name=last_frame_name,
|
||
shot_prompt=shot["运镜提示词"], # 分镜运镜提示词
|
||
)
|
||
|
||
# 生成剪映草稿
|
||
draft_generator = JianYingDraftGenerator()
|
||
draft_generator.create_draft(
|
||
task_id=task_id,
|
||
contents=[shot["口播"] for shot in storyboard["分镜脚本"]],
|
||
)
|