Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a somewhat usable "metatize" for TF helper functions #66

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

brandonwillard
Copy link
Contributor

@brandonwillard brandonwillard commented Sep 4, 2019

This addresses #56 in another way; namely, it uses an intermediate/temporary TF graph that mirrors a given meta graph with the meta tensor terms replaced by Placeholders. The temporary TF graph is given to the TF function we want to metatize, the result is turned into a meta graph (i.e. "metatized") and the Placeholder stand-ins are replaced by the original meta tensors.

The reason this seems like a worthwhile approach: Placeholders have some flexibility for unknown shape and dtype information, so, when meta tensors use logic variables for those values, we have a workable mapping between meta tensors and valid TF tensors.

Naturally, this approach has its limits, and the reason is that some TF helper functions simply do not accept unknown shape and dtype input (i.e. "variant" dtype). However, the better we are about inferring/specifying dtype and shape information (when it's possible to do so) for meta objects, the better this approach will work.

Example

We start by making the TF graph we ultimately want in meta form:

import tensorflow as tf

from tensorflow.python.eager.context import graph_mode

from symbolic_pymc.tensorflow.meta import mt
from symbolic_pymc.tensorflow.printing import tf_dprint


with graph_mode():
    # Create an identity matrix with the number of rows derived from another
    # matrix's shape in TF.
    A_tf = tf.compat.v1.placeholder(tf.float64, name='A',
                                    shape=tf.TensorShape([None, None]))

    A_shape_tf = tf.shape(A_tf)
    A_rows_tf = A_shape_tf[0]
    # The TF function for this is `tf.eye`.
    I_A_tf = tf.eye(A_rows_tf)

In an ideal world, there would be an OpDef behind the function tf.eye, but, since there isn't, we have to build an equivalent meta graph by hand. The meta graph should mirror the TF graph for I_A_tf, so we can always inspect I_A_tf to see what tf.eye constructed from its inputs (i.e. A_rows_tf):

>>> tf_dprint(I_A_tf)
Tensor(MatrixDiag):0,	shape=[None, None]	"eye/MatrixDiag:0"
|  Op(MatrixDiag)	"eye/MatrixDiag"
|  |  Tensor(Fill):0,	shape=[None]	"eye/ones:0"
|  |  |  Op(Fill)	"eye/ones"
|  |  |  |  Tensor(ConcatV2):0,	shape=[1]	"eye/concat:0"
|  |  |  |  |  Op(ConcatV2)	"eye/concat"
|  |  |  |  |  |  Tensor(Const):0,	shape=[0]	"eye/shape:0"
|  |  |  |  |  |  Tensor(Pack):0,	shape=[1]	"eye/concat/values_1:0"
|  |  |  |  |  |  |  Op(Pack)	"eye/concat/values_1"
|  |  |  |  |  |  |  |  Tensor(Minimum):0,	shape=[]	"eye/Minimum:0"
|  |  |  |  |  |  |  |  |  Op(Minimum)	"eye/Minimum"
|  |  |  |  |  |  |  |  |  |  Tensor(StridedSlice):0,	shape=[]	"strided_slice:0"
|  |  |  |  |  |  |  |  |  |  |  Op(StridedSlice)	"strided_slice"
|  |  |  |  |  |  |  |  |  |  |  |  Tensor(Shape):0,	shape=[2]	"Shape:0"
|  |  |  |  |  |  |  |  |  |  |  |  |  Op(Shape)	"Shape"
|  |  |  |  |  |  |  |  |  |  |  |  |  |  Tensor(Placeholder):0,	shape=[None, None]	"A:0"
|  |  |  |  |  |  |  |  |  |  |  |  Tensor(Const):0,	shape=[1]	"strided_slice/stack:0"
|  |  |  |  |  |  |  |  |  |  |  |  Tensor(Const):0,	shape=[1]	"strided_slice/stack_1:0"
|  |  |  |  |  |  |  |  |  |  |  |  Tensor(Const):0,	shape=[1]	"strided_slice/stack_2:0"
|  |  |  |  |  |  |  |  |  |  Tensor(StridedSlice):0,	shape=[]	"strided_slice:0"
|  |  |  |  |  |  |  |  |  |  |  ...
|  |  |  |  |  |  Tensor(Const):0,	shape=[]	"eye/concat/axis:0"
|  |  |  |  Tensor(Const):0,	shape=[]	"eye/ones/Const:0"

Basically, reconstructing graphs like these by hand involves reproducing the steps in the function tf.eye.

With the TF function "metatizing" in this PR, the process is much simpler:

with graph_mode():

    A_mt = mt(A_tf)
    A_shape_mt = mt.shape(A_mt)
    # There's still work to do to make things easier...
    A_rows_mt = mt.StridedSlice(A_shape_mt, [0], [1], [1], shrink_axis_mask=1)

    I_A_mt = mt.metatize_tf_function(tf.eye, A_rows_mt)

Now, if we convert the meta graph I_A_mt into a TF graph and print the results, we see essentially the same results as the original tf.eye, which verifies the correspondence between the two graphs:

>>> tf_dprint(I_A_mt.reify())
Tensor(MatrixDiag):0,	shape=[None, None]	"eye_5/MatrixDiag_1:0"
|  Op(MatrixDiag)	"eye_5/MatrixDiag_1"
|  |  Tensor(Fill):0,	shape=[None]	"eye_5/ones_1:0"
|  |  |  Op(Fill)	"eye_5/ones_1"
|  |  |  |  Tensor(ConcatV2):0,	shape=[1]	"eye_5/concat_1:0"
|  |  |  |  |  Op(ConcatV2)	"eye_5/concat_1"
|  |  |  |  |  |  Tensor(Const):0,	shape=[0]	"eye_5/shape:0"
|  |  |  |  |  |  Tensor(Pack):0,	shape=[1]	"eye_5/concat/values_1_1:0"
|  |  |  |  |  |  |  Op(Pack)	"eye_5/concat/values_1_1"
|  |  |  |  |  |  |  |  Tensor(Minimum):0,	shape=[]	"eye_5/Minimum_1:0"
|  |  |  |  |  |  |  |  |  Op(Minimum)	"eye_5/Minimum_1"
|  |  |  |  |  |  |  |  |  |  Tensor(StridedSlice):0,	shape=[]	"StridedSlice_4:0"
|  |  |  |  |  |  |  |  |  |  |  Op(StridedSlice)	"StridedSlice_4"
|  |  |  |  |  |  |  |  |  |  |  |  Tensor(Shape):0,	shape=[2]	"Shape_5:0"
|  |  |  |  |  |  |  |  |  |  |  |  |  Op(Shape)	"Shape_5"
|  |  |  |  |  |  |  |  |  |  |  |  |  |  Tensor(Placeholder):0,	shape=[None, None]	"A_1:0"
|  |  |  |  |  |  |  |  |  |  |  |  Tensor(Const):0,	shape=[1]	"StridedSlice_4/begin:0"
|  |  |  |  |  |  |  |  |  |  |  |  Tensor(Const):0,	shape=[1]	"StridedSlice_4/end:0"
|  |  |  |  |  |  |  |  |  |  |  |  Tensor(Const):0,	shape=[1]	"StridedSlice_4/strides:0"
|  |  |  |  |  |  |  |  |  |  Tensor(StridedSlice):0,	shape=[]	"StridedSlice_4:0"
|  |  |  |  |  |  |  |  |  |  |  ...
|  |  |  |  |  |  Tensor(Const):0,	shape=[]	"eye_5/concat/axis:0"
|  |  |  |  Tensor(Const):0,	shape=[]	"eye_5/ones/Const:0"

@brandonwillard brandonwillard added the enhancement New feature or request label Sep 5, 2019
@brandonwillard brandonwillard self-assigned this Sep 5, 2019
@brandonwillard brandonwillard added TensorFlow This issue involves the TensorFlow backend meta graph This issue involves the meta graph objects labels Mar 13, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request meta graph This issue involves the meta graph objects TensorFlow This issue involves the TensorFlow backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant