|
63 | 63 | from ._version import kernel_protocol_version
|
64 | 64 |
|
65 | 65 |
|
66 |
| -def _accepts_cell_id(meth): |
| 66 | +def _accepts_parameters(meth, param_names): |
67 | 67 | parameters = inspect.signature(meth).parameters
|
68 |
| - cid_param = parameters.get("cell_id") |
69 |
| - return (cid_param and cid_param.kind == cid_param.KEYWORD_ONLY) or any( |
70 |
| - p.kind == p.VAR_KEYWORD for p in parameters.values() |
71 |
| - ) |
| 68 | + accepts = {param: False for param in param_names} |
| 69 | + |
| 70 | + for param in param_names: |
| 71 | + param_spec = parameters.get(param) |
| 72 | + accepts[param] = ( |
| 73 | + param_spec |
| 74 | + and param_spec.kind in [param_spec.KEYWORD_ONLY, param_spec.POSITIONAL_OR_KEYWORD] |
| 75 | + ) or any(p.kind == p.VAR_KEYWORD for p in parameters.values()) |
| 76 | + |
| 77 | + return accepts |
72 | 78 |
|
73 | 79 |
|
74 | 80 | class Kernel(SingletonConfigurable):
|
@@ -735,25 +741,28 @@ async def execute_request(self, stream, ident, parent):
|
735 | 741 | self.execution_count += 1
|
736 | 742 | self._publish_execute_input(code, parent, self.execution_count)
|
737 | 743 |
|
738 |
| - cell_id = (parent.get("metadata") or {}).get("cellId") |
| 744 | + cell_meta = parent.get("metadata", {}) |
| 745 | + cell_id = metadata.get("cellId") |
739 | 746 |
|
740 |
| - if _accepts_cell_id(self.do_execute): |
741 |
| - reply_content = self.do_execute( |
742 |
| - code, |
743 |
| - silent, |
744 |
| - store_history, |
745 |
| - user_expressions, |
746 |
| - allow_stdin, |
747 |
| - cell_id=cell_id, |
748 |
| - ) |
749 |
| - else: |
750 |
| - reply_content = self.do_execute( |
751 |
| - code, |
752 |
| - silent, |
753 |
| - store_history, |
754 |
| - user_expressions, |
755 |
| - allow_stdin, |
756 |
| - ) |
| 747 | + # Check which parameters do_execute can accept |
| 748 | + accepts_params = _accepts_parameters(self.do_execute, ["metadata", "cell_id"]) |
| 749 | + |
| 750 | + # Arguments based on the do_execute signature |
| 751 | + do_execute_args = { |
| 752 | + "code": code, |
| 753 | + "silent": silent, |
| 754 | + "store_history": store_history, |
| 755 | + "user_expressions": user_expressions, |
| 756 | + "allow_stdin": allow_stdin, |
| 757 | + } |
| 758 | + |
| 759 | + if accepts_params["metadata"]: |
| 760 | + do_execute_args["metadata"] = cell_meta |
| 761 | + if accepts_params["cell_id"]: |
| 762 | + do_execute_args["cell_id"] = cell_id |
| 763 | + |
| 764 | + # Call do_execute with the appropriate arguments |
| 765 | + reply_content = self.do_execute(**do_execute_args) |
757 | 766 |
|
758 | 767 | if inspect.isawaitable(reply_content):
|
759 | 768 | reply_content = await reply_content
|
@@ -793,6 +802,7 @@ def do_execute(
|
793 | 802 | user_expressions=None,
|
794 | 803 | allow_stdin=False,
|
795 | 804 | *,
|
| 805 | + cell_meta=None, |
796 | 806 | cell_id=None,
|
797 | 807 | ):
|
798 | 808 | """Execute user code. Must be overridden by subclasses."""
|
|
0 commit comments