网易首页 > 网易号 > 正文 申请入驻

简单介绍tensorflow2 自定义损失函数使用的隐藏坑

0
分享至

本文主要介绍了tensorflow2 自定义损失函数使用的隐藏坑,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制。当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型。

这里以tensorflow2官网中的例子来说明:

import numpy as np import tensorflow as tf from tensorflow import keras x = np.random.random((1000, 32)) y = np.random.random((1000, 1)) class CustomModel(keras.Model): tf.random.set_seed(100) def train_step(self, data): # Unpack the data. Its structure depends on your model and # on what you pass to `fit()`. x, y = data with tf.GradientTape() as tape: y_pred = self(x, training=True) # Forward pass # Compute the loss value # (the loss function is configured in `compile()`) loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) # Compute gradients trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) # Update weights self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Update metrics (includes the metric that tracks the loss) self.compiled_metrics.update_state(y, y_pred) # Return a dict mapping metric names to current value return {m.name: m.result() for m in self.metrics} # Construct and compile an instance of CustomModel inputs = keras.Input(shape=(32,)) outputs = keras.layers.Dense(1)(inputs) model = CustomModel(inputs, outputs) model.compile(optimizer="adam", loss=tf.losses.MSE, metrics=["mae"]) # Just use `fit` as usual model.fit(x, y, epochs=1, shuffle=False) 32/32 [==============================] - 0s 1ms/step - loss: 0.2783 - mae: 0.4257

这里的loss是tensorflow库中实现了的损失函数,如果想自定义损失函数,然后将损失函数传入model.compile中,能正常按我们预想的work吗?

答案竟然是否定的,而且没有错误提示,只是loss计算不会符合我们的预期。

def custom_mse(y_true, y_pred): return tf.reduce_mean((y_true - y_pred)**2, axis=-1) a_true = tf.constant([1., 1.5, 1.2]) a_pred = tf.constant([1., 2, 1.5]) custom_mse(a_true, a_pred)tf.losses.MSE(a_true, a_pred)

以上结果证实了我们自定义loss的正确性,下面我们直接将自定义的loss置入compile中的loss参数中,看看会发生什么。

my_model = CustomModel(inputs, outputs) my_model.compile(optimizer="adam", loss=custom_mse, metrics=["mae"]) my_model.fit(x, y, epochs=1, shuffle=False) 32/32 [==============================] - 0s 820us/step - loss: 0.1628 - mae: 0.3257

我们看到,这里的loss与我们与标准的tf.losses.MSE明显不同。这说明我们自定义的loss以这种方式直接传递进model.compile中,是完全错误的操作。

正确运用自定义loss的姿势是什么呢?下面揭晓。

loss_tracker = keras.metrics.Mean(name="loss") mae_metric = keras.metrics.MeanAbsoluteError(name="mae") class MyCustomModel(keras.Model): tf.random.set_seed(100) def train_step(self, data): # Unpack the data. Its structure depends on your model and # on what you pass to `fit()`. x, y = data with tf.GradientTape() as tape: y_pred = self(x, training=True) # Forward pass # Compute the loss value # (the loss function is configured in `compile()`) loss = custom_mse(y, y_pred) # loss += self.losses # Compute gradients trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) # Update weights self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Compute our own metrics loss_tracker.update_state(loss) mae_metric.update_state(y, y_pred) return {"loss": loss_tracker.result(), "mae": mae_metric.result()} @property def metrics(self): # We list our `Metric` objects here so that `reset_states()` can be # called automatically at the start of each epoch # or at the start of `evaluate()`. # If you don't implement this property, you have to call # `reset_states()` yourself at the time of your choosing. return [loss_tracker, mae_metric] # Construct and compile an instance of CustomModel inputs = keras.Input(shape=(32,)) outputs = keras.layers.Dense(1)(inputs) my_model_beta = MyCustomModel(inputs, outputs) my_model_beta.compile(optimizer="adam") # Just use `fit` as usual my_model_beta.fit(x, y, epochs=1, shuffle=False) 32/32 [==============================] - 0s 960us/step - loss: 0.2783 - mae: 0.4257


终于,通过跳过在 compile() 中传递损失函数,而在 train_step 中手动完成所有计算内容,我们获得了与之前默认tf.losses.MSE完全一致的输出,这才是我们想要的结果。

总结

当我们在模型中想用自定义的损失函数,不能直接传入fit函数,而是需要在train_step中手动传入,完成计算过程。到此这篇关于tensorflow2 自定义损失函数使用的隐藏坑的文章就介绍到这了。

原文来自:https://www.jb51.net/article/218240.htm
本文地址:https://www.linuxprobe.com/tensorflow-linux-five.html
Linux命令大全:https://www.linuxcool.com/

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
一直以为退圈,原来是去《墨雨云间》给顶流演爸妈了,美人迟暮啊

一直以为退圈,原来是去《墨雨云间》给顶流演爸妈了,美人迟暮啊

兰子记
2024-06-16 20:08:32
倪妮被偶遇造型惨不忍睹,小腹隆起像怀孕,皮肤很黄显老仪态变差

倪妮被偶遇造型惨不忍睹,小腹隆起像怀孕,皮肤很黄显老仪态变差

