TensorFlow的estimator类函数:tf.estimator.Estimator
tf.estimator.Estimator函数
Estimator类
定义在:tensorflow/python/estimator/estimator.py
estimator类对TensorFlow模型进行训练和计算.
Estimator对象包装由model_fn指定的模型,其中,给定输入和其他一些参数,返回需要进行训练、计算,或预测的操作.
所有输出(检查点,事件文件等)都被写入model_dir或其子目录.如果model_dir未设置,则使用临时目录.
可以通过RunConfig对象(包含了有关执行环境的信息)传递config参数.它被传递给model_fn,如果model_fn有一个名为“config”的参数(和输入函数以相同的方式).如果该config参数未被传递,则由Estimator进行实例化.不传递配置意味着使用对本地执行有用的默认值.Estimator使配置对模型可用(例如,允许根据可用的工作人员数量进行专业化),并且还使用其一些字段来控制内部,特别是关于检查点.
该params参数包含hyperparameter,如果model_fn有一个名为“PARAMS”的参数,并且以相同的方式传递给输入函数,则将它传递给 model_fn.Estimator只是沿着参数传递,并不检查它.因此,params的结构完全取决于开发人员.
不能在子类中重写任何Estimator方法(其构造函数强制执行此操作).子类应使用model_fn来配置基类,并且可以添加实现专门功能的方法.
Eager兼容性
estimator与eager执行不兼容.
属性
- config
- model_dir
- model_fn
返回绑定到self.params的model_fn.
返回:返回具有以下签名的model_fn: def model_fn(features, labels, mode, config) - params
返回绑定到self.params的model_fn.
返回:返回具有以下签名的model_fn: def model_fn(features, labels, mode, config)
方法
__init__
__init__(
model_fn,
model_dir=None,
config=None,
params=None,
warm_start_from=None
)
构造一个Estimator实例.
请参阅Estimator了解更多信息.启动一个Estimator的方法如下所示:
estimator = tf.estimator.DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],
warm_start_from="/path/to/checkpoint/dir")
有关warm-start启动配置的更多详细信息,请参阅WarmStartSettings.
参数:
- model_fn:模型函数,具有以下签名:
- ARGS.
- features:这是从input_fn传递给train、evaluate和predict返回的第一个项目.这应该是一个相同的单一的Tensor或dict.
- labels:这是从input_fn传递给train、evaluate和predict返回的第二个项目.这应该是相同的单个Tensor或dict(对于multi-head模型).如果模式是ModeKeys.PREDICT,则将传递labels=None.如果model_fn签名不接受mode,model_fn必须仍然能够处理labels=None.
- mode:可选的.指定train、evaluate和predict.参考ModeKeys.
- params:hyperparameters的可选字典.将在params参数中接收传递给Estimator的内容.这允许从hyperparameters调整来配置Estimator.
- config:可选配置对象.将收到传递给Estimator的config参数或默认值config.允许根据配置(如num_ps_replicas或model_dir)更新您的model_fn中的内容. 返回:EstimatorSpec
- model_dir:保存模型参数、图形等的目录.这也可用于将目录中的检查点加载到Estimator中,以继续训练以前保存的模型.如果为PathLike对象,则路径将被解析.如果为None,则将使用config中的model_dir(如果设置的话).如果两者都设置,则它们必须相同.如果两者都是None,则会使用临时目录.
- config:配置对象.
- params:dict将传递到model_fn中的hyperparameters.key是参数的名称,value是基本的Python类型.
- warm_start_from:可选的字符串文件路径,用于从warm-start的检查点;或tf.estimator.WarmStartSettings对象,用于完全配置warm-start.如果提供字符串文件路径而不是WarmStartSettings,则所有变量都是warm-start的,并且假定词汇表和张量名称未更改.
可能引发的异常:
- RuntimeError:如果eager执行已启用.
- ValueError:参数model_fn不匹配params.
- ValueError:如果这是通过子类调用的,并且该类重写了Estimator的一个成员.
evaluate
evaluate(
input_fn,
steps=None,
hooks=None,
checkpoint_path=None,
name=None
)
计算给定计算数据input_fn的模型.
对于每个步骤来说,调用input_fn返回一批数据.计算直到: -steps批处理被处理,或-input_fn引发输入结束异常(OutOfRangeError或StopIteration).
参数:
- input_fn:构造用于计算的输入数据的函数.有关更多信息,请参阅TensorFlow入门.该函数应该构造并返回下列选项之一:
- tf.data.Dataset对象:Dataset对象的输出必须是一个具有相同约束的元组(特征(features),标签(labels)),其约束条件与下面相同.
- tuple (features, labels):其中features是Tensor或者名为Tensor的字符串特征的字典,而labels是Tensor或者名为Tensor的字符串标签的字典.这两个特征和标签都由model_fn消耗.他们应该满足model_fn对输入的期望.
- steps:计算模型所需的步骤数.如果为None,则计算直到input_fn引发输入异常时结束.
- hooks:SessionRunHook子类实例列表.用于计算调用中的回调.
- checkpoint_path:计算特定检查点的路径.如果为None,则使用model_dir中的最新检查点.
- name:需要使用的计算的名称,如果用户需要在不同的数据集上运行多个计算(如培训数据和测试数据).不同计算的度量标准保存在单独的文件夹中,并单独出现在tensorboard中.
返回值:
返回一个包含按name为键的model_fn中指定的计算指标的词典,以及包含执行此技术的全局步骤的值的条目global_step.
可能引发的异常:
- ValueError:如果steps <= 0.
- ValueError:如果没有模型被训练,名为model_dir,或者给定checkpoint_path是空的.
export_savedmodel
export_savedmodel(
export_dir_base,
serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
strip_default_attrs=False
)
将推理图作为SavedModel导出到给定的目录中.
该方法通过首先调用serving_input_receiver_fn来获取特征Tensors来构建一个新图,然后调用这个Estimator的model_fn来基于这些特征生成模型图.它在新的会话中将给定的检查点恢复到该图中.最后它会在给定的export_dir_base下面创建一个时间戳导出目录,并在其中写入一个SavedModel,其中包含从此会话保存的单个MetaGraphDef.
导出的MetaGraphDef将为从model_fn返回的export_outputs字典的每个元素提供一个SignatureDef,该字典使用相同的key命名.其中一个key始终为signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,指示在服务请求未指定签名时将提供哪个签名.对于每个签名,输出由相应的ExportOutputs提供,并且输入始终是由serving_input_receiver_fn提供的输入接收器.
额外的资产可以通过assets_extra参数写入SavedModel.这应该是一个字典,其中每个key给出与assets.extra目录相关的目标路径(包括文件名).相应的值给出了要复制的源文件的完整路径.例如,在不重命名的情况下复制单个文件的简单情况被指定为{'my_asset_file.txt': '/path/to/my_asset_file.txt'}.
参数:
- export_dir_base:包含一个目录的字符串,在该目录中创建包含导出的SavedModels的时间戳子目录.
- serving_input_receiver_fn:一个不带参数并返回一个ServingInputReceiver的函数.
- assets_extra:指定如何在导出的SavedModel中填充assets.extra目录的字典,如果不需要额外的资产,则为 None.
- as_text:是否以文本格式编写SavedModel原型.
- checkpoint_path:要导出的检查点路径.如果None(默认),则选择在模型目录中找到的最近检查点.
- strip_default_attrs:布尔值.如果True,则将从NodeDefs中删除默认值属性.
返回值:
导出目录的字符串路径.
可能引发的异常:
- ValueError:如果未提供serving_input_receiver_fn,则不提供export_outputs,或者找不到检查点.
get_variable_names
get_variable_names()
返回此模型中所有变量名称的列表.
返回值:
返回名字列表.
可能引发的异常:
- ValueError:如果Estimator尚未产生检查点.
get_variable_value
get_variable_value(name)
返回由名称给出的变量的值.
参数:
- name:字符串或字符串列表,张量的名称.
返回值:
Numpy数组 - 张量的值.
可能引发的异常:
- ValueError:如果Estimator尚未产生检查点.
latest_checkpoint
latest_checkpoint()
查找model_dir中最新保存的检查点文件的文件名.
返回值:
返回最新检查点的完整路径或None(未找到检查点).
predict
predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None
)
对给定的features产生预测.
参数:
- input_fn:构造特征的函数.预测继续,直到input_fn引发输入端异常(OutOfRangeError或StopIteration).有关更多信息,请参阅TensorFlow入门.该函数应该构造并返回下列之一:
- tf.data.Dataset对象:Dataset对象的输出必须具有与下面相同的约束.
- features:一个Tensor或者名为Tensor的字符串特征的字典.feature被model_fn消耗.他们应该满足model_fn对输入的期望.
- 一个元组,在这种情况下,第一个项被提取为feature.
- predict_keys:str列表,要预测的键名称.如果EstimatorSpec.predictions是字典,则使用该方法.如果使用predict_keys,则剩余的预测将从字典中过滤.如果None,则返回全部.
- hooks:SessionRunHook子类实例列表.用于预测调用中的回调.
- checkpoint_path:要预测的特定检查点的路径.如果为None,则使用model_dir中的最新的检查点.
返回值:
predictions张量的计算值.
可能引发的异常:
- ValueError:在model_dir中找不到训练有素的模型.
- ValueError:如果批次的预测长度不相同.
- ValueError:如果predict_keys和predictions之间有冲突.例如,如果predict_keys不是None,但EstimatorSpec.predictions不是一个dict.
train
train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)
训练给定训练数据input_fn的模型.
参数:
- input_fn:提供作为minibatches培训的输入数据的函数.有关更多信息,请参阅TensorFlow入门.该函数应该构造并返回下列之一:
- tf.data.Dataset对象:Dataset对象的输出必须是一个具有相同约束的元组(特征,标签)((features, labels)),其约束条件与下面相同.
- tuple (features, labels):其中features是一个Tensor或者名为Tensor的字符串特征的字典,labels是一个Tensor或者名为Tensor的字符串标签的字典.这两个特征和标签都由model_fn消耗.他们应该满足model_fn对输入的期望.
- hooks:SessionRunHook子类实例列表.用于训练循环内的回调.
- steps:训练模型的步骤数.如果为None,则永远训练或训练直到input_fn产生OutOfRange错误或StopIteration异常.“steps”逐步运作.如果您调用两次train(steps=10),则训练总共发生20个步骤.如果OutOfRange或StopIteration发生在中间,训练在20步之前停止.如果你不想有增量行为,请改为设置.如果设置max_steps,max_steps必须None.
- max_steps:训练模型的总步骤数.如果为None,则永远训练或训练直到input_fn产生OutOfRange错误或StopIteration异常.如果设置,steps必须None.如果OutOfRange或StopIteration发生在中间,训练在max_steps步骤之前停止.两次调用train(steps=100)意味着200次训练迭代.另一方面,两次调用train(max_steps=100)意味着第二次调用将不会做任何迭代,因为第一次调用完成了所有100个步骤.
- saving_listeners:CheckpointSaverListener对象列表.用于在检查点节省之前或之后立即执行的回调.
可能引发的异常:
- ValueError:如果steps和max_steps都不是None.
- ValueError:如果steps或max_steps其中之一小于等于0.
更多建议: