Skip to content

Commit

Permalink
feat: support custom args
Browse files Browse the repository at this point in the history
  • Loading branch information
aircloud committed Jul 5, 2023
1 parent 639d160 commit a3909c4
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 46 deletions.
38 changes: 25 additions & 13 deletions jupyterlab_tensorboard_pro/api_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import logging

from tornado import web
from notebook.base.handlers import APIHandler
Expand Down Expand Up @@ -41,26 +42,37 @@ def get(self):
'reload_interval': entry.reload_interval,
'enable_multi_log': entry.enable_multi_log,
'logdir': _trim_notebook_dir(entry.logdir, entry.enable_multi_log),
'additional_args': entry.additional_args,
} for entry in
self.settings["tensorboard_manager"].values()
]
self.finish(json.dumps(terms))

@web.authenticated
def post(self):
data = self.get_json_body()
reload_interval = data.get("reload_interval", None)
enable_multi_log = data.get("enable_multi_log", False)
entry = (
self.settings["tensorboard_manager"]
.new_instance(data["logdir"], reload_interval=reload_interval, enable_multi_log=enable_multi_log)
)
self.finish(json.dumps({
'name': entry.name,
'reload_interval': entry.reload_interval,
'enable_multi_log': entry.enable_multi_log,
'logdir': _trim_notebook_dir(entry.logdir, entry.enable_multi_log),
}))
try:
data = self.get_json_body()
reload_interval = data.get("reload_interval", None)
enable_multi_log = data.get("enable_multi_log", False)
additional_args = data.get("additional_args", '')
entry = (
self.settings["tensorboard_manager"]
.new_instance(data["logdir"], reload_interval=reload_interval, enable_multi_log=enable_multi_log, additional_args=additional_args)
)
self.finish(json.dumps({
'name': entry.name,
'reload_interval': entry.reload_interval,
'enable_multi_log': entry.enable_multi_log,
'additional_args': entry.additional_args,
'logdir': _trim_notebook_dir(entry.logdir, entry.enable_multi_log),
}))
except SystemExit:
logging.error("[Tensorboard Error] mostly parse args error")
raise web.HTTPError(
500, "Tensorboard Error: mostly parse args error")
except Exception as e:
logging.error("[Tensorboard Error] catch exception: {e}")
print('[Tensorboard Error]', e)


class TbInstanceHandler(APIHandler):
Expand Down
24 changes: 14 additions & 10 deletions jupyterlab_tensorboard_pro/tensorboard_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ def get_plugins():
# TensorBoard 1.10 or above series
from tensorboard import program

def create_tb_app(logdir, reload_interval, purge_orphaned_data, enable_multi_log):
def create_tb_app(logdir, reload_interval, purge_orphaned_data, enable_multi_log, additional_args):
argv = [
"",
"--logdir", logdir,
"--reload_interval", str(reload_interval),
"--purge_orphaned_data", str(purge_orphaned_data),
]

# example: "--samples_per_plugin", "images=1"
additional_args_arr = additional_args.split()
argv += additional_args_arr

if enable_multi_log:
argv[1] = "--logdir_spec"

Expand All @@ -75,15 +79,15 @@ def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider):
return application.TensorBoardWSGIApp(flags, plugin_loaders, ingester.data_provider,
assets_zip_provider, ingester.deprecated_multiplexer)

