网易首页 > 网易号 > 正文 申请入驻

Lag-Llama:第一个时间序列预测的开源基础模型介绍和性能测试

0
分享至

2023年10月,我们发表了一篇关于TimeGPT的文章,TimeGPT是时间序列预测的第一个基础模型之一,具有零样本推理、异常检测和共形预测能力。

虽然TimeGPT是一个专有模型,只能通过API访问。但是它还是引发了对时间序列基础模型的更多研究。到了2024年2月,已经有了一个用于时间序列预测的开源基础模型:laglllama。

在原论文《Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting》中,模型作为单变量概率预测的通用基础模型提出。它是由来自不同机构的大型团队开发的,这些机构包括Morgan Stanley, ServiceNow, Université de Montréal, Mila-Quebec, 和McGill University.

在本文中,我们将探讨Lag-Llama的架构、功能以及训练方式。还会将lagllama应用于一个预测项目中,并将其与其他深度学习方法Temporal Fusion Transformer (TFT) 和DeepAR进行性能比较。

Lag-Llama

lagllama是为单变量概率预测而构建的。它使用不依赖于频率的通用方法来标记时间序列数据。这样模型可以很好地推广到不可见的频率。

它利用Transformer体系结构和分布头来解析输入令牌,并将它们映射到具有置信区间的未来预测。

1、具有滞后特征的标记

laglllama的标记策略是使用一组指定的滞后来构造序列的滞后特征。

它将从这个列表中为给定的数据集选择所有合适的频率:

季度、月、周、天、小时、秒

也就是说,如果以每日频率提供数据集,lag - llama将尝试使用每日滞后(t-1),每周滞后(t-7),每月滞后(t-30)等构建特征。

策略如下图所示。

从上图中,我们还可以看到模型构建了其他静态协变量,例如秒/分、小时/天等等,直到季度/年。虽然这可以很好地推广到所有类型的时间序列,但它有一个致命的缺点:由于固定的滞后指数列表,输入令牌可能会变得非常大。

例如,查看每小时数据的每月频率需要730个时间步。这意味着除了所有静态协变量之外,输入令牌的长度至少为730。

2、Lag-Llama架构

Lag-Llama是一个基于transformer的纯解码器模型,其灵感来自大型语言模型LLaMA的体系结构。

架构的示意图如下所示。

从图中可以看到输入标记是滞后时间步长和静态协变量的拼接。输入序列通过线性投影层将特征映射到解码器内部注意力模块的隐藏维度。另外就是在最后的输出,序列被发送到一个分布头负责输出一个概率分布。

在推理过程中,输入序列生成下一个时间点的分布。然后通过自回归,模型逐个生成剩余的预测序列,直到达到设置的长度。

生成预测的自回归过程有效地允许模型为其预测生成不确定性区间。但是这里的问题就是如果序列很长,自回归的方式会将错误扩大。

3、Lag-Llama分布头

Lag-Llama的分布头负责输出概率分布。这样模型就能够生成预测区间。

在模型的迭代中,最后一层使用Student 's t分布来构造不确定性区间。从理论上讲不同的分布头可以组合在一起,但是论文并没有做这样的实验,可能是想在以后在做吧。

4、Lag-Llama的训练

作为一个基础模型,Lag-Llama显然是在大量的时间序列数据语料库上训练的,因此该模型可以很好地泛化未见过的时间序列并进行零样本预测。

论文中说:Lag-Llama在来自不同领域的27个时间序列数据集上进行了训练,如能源、交通、经济等。

数据包含7965个单变量时间序列,总计约3.52亿个令牌。

所有数据集都是开源的,包括ethth, Exchange和Weather等。

Lag-Llama测试

因为代码已经开源,所以我们可以直接测试,我们首先使用Lag-Llama的零样本预测能力,并将其性能与特定数据模型(如TFT和DeepAR)进行比较。

Lag-Llama的实现是建立在GluonTS之上的,所以我们还需要安装这个库。实验使用了澳大利亚电力需求数据集,该数据集包含五个单变量时间序列,以半小时的频率跟踪能源需求。

这里有个说明:Lag-Llama目前的实现是初期阶段。并且存还在积极开发中,后面可能还会有很大的调整,因为目前还没加入微调的功能。

1、环境设置

!git clone https://github.com/time-series-foundation-models/lag-llama/
cd lag-llama
pip install -r requirements.txt --quiet

然后需要我们从HuggingFace下载模型的权重。

!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama

2、加载数据集

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import torch
from itertools import islice
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from lag_llama.gluon.estimator import LagLlamaEstimator

可以直接从GluonTS加载数据集。

dataset = get_dataset("australian_electricity_demand")
backtest_dataset = dataset.test prediction_length = dataset.metadata.prediction_length
context_length = 3 * prediction_length

3、使用Lag-Llama预测

