3232import pytensor .tensor as pt
3333import scipy .sparse as sps
3434
35- from pytensor .compile import DeepCopyOp , Function , get_mode
35+ from pytensor .compile import DeepCopyOp , Function , ProfileStats , get_mode
3636from pytensor .compile .sharedvalue import SharedVariable
3737from pytensor .graph .basic import Constant , Variable , ancestors , graph_inputs
3838from pytensor .tensor .random .op import RandomVariable
@@ -1657,7 +1657,15 @@ def compile_fn(
16571657 return PointFunc (fn )
16581658 return fn
16591659
1660- def profile (self , outs , * , n = 1000 , point = None , profile = True , ** kwargs ):
1660+ def profile (
1661+ self ,
1662+ outs ,
1663+ * ,
1664+ n = 1000 ,
1665+ point = None ,
1666+ profile = True ,
1667+ ** compile_fn_kwargs ,
1668+ ) -> ProfileStats :
16611669 """Compile and profile a PyTensor function which returns ``outs`` and takes values of model vars as a dict as an argument.
16621670
16631671 Parameters
@@ -1668,16 +1676,22 @@ def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
16681676 point : Point
16691677 Point to pass to the function
16701678 profile : True or ProfileStats
1671- args, kwargs
1672- Compilation args
1679+ compile_fn_kwargs
1680+ Compilation kwargs for :func:`pymc.model.core.Model.compile_fn`
16731681
16741682 Returns
16751683 -------
1676- ProfileStats
1684+ pytensor.compile.profiling. ProfileStats
16771685 Use .summary() to print stats.
16781686 """
1679- kwargs .setdefault ("on_unused_input" , "ignore" )
1680- f = self .compile_fn (outs , inputs = self .value_vars , point_fn = False , profile = profile , ** kwargs )
1687+ compile_fn_kwargs .setdefault ("on_unused_input" , "ignore" )
1688+ f = self .compile_fn (
1689+ outs ,
1690+ inputs = self .value_vars ,
1691+ point_fn = False ,
1692+ profile = profile ,
1693+ ** compile_fn_kwargs ,
1694+ )
16811695 if point is None :
16821696 point = self .initial_point ()
16831697
0 commit comments