return manager.add_instance(logdir, reload_interval, enable_multi_log, standard_tensorboard_wsgi(
return manager.add_instance(logdir, reload_interval, enable_multi_log, additional_args, standard_tensorboard_wsgi(
tensorboard.flags,
tensorboard.plugin_loaders,
tensorboard.assets_zip_provider))
else:
logging.debug("TensorBoard 0.4.x series detected")

def create_tb_app(logdir, reload_interval, purge_orphaned_data, enable_multi_log):
return manager.add_instance(logdir, reload_interval, enable_multi_log, application.standard_tensorboard_wsgi(
def create_tb_app(logdir, reload_interval, purge_orphaned_data, enable_multi_log, additional_args):
return manager.add_instance(logdir, reload_interval, enable_multi_log, additional_args, application.standard_tensorboard_wsgi(
logdir=logdir, reload_interval=reload_interval,
purge_orphaned_data=purge_orphaned_data,
plugins=default.get_plugins()))
Expand Down Expand Up @@ -115,7 +119,7 @@ def create_tb_app(logdir, reload_interval, purge_orphaned_data, enable_multi_log
profile_plugin.ProfilePlugin,
]

def create_tb_app(logdir, reload_interval, purge_orphaned_data, enable_multi_log):
def create_tb_app(logdir, reload_interval, purge_orphaned_data, enable_multi_log, additional_args):
return application.standard_tensorboard_wsgi(
logdir=logdir, reload_interval=reload_interval,
purge_orphaned_data=purge_orphaned_data,
Expand All @@ -125,7 +129,7 @@ def create_tb_app(logdir, reload_interval, purge_orphaned_data, enable_multi_log
from .handlers import notebook_dir # noqa

TensorBoardInstance = namedtuple(
'TensorBoardInstance', ['name', 'logdir', 'reload_interval', 'enable_multi_log', 'tb_app'])
'TensorBoardInstance', ['name', 'logdir', 'reload_interval', 'enable_multi_log', 'additional_args', 'tb_app'])


class TensorboardManger(dict):
Expand Down Expand Up @@ -164,7 +168,7 @@ def format_dir(dir):

return ','.join(map(format_dir, dirs))

def new_instance(self, logdir, reload_interval, enable_multi_log):
def new_instance(self, logdir, reload_interval, enable_multi_log, additional_args):
if not enable_multi_log and not os.path.isabs(logdir) and notebook_dir and not logdir.startswith("s3://"):
logdir = os.path.join(notebook_dir, logdir)

Expand All @@ -175,14 +179,14 @@ def new_instance(self, logdir, reload_interval, enable_multi_log):
logdir = self.format_multi_dir_path(logdir)
create_tb_app(
logdir=logdir, reload_interval=reload_interval,
purge_orphaned_data=purge_orphaned_data, enable_multi_log=enable_multi_log)
purge_orphaned_data=purge_orphaned_data, enable_multi_log=enable_multi_log, additional_args=additional_args)

return self._logdir_dict[logdir]

def add_instance(self, logdir, reload_interval, enable_multi_log, tb_application):
def add_instance(self, logdir, reload_interval, enable_multi_log, additional_args, tb_application):
name = self._next_available_name()
instance = TensorBoardInstance(
name, logdir, reload_interval, enable_multi_log, tb_application)
name, logdir, reload_interval, enable_multi_log, additional_args, tb_application)
self[name] = instance
self._logdir_dict[logdir] = instance
return tb_application
Expand Down
76 changes: 67 additions & 9 deletions src/biz/tab.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,24 @@ import { Loading } from './loading';
import { Tensorboard } from '../tensorboard';
import { TensorboardManager } from '../manager';
import { DEFAULT_REFRESH_INTERVAL } from '../consts';
import { copyToClipboard } from '../utils/copy';

export interface TensorboardCreatorProps {
disable: boolean;
getCWD: () => string;
openDoc: () => void;
startTensorBoard: (logDir: string, reloadInterval: number, enableMultiLog: boolean) => void;
startTensorBoard: (
logDir: string,
reloadInterval: number,
enableMultiLog: boolean,
additionalArgs: string
) => void;
}

const TensorboardCreator = (props: TensorboardCreatorProps): JSX.Element => {
const [logDir, setLogDir] = useState(props.getCWD());
const [reloadInterval, setReloadInterval] = useState(DEFAULT_REFRESH_INTERVAL);
const [additionalArgs, setAdditionalArgs] = useState('');
const [enableReloadInterval, setEnableReloadInterval] = useState(false);
const [enableMultiLog, setEnableMultiLog] = useState(false);

Expand Down Expand Up @@ -86,6 +93,17 @@ const TensorboardCreator = (props: TensorboardCreatorProps): JSX.Element => {
)}
</div>
</div>
<InputGroup
className={classNames('additional-config-input', {
'with-content': !!additionalArgs.length
})}
small={true}
placeholder="Custom Args..."
value={additionalArgs}
onChange={e => {
setAdditionalArgs(e.target.value);
}}
/>
<div className="tensorboard-ng-ops create">
<Button
small={true}
Expand All @@ -95,7 +113,8 @@ const TensorboardCreator = (props: TensorboardCreatorProps): JSX.Element => {
props.startTensorBoard(
logDir,
enableReloadInterval ? reloadInterval : 0,
enableMultiLog
enableMultiLog,
additionalArgs
);
}}
disabled={props.disable}
Expand Down Expand Up @@ -132,6 +151,7 @@ export interface TensorboardTabReactProps {
logdir: string,
refreshInterval: number,
enableMultiLog: boolean,
additionalArgs: string,
options?: Tensorboard.IOptions
) => Promise<Tensorboard.ITensorboard>;
}
Expand Down Expand Up @@ -219,7 +239,12 @@ export const TensorboardTabReact = (props: TensorboardTabReactProps): JSX.Elemen
});
};

