不见春山
骑马倚斜桥,满楼红袖招。
Home
Categories
Archives
Tags
About
Home
Tensorflow 调用 Matlab 生成的 .mat 文件
Tensorflow 调用 Matlab 生成的 .mat 文件
取消
Tensorflow 调用 Matlab 生成的 .mat 文件
由
ctaoist
发布于 2021-10-22
·
最后更新:2022-05-12
1
matlab 处理好的数据想送进 Tensorflow 的神经网络中,在数据量极大的时候,全部读进内存也不是太理想,综合 `tfrecord`,自己构建 `tfds` 数据集等方式,还是生成 `tf.data.Dataset` 会更便捷一些。 在数据量极大的时候,期望的是 Tensorflow 在使用数据的时候才读取相应的数据,则要求Matlab保存数据的时候不能将所有数据保存到一个大的 `.mat` 文件,而是应该分开保存: ``` path_to_data └───────0.mat 1.mat 2.mat ... ``` 如果保存的`.mat`文件的版本低于`v7.3`,可以用`scipy.io.laodmat`来读取,反之可以使用`mat73`这个库来读取文件: ![Matlab_mat_7.3版](http://blog.qiniu.ctaoist.cn/Matlab_mat_7.3版本.png) 代码如下: ```py mat_paths = ["path_to_data/0.mat", "path_to_data/1.mat", ...] label = [] # 定义读取数据的函数 # https://www.tensorflow.org/tutorials/load_data/images?hl=zh-cn#%E6%9E%84%E5%BB%BA%E4%B8%80%E4%B8%AA_tfdatadataset def load_mat(path, label): return np.expand_dims(mat73.loadmat(path)['data_xx'].astype(np.float32), axis = -1), label #return mat73.loadmat(path)['data_xx'].astype(np.float32) # path 传递进来的是张量,需要用 tf.numpy_function/tf.py_function 包一下,方便取出张量中的值,在训练时不支持用 .numpy()方法 def tf_env_load_mat(path: tf.Tensor, label): fn = tf.numpy_function(load_mat, [path, label], [tf.float32, tf.int32]) # 经由 tf.numpy_function 处理后的函数不能自动识别shape,需要手动指定 fn[0].set_shape((400,400,1)) fn[1].set_shape(()) return fn ``` 生成数据集: ```py # 首先生成 path 的 Dataset path_label_ds = tf.data.Dataset.from_tensor_slices((mat_paths, label)) # 生成数据的 Dataset path_label_ds = path_label_ds.shuffle(1000) # shuffle 放在 map 之前,否则shuffle会非常慢 mat_label_ds = path_label_ds.map(tf_env_load_mat, num_parallel_calls=tf.data.experimental.AUTOTUNE) ``` 后续的一些处理: ```py mat_label_ds = mat_label_ds.batch(batch_size) # 当模型在训练的时候,`prefetch` 使数据集在后台取得 batch。 ds_train = ds_train.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) ``` >注: 调用`.batch()`后,数据的`shape`才能对得上,形如: ` <PrefetchDataset shapes: ((None, 400, 400, 1), (None,)), types: (tf.float32, tf.int32)> `
机器学习
该博客文章由作者通过
CC BY 4.0
进行授权。
分享
最近更新
群晖升级 ARPL 笔记
本地部署大语言模型
LVM 管理
HK1 RBOX X4 电视盒子折腾笔记
使用usbip网络转发usb设备到远程主机
热门标签
机器学习
Linux
Router
ROS
Tensorflow
VPN
虚拟组网
ARM
Latex
zerotier
文章目录
Tinc 搭建教程
残差网络(ResNet)