Skip to content

Commit

Permalink
test: fix serial console auth needed.
Browse files Browse the repository at this point in the history
  • Loading branch information
squirrelsc authored and LiliDeng committed Jun 26, 2024
1 parent 57055c4 commit 9581d80
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions lisa/sut_orchestrator/azure/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def write(self, data: str) -> None:
return
except websockets.ConnectionClosed as e: # type: ignore
# If the connection is closed, we need to reconnect
self._log.debug(f"Connection closed: {e}")
self._log.debug(f"Connection closed on read serial console: {e}")
self._ws = None
self._get_connection()
raise e
Expand All @@ -340,7 +340,7 @@ def read(self) -> str:
return output
except websockets.ConnectionClosed as e: # type: ignore
# If the connection is closed, we need to reconnect
self._log.debug(f"Connection closed: {e}")
self._log.debug(f"Connection closed on read serial console: {e}")
self._ws = None
self._get_connection()
raise e
Expand All @@ -360,10 +360,18 @@ def _get_connection(self) -> Any:
connection_str = self._get_connection_string()

# create websocket connection
self._ws = self._get_event_loop().run_until_complete(
ws = self._get_event_loop().run_until_complete(
websockets.connect(connection_str) # type: ignore
)

token = self._get_access_token()
# add to secret in case it's echo back.
add_secret(token)
# send token to auth
self._get_event_loop().run_until_complete(ws.send(token))

self._ws = ws

return self._ws

def _write(self, cmd: str) -> None:
Expand Down Expand Up @@ -402,6 +410,14 @@ def _read(self) -> str:

return output

def _get_access_token(self) -> str:
platform: AzurePlatform = self._platform # type: ignore
access_token = platform.credential.get_token(
"https://management.core.windows.net/.default"
).token

return access_token

def _get_console_log(self, saved_path: Optional[Path]) -> bytes:
platform: AzurePlatform = self._platform # type: ignore
return save_console_log(
Expand All @@ -414,19 +430,15 @@ def _get_console_log(self, saved_path: Optional[Path]) -> bytes:

def _get_connection_string(self) -> str:
# setup connection string
platform: AzurePlatform = self._platform # type: ignore
connection = self._serial_port_operations.connect(
resource_group_name=self._resource_group_name,
resource_provider_namespace=self.RESOURCE_PROVIDER_NAMESPACE,
parent_resource_type=self.PARENT_RESOURCE_TYPE,
parent_resource=self._vm_name,
serial_port=self._serial_port.name,
)
access_token = platform.credential.get_token(
"https://management.core.windows.net/.default"
).token
serial_port_connection_str = (
f"{connection.connection_string}?authorization={access_token}"
f"{connection.connection_string}?authorization={self._get_access_token()}"
)

return serial_port_connection_str
Expand Down Expand Up @@ -476,6 +488,8 @@ def _initialize_serial_console(self, port_id: int) -> None:
if int(serialport.name) == port_id
][0]

self._log.debug(f"Serial port {port_id} is enabled: {self._serial_port}")

# setup shared web socket connection variable
self._ws = None

Expand Down

0 comments on commit 9581d80

Please sign in to comment.