Skip to content

Commit

Permalink
Find the open port when running MultiProcessExecutor. (#137)
Browse files Browse the repository at this point in the history
* Find the open port when running MultiProcessExecutor.

* Added  apis/utils/common_utils.py

* Added error checking for get_open_ports.
  • Loading branch information
yhwen authored Jan 21, 2022
1 parent 8d90432 commit aba6943
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 15 deletions.
36 changes: 36 additions & 0 deletions nvflare/apis/utils/common_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import socket


def get_open_ports(number):
""" Get the number of open ports from the system.
Args:
number: number of ports
Returns: list of open_ports
"""
ports = []
for i in range(number):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
if port > 0:
ports.append(port)
if len(ports) != number:
raise RuntimeError("Could not get enough open ports from the system.")
return ports
16 changes: 2 additions & 14 deletions nvflare/app_common/executors/multi_process_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import os
import shlex
import socket
import subprocess
import threading
import time
Expand All @@ -29,24 +28,13 @@
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.common_utils import get_open_ports
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.fuel.common.multi_process_executor_constants import CommunicateData, CommunicationMetaData
from nvflare.fuel.utils.class_utils import ModuleScanner
from nvflare.fuel.utils.component_builder import ComponentBuilder


def _get_open_ports(number):
ports = []
for i in range(number):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
ports.append(port)
return ports


class WorkerComponentBuilder(ComponentBuilder):
FL_PACKAGES = ["nvflare"]
FL_MODULES = ["client", "app"]
Expand Down Expand Up @@ -153,7 +141,7 @@ def initialize(self, fl_ctx: FLContext):
def _initialize_multi_process(self, fl_ctx: FLContext):

try:
self.open_ports = _get_open_ports(self.num_of_processes * 3)
self.open_ports = get_open_ports(self.num_of_processes * 3)

command = (
self.get_multi_process_command()
Expand Down
4 changes: 3 additions & 1 deletion nvflare/app_common/pt/pt_multi_process_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import sys

from nvflare.apis.utils.common_utils import get_open_ports
from nvflare.app_common.executors.multi_process_executor import MultiProcessExecutor


Expand All @@ -26,5 +27,6 @@ def get_multi_process_command(self) -> str:
f"{sys.executable} -m torch.distributed.run --nproc_per_node="
+ str(self.num_of_processes)
+ " --nnodes=1 --node_rank=0"
+ ' --master_addr="localhost" --master_port=1234'
+ ' --master_addr="localhost" --master_port='
+ str(get_open_ports(1)[0])
)

0 comments on commit aba6943

Please sign in to comment.