@@ -196,6 +196,9 @@ class _TunnelProtocol(Protocol):
196196 def close (self ) -> None :
197197 """Close this tunnel"""
198198
199+ async def wait_closed (self ):
200+ """Wait for this tunnel to close"""
201+
199202class _TunnelConnectorProtocol (_TunnelProtocol , Protocol ):
200203 """Protocol to open a connection to tunnel an SSH connection over"""
201204
@@ -387,6 +390,11 @@ def pipe_connection_lost(self, fd: int,
387390
388391 self ._conn .connection_lost (exc )
389392
393+ def process_exited (self ):
394+ """Called when the child process has exited"""
395+
396+ self ._close_event .set ()
397+
390398 def write (self , data : bytes ) -> None :
391399 """Write data to this tunnel"""
392400
@@ -403,13 +411,20 @@ def close(self) -> None:
403411
404412 if self ._transport : # pragma: no cover
405413 self ._transport .close ()
414+ self ._transport = None
406415
407- self ._close_event .set ()
416+ async def wait_closed (self ):
417+ """Wait for this subprocess to exit"""
418+
419+ await self ._close_event .wait ()
408420
421+ _ , tunnel = await loop .subprocess_exec (_ProxyCommandTunnel , * command ,
422+ start_new_session = True )
409423
410- _ , tunnel = await loop .subprocess_exec (_ProxyCommandTunnel , * command )
424+ conn = cast (_Conn , cast (_ProxyCommandTunnel , tunnel ).get_conn ())
425+ conn .set_tunnel (tunnel )
411426
412- return cast ( _Conn , cast ( _ProxyCommandTunnel , tunnel ). get_conn ())
427+ return conn
413428
414429
415430async def _open_tunnel (tunnels : object , options : _Options ,
@@ -1090,15 +1105,15 @@ def _cleanup(self, exc: Optional[Exception]) -> None:
10901105
10911106 self ._owner = None
10921107
1108+ if self ._tunnel :
1109+ self ._tunnel .close ()
1110+ self ._tunnel = None
1111+
10931112 self ._cancel_login_timer ()
10941113 self ._close_event .set ()
10951114
10961115 self ._inpbuf = b''
10971116
1098- if self ._tunnel :
1099- self ._tunnel .close ()
1100- self ._tunnel = None
1101-
11021117 def _cancel_login_timer (self ) -> None :
11031118 """Cancel the login timer"""
11041119
@@ -2851,6 +2866,9 @@ async def wait_closed(self) -> None:
28512866 if self ._agent :
28522867 await self ._agent .wait_closed ()
28532868
2869+ if self ._tunnel :
2870+ await self ._tunnel .wait_closed ()
2871+
28542872 await self ._close_event .wait ()
28552873
28562874 def disconnect (self , code : int , reason : str ,
0 commit comments