1. 检测gpu 

pip install tensorflow-gpu==2.5
import tensorflow as tf

# 打印TensorFlow是否使用GPU加速
print(tf.test.gpu_device_name())

2.版本问题

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

3.Keras问题

一般不升级Keras会有问题产生,需要将Keras升级到Tensorflow同样的版本。另外有些函数需要进行改变。
如:修改前(此图是安装完2.x高版本的tensorflow-gpu后,改回原来的代码,SGD报错原因——keras在tensorflow里面,低版本的keras与高版本的tensorflow不兼容,产生报错)

安装tensorflow-gpu,修改后:

 4.程序中调用GPU

 注意:原本代码不含蓝色一块,由于调用gpu的原因添加,然后可成功调用服务器的gpu(附代码)

import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)


更多推荐