(2024,一致性模型,强化学习,MDP,DDPO)一致性模型的强化学习:更快的奖励引导文本到图像生成

RL for Consistency Models: Faster Reward Guided Text-to-Image Generation

公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
部分图像上传缓慢,可看原论文或在 EDPJ 查看 

目录

0. 摘要

3. 基础

3.1 强化学习

3.2 扩散模型与一致性模型

3.3 用于扩散模型的强化学习

4. 一致性模型的强化学习

5. 实验 

6. 结论和未来方向


0. 摘要

强化学习(Reinforcement learning,RL)通过直接优化捕捉图像质量、美学和指令遵循能力的奖励,改进了扩散模型引导图像生成。然而,由此产生的生成策略继承了扩散模型的相同迭代采样过程,导致生成速度缓慢。为了克服这一限制,一致性模型提出了学习一种新的生成模型,直接将噪声映射到数据,从而产生一种可以在至少一个采样迭代中生成图像的模型。在这项工作中,为了针对任务特定的奖励优化文本到图像的生成模型,并实现快速训练和推断,我们提出了一个通过 RL进行细化的一致性模型的框架。我们的框架,称为一致性模型的强化学习(Reinforcement Learning for Consistency Model,RLCM),将一致性模型的迭代推理过程构建为 RL 过程。RLCM 在文本到图像生成能力上改进了 RL 细化的扩散模型,并在推理时交换计算以获得样本质量。在实验中,我们展示了 RLCM 能够将文本到图像一致性模型调整到使用提示难以表达的目标,例如图像可压缩性,以及从人类反馈中得出的目标,例如美学质量。与 RL 细化的扩散模型相比,RLCM 的训练速度显著更快,根据奖励目标测量的生成质量得到了提高,并通过在仅两个推理步骤中生成高质量图像加速了推理过程。

项目页面:https://rlcm.owenoertell.com/

3. 基础

3.1 强化学习

我们将我们的序贯决策过程建模为有限时间段的马尔可夫决策过程(MDP),M = (S, A, P, R, μ, H)。在这个元组中,我们定义了我们的状态空间 S,动作空间 A,转移函数 P: S × A → Δ(S),奖励函数 R: S × A → R,初始状态分布 μ 和时间段(horizon) H。在每个时间步 t,代理观察到一个状态 s_t ∈ S,根据策略 π(a_t | s_t) 采取一个动作,并过渡到下一个状态 s_(t+1) ∼ P(s_(t+1) | s_t, a_t)。经过 H 个时间步后,代理生成一个轨迹,作为状态和动作序列 τ = (s_0, a_0, s_1, a_1, . . . , s_H, a_H)。我们的目标是学习一个策略 π,最大化从 π 中采样的轨迹上的期望累积奖励。

3.2 扩散模型与一致性模型

生成模型旨在将模型与数据分布匹配,这样我们就可以通过从分布中采样来随意合成新的数据点。扩散模型属于一种新颖的生成模型类型,它使用分数函数而不是密度函数来描述概率分布。具体而言,它通过逐渐修改数据分布然后通过连续去噪步骤从噪声中生成样本来产生数据。更正式地说,我们从数据分布 p_data(x) 开始,根据随机微分方程(SDE)(Song 等人,2020年)将其与噪声混合:

对于给定的 t ∈ [0, T],固定常数 T > 0,并且漂移(drift)系数 μ(·, ·)、扩散系数 σ(·),{w}_(t∈[0,T]) 是布朗运动。令 p_0(x) = p_data(x),p_t(x) 为由上述 SDE 引起的时间 t 的边际分布,如 Song 等人 (2020) 所示,存在一个 ODE(也称为概率流),其在时间 t 的引起(induced)分布也是 p_t(x)。特别地: 

∇log pt(xt) 也被称为得分函数(Song & Ermon,2019年;Song 等人,2020年)。在这种设置下训练扩散模型时,使用一种称为得分匹配的技术(Dinh 等人,2016年;Vincent,2011年),其中训练一个网络来近似得分函数,然后使用 ODE 求解器采样轨迹。一旦我们学习到这样一个近似得分函数的神经网络,我们可以通过从 T 到 0 向后时间积分上述 ODE 来生成图像,其中 xT ∼ pT,这通常是一个可处理的分布(例如,在大多数扩散模型的公式中是高斯分布)。

