Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept bytes for Unix socket paths. #640

Merged
merged 7 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ Version history

This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

**UNRELEASED**

- Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``,
``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by
Lura Skye.)

**4.1.0**

- Adapted to API changes made in Trio v0.23:
Expand Down
4 changes: 2 additions & 2 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,7 +2229,7 @@ async def connect_tcp(
return SocketStream(transport, protocol)

@classmethod
async def connect_unix(cls, path: str) -> abc.UNIXSocketStream:
async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
await cls.checkpoint()
loop = get_running_loop()
raw_socket = socket.socket(socket.AF_UNIX)
Expand Down Expand Up @@ -2282,7 +2282,7 @@ async def create_udp_socket(

@classmethod
async def create_unix_datagram_socket( # type: ignore[override]
cls, raw_socket: socket.socket, remote_path: str | None
cls, raw_socket: socket.socket, remote_path: str | bytes | None
) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
await cls.checkpoint()
loop = get_running_loop()
Expand Down
6 changes: 3 additions & 3 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ async def connect_tcp(
return SocketStream(trio_socket)

@classmethod
async def connect_unix(cls, path: str) -> abc.UNIXSocketStream:
async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
trio_socket = trio.socket.socket(socket.AF_UNIX)
try:
await trio_socket.connect(path)
Expand Down Expand Up @@ -1006,13 +1006,13 @@ async def create_unix_datagram_socket(
@classmethod
@overload
async def create_unix_datagram_socket(
cls, raw_socket: socket.socket, remote_path: str
cls, raw_socket: socket.socket, remote_path: str | bytes
) -> abc.ConnectedUNIXDatagramSocket:
...

@classmethod
async def create_unix_datagram_socket(
cls, raw_socket: socket.socket, remote_path: str | None
cls, raw_socket: socket.socket, remote_path: str | bytes | None
) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
trio_socket = trio.socket.from_stdlib_socket(raw_socket)

Expand Down
41 changes: 27 additions & 14 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

import errno
import os
import socket
import ssl
import stat
import sys
from collections.abc import Awaitable
from ipaddress import IPv6Address, ip_address
from os import PathLike, chmod
from pathlib import Path
from socket import AddressFamily, SocketKind
from typing import Literal, cast, overload
from typing import Any, Literal, cast, overload

from .. import to_thread
from ..abc import (
Expand Down Expand Up @@ -245,7 +247,7 @@ async def try_connect(remote_host: str, event: Event) -> None:
return connected_stream


async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream:
async def connect_unix(path: str | bytes | PathLike[Any]) -> UNIXSocketStream:
"""
Connect to the given UNIX socket.

Expand All @@ -255,7 +257,7 @@ async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream:
:return: a socket stream object

"""
path = str(Path(path))
path = os.fspath(path)
return await get_async_backend().connect_unix(path)


Expand Down Expand Up @@ -340,7 +342,10 @@ async def create_tcp_listener(


async def create_unix_listener(
path: str | PathLike[str], *, mode: int | None = None, backlog: int = 65536
path: str | bytes | PathLike[Any],
*,
mode: int | None = None,
backlog: int = 65536,
) -> SocketListener:
"""
Create a UNIX socket listener.
Expand Down Expand Up @@ -466,7 +471,7 @@ async def create_connected_udp_socket(

async def create_unix_datagram_socket(
*,
local_path: None | str | PathLike[str] = None,
local_path: None | str | bytes | PathLike[Any] = None,
local_mode: int | None = None,
) -> UNIXDatagramSocket:
"""
Expand All @@ -493,9 +498,9 @@ async def create_unix_datagram_socket(


async def create_connected_unix_datagram_socket(
remote_path: str | PathLike[str],
remote_path: str | bytes | PathLike[Any],
*,
local_path: None | str | PathLike[str] = None,
local_path: None | str | bytes | PathLike[Any] = None,
local_mode: int | None = None,
) -> ConnectedUNIXDatagramSocket:
"""
Expand All @@ -516,7 +521,7 @@ async def create_connected_unix_datagram_socket(
:return: a connected UNIX datagram socket

"""
remote_path = str(Path(remote_path))
remote_path = os.fspath(remote_path)
raw_socket = await setup_unix_local_socket(
local_path, local_mode, socket.SOCK_DGRAM
)
Expand Down Expand Up @@ -665,7 +670,7 @@ def convert_ipv6_sockaddr(


async def setup_unix_local_socket(
path: None | str | PathLike[str],
path: None | str | bytes | PathLike[Any],
mode: int | None,
socktype: int,
) -> socket.socket:
Expand All @@ -680,11 +685,19 @@ async def setup_unix_local_socket(
:param socktype: socket.SOCK_STREAM or socket.SOCK_DGRAM

"""
path_str: str | bytes | None
if path is not None:
path_str = str(path)
path = Path(path)
if path.is_socket():
path.unlink()
path_str = os.fspath(path)

# Copied from pathlib...
try:
stat_result = os.stat(path)
except OSError as e:
if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP):
raise
else:
if stat.S_ISSOCK(stat_result.st_mode):
os.unlink(path)
else:
path_str = None

Expand Down
6 changes: 3 additions & 3 deletions src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ async def connect_tcp(

@classmethod
@abstractmethod
async def connect_unix(cls, path: str) -> UNIXSocketStream:
async def connect_unix(cls, path: str | bytes) -> UNIXSocketStream:
pass

@classmethod
Expand Down Expand Up @@ -299,14 +299,14 @@ async def create_unix_datagram_socket(
@classmethod
@overload
async def create_unix_datagram_socket(
cls, raw_socket: socket, remote_path: str
cls, raw_socket: socket, remote_path: str | bytes
) -> ConnectedUNIXDatagramSocket:
...

@classmethod
@abstractmethod
async def create_unix_datagram_socket(
cls, raw_socket: socket, remote_path: str | None
cls, raw_socket: socket, remote_path: str | bytes | None
) -> UNIXDatagramSocket | ConnectedUNIXDatagramSocket:
pass

Expand Down
44 changes: 44 additions & 0 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,24 @@ async def test_cannot_connect(self, socket_path: Path) -> None:
with pytest.raises(FileNotFoundError):
await connect_unix(socket_path)

async def test_connecting_using_bytes(
self, server_sock: socket.socket, socket_path: Path
) -> None:
async with await connect_unix(str(socket_path).encode()):
pass

@pytest.mark.skipif(
platform.system() == "Darwin", reason="macOS requires valid UTF-8 paths"
)
async def test_connecting_with_non_utf8(self, socket_path: Path) -> None:
actual_path = str(socket_path).encode() + b"\xF0"
server = socket.socket(socket.AF_UNIX)
server.bind(actual_path)
server.listen(1)

async with await connect_unix(actual_path):
pass


@pytest.mark.skipif(
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
Expand Down Expand Up @@ -1101,6 +1119,18 @@ async def test_bind_twice(self, socket_path: Path) -> None:
async with await create_unix_listener(socket_path):
pass

async def test_listening_bytes_path(self, socket_path: Path) -> None:
async with await create_unix_listener(str(socket_path).encode()):
pass

@pytest.mark.skipif(
platform.system() == "Darwin", reason="macOS requires valid UTF-8 paths"
)
async def test_listening_invalid_ascii(self, socket_path: Path) -> None:
real_path = str(socket_path).encode() + b"\xF0"
async with await create_unix_listener(real_path):
pass


async def test_multi_listener(tmp_path_factory: TempPathFactory) -> None:
async def handle(stream: SocketStream) -> None:
Expand Down Expand Up @@ -1495,6 +1525,20 @@ async def test_send_after_close(self, socket_path: Path) -> None:
with pytest.raises(ClosedResourceError):
await unix_dg.sendto(b"foo", path)

async def test_local_path_bytes(self, socket_path: Path) -> None:
async with await create_unix_datagram_socket(
local_path=str(socket_path).encode()
):
pass

@pytest.mark.skipif(
platform.system() == "Darwin", reason="macOS requires valid UTF-8 paths"
)
async def test_local_path_invalid_ascii(self, socket_path: Path) -> None:
real_path = str(socket_path).encode() + b"\xF0"
async with await create_unix_datagram_socket(local_path=real_path):
pass


@pytest.mark.skipif(
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
Expand Down