娱乐白名单
2024-06-16 13:36:32
完美收官!中国女排迎三大好消息,1/4决赛对手日本迎来复仇机会

完美收官!中国女排迎三大好消息,1/4决赛对手日本迎来复仇机会

时刻侃体坛
2024-06-16 23:09:23
归化专家:目前全球活跃在各级联赛U10队伍以上华裔球员近200位

归化专家:目前全球活跃在各级联赛U10队伍以上华裔球员近200位

直播吧
2024-06-16 16:41:13
唉!又有一家大企业成功“结业”了!

唉!又有一家大企业成功“结业”了!

翻开历史和现实
2024-06-10 18:54:33
利空来袭!A股明天惨了?大盘3000点不保?

利空来袭!A股明天惨了?大盘3000点不保?

惜别的海岸
2024-06-16 23:03:27
退休金待遇比较高的人,一般都是什么类型的人呢?

退休金待遇比较高的人,一般都是什么类型的人呢?

社保小达人
2024-06-08 12:57:46
解放战争中,如果国民党获得胜利,今天的中国会是什么样

解放战争中,如果国民党获得胜利,今天的中国会是什么样

史诗长歌
2024-05-13 13:34:32
姜萍父亲发声,家庭困难住着烂房子,姐妹都是学霸,刘奔爆笑回应

姜萍父亲发声,家庭困难住着烂房子,姐妹都是学霸,刘奔爆笑回应

兰子记
2024-06-15 21:53:08
伦巴说:今天六子在叔叔那直接都吓怂了

伦巴说:今天六子在叔叔那直接都吓怂了

综艺拼盘汇
2024-06-16 23:31:05
中超半程积分榜:申花多赛1场仅1分优势领跑 8队陷入保级大乱斗

中超半程积分榜:申花多赛1场仅1分优势领跑 8队陷入保级大乱斗

直播吧
2024-06-16 22:32:10
传媒湃|中央政法委机关报《法治日报》社长、总编辑调整

传媒湃|中央政法委机关报《法治日报》社长、总编辑调整

澎湃新闻
2024-06-16 19:48:28
上海申花vs成都蓉城球员评分:罗慕洛9.2分,费利佩5.8分

上海申花vs成都蓉城球员评分:罗慕洛9.2分,费利佩5.8分

懂球帝
2024-06-16 22:10:14
囤农夫山泉日入3万,囤娃哈哈赔得精光…情怀终究败给了营销吗?

囤农夫山泉日入3万,囤娃哈哈赔得精光…情怀终究败给了营销吗?

我不叫阿哏
2024-06-16 23:07:28
儿媳照顾50岁农村公公,酒后公公行夫妻之事,公公:儿媳经验丰富

儿媳照顾50岁农村公公,酒后公公行夫妻之事,公公:儿媳经验丰富

魅老八足球
2024-05-13 13:49:37
回顾:湖北女子带娃回家,被前夫和现任打,一脚干翻俩:真解气

回顾:湖北女子带娃回家,被前夫和现任打,一脚干翻俩:真解气

佑宛故事汇
2024-06-15 17:12:07
A股这27家企业跌到历史底部了还要减持3%,就是一个局,赶紧远离

A股这27家企业跌到历史底部了还要减持3%,就是一个局,赶紧远离

股海风云大作手
2024-06-16 17:36:07
中国人不骗中国人,TikTok上建盏开窑直播间忽悠老外,评论笑死

中国人不骗中国人,TikTok上建盏开窑直播间忽悠老外,评论笑死

猫小狸同学
2024-06-16 17:42:44
[神吐槽]跟这些人物排一起,穆雷挺牛逼的!这两人的眼神挺像,带着一股清澈的愚蠢

[神吐槽]跟这些人物排一起,穆雷挺牛逼的!这两人的眼神挺像,带着一股清澈的愚蠢

篮球神吐槽
2024-06-17 00:10:21
无语!中国女排四连胜之夜,香港极端球迷拉波兰国旗搞事,败人品

无语!中国女排四连胜之夜,香港极端球迷拉波兰国旗搞事,败人品

二哥聊球
2024-06-16 22:37:03
2024-06-17 00:30:44
孙有匪
孙有匪
科技
1595文章数 2009关注度
往期回顾 全部

科技要闻

iPhone 16会杀死大模型APP吗?

头条要闻

欧洲猪肉业界:中国若限制进口将是梦魇

头条要闻

欧洲猪肉业界:中国若限制进口将是梦魇

体育要闻

没人永远年轻 但青春如此无敌还是离谱了些

娱乐要闻

上影节红毯:倪妮好松弛,娜扎吸睛

财经要闻

打断妻子多根肋骨 上市公司创始人被公诉

汽车要闻

售17.68万-21.68万元 极狐阿尔法S5正式上市

态度原创

旅游
房产
亲子
教育
公开课

旅游要闻

@毕业生,江苏这些景区可享免票或优惠

房产要闻

万华对面!海口今年首宗超百亩宅地,重磅挂出!

亲子要闻

玩这个游戏的都是勇士

教育要闻

有一类中考必考,分值不低,形式多样的物理题!你能满分吗?

公开课

近视只是视力差?小心并发症

无障碍浏览 进入关怀版