这种技术明显受到一个事实的制约,在生成过程中,必须对 ODE 求解器进行大量步骤的反向(从 T 到 0)才能获得具有竞争力的样本(Song 等人,2023年)。为了缓解这个问题,Song 等人(2023年)提出了一致性模型,其目标是直接将噪声样本映射到数据。目标变成了在给定概率流上学习一致性函数。这个函数的目的是对于任意的 t,t′ ∈ [ϵ,T],概率流 ODE 上的两个样本,它们通过一致性函数映射到相同的图像:f_θ(xt, t) = f_θ(xt′ , t′) = x_ϵ,其中 x_ϵ 是时间 ϵ 处 ODE 的解。在高层次上,这个一致性函数通过取两个相邻的时间步长并在某个图像距离度量下最小化一致性损失 d(fθ(xt, t), fθ(xt′ , t′)) 来训练。为了避免一个恒定的平凡解,我们还将初始条件设置为 fθ(xϵ, ϵ) = xϵ。

一致性模型中的推理:在模型训练后,可以使用附录 A 算法 2 中给出的多步推理过程来将推理时间交换为生成质量。在高层次上,多步一致性采样算法首先将概率流分成 H + 1 个点(T = τ0 > τ1 > τ2 . . . > τH = ϵ)。给定一个样本 xT ∼ pT,然后在(xT,T)处应用一致性函数 fθ,得到 ^x0。为了进一步提高 ^x0 的质量,可以使用方程

再次将噪声添加到 ^x0 中,然后在(^x_(τn),τn)处应用一致性函数,得到 ^x0。可以重复这个过程几步,直到生成质量满意为止。在本文的其余部分,我们将引用使用多步程序进行采样。我们在稍后介绍 RLCM 时也会提供更多细节。

3.3 用于扩散模型的强化学习

Black 等人(2024年)和 Fan 等人(2023年)将条件扩散概率模型(Sohl-Dickstein 等人,2015年;Ho 等人,2020年)的训练和微调形式化为一个马尔可夫决策过程(MDP)。Black 等人(2024年)定义了一类算法,称为去噪扩散策略优化(DDPO),该算法优化任意奖励函数以改进使用 RL 对扩散模型进行引导微调。

扩散模型去噪作为 MDP 的条件扩散概率模型在上下文 c(在文本到图像生成的情况下,为提示)上进行条件。如 DDPO 所介绍的,我们将迭代去噪过程映射到 MDP M = (S, A, P, R, μ, H)。让 r(s, c) 为任务奖励函数。另外,注意概率流从 xT → x0 进行。将 T = τ0 > τ1 > τ2 . . . > τH = ϵ 划分为概率流间隔:

其中 δy 是在 y 处非零密度的狄拉克 delta 分布。换句话说,我们将图像映射为状态,并将去噪流中下一个状态的预测作为动作。此外,我们可以将确定性动态看作是让下一个状态成为策略选择的动作。最后,我们可以认为每个状态的奖励直到轨迹结束时都为 0,然后根据任务奖励函数评估最终图像。这种表述允许以下损失项: 

其中使用修剪来确保当我们优化 pθ 时,新策略保持接近 pθold,这是一个由著名算法 Proximal Policy Optimization (PPO)(Schulman 等人,2017年)推广的技巧。

在扩散模型中,通常将 horizon H 设置为 50 或更大,时间 T 设置为 1000。选择小的步长用于 ODE 求解器以最小化误差,确保生成高质量图像,正如 Ho 等人 (2020年) 所展示的那样。由于长时间跨度和稀疏奖励,使用强化学习训练扩散模型可能具有挑战性。

4. 一致性模型的强化学习

为了解决在扩散模型的 MDP 制定过程中发生的长期推理时间跨度,我们将一致性模型重新构建为一个 MDP。我们也让 H 表示此 MDP 的时间跨度。就像我们对 DDPO 所做的一样,我们将整个概率流 ([0, T]) 划分为段,T = τ0 > τ1 > . . . > τH = ϵ。在本节中,我们将 t 表示为 MDP 中的离散时间步长,即 t ∈ {0, 1, . . . ,H},而 τt 是连续时间区间 [0, T] 中的相应时间。我们现在提出一致性模型 MDP 公式。

一致性模型推理作为MDP。我们将一致性模型中的多步推理过程(算法2)重新制定为 MDP:

其中 Z 是算法 2 中第 5 行的噪声。此外,r(·, ·) 是我们用来对齐模型的奖励函数,RH 是时间步 H 的奖励。在其他时间步上,我们让奖励为 0。我们可以在图 2 中可视化从多步推理到 MDP 的转换。将 MDP 建模为策略 

