Keras多线程机制与flask多线程冲突的解决方案

在使用flask部署Keras,tensorflow等框架时候,经常出现

或者

使用keras.backend.clear_session()可能会导致前后两处预测结果不一样,因为图发生了变化。以下是解决方案。

?

1

2

3

4

5

6

7

8

9

10

graph = tf.get_default_graph()

sess = tf.Session(graph=graph)

def modelpredict(content):

#keras.backend.clear_session()

global graph

global sess

with sess.as_default():

with graph.as_default():

keras.model.predict()

补充:Flask与keras结合的几个常见错误

在Flask中使用tensorflow的model,一在界面中调用 model.predict() 就报下面这个错误,不过在单独的 .py 文件中使用却不报错。

添加如下代码可以解决:

?

1

2

3

4

5

6

7

8

9

10

import tensorflow as tf

graph = tf.get_default_graph()

model = models.load_model(…………)

# 使用处添加:

global graph

global model

with graph.as_default():

model.predict()

# 执行预测函数

但是我当时测试时又报了另一个bug,但是这个bug也不好解决,试了很多方法也没解决,当然最终还是可以解决的,具体解决方式参考第三点。

后来经过N遍测试后找到了以下两种解决方式,仅供参考:

方法一:

在调用前加载model和graph,但是这样会导致程序每次调用都需要重新加载model,然后运行速度就会很慢,不过这种修改方式是最简单的。

?

1

2

3

4

graph = tf.get_default_graph()

model = models.load_model('./static/my_model2.h5')

with graph.as_default():

result = model.predict(tokens_pad)

方法二:

在创建model后,先使用一遍 model.predict(),参数的大小和真实大小一致,这个是真正解决之道,同时不影响使用速率。

?

1

2

3

4

5

6

7

8

9

# 使用前:

model = models.load_model('./static/my_model2.h5')

# a 矩阵大小和 tokens_pad 一致

a = np.ones((1, 220))

model.predict(a)

# 使用时:

global model

result = model.predict(tokens_pad)

但是在使用后又遇到了 The Session graph is empty…… 的错误即第二点,不过估摸着这个是个例,应该是程序问题。

?

1

2

3

4

graph = tf.get_default_graph()

with graph.as_default():

# 相关代码

# 本次测试中是需要把调用包含model.predict()方法的方法的代码放到这里

这个错误呢,也是TensorFlow和Flask结合使用时的常见错误,解决方式如下:

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

from tensorflow.python.keras.backend import set_session

# 程序开始时声明

sess = tf.Session()

graph = tf.get_default_graph()

# 在model加载前添加set_session

set_session(sess)

model = models.load_model(…………)

# 每次使用有关TensorFlow的请求时

# in each request (i.e. in each thread):

global sess

global graph

with graph.as_default():

set_session(sess)

model.predict(...)

————————————————

设置一下XLA_FLAGS指向你的cuda安装目录即可

?

1

os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/usr/local/cuda-10.0"

以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/weixin_40939578/article/details/100154100

本文链接:https://my.lmcjl.com/post/16917.html

展开阅读全文

4 评论

留下您的评论.