不见春山
骑马倚斜桥,满楼红袖招。
Home
Categories
Archives
Tags
About
Home
Tensorflow 调用 Matlab 生成的 .mat 文件
Tensorflow 调用 Matlab 生成的 .mat 文件
取消
Tensorflow 调用 Matlab 生成的 .mat 文件
由
ctaoist
发布于 2021-10-22
·
最后更新:2022-05-12
1
matlab 处理好的数据想送进 Tensorflow 的神经网络中,在数据量极大的时候,全部读进内存也不是太理想,综合 `tfrecord`,自己构建 `tfds` 数据集等方式,还是生成 `tf.data.Dataset` 会更便捷一些。 在数据量极大的时候,期望的是 Tensorflow 在使用数据的时候才读取相应的数据,则要求Matlab保存数据的时候不能将所有数据保存到一个大的 `.mat` 文件,而是应该分开保存: ``` path_to_data └───────0.mat 1.mat 2.mat ... ``` 如果保存的`.mat`文件的版本低于`v7.3`,可以用`scipy.io.laodmat`来读取,反之可以使用`mat73`这个库来读取文件: ![Matlab_mat_7.3版](http://blog.qiniu.ctaoist.cn/Matlab_mat_7.3版本.png) 代码如下: ```py mat_paths = ["path_to_data/0.mat", "path_to_data/1.mat", ...] label = [] # 定义读取数据的函数 # https://www.tensorflow.org/tutorials/load_data/images?hl=zh-cn#%E6%9E%84%E5%BB%BA%E4%B8%80%E4%B8%AA_tfdatadataset def load_mat(path, label): return np.expand_dims(mat73.loadmat(path)['data_xx'].astype(np.float32), axis = -1), label #return mat73.loadmat(path)['data_xx'].astype(np.float32) # path 传递进来的是张量,需要用 tf.numpy_function/tf.py_function 包一下,方便取出张量中的值,在训练时不支持用 .numpy()方法 def tf_env_load_mat(path: tf.Tensor, label): fn = tf.numpy_function(load_mat, [path, label], [tf.float32, tf.int32]) # 经由 tf.numpy_function 处理后的函数不能自动识别shape,需要手动指定 fn[0].set_shape((400,400,1)) fn[1].set_shape(()) return fn ``` 生成数据集: ```py # 首先生成 path 的 Dataset path_label_ds = tf.data.Dataset.from_tensor_slices((mat_paths, label)) # 生成数据的 Dataset path_label_ds = path_label_ds.shuffle(1000) # shuffle 放在 map 之前,否则shuffle会非常慢 mat_label_ds = path_label_ds.map(tf_env_load_mat, num_parallel_calls=tf.data.experimental.AUTOTUNE) ``` 后续的一些处理: ```py mat_label_ds = mat_label_ds.batch(batch_size) # 当模型在训练的时候,`prefetch` 使数据集在后台取得 batch。 ds_train = ds_train.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) ``` >注: 调用`.batch()`后,数据的`shape`才能对得上,形如: ` <PrefetchDataset shapes: ((None, 400, 400, 1), (None,)), types: (tf.float32, tf.int32)> `
机器学习
该博客文章由作者通过
CC BY 4.0
进行授权。
分享
最近更新
群晖升级 ARPL 笔记
本地部署大语言模型
WireGuard 搭建组网教程
LVM 管理
HK1 RBOX X4 电视盒子折腾笔记
热门标签
机器学习
Tensorflow
Linux
VPN
虚拟组网
Router
ROS
嵌入式
C++
C
文章目录
Tinc 搭建教程
残差网络(ResNet)