而不是将 π(·) 定义为一致性函数本身,这有一个重要的好处,即这使我们得到一个随机策略而不是确定性算法(例如 DPG (Silver 等人,2014年),我们发现这种算法不稳定且通常不是无偏的。因此,策略由两部分组成:一致性函数和加入高斯噪声。一致性函数采用图 2 中红色箭头的形式,而噪声则是绿色箭头。换句话说,我们的策略是一个高斯策略,其均值由一致性函数 fθ 模拟,方差为 (τ^2_t −ϵ^2)·I(这里的 I 是单位矩阵)。注意,根据算法 2 中的采样过程,我们只对轨迹的一部分加入噪声。请注意,轨迹的最后一步略有不同。特别地,为了计算最终奖励,我们只需使用一致性函数进行过渡(红/黄色箭头),然后在那里获得最终奖励。

策略梯度 RLCM。我们可以使用策略梯度优化器来实例化 RLCM,与 Black等人(2024年);Fan等人(2023年)的精神相一致。我们的算法描述如算法 1 所示。在实践中,我们会对每个提示的奖励进行归一化。也就是说,我们为每个提示创建一个运行均值和标准差,并将其用作归一化器,而不是在每批次中计算。这是因为在某些奖励模型下,每个提示的平均分数可能会有很大的变化。

5. 实验 

6. 结论和未来方向

我们提出了 RLCM,这是一个快速高效的 RL 框架,可以直接优化各种奖励来训练一致性模型。我们在实证上展示了 RLCM 在大多数任务上都比扩散模型 RL 基线 DDPO 表现更好,同时享受一致性模型的快速训练和推理时间优势。最后,我们提供了微调模型的定性结果,并测试了它们的下游泛化能力。

仍然有一些未探索的方向,我们将其留给未来的工作。特别是,所提出的特定策略梯度方法使用了稀疏奖励。可以考虑使用密集奖励,利用一致性模型始终预测到 x0 的属性。另一个未来的方向是创建一个进一步强化一致性属性的损失,进一步提高 RLCM 策略的推理时间能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/583597.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024/4/29 英语每日一段

Many have turned to cheaper, hand-rolled tobacco instead of normal cigarettes, with young women telling The Times that the habit was a social way to get rid of “anxious energy”. The news comes as the government voted on Tuesday to phase out smoking in Br…

RCE复习(ctfhub下)

先了解一下命令注入的知识点: 知识点 1、常见的拼接符 A ; B 先执行A,再执行BA & B 简单的拼接A | B 显示B的执行结果A&&B A执行成功之后才会执行BA || B A执行失败之后才会执行B , 在特殊情况下可代替空格…

pytorch 实现语义分割 PSPNet

语意分割是指一张图片上包含多个物体,通过语义分割可以识别物体分类、物体名称、像素识别的任务。和物体检测不同,他不会将物体框出来,而是根据像素的归属把物体标注出来。PSPNet 的输入是一张图片,例如300500,那么输出…

Redis基本數據結構 ― List

Redis基本數據結構 ― List 介紹常用命令範例1. 將元素推入List中2. 取得List內容3. 彈出元素 介紹 Redis中的List結構是一個雙向鏈表。 LPUSH LPOP StackLPUSH RPOP QueueLPUSH BRPOP Queue(消息隊列) 常用命令 命令功能LPUSH將元素推入列表左端RPUSH將元素推入列表右…

特别推荐一个学习开发编程的网站

http://www.somecore.cn/ 为开发人员提供一系列好看的技术备忘单,方便开发过程中速查基本语法、快捷键、命令,节省查找时间,提高开发效率。 【人生苦短,抓住重点】

Java 面向对象—重载和重写/覆盖(面试)

重载和重写/覆盖: 重载(overload): Java重载是发生在本类中的,允许同一个类中,有多个同名方法存在,方法名可以相同,方法参数的个数和类型不同,即要求形参列表不一致。重载…

有趣的 CSS 图标整合技术!sprites精灵图,css贴图定位

你好,我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃-大专生,一枚程序媛,感谢关注。回复 “前端基础题”,可免费获得前端基础 100 题汇总,回复 “前端工具”,可获取 Web 开发工具合…

【C语言进阶】程序编译中的预处理操作

📚作者简介:爱编程的小马,正在学习C/C,Linux及MySQL.. 📚以后会将数据结构收录为一个系列,敬请期待 ● 本期内容讲解C语言中程序预处理要做的事情 目录 1.1 预处理符号 1.2 #define 1.2.1 #define定义标识…

数据结构(01)——链表OJ

目录 移除链表元素 思路1 不创建虚拟头节点 思路2 创建虚拟头节点 反转链表 寻找链表中间节点 判断链表是否相交 回文链表 环形链表 环形链表|| 移除链表元素 . - 力扣(LeetCode) 要想移除链表的元素,那么只需要将目标节点的前一…

07_for循环返回值while循环

文章目录 1.循环返回值2.yield接收for返回值3.scala调用yield方法创建线程对象4.scala中的while循环5.scala中的流程控制 1.循环返回值 for循环返回值是Unit 原因是防止产生歧义; 2.yield接收for返回值 // 2.yield关键字打破循环,可以使for循环输出…

智慧农业设备——虫情监测系统

随着科技的不断进步和农业生产的日益现代化,智慧农业成为了新时代农业发展的重要方向。其中,虫情监测系统作为智慧农业的重要组成部分,正逐渐受到广大农户和农业专家的关注。 虫情监测系统是一种基于现代传感技术、图像识别技术和大数据分析技…

面试笔记——线程池

线程池的核心参数&#xff08;原理&#xff09; public ThreadPoolExecutor(int corePoolSize,int maximumPoolSize,long keepAliveTime,TimeUnit unit,BlockingQueue<Runnable> workQueue,ThreadFactory threadFactory,RejectedExecutionHandler handler)corePoolSize …

25计算机考研院校数据分析 | 四川大学

四川大学(Sichuan University)简称“川大”&#xff0c;由中华人民共和国教育部直属&#xff0c;中央直管副部级建制&#xff0c;是世界一流大学建设高校、985工程”、"211工程"重点建设的高水平综合性全国重点大学&#xff0c;入选”2011计划"、"珠峰计划…

PostgreSQL的学习心得和知识总结(一百四十)|深入理解PostgreSQL数据库 psql工具 \set 变量内部及HOOK机制

目录结构 注&#xff1a;提前言明 本文借鉴了以下博主、书籍或网站的内容&#xff0c;其列表如下&#xff1a; 1、参考书籍&#xff1a;《PostgreSQL数据库内核分析》 2、参考书籍&#xff1a;《数据库事务处理的艺术&#xff1a;事务管理与并发控制》 3、PostgreSQL数据库仓库…

【能力展现】魔改ZXING源码实现商业级DM码检测能力

学习《OpenCV应用开发&#xff1a;入门、进阶与工程化实践》一书 做真正的OpenCV开发者&#xff0c;从入门到入职&#xff0c;一步到位&#xff01; 什么是DM码 dataMatrix是一种二维码&#xff0c;原名datacode&#xff0c;由美国国际资料公司于1989年发明。dataMatrix二维码…

GuildFi升级为Zentry的背后 链游公会的探索与转型

​链游即区块链游戏&#xff0c;指依托区块链技术构建的游戏产品。其与传统游戏的最大区别在于区块链的去中心化特性对玩家的资产有着天然的确权行为&#xff0c;因此玩家在链游中的资产是作为玩家的个人资产存在。较于 GameFi 来说&#xff0c;链游的包含范围更大&#xff0c;…

吴恩达机器学习笔记:第 8 周-14降维(Dimensionality Reduction) 14.3-14.5

目录 第 8 周 14、 降维(Dimensionality Reduction)14.3 主成分分析问题14.4 主成分分析算法14.5 选择主成分的数量 第 8 周 14、 降维(Dimensionality Reduction) 14.3 主成分分析问题 主成分分析(PCA)是最常见的降维算法。 在 PCA 中&#xff0c;我们要做的是找到一个方向…

【高校科研前沿】华东师大白开旭教授博士研究生李珂为一作在RSE发表团队最新成果:基于波谱特征优化的全球大气甲烷智能反演技术

文章简介 论文名称&#xff1a;Developing unbiased estimation of atmospheric methane via machine learning and multiobjective programming based on TROPOMI and GOSAT data&#xff08;基于TROPOMI和GOSAT数据&#xff0c;通过机器学习和多目标规划实现大气甲烷的无偏估…

OS复习笔记ch5-1

引言 讲解完进程和线程之后&#xff0c;我们就要来到进程的并发控制这里&#xff0c;这一章和下一章是考试喜欢考察的点&#xff0c;有可能会出大题&#xff0c;面试也有可能会被频繁问到&#xff0c;所以章节内容较多。请小伙伴们慢慢食用&#xff0c;看完之后多思考加强消化…

【JPE】顶刊测算-工业智能化数据(附stata代码)

数据来源&#xff1a;国家TJ局、CEC2008、IFR数据 时间跨度&#xff1a;2006-2019年 数据范围&#xff1a;各省、地级市 数据指标&#xff1a; 本数据集展示了2006-2019年各省、各地级市的共工业智能化水平的数据。本数据集包含三种构建工业机器人密度来反映工业智能化水平的方…
最新文章