简单地初始化模型并使用LagLlamaEstimator对象。

ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda:0'))
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
estimator = LagLlamaEstimator( ckpt_path="lag-llama.ckpt",
prediction_length=prediction_length,
context_length=context_length,
input_size=estimator_args["input_size"],
n_layer=estimator_args["n_layer"],
n_embd_per_head=estimator_args["n_embd_per_head"],
n_head=estimator_args["n_head"],
scaling=estimator_args["scaling"],
time_feat=estimator_args["time_feat"])
lightning_module = estimator.create_lightning_module()
transformation = estimator.create_transformation()
predictor = estimator.create_predictor(transformation, lightning_module)

使用make_evaluation_predictions函数生成零样本的预测。

forecast_it, ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=predictor)

这个函数返回生成器。我们需要把它们转换成列表。

forecasts = list(forecast_it)
tss = list(ts_it)

4、评估

GluonTS可以使用Evaluator对象方便地计算不同的性能指标。

evaluator = Evaluator()
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

RMSE为481.57。

我们还可以随意地将预测可视化。

plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})
for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 4):
ax = plt.subplot(2, 2, idx+1)
plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target")
forecast.plot( color='g')
plt.xticks(rotation=60)
ax.xaxis.set_major_formatter(date_formater)
ax.set_title(forecast.item_id)
plt.gcf().tight_layout()
plt.legend()
plt.show()

上图可以看到模型对数据做出了合理的预测,尽管它在第四个序列(图的右下角)上确实存在问题。

另外由于 Lag-Llama实现了概率预测,可以得到预测的不确定性区间。

5、与TFT和DeepAR相比

我们在数据集上训练TFT和DeepAR模型,看看它们是否能表现得更好。

为了节省时间,我们将训练设置为5个epoch。

from gluonts.torch import TemporalFusionTransformerEstimator, DeepAREstimator
tft_estimator = TemporalFusionTransformerEstimator(
prediction_length=prediction_length,
context_length=context_length,
freq="30min",
trainer_kwargs={"max_epochs": 5})
deepar_estimator = DeepAREstimator(
prediction_length=prediction_length,
context_length=context_length,
freq="30min",
trainer_kwargs={"max_epochs": 5})

训练过程。

tft_predictor = tft_estimator.train(dataset.train)
deepar_predictor = deepar_estimator.train(dataset.train)

训练完成后,生成预测并计算RMSE。


tft_forecast_it, tft_ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=tft_predictor)
deepar_forecast_it, deepar_ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=deepar_predictor)
tft_forecasts = list(tft_forecast_it)
tft_tss = list(tft_ts_it)
deepar_forecasts = list(deepar_forecast_it)
deepar_tss = list(deepar_ts_it)
# Get evaluation metrics
tft_agg_metrics, tft_ts_metrics = evaluator(iter(tft_tss), iter(tft_forecasts))
deepar_agg_metrics, deepar_ts_metrics = evaluator(iter(deepar_tss), iter(deepar_forecasts))

下表突出显示了性能最好的模型。

可以看到TFT是目前表现最好的模型,DeepAR的表现也优于laglama。

虽然laglllama的表现似乎不尽如人意,但该模型没有经过微调,而且零样本测本身就比较困难。

有趣的是,只训练了5个epoch这两个模型都取得了比Lag-Llama更好的结果。虽然样本预测可以节省时间,但训练五个epoch在时间和计算能力方面的要求应该不是很苛刻。所以目前可能零样本学习方面还需要很大的提升。

总结

在尝试了TimeGPT和Lag-Llama之后,Lag-Llama算是构建开源预测模型的第一步,但与TimeGPT相比,它在功能方面存在不足。

TimeGPT可以处理多变量时间序列、不规则时间戳,并实现共形预测,与使用laglama等固定分布相比,这是一种更稳健的量化不确定性的方式。

laglllama是一个开源的基础模型,只用于单变量概率预测,并且我觉得它训练的数据有点少了。我相信在不久的将来会看到更多的开源预测模型出现。他们的表现可能会得到改善,这代表了该领域的一个重大转变。

最后论文地址:

Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting by K. Rasul, A. Ashok, A. Williams, H. Ghonia, R. Bhagwatkar, A. Khorasani, M. Bayazi, G. Adamopoulos, R. Riachi, N. Hassen, M. Bilos, S. Garg, A. Schneider, N. Chapados, A. Drouin, V. Zantedeschi, Y. Nevmyvaka, I. Rish

https://avoid.overfit.cn/post/8a9120d3cf074c1ba0de0a7a247993c9

作者:Marco Peixeiro

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
她和撒贝宁同居多年,却转身投入富豪怀抱,没想到现在竟沦落至此

她和撒贝宁同居多年,却转身投入富豪怀抱,没想到现在竟沦落至此

民间平哥
2026-01-02 00:25:44
腿粗、胯宽、屁股大的女生怎么穿?吊带背心搭配瑜伽裤,藏肉显瘦

腿粗、胯宽、屁股大的女生怎么穿?吊带背心搭配瑜伽裤,藏肉显瘦

灼灼小齐
2026-01-01 18:51:07
周深一晚“赶”5场跨年演唱会,创下了艺人单日跨年曝光纪录:以不同形式在五台跨年晚会中亮相

周深一晚“赶”5场跨年演唱会,创下了艺人单日跨年曝光纪录:以不同形式在五台跨年晚会中亮相

台州交通广播
2026-01-01 21:27:51
一路走好!距离跨年仅1天,就传来3位名人离世消息,最小才51岁

一路走好!距离跨年仅1天,就传来3位名人离世消息,最小才51岁

社会日日鲜
2026-01-01 14:27:40
一心想骑在人民头上作威作福?“全职考公”这种歪风应当遏制

一心想骑在人民头上作威作福?“全职考公”这种歪风应当遏制

北欧模式
2026-01-01 21:32:49
长春"威哥"后续:已被拘留,知情人曝身份,坑人不是一天两天了

长春"威哥"后续:已被拘留,知情人曝身份,坑人不是一天两天了

奇思妙想草叶君
2026-01-01 09:09:38
杜兰特超越滑翔机成火箭队史单场贡献20分10助攻的最高龄球员

杜兰特超越滑翔机成火箭队史单场贡献20分10助攻的最高龄球员

北青网-北京青年报
2026-01-02 12:34:04
“美女副区长”于媛媛,整容脸,皮肤白皙,严重违反生活纪律

“美女副区长”于媛媛,整容脸,皮肤白皙,严重违反生活纪律

李昕言温度空间
2026-01-01 22:38:33
朱德的女儿,为了生存曾经装哑四年,一生不肯原谅当了叛徒的母亲

朱德的女儿,为了生存曾经装哑四年,一生不肯原谅当了叛徒的母亲

历史龙元阁
2026-01-01 13:00:08
到了年纪就会上演“嘴唇消失术”?嘴唇的厚薄对颜值的影响好直观

到了年纪就会上演“嘴唇消失术”?嘴唇的厚薄对颜值的影响好直观

上官晚安
2026-01-02 00:27:14
新年第一天,乌克兰战场传来好消息

新年第一天,乌克兰战场传来好消息

难得君
2026-01-01 21:40:23
Manus肖弘的20个人生关键细节

Manus肖弘的20个人生关键细节

新浪财经
2025-12-31 12:46:44
泽连斯基新年贺词,相当于提前宣判了俄罗斯军事行动的失败!

泽连斯基新年贺词,相当于提前宣判了俄罗斯军事行动的失败!

娱宙观
2026-01-01 10:03:50
DeepSeek 元旦扔出王炸!CEO 梁文锋亲自署名,要动 AI 用了 10 年的“承重墙”?

DeepSeek 元旦扔出王炸!CEO 梁文锋亲自署名,要动 AI 用了 10 年的“承重墙”?

AI范儿
2026-01-01 20:56:39
车价腰斩、房企崩盘,你以为捡了便宜?其实是资本收割的开始!

车价腰斩、房企崩盘,你以为捡了便宜?其实是资本收割的开始!

流苏晚晴
2025-11-29 17:10:03
阿门-汤普森23分杜兰特22+11 火箭轻取篮网四连胜

阿门-汤普森23分杜兰特22+11 火箭轻取篮网四连胜

北青网-北京青年报
2026-01-02 12:38:05
用力过猛!51岁林志玲打扮“日系”现身上海,网友:又老又年轻

用力过猛!51岁林志玲打扮“日系”现身上海,网友:又老又年轻

完善法
2025-12-31 11:05:31
一碗水端平,中方欲向泰国提供2000万援助

一碗水端平,中方欲向泰国提供2000万援助

跟着老李看世界
2026-01-01 21:01:58
分卫界助攻天花板+控卫界得分天花板,哈登无愧于史上最强双能卫

分卫界助攻天花板+控卫界得分天花板,哈登无愧于史上最强双能卫

大眼瞄世界
2026-01-01 22:09:10
69场造49球仍被弃!31岁外援炮轰申花:不尊重我和家人,高层业余

69场造49球仍被弃!31岁外援炮轰申花:不尊重我和家人,高层业余

我爱英超
2026-01-02 07:15:08
2026-01-02 13:28:49
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1877文章数 1440关注度
往期回顾 全部

科技要闻

新势力年榜:零跑险胜华为,蔚来小鹏新高

头条要闻

东风-5C打击范围覆盖全球成核威慑王牌 军事专家释疑

头条要闻

东风-5C打击范围覆盖全球成核威慑王牌 军事专家释疑

体育要闻

英超离谱夜?4战全平3场0-0 曼城红军翻车

娱乐要闻

武林外传开播20年,郭芙蓉打工期结束

财经要闻

8200亿扩产潮下的锂电供应链之战

汽车要闻

奇瑞汽车12月销量超23万辆 全年超263万辆

态度原创

房产
游戏
本地
艺术
公开课

房产要闻

封关红利爆发!三亚主城大盘 2.2 万 /㎡起,性价比直接封神!

《逃离塔科夫》开年不利:服务器故障难以登陆!

本地新闻

即将过去的2025年,对重庆的影响竟然如此深远

艺术要闻

雷蒙多·德·马德拉索:定义“美丽时代”的肖像大师

公开课

李玫瑾:为什么性格比能力更重要?

无障碍浏览 进入关怀版