const startTensorBoard = (logDir: string, reloadInterval: number, enableMultiLog: boolean) => {
const startTensorBoard = (
logDir: string,
reloadInterval: number,
enableMultiLog: boolean,
additionalArgs: string
) => {
if (Number.isNaN(reloadInterval) || reloadInterval < 0) {
return showDialog({
title: 'Param Check Failed',
Expand All @@ -230,7 +255,7 @@ export const TensorboardTabReact = (props: TensorboardTabReactProps): JSX.Elemen
updateCreatePending(true);
const currentName = currentTensorBoard?.name;
props
.startNew(logDir, reloadInterval, enableMultiLog)
.startNew(logDir, reloadInterval, enableMultiLog, additionalArgs)
.then(tb => {
if (currentName === tb.model.name) {
showDialog({
Expand All @@ -251,10 +276,27 @@ export const TensorboardTabReact = (props: TensorboardTabReactProps): JSX.Elemen
setShowNewRow(false);
})
.catch(e => {
showDialog({
body: 'Start TensorBoard internal error',
buttons: [Dialog.okButton()]
});
updateCreatePending(false);

const getMessage = () =>
e.response.json().then((json: any) => {
return json.message as string;
});
const defaultMessage = 'Start TensorBoard internal error';

getMessage()
.then((msg: string) => {
showDialog({
body: msg || defaultMessage,
buttons: [Dialog.okButton()]
});
})
.catch(() => {
showDialog({
body: defaultMessage,
buttons: [Dialog.okButton()]
});
});
});
};

Expand All @@ -272,6 +314,7 @@ export const TensorboardTabReact = (props: TensorboardTabReactProps): JSX.Elemen
: DEFAULT_REFRESH_INTERVAL;
const currentLogDir = currentTensorBoard.logdir;
const enableMultiLog = currentTensorBoard.enable_multi_log;
const additionalArgs = currentTensorBoard.additional_args;

const errorCallback = (e: any) => {
showDialog({
Expand All @@ -287,7 +330,7 @@ export const TensorboardTabReact = (props: TensorboardTabReactProps): JSX.Elemen
.shutdown(currentTensorBoard.name)
.then(res => {
props.tensorboardManager
.startNew(currentLogDir, reloadInterval, enableMultiLog)
.startNew(currentLogDir, reloadInterval, enableMultiLog, additionalArgs)
.then(res => {
refreshRunning();
updateReloadPending(false);
Expand Down Expand Up @@ -448,6 +491,21 @@ export const TensorboardTabReact = (props: TensorboardTabReactProps): JSX.Elemen
reload interval(s): {currentTensorBoard?.reload_interval || 'Never'}
</p>
)}
{currentTensorBoard?.additional_args && (
<>
<p title={currentTensorBoard?.additional_args} className="custom-args-tip">
{currentTensorBoard?.additional_args}
</p>
<Button
small={true}
minimal
icon="duplicate"
onClick={() => {
copyToClipboard(currentTensorBoard?.additional_args);
}}
/>
</>
)}
</div>
</div>
<div className="tensorboard-ng-expand" />
Expand Down
9 changes: 8 additions & 1 deletion src/biz/widget.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,17 @@ export class TensorboardTabReactWidget extends ReactWidget {
logdir: string,
refreshInterval: number,
enableMultiLog: boolean,
additionalArgs: string,
options?: Tensorboard.IOptions
): Promise<Tensorboard.ITensorboard> => {
this.currentLogDir = logdir;
return this.tensorboardManager.startNew(logdir, refreshInterval, enableMultiLog, options);
return this.tensorboardManager.startNew(
logdir,
refreshInterval,
enableMultiLog,
additionalArgs,
options
);
};

setWidgetName = (name: string): void => {
Expand Down
13 changes: 7 additions & 6 deletions src/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,22 @@ export class TensorboardManager implements Tensorboard.IManager {
*
* @returns A promise that resolves with the tensorboard instance.
*/
startNew(
async startNew(
logdir: string,
refreshInterval: number = DEFAULT_REFRESH_INTERVAL,
enableMultiLog: boolean = DEFAULT_ENABLE_MULTI_LOG,
additionalArgs = '',
options?: Tensorboard.IOptions
): Promise<Tensorboard.ITensorboard> {
return Tensorboard.startNew(
const tensorboard = await Tensorboard.startNew(
logdir,
refreshInterval,
enableMultiLog,
additionalArgs,
this._getOptions(options)
).then(tensorboard => {
this._onStarted(tensorboard);
return tensorboard;
});
);
this._onStarted(tensorboard);
return tensorboard;
}

/**
Expand Down
Loading

0 comments on commit a3909c4

Please sign in to comment.