Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support recurrent with no states. #1113

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions blocks/bricks/recurrent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def initial_states(self, batch_size, *args, **kwargs):
The keyword arguments of the application call.

"""
if not hasattr(self, 'apply') or not self.apply.states:
return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how it works? I cannot immediately see it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when some subclass call the default initial_states function in the BaseRecurrent class. This line would check whether it is necessary to return the initial states. If the subclass does not have an apply method or its apply method does not contain states, the initial_states would not return anything.
This line would make it to support recurrent class with no apply function or with no states.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to have a class without apply? It's a mistake if a user forgot to define apply and the best is to crash soon.

In a case if apply.states is empty, initial_states would return an empty list before this change, why is it wrong?

Copy link
Author

@Beronx86 Beronx86 Jun 16, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this line is added, the above code, which contains a recurrent brick with no apply method, would run well.
But, you are right about the apply method. The Brick subclass should follow some design rules. The problem is no code checks whether there is an apply method in a Brick subclass at present.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Beronx86 , checking apply.states in BaseRecurrent.initial_states is not a solution. There are quite a few places in Blocks-dependent code where initial_states method is overloaded. Instead, like in your previous solution, initial_states should not be called if application does not have states. Can you please revert back to the previous version of your fix?

Copy link
Author

@Beronx86 Beronx86 Jun 20, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rizar I think this check could be carried out in Brick.__init__ method. So we can make sure all Brick subclasses contain apply methods. I reverted back the changes in BaseRecurrent.

result = []
for state in self.apply.states:
dim = self.get_dim(state)
Expand Down
13 changes: 10 additions & 3 deletions blocks/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Annotated computation graph management."""
import logging
from collections import OrderedDict
from collections import OrderedDict, deque
from itertools import chain
import warnings

Expand Down Expand Up @@ -103,8 +103,15 @@ def auxiliary_variables(self):

@property
def scan_variables(self):
"""Variables of Scan ops."""
return list(chain(*[g.variables for g in self._scan_graphs]))
"""Variables of Scan ops. Breadth-first search"""
sg_que = deque(self._scan_graphs)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code supposed that no recurrent class is nested. #1115

var_list = []
while sg_que:
g = sg_que.popleft()
var_list.append(g.variables)
if g._scan_graphs:
sg_que.extend(g._scan_graphs)
return list(chain(*var_list))
Copy link
Contributor

@rizar rizar Jun 14, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a desirable change, but I have a few concerns:

  • modifying the list that you iterate over is hard to understand and can very likely cause errors. Please try to simplify the code
  • a test is required


def _get_variables(self):
"""Collect variables, updates and auxiliary variables.
Expand Down
44 changes: 43 additions & 1 deletion tests/bricks/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from blocks.utils import is_shared_variable
from blocks.bricks.base import application
from blocks.bricks import Tanh
from blocks.bricks import Brick, Tanh
from blocks.bricks.recurrent import (
recurrent, BaseRecurrent, GatedRecurrent,
SimpleRecurrent, Bidirectional, LSTM,
Expand Down Expand Up @@ -71,6 +71,48 @@ def test(self):
assert_allclose(h2 * 10, out_2_eval)


class RecurrentWrapperNoStatesClass(BaseRecurrent):
def __init__(self, dim, ** kwargs):
super(RecurrentWrapperNoStatesClass, self).__init__(self, ** kwargs)
self.dim = dim

def get_dim(self, name):
if name in ['inputs', 'outputs', 'outputs_2']:
return self.dim
if name == 'mask':
return 0
return super(RecurrentWrapperNoStatesClass, self).get_dim(name)

@recurrent(sequences=['inputs', 'mask'], states=[],
outputs=['outputs', 'outputs_2'], contexts=[])
def apply(self, inputs=None, mask=None):
outputs = inputs * 10
outputs_2 = tensor.sqr(inputs)
if mask:
outputs *= mask
outputs_2 *= mask
return outputs, outputs_2


class TestRecurrentWrapperNoStates(unittest.TestCase):
def setUp(self):
self.recurrent_examples = RecurrentWrapperNoStatesClass(dim=1)

def test(self):
X = tensor.tensor3('X')
out, out_2 = self.recurrent_examples.apply(
inputs=X, mask=None)

x_val = numpy.random.uniform(size=(5, 1, 1))
x_val = numpy.asarray(x_val, dtype=theano.config.floatX)

out_eval = out.eval({X: x_val})
out_2_eval = out_2.eval({X: x_val})

assert_allclose(x_val * 10, out_eval)
assert_allclose(numpy.square(x_val), out_2_eval)


class RecurrentBrickWithBugInInitialStates(BaseRecurrent):

@recurrent(sequences=[], contexts=[],
Expand Down
52 changes: 51 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import theano
import warnings
from numpy.testing import assert_allclose
from theano import tensor
from theano import function, tensor
from theano.sandbox.rng_mrg import MRG_RandomStreams

from blocks.bricks import MLP, Identity, Logistic, Tanh
Expand Down Expand Up @@ -74,6 +74,56 @@ def test_computation_graph():
assert all(v in cg6.scan_variables for v in scan.inputs + scan.outputs)


def test_computation_graph_nested_scan():
inner_x = tensor.matrix('inner_x')
outer_x = tensor.matrix('outer_x')
factor = tensor.matrix('factor')

def inner_scan(inner_x_, outer_x_one_step):
inner_o, _ = theano.scan(fn=lambda inp, ctx: inp + ctx,
sequences=inner_x_,
non_sequences=outer_x_one_step)
return inner_o.sum(axis=0)

outer_o, _ = theano.scan(fn=lambda inp, ctx: inner_scan(ctx, inp),
sequences=outer_x,
non_sequences=inner_x)

outs = outer_o * factor

nested_scan = outs.owner.inputs[0].owner.op
cg = ComputationGraph(outer_o)

assert cg.scans == [nested_scan]
assert all(var in cg.scan_variables
for var in nested_scan.inputs + nested_scan.outputs)

func = function(inputs=[inner_x, outer_x, factor], outputs=outs,
allow_input_downcast=True)

in_len = 9
out_len = 7
dim = 3

floatX = theano.config.floatX
x_val = numpy.asarray(numpy.random.uniform(size=(in_len, dim)),
dtype=floatX)
y_val = numpy.asarray(numpy.random.uniform(size=(out_len, dim)),
dtype=floatX)
factor_val = numpy.asarray(numpy.random.uniform(size=(out_len, dim)),
dtype=floatX)

results = func(x_val, y_val, factor_val)

results2 = numpy.zeros(shape=(out_len, dim))
for i, y in enumerate(y_val):
for x in x_val:
results2[i] += (x + y)
results2 = results2 * factor_val

assert_allclose(results, results2)


def test_computation_graph_variable_duplicate():
# Test if ComputationGraph.variables contains duplicates if some outputs
# are part of the computation graph
Expand Down