diff --git a/.ci/appveyor.yml b/.ci/appveyor.yml
deleted file mode 100644
index 7a2ec05d..00000000
--- a/.ci/appveyor.yml
+++ /dev/null
@@ -1,58 +0,0 @@
-services:
- - postgresql95
-
-environment:
- global:
- PGINSTALLATION: C:\\Program Files\\PostgreSQL\\9.5\\bin
- S3_UPLOAD_USERNAME: oss-ci-bot
- S3_UPLOAD_BUCKET: magicstack-oss-releases
- S3_UPLOAD_ACCESSKEY:
- secure: 1vmOqSXq5zDN8UdezZ3H4l0A9LUJiTr7Wuy9whCdffE=
- S3_UPLOAD_SECRET:
- secure: XudOvV6WtY9yRoqKahXMswFth8SF1UTnSXws4UBjeqzQUjOx2V2VRvIdpPfiqUKt
-
- matrix:
- - PYTHON: "C:\\Python35\\python.exe"
- - PYTHON: "C:\\Python35-x64\\python.exe"
- - PYTHON: "C:\\Python36\\python.exe"
- - PYTHON: "C:\\Python36-x64\\python.exe"
- - PYTHON: "C:\\Python37\\python.exe"
- - PYTHON: "C:\\Python37-x64\\python.exe"
-
-branches:
- # Avoid building PR branches.
- only:
- - master
- - ci
- - releases
-
-install:
- - git submodule update --init --recursive
- - "%PYTHON% -m pip install --upgrade pip wheel setuptools"
-
-build_script:
- - "%PYTHON% setup.py build_ext --inplace --cython-always"
-
-test_script:
- - "%PYTHON% setup.py --verbose test"
-
-after_test:
- - "%PYTHON% setup.py bdist_wheel"
-
-artifacts:
- - path: dist\*
-
-deploy_script:
- - ps: |
- if ($env:appveyor_repo_branch -eq 'releases') {
- & "$env:PYTHON" -m pip install -U -r ".ci/requirements-publish.txt"
- $PACKAGE_VERSION = & "$env:PYTHON" ".ci/package-version.py"
- $PYPI_VERSION = & "$env:PYTHON" ".ci/pypi-check.py" "asyncpg"
-
- if ($PACKAGE_VERSION -eq $PYPI_VERSION) {
- Write-Error "asyncpg-$PACKAGE_VERSION is already published on PyPI"
- exit 1
- }
-
- & "$env:PYTHON" ".ci/s3-upload.py" dist\*.whl
- }
diff --git a/.ci/build-manylinux-wheels.sh b/.ci/build-manylinux-wheels.sh
deleted file mode 100755
index 87496685..00000000
--- a/.ci/build-manylinux-wheels.sh
+++ /dev/null
@@ -1,27 +0,0 @@
-#!/bin/bash
-
-set -e -x
-
-# Compile wheels
-PYTHON="/opt/python/${PYTHON_VERSION}/bin/python"
-PIP="/opt/python/${PYTHON_VERSION}/bin/pip"
-${PIP} install --upgrade setuptools pip wheel~=0.31.1
-cd /io
-make clean
-${PYTHON} setup.py bdist_wheel
-
-# Bundle external shared libraries into the wheels.
-for whl in /io/dist/*.whl; do
- auditwheel repair $whl -w /io/dist/
- rm /io/dist/*-linux_*.whl
-done
-
-${PIP} install ${PYMODULE}[test] -f "file:///io/dist"
-
-# Grab docker host, where Postgres should be running.
-export PGHOST=$(ip route | awk '/default/ { print $3 }' | uniq)
-export PGUSER="postgres"
-
-rm -rf /io/tests/__pycache__
-make -C /io PYTHON="${PYTHON}" testinstalled
-rm -rf /io/tests/__pycache__
diff --git a/.ci/package-version.py b/.ci/package-version.py
deleted file mode 100755
index 59d864fe..00000000
--- a/.ci/package-version.py
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/usr/bin/env python3
-
-
-import os.path
-import sys
-
-
-def main():
- version_file = os.path.join(
- os.path.dirname(os.path.dirname(__file__)), 'asyncpg', '__init__.py')
-
- with open(version_file, 'r') as f:
- for line in f:
- if line.startswith('__version__ ='):
- _, _, version = line.partition('=')
- print(version.strip(" \n'\""))
- return 0
-
- print('could not find package version in asyncpg/__init__.py',
- file=sys.stderr)
- return 1
-
-
-if __name__ == '__main__':
- sys.exit(main())
diff --git a/.ci/push_key.enc b/.ci/push_key.enc
deleted file mode 100644
index ae261920..00000000
Binary files a/.ci/push_key.enc and /dev/null differ
diff --git a/.ci/pypi-check.py b/.ci/pypi-check.py
deleted file mode 100755
index 1b9c11c4..00000000
--- a/.ci/pypi-check.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#!/usr/bin/env python3
-
-
-import argparse
-import sys
-import xmlrpc.client
-
-
-def main():
- parser = argparse.ArgumentParser(description='PyPI package checker')
- parser.add_argument('package_name', metavar='PACKAGE-NAME')
-
- parser.add_argument(
- '--pypi-index-url',
- help=('PyPI index URL.'),
- default='https://pypi.python.org/pypi')
-
- args = parser.parse_args()
-
- pypi = xmlrpc.client.ServerProxy(args.pypi_index_url)
- releases = pypi.package_releases(args.package_name)
-
- if releases:
- print(next(iter(sorted(releases, reverse=True))))
-
- return 0
-
-
-if __name__ == '__main__':
- sys.exit(main())
diff --git a/.ci/requirements-publish.txt b/.ci/requirements-publish.txt
deleted file mode 100644
index 403ef596..00000000
--- a/.ci/requirements-publish.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-tinys3
-twine
diff --git a/.ci/s3-download-release.py b/.ci/s3-download-release.py
deleted file mode 100755
index 223f7f17..00000000
--- a/.ci/s3-download-release.py
+++ /dev/null
@@ -1,74 +0,0 @@
-#!/usr/bin/env python3
-
-
-import argparse
-import os
-import os.path
-import sys
-import urllib.request
-
-import tinys3
-
-
-def main():
- parser = argparse.ArgumentParser(description='S3 File Uploader')
- parser.add_argument(
- '--s3-bucket',
- help=('S3 bucket name (defaults to $S3_UPLOAD_BUCKET)'),
- default=os.environ.get('S3_UPLOAD_BUCKET'))
- parser.add_argument(
- '--s3-region',
- help=('S3 region (defaults to $S3_UPLOAD_REGION)'),
- default=os.environ.get('S3_UPLOAD_REGION'))
- parser.add_argument(
- '--s3-username',
- help=('S3 username (defaults to $S3_UPLOAD_USERNAME)'),
- default=os.environ.get('S3_UPLOAD_USERNAME'))
- parser.add_argument(
- '--s3-key',
- help=('S3 access key (defaults to $S3_UPLOAD_ACCESSKEY)'),
- default=os.environ.get('S3_UPLOAD_ACCESSKEY'))
- parser.add_argument(
- '--s3-secret',
- help=('S3 secret (defaults to $S3_UPLOAD_SECRET)'),
- default=os.environ.get('S3_UPLOAD_SECRET'))
- parser.add_argument(
- '--destdir',
- help='Destination directory.')
- parser.add_argument(
- 'package', metavar='PACKAGE',
- help='Package name and version to download.')
-
- args = parser.parse_args()
-
- if args.s3_region:
- endpoint = 's3-{}.amazonaws.com'.format(args.s3_region.lower())
- else:
- endpoint = 's3.amazonaws.com'
-
- conn = tinys3.Connection(
- access_key=args.s3_key,
- secret_key=args.s3_secret,
- default_bucket=args.s3_bucket,
- tls=True,
- endpoint=endpoint,
- )
-
- files = []
-
- for entry in conn.list(args.package):
- files.append(entry['key'])
-
- destdir = args.destdir or os.getpwd()
-
- for file in files:
- print('Downloading {}...'.format(file))
- url = 'https://{}/{}/{}'.format(endpoint, args.s3_bucket, file)
- target = os.path.join(destdir, file)
- urllib.request.urlretrieve(url, target)
-
- return 0
-
-
-if __name__ == '__main__':
- sys.exit(main())
diff --git a/.ci/s3-upload.py b/.ci/s3-upload.py
deleted file mode 100755
index 92479afe..00000000
--- a/.ci/s3-upload.py
+++ /dev/null
@@ -1,62 +0,0 @@
-#!/usr/bin/env python3
-
-
-import argparse
-import glob
-import os
-import os.path
-import sys
-
-import tinys3
-
-
-def main():
- parser = argparse.ArgumentParser(description='S3 File Uploader')
- parser.add_argument(
- '--s3-bucket',
- help=('S3 bucket name (defaults to $S3_UPLOAD_BUCKET)'),
- default=os.environ.get('S3_UPLOAD_BUCKET'))
- parser.add_argument(
- '--s3-region',
- help=('S3 region (defaults to $S3_UPLOAD_REGION)'),
- default=os.environ.get('S3_UPLOAD_REGION'))
- parser.add_argument(
- '--s3-username',
- help=('S3 username (defaults to $S3_UPLOAD_USERNAME)'),
- default=os.environ.get('S3_UPLOAD_USERNAME'))
- parser.add_argument(
- '--s3-key',
- help=('S3 access key (defaults to $S3_UPLOAD_ACCESSKEY)'),
- default=os.environ.get('S3_UPLOAD_ACCESSKEY'))
- parser.add_argument(
- '--s3-secret',
- help=('S3 secret (defaults to $S3_UPLOAD_SECRET)'),
- default=os.environ.get('S3_UPLOAD_SECRET'))
- parser.add_argument(
- 'files', nargs='+', metavar='FILE', help='Files to upload')
-
- args = parser.parse_args()
-
- if args.s3_region:
- endpoint = 's3-{}.amazonaws.com'.format(args.s3_region.lower())
- else:
- endpoint = 's3.amazonaws.com'
-
- conn = tinys3.Connection(
- access_key=args.s3_key,
- secret_key=args.s3_secret,
- default_bucket=args.s3_bucket,
- tls=True,
- endpoint=endpoint,
- )
-
- for pattern in args.files:
- for fn in glob.iglob(pattern):
- with open(fn, 'rb') as f:
- conn.upload(os.path.basename(fn), f)
-
- return 0
-
-
-if __name__ == '__main__':
- sys.exit(main())
diff --git a/.ci/travis-before-install.sh b/.ci/travis-before-install.sh
deleted file mode 100755
index 3dd6b8a8..00000000
--- a/.ci/travis-before-install.sh
+++ /dev/null
@@ -1,51 +0,0 @@
-#!/bin/bash
-
-set -e -x
-
-if [ -z "${PGVERSION}" ]; then
- echo "Missing PGVERSION environment variable."
- exit 1
-fi
-
-if [[ "${TRAVIS_OS_NAME}" == "linux" && "${BUILD}" == *wheels* ]]; then
- sudo service postgresql stop ${PGVERSION}
-
- echo "port = 5432" | \
- sudo tee --append /etc/postgresql/${PGVERSION}/main/postgresql.conf
-
- if [[ "${BUILD}" == *wheels* ]]; then
- # Allow docker guests to connect to the database
- echo "listen_addresses = '*'" | \
- sudo tee --append /etc/postgresql/${PGVERSION}/main/postgresql.conf
- echo "host all all 172.17.0.0/16 trust" | \
- sudo tee --append /etc/postgresql/${PGVERSION}/main/pg_hba.conf
-
- if [ "${PGVERSION}" -ge "11" ]; then
- # Disable JIT to avoid unpredictable timings in tests.
- echo "jit = off" | \
- sudo tee --append /etc/postgresql/${PGVERSION}/main/postgresql.conf
- fi
- fi
-
- sudo service postgresql start ${PGVERSION}
-fi
-
-if [ "${TRAVIS_OS_NAME}" == "osx" ]; then
- brew update >/dev/null
- brew upgrade pyenv
- eval "$(pyenv init -)"
-
- if ! (pyenv versions | grep "${PYTHON_VERSION}$"); then
- pyenv install ${PYTHON_VERSION}
- fi
- pyenv global ${PYTHON_VERSION}
- pyenv rehash
-
- # Install PostgreSQL
- if brew ls --versions postgresql > /dev/null; then
- brew remove --force --ignore-dependencies postgresql
- fi
-
- brew install postgresql@${PGVERSION}
- brew services start postgresql@${PGVERSION}
-fi
diff --git a/.ci/travis-build-docs.sh b/.ci/travis-build-docs.sh
deleted file mode 100755
index 1716e330..00000000
--- a/.ci/travis-build-docs.sh
+++ /dev/null
@@ -1,11 +0,0 @@
-#!/bin/bash
-
-set -e -x
-
-if [[ "${BUILD}" != *docs* ]]; then
- echo "Skipping documentation build."
- exit 0
-fi
-
-pip install -U -e .[docs]
-make htmldocs SPHINXOPTS="-q -W -j4"
diff --git a/.ci/travis-build-wheels.sh b/.ci/travis-build-wheels.sh
deleted file mode 100755
index 751a22ae..00000000
--- a/.ci/travis-build-wheels.sh
+++ /dev/null
@@ -1,73 +0,0 @@
-#!/bin/bash
-
-set -e -x
-
-
-if [[ "${TRAVIS_BRANCH}" != "releases" || "${BUILD}" != *wheels* ]]; then
- # Not a release
- exit 0
-fi
-
-
-if [ "${TRAVIS_OS_NAME}" == "osx" ]; then
- PYENV_ROOT="$HOME/.pyenv"
- PATH="$PYENV_ROOT/bin:$PATH"
- eval "$(pyenv init -)"
-fi
-
-PACKAGE_VERSION=$(python ".ci/package-version.py")
-PYPI_VERSION=$(python ".ci/pypi-check.py" "${PYMODULE}")
-
-if [ "${PACKAGE_VERSION}" == "${PYPI_VERSION}" ]; then
- echo "${PYMODULE}-${PACKAGE_VERSION} is already published on PyPI"
- exit 1
-fi
-
-
-_root="${TRAVIS_BUILD_DIR}"
-
-
-_upload_wheels() {
- python "${_root}/.ci/s3-upload.py" "${_root}/dist"/*.whl
- sudo rm -rf "${_root}/dist"/*.whl
-}
-
-
-pip install -U -r ".ci/requirements-publish.txt"
-
-
-if [ "${TRAVIS_OS_NAME}" == "linux" ]; then
- for pyver in ${RELEASE_PYTHON_VERSIONS}; do
- ML_PYTHON_VERSION=$(python3 -c \
- "print('cp{maj}{min}-cp{maj}{min}m'.format( \
- maj='${pyver}'.split('.')[0], \
- min='${pyver}'.split('.')[1]))")
-
- for arch in x86_64 i686; do
- ML_IMAGE="quay.io/pypa/manylinux1_${arch}"
- docker pull "${ML_IMAGE}"
- docker run --rm \
- -v "${_root}":/io \
- -e "PYMODULE=${PYMODULE}" \
- -e "PYTHON_VERSION=${ML_PYTHON_VERSION}" \
- -e "ASYNCPG_VERSION=${PACKAGE_VERSION}" \
- "${ML_IMAGE}" /io/.ci/build-manylinux-wheels.sh
-
- _upload_wheels
- done
- done
-
-elif [ "${TRAVIS_OS_NAME}" == "osx" ]; then
- export PGINSTALLATION="/usr/local/opt/postgresql@${PGVERSION}/bin"
-
- make clean
- python setup.py bdist_wheel
-
- pip install ${PYMODULE}[test] -f "file:///${_root}/dist"
- make -C "${_root}" ASYNCPG_VERSION="${PACKAGE_VERSION}" testinstalled
-
- _upload_wheels
-
-else
- echo "Cannot build on ${TRAVIS_OS_NAME}."
-fi
diff --git a/.ci/travis-install.sh b/.ci/travis-install.sh
deleted file mode 100755
index b5124eb8..00000000
--- a/.ci/travis-install.sh
+++ /dev/null
@@ -1,13 +0,0 @@
-#!/bin/bash
-
-set -e -x
-
-if [ "${TRAVIS_OS_NAME}" == "osx" ]; then
- PYENV_ROOT="$HOME/.pyenv"
- PATH="$PYENV_ROOT/bin:$PATH"
- eval "$(pyenv init -)"
-fi
-
-pip install --upgrade setuptools pip wheel
-pip download --dest=/tmp/deps .[test]
-pip install -U --no-index --find-links=/tmp/deps /tmp/deps/*
diff --git a/.ci/travis-publish-docs.sh b/.ci/travis-publish-docs.sh
deleted file mode 100755
index 95e55c79..00000000
--- a/.ci/travis-publish-docs.sh
+++ /dev/null
@@ -1,70 +0,0 @@
-#!/bin/bash
-
-# Based on https://gist.github.com/domenic/ec8b0fc8ab45f39403dd
-
-set -e -x
-
-SOURCE_BRANCH="master"
-TARGET_BRANCH="gh-pages"
-DOC_BUILD_DIR="_build/html/"
-
-if [ "${TRAVIS_PULL_REQUEST}" != "false" ]; then
- echo "Skipping documentation deploy."
- exit 0
-fi
-
-pip install -U .[dev]
-make htmldocs
-
-git config --global user.email "infra@magic.io"
-git config --global user.name "Travis CI"
-
-PACKAGE_VERSION=$(python ".ci/package-version.py")
-REPO=$(git config remote.origin.url)
-SSH_REPO=${REPO/https:\/\/github.com\//git@github.com:}
-COMMITISH=$(git rev-parse --verify HEAD)
-AUTHOR=$(git show --quiet --format="%aN <%aE>" "${COMMITISH}")
-
-git clone "${REPO}" docs/gh-pages
-cd docs/gh-pages
-git checkout "${TARGET_BRANCH}" || git checkout --orphan "${TARGET_BRANCH}"
-cd ..
-
-if [[ ${PACKAGE_VERSION} = *"dev"* ]]; then
- VERSION="devel"
-else
- VERSION="current"
-fi
-
-rm -r "gh-pages/${VERSION}/"
-rsync -a "${DOC_BUILD_DIR}/" "gh-pages/${VERSION}/"
-
-cd gh-pages
-
-if git diff --quiet --exit-code; then
- echo "No changes to documentation."
- exit 0
-fi
-
-git add --all .
-git commit -m "Automatic documentation update" --author="${AUTHOR}"
-
-set +x
-echo "Decrypting push key..."
-ENCRYPTED_KEY_VAR="encrypted_${DOCS_PUSH_KEY_LABEL}_key"
-ENCRYPTED_IV_VAR="encrypted_${DOCS_PUSH_KEY_LABEL}_iv"
-ENCRYPTED_KEY=${!ENCRYPTED_KEY_VAR}
-ENCRYPTED_IV=${!ENCRYPTED_IV_VAR}
-openssl aes-256-cbc -K "${ENCRYPTED_KEY}" -iv "${ENCRYPTED_IV}" \
- -in "${TRAVIS_BUILD_DIR}/.ci/push_key.enc" \
- -out "${TRAVIS_BUILD_DIR}/.ci/push_key" -d
-set -x
-chmod 600 "${TRAVIS_BUILD_DIR}/.ci/push_key"
-eval `ssh-agent -s`
-ssh-add "${TRAVIS_BUILD_DIR}/.ci/push_key"
-
-git push "${SSH_REPO}" "${TARGET_BRANCH}"
-rm "${TRAVIS_BUILD_DIR}/.ci/push_key"
-
-cd "${TRAVIS_BUILD_DIR}"
-rm -rf docs/gh-pages
diff --git a/.ci/travis-release.sh b/.ci/travis-release.sh
deleted file mode 100755
index c9c1e936..00000000
--- a/.ci/travis-release.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/bin/bash
-
-set -e -x
-
-if [ -z "${TRAVIS_TAG}" ]; then
- # Not a tagged commit.
- exit 0
-fi
-
-pip install -U -r ".ci/requirements-publish.txt"
-
-PACKAGE_VERSION=$(python ".ci/package-version.py")
-PYPI_VERSION=$(python ".ci/pypi-check.py" "${PYMODULE}")
-
-if [ "${PACKAGE_VERSION}" == "${PYPI_VERSION}" ]; then
- echo "${PYMODULE}-${PACKAGE_VERSION} is already published on PyPI"
- exit 0
-fi
-
-# Check if all expected wheels have been built and uploaded.
-release_platforms=(
- "macosx_10_??_x86_64"
- "manylinux1_i686"
- "manylinux1_x86_64"
- "win32"
- "win_amd64"
-)
-
-P="${PYMODULE}-${PACKAGE_VERSION}"
-expected_wheels=()
-
-for pyver in ${RELEASE_PYTHON_VERSIONS}; do
- pyver="${pyver//./}"
- abitag="cp${pyver}-cp${pyver}m"
- for plat in "${release_platforms[@]}"; do
- expected_wheels+=("${P}-${abitag}-${plat}.whl")
- done
-done
-
-rm -rf dist/*.whl dist/*.tar.*
-python setup.py sdist
-python ".ci/s3-download-release.py" --destdir=dist/ "${P}"
-
-_file_exists() { [[ -f $1 ]]; }
-
-for distfile in "${expected_wheels[@]}"; do
- if ! _file_exists dist/${distfile}; then
- echo "Expected wheel ${distfile} not found."
- exit 1
- fi
-done
-
-python -m twine upload dist/*.whl dist/*.tar.*
diff --git a/.ci/travis-tests.sh b/.ci/travis-tests.sh
deleted file mode 100755
index 397616c5..00000000
--- a/.ci/travis-tests.sh
+++ /dev/null
@@ -1,28 +0,0 @@
-#!/bin/bash
-
-set -e -x
-
-if [[ "${BUILD}" != *tests* ]]; then
- echo "Skipping tests."
- exit 0
-fi
-
-if [ "${TRAVIS_OS_NAME}" == "osx" ]; then
- PYENV_ROOT="$HOME/.pyenv"
- PATH="$PYENV_ROOT/bin:$PATH"
- eval "$(pyenv init -)"
-fi
-
-# Make sure we test with the correct PostgreSQL version.
-if [ "${TRAVIS_OS_NAME}" == "osx" ]; then
- export PGINSTALLATION="/usr/local/opt/postgresql@${PGVERSION}/bin"
-else
- export PGINSTALLATION="/usr/lib/postgresql/${PGVERSION}/bin"
-fi
-
-if [[ "${BUILD}" == *quicktests* ]]; then
- make && make quicktest
-else
- make && make test
- make clean && make debug && make test
-fi
diff --git a/.clang-format b/.clang-format
new file mode 100644
index 00000000..b2bb93db
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1,17 @@
+# A clang-format style that approximates Python's PEP 7
+BasedOnStyle: Google
+AlwaysBreakAfterReturnType: All
+AllowShortIfStatementsOnASingleLine: false
+AlignAfterOpenBracket: Align
+BreakBeforeBraces: Stroustrup
+ColumnLimit: 95
+DerivePointerAlignment: false
+IndentWidth: 4
+Language: Cpp
+PointerAlignment: Right
+ReflowComments: true
+SpaceBeforeParens: ControlStatements
+SpacesInParentheses: false
+TabWidth: 4
+UseTab: Never
+SortIncludes: false
diff --git a/.clangd b/.clangd
new file mode 100644
index 00000000..6c88d686
--- /dev/null
+++ b/.clangd
@@ -0,0 +1,4 @@
+Diagnostics:
+ Includes:
+ IgnoreHeader:
+ - "pythoncapi_compat.*\\.h"
diff --git a/.coveragerc b/.coveragerc
deleted file mode 100644
index 081835d3..00000000
--- a/.coveragerc
+++ /dev/null
@@ -1,12 +0,0 @@
-[run]
-branch = True
-plugins = Cython.Coverage
-source =
- asyncpg/
- tests/
-omit =
- *.pxd
-
-[paths]
-source =
- asyncpg
diff --git a/.flake8 b/.flake8
index bfc97a81..d4e76b7a 100644
--- a/.flake8
+++ b/.flake8
@@ -1,3 +1,5 @@
[flake8]
-ignore = E402,E731
-exclude = .git,__pycache__,build,dist,.eggs,.github,.local
+select = C90,E,F,W,Y0
+ignore = E402,E731,W503,W504,E252
+exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox
+per-file-ignores = *.pyi: F401,F403,F405,F811,E127,E128,E203,E266,E301,E302,E305,E501,E701,E704,E741,B303,W503,W504
diff --git a/.github/release_log.py b/.github/release_log.py
index 0e3ee7f4..717cd6f6 100755
--- a/.github/release_log.py
+++ b/.github/release_log.py
@@ -45,10 +45,7 @@ def main():
print(f'* {first_line}')
print(f' (by {username} in {sha}', end='')
- if issue_num:
- print(f' for #{issue_num})')
- else:
- print(')')
+ print(')')
print()
diff --git a/.github/workflows/install-krb5.sh b/.github/workflows/install-krb5.sh
new file mode 100755
index 00000000..bdb5744d
--- /dev/null
+++ b/.github/workflows/install-krb5.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+set -Eexuo pipefail
+shopt -s nullglob
+
+if [[ $OSTYPE == linux* ]]; then
+ if [ "$(id -u)" = "0" ]; then
+ SUDO=
+ else
+ SUDO=sudo
+ fi
+
+ if [ -e /etc/os-release ]; then
+ source /etc/os-release
+ elif [ -e /etc/centos-release ]; then
+ ID="centos"
+ VERSION_ID=$(cat /etc/centos-release | cut -f3 -d' ' | cut -f1 -d.)
+ else
+ echo "install-krb5.sh: cannot determine which Linux distro this is" >&2
+ exit 1
+ fi
+
+ if [ "${ID}" = "debian" -o "${ID}" = "ubuntu" ]; then
+ export DEBIAN_FRONTEND=noninteractive
+
+ $SUDO apt-get update
+ $SUDO apt-get install -y --no-install-recommends \
+ libkrb5-dev krb5-user krb5-kdc krb5-admin-server
+ elif [ "${ID}" = "almalinux" ]; then
+ $SUDO dnf install -y krb5-server krb5-workstation krb5-libs krb5-devel
+ elif [ "${ID}" = "centos" ]; then
+ $SUDO yum install -y krb5-server krb5-workstation krb5-libs krb5-devel
+ elif [ "${ID}" = "alpine" ]; then
+ $SUDO apk add krb5 krb5-server krb5-dev
+ else
+ echo "install-krb5.sh: Unsupported linux distro: ${distro}" >&2
+ exit 1
+ fi
+else
+ echo "install-krb5.sh: unsupported OS: ${OSTYPE}" >&2
+ exit 1
+fi
diff --git a/.github/workflows/install-postgres.sh b/.github/workflows/install-postgres.sh
new file mode 100755
index 00000000..733c7033
--- /dev/null
+++ b/.github/workflows/install-postgres.sh
@@ -0,0 +1,62 @@
+#!/bin/bash
+
+set -Eexuo pipefail
+shopt -s nullglob
+
+if [[ $OSTYPE == linux* ]]; then
+ PGVERSION=${PGVERSION:-12}
+
+ if [ -e /etc/os-release ]; then
+ source /etc/os-release
+ elif [ -e /etc/centos-release ]; then
+ ID="centos"
+ VERSION_ID=$(cat /etc/centos-release | cut -f3 -d' ' | cut -f1 -d.)
+ else
+ echo "install-postgres.sh: cannot determine which Linux distro this is" >&2
+ exit 1
+ fi
+
+ if [ "${ID}" = "debian" -o "${ID}" = "ubuntu" ]; then
+ export DEBIAN_FRONTEND=noninteractive
+
+ apt-get install -y --no-install-recommends curl gnupg ca-certificates
+ curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
+ mkdir -p /etc/apt/sources.list.d/
+ echo "deb https://apt.postgresql.org/pub/repos/apt/ ${VERSION_CODENAME}-pgdg main" \
+ >> /etc/apt/sources.list.d/pgdg.list
+ apt-get update
+ apt-get install -y --no-install-recommends \
+ "postgresql-${PGVERSION}" \
+ "postgresql-contrib-${PGVERSION}"
+ elif [ "${ID}" = "almalinux" ]; then
+ yum install -y \
+ "postgresql-server" \
+ "postgresql-devel" \
+ "postgresql-contrib"
+ elif [ "${ID}" = "centos" ]; then
+ el="EL-${VERSION_ID%.*}-$(arch)"
+ baseurl="https://download.postgresql.org/pub/repos/yum/reporpms"
+ yum install -y "${baseurl}/${el}/pgdg-redhat-repo-latest.noarch.rpm"
+ if [ ${VERSION_ID%.*} -ge 8 ]; then
+ dnf -qy module disable postgresql
+ fi
+ yum install -y \
+ "postgresql${PGVERSION}-server" \
+ "postgresql${PGVERSION}-contrib"
+ ln -s "/usr/pgsql-${PGVERSION}/bin/pg_config" "/usr/local/bin/pg_config"
+ elif [ "${ID}" = "alpine" ]; then
+ apk add shadow postgresql postgresql-dev postgresql-contrib
+ else
+ echo "install-postgres.sh: unsupported Linux distro: ${distro}" >&2
+ exit 1
+ fi
+
+ useradd -m -s /bin/bash apgtest
+
+elif [[ $OSTYPE == darwin* ]]; then
+ brew install postgresql
+
+else
+ echo "install-postgres.sh: unsupported OS: ${OSTYPE}" >&2
+ exit 1
+fi
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 00000000..353ed824
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,253 @@
+name: Release
+
+on:
+ pull_request:
+ branches:
+ - "master"
+ - "ci"
+ - "[0-9]+.[0-9x]+*"
+ paths:
+ - "asyncpg/_version.py"
+
+jobs:
+ validate-release-request:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Validate release PR
+ uses: edgedb/action-release/validate-pr@master
+ id: checkver
+ with:
+ require_team: Release Managers
+ require_approval: no
+ github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
+ version_file: asyncpg/_version.py
+ version_line_pattern: |
+ __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
+
+ - name: Stop if not approved
+ if: steps.checkver.outputs.approved != 'true'
+ run: |
+ echo ::error::PR is not approved yet.
+ exit 1
+
+ - name: Store release version for later use
+ env:
+ VERSION: ${{ steps.checkver.outputs.version }}
+ run: |
+ mkdir -p dist/
+ echo "${VERSION}" > dist/VERSION
+
+ - uses: actions/upload-artifact@v4
+ with:
+ name: dist-version
+ path: dist/VERSION
+
+ build-sdist:
+ needs: validate-release-request
+ runs-on: ubuntu-latest
+
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: 1
+
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ fetch-depth: 50
+ submodules: true
+ persist-credentials: false
+
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: "3.x"
+
+ - name: Build source distribution
+ run: |
+ pip install -U setuptools wheel pip
+ python setup.py sdist
+
+ - uses: actions/upload-artifact@v4
+ with:
+ name: dist-sdist
+ path: dist/*.tar.*
+
+ build-wheels-matrix:
+ needs: validate-release-request
+ runs-on: ubuntu-latest
+ outputs:
+ include: ${{ steps.set-matrix.outputs.include }}
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ persist-credentials: false
+ - uses: actions/setup-python@v6
+ with:
+ python-version: "3.x"
+ - run: pip install cibuildwheel==3.3.0
+ - id: set-matrix
+ run: |
+ MATRIX_INCLUDE=$(
+ {
+ cibuildwheel --print-build-identifiers --platform linux --archs x86_64,aarch64 | grep cp | jq -nRc '{"only": inputs, "os": "ubuntu-latest"}' \
+ && cibuildwheel --print-build-identifiers --platform macos --archs x86_64,arm64 | grep cp | jq -nRc '{"only": inputs, "os": "macos-latest"}' \
+ && cibuildwheel --print-build-identifiers --platform windows --archs x86,AMD64 | grep cp | jq -nRc '{"only": inputs, "os": "windows-latest"}'
+ } | jq -sc
+ )
+ echo "include=$MATRIX_INCLUDE" >> $GITHUB_OUTPUT
+
+ build-wheels:
+ needs: build-wheels-matrix
+ runs-on: ${{ matrix.os }}
+ name: Build ${{ matrix.only }}
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include: ${{ fromJson(needs.build-wheels-matrix.outputs.include) }}
+
+ defaults:
+ run:
+ shell: bash
+
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: 1
+
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ fetch-depth: 50
+ submodules: true
+ persist-credentials: false
+
+ - name: Set up QEMU
+ if: runner.os == 'Linux'
+ uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0
+
+ - uses: pypa/cibuildwheel@63fd63b352a9a8bdcc24791c9dbee952ee9a8abc # v3.3.0
+ with:
+ only: ${{ matrix.only }}
+ env:
+ CIBW_BUILD_VERBOSITY: 1
+
+ - uses: actions/upload-artifact@v4
+ with:
+ name: dist-wheels-${{ matrix.only }}
+ path: wheelhouse/*.whl
+
+ merge-artifacts:
+ runs-on: ubuntu-latest
+ needs: [build-sdist, build-wheels]
+ steps:
+ - name: Merge Artifacts
+ uses: actions/upload-artifact/merge@v4
+ with:
+ name: dist
+ delete-merged: true
+
+ publish-docs:
+ needs: [build-sdist, build-wheels]
+ runs-on: ubuntu-latest
+
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: 1
+
+ steps:
+ - name: Checkout source
+ uses: actions/checkout@v5
+ with:
+ fetch-depth: 5
+ submodules: true
+ persist-credentials: false
+
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: "3.x"
+
+ - name: Build docs
+ run: |
+ pip install --group docs
+ pip install -e .
+ make htmldocs
+
+ - name: Checkout gh-pages
+ uses: actions/checkout@v5
+ with:
+ fetch-depth: 5
+ ref: gh-pages
+ path: docs/gh-pages
+ persist-credentials: false
+
+ - name: Sync docs
+ run: |
+ rsync -a docs/_build/html/ docs/gh-pages/current/
+
+ - name: Commit and push docs
+ uses: magicstack/gha-commit-and-push@master
+ with:
+ target_branch: gh-pages
+ workdir: docs/gh-pages
+ commit_message: Automatic documentation update
+ github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
+ ssh_key: ${{ secrets.RELEASE_BOT_SSH_KEY }}
+ gpg_key: ${{ secrets.RELEASE_BOT_GPG_KEY }}
+ gpg_key_id: "5C468778062D87BF!"
+
+ publish:
+ needs: [build-sdist, build-wheels, publish-docs]
+ runs-on: ubuntu-latest
+
+ environment:
+ name: pypi
+ url: https://pypi.org/p/asyncpg
+ permissions:
+ id-token: write
+ attestations: write
+ contents: write
+ deployments: write
+
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ fetch-depth: 5
+ submodules: false
+ persist-credentials: false
+
+ - uses: actions/download-artifact@v4
+ with:
+ name: dist
+ path: dist/
+
+ - name: Extract Release Version
+ id: relver
+ run: |
+ set -e
+ echo "version=$(cat dist/VERSION)" >> $GITHUB_OUTPUT
+ rm dist/VERSION
+
+ - name: Merge and tag the PR
+ uses: edgedb/action-release/merge@master
+ with:
+ github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
+ ssh_key: ${{ secrets.RELEASE_BOT_SSH_KEY }}
+ gpg_key: ${{ secrets.RELEASE_BOT_GPG_KEY }}
+ gpg_key_id: "5C468778062D87BF!"
+ tag_name: v${{ steps.relver.outputs.version }}
+
+ - name: Publish Github Release
+ uses: elprans/gh-action-create-release@master
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ tag_name: v${{ steps.relver.outputs.version }}
+ release_name: v${{ steps.relver.outputs.version }}
+ target: ${{ github.event.pull_request.base.ref }}
+ body: ${{ github.event.pull_request.body }}
+
+ - run: |
+ ls -al dist/
+
+ - name: Upload to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+ with:
+ attestations: true
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
new file mode 100644
index 00000000..77e63738
--- /dev/null
+++ b/.github/workflows/tests.yml
@@ -0,0 +1,158 @@
+name: Tests
+
+on:
+ push:
+ branches:
+ - master
+ - ci
+ pull_request:
+ branches:
+ - master
+
+jobs:
+ test-platforms:
+ # NOTE: this matrix is for testing various combinations of Python and OS
+ # versions on the system-installed PostgreSQL version (which is usually
+ # fairly recent). For a PostgreSQL version matrix see the test-postgres
+ # job.
+ strategy:
+ matrix:
+ python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t"]
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ loop: [asyncio, uvloop]
+ exclude:
+ # uvloop does not support windows
+ - loop: uvloop
+ os: windows-latest
+
+ runs-on: ${{ matrix.os }}
+
+ permissions: {}
+
+ defaults:
+ run:
+ shell: bash
+
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: 1
+
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ fetch-depth: 50
+ submodules: true
+ persist-credentials: false
+
+ - name: Check if release PR.
+ uses: edgedb/action-release/validate-pr@master
+ id: release
+ with:
+ github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
+ missing_version_ok: yes
+ version_file: asyncpg/_version.py
+ version_line_pattern: |
+ __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
+
+ - name: Setup PostgreSQL
+ if: "!steps.release.outputs.is_release && matrix.os == 'macos-latest'"
+ run: |
+ POSTGRES_FORMULA="postgresql@18"
+ brew install "$POSTGRES_FORMULA"
+ echo "$(brew --prefix "$POSTGRES_FORMULA")/bin" >> $GITHUB_PATH
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v6
+ if: "!steps.release.outputs.is_release"
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install Python Deps
+ if: "!steps.release.outputs.is_release"
+ run: |
+ [ "$RUNNER_OS" = "Linux" ] && .github/workflows/install-krb5.sh
+ python -m pip install -U pip setuptools wheel
+ python -m pip install --group test
+ python -m pip install -e .
+
+ - name: Test
+ if: "!steps.release.outputs.is_release"
+ env:
+ LOOP_IMPL: ${{ matrix.loop }}
+ run: |
+ if [ "${LOOP_IMPL}" = "uvloop" ]; then
+ env USE_UVLOOP=1 python -m unittest -v tests.suite
+ else
+ python -m unittest -v tests.suite
+ fi
+
+ test-postgres:
+ strategy:
+ matrix:
+ postgres-version: ["9.5", "9.6", "10", "11", "12", "13", "14", "15", "16", "17", "18"]
+
+ runs-on: ubuntu-latest
+
+ permissions: {}
+
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: 1
+
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ fetch-depth: 50
+ submodules: true
+ persist-credentials: false
+
+ - name: Check if release PR.
+ uses: edgedb/action-release/validate-pr@master
+ id: release
+ with:
+ github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
+ missing_version_ok: yes
+ version_file: asyncpg/_version.py
+ version_line_pattern: |
+ __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
+
+ - name: Set up PostgreSQL
+ if: "!steps.release.outputs.is_release"
+ env:
+ PGVERSION: ${{ matrix.postgres-version }}
+ DISTRO_NAME: focal
+ run: |
+ sudo env DISTRO_NAME="${DISTRO_NAME}" PGVERSION="${PGVERSION}" \
+ .github/workflows/install-postgres.sh
+ echo PGINSTALLATION="/usr/lib/postgresql/${PGVERSION}/bin" \
+ >> "${GITHUB_ENV}"
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v6
+ if: "!steps.release.outputs.is_release"
+ with:
+ python-version: "3.x"
+
+ - name: Install Python Deps
+ if: "!steps.release.outputs.is_release"
+ run: |
+ [ "$RUNNER_OS" = "Linux" ] && .github/workflows/install-krb5.sh
+ python -m pip install -U pip setuptools wheel
+ python -m pip install --group test
+ python -m pip install -e .
+
+ - name: Test
+ if: "!steps.release.outputs.is_release"
+ env:
+ PGVERSION: ${{ matrix.postgres-version }}
+ run: |
+ python -m unittest -v tests.suite
+
+ # This job exists solely to act as the test job aggregate to be
+ # targeted by branch policies.
+ regression-tests:
+ name: "Regression Tests"
+ needs: [test-platforms, test-postgres]
+ runs-on: ubuntu-latest
+ permissions: {}
+
+ steps:
+ - run: echo OK
diff --git a/.gitignore b/.gitignore
index 21286094..ec9c96ac 100644
--- a/.gitignore
+++ b/.gitignore
@@ -33,4 +33,8 @@ docs/_build
/.pytest_cache/
/.eggs
/.vscode
+/.zed
/.mypy_cache
+/.venv*
+/.tox
+/compile_commands.json
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 63ca04d5..00000000
--- a/.travis.yml
+++ /dev/null
@@ -1,175 +0,0 @@
-language: generic
-
-env:
- global:
- - PYMODULE=asyncpg
- - RELEASE_PYTHON_VERSIONS="3.5 3.6 3.7"
-
- - S3_UPLOAD_USERNAME=oss-ci-bot
- - S3_UPLOAD_BUCKET=magicstack-oss-releases
- # S3_UPLOAD_ACCESSKEY:
- - secure: "iU37gukuyeaYM69StkR/aUTNgolblBdw2is034evvrm/SG0bKyzVVSrcK/dts9jolkCxJi+01VfpzxIBu2PF11QnCN1exUILb+XfmR+dVxUnNY2M1qqjILHvQ92rFJ9f2TlbYa2AlwgKynZlY4+edVSACSWwD/+TbWGAQEp0WInalA8ohljir+EPueXaYyC8mmH55cNQIa5WdDA2Vpg5ahRDdhVyD2J+/fLg78syLV7FGlnpXtASo9XiQKmRpPyHIT23yQB444kVh9xcjvuiB3aUBP5bGC2H4unElGYhCvfQvb1GoWvDqyvfzZvTOjlHqnG4AvIPoSCgEu/9cu8Cm/9OxWtqtWy7dECM8ZUIlOi3oPcvwUYDpNYAdATbTr1T6FRCBEp2eOi3sKoeE+nUDgQaN4r+ple4BKYnjrsSllXhI5W8ZqDNoUSsoGu+z6GFn6Dszrj6jbq8JHV4mZT9RCfR1y6inXWYGmaNRlwzm8wPHTav2RbW2O6bbwkkATWwYpyRB2FRlwMX6BB06druZWNOzx09RS8pTHnqcKOXW2mENNMgrA03OJUEV30UG/ncLZELYTpBARSJwymxjmmTK7vEI/HfxHkPrKcLLPPn2GoWym7mF2Lkh+jp81FkCGYrLTquyKPaoeUsofYukWMbGwE99ePL5dLocVDqTzatAoU="
- # S3_UPLOAD_SECRET:
- - secure: "uCcM67fmT3ODhaHrUKhuid/DJZzy9aMtaVdRCkAkjaF/X0Dgj1Dq9bV/PPLn+HjVIRiloHKK4fk6eMLmnI/PWPnTD7hVYse2ECLEeTxZGtoBmTdBPzdf1+aHpl18o8xbN/XfO02SVk6pTHDSjZIwR8ZeiZ2Q75aY462BW9WBgV0JOL9xK5yJd3TjODsDJpwcK0P4TMwi1j2qKIpXMUJaZkyUPafZIykil2CbcePd2Y5AUfDN2BAqaJZqM9vVCeRRs7ECzCamBPsj2WUmXqs621IH3Ml/sSONCzeQoUlgUsG2a7b+Jic92sVsFHyLVqG56G5urIAsXm+Jc/8Ly/dTk1M3ER/OdvsB0z21mhQfaVHwROixPk6HPCbvTl3PITEauaU+wLwCIduiEbb6fcpoB11n3oRzgiLY5e4+QDA86LBNySDhBE8WIq1VKphgTp7ojgM/mHJg4VBZX3m+89JruUOLi49VPx1cK/CiWEBj3gWHZMNDL9agS5N/fwl6UnD5DAklTZtqlA5M2FZ8/aPN8/FgW4jTEgBBU87Ko2rTvVRmKZeCVEkIBS2lYsRDTG3ZmlyJuh2AGGReUzCh524pNAsonIF2ydCOzLv4DlTZSthOwbdnX0EMBRYuPEa436dgkVUUVP6ds5859IPZeXcN6JKJWPWQkzFWFwzoK9ttQLc="
-
- - TWINE_USERNAME: magicstack-ci
- # TWINE_PASSWORD:
- - secure: "jyc9xHK3VjGPxvBZKx8Mcf5nfVvfIyGn6b4atcrmwVdJsV1bBLdKoAjUX3RGjNGyAHpNYOEKOdNfeZs+Wziwg5NK7ucC5qybaBK3MOTEOInCzaO0QJpcxThaHBQkkDxVtn8Qu1Gk3S/hXcXWjT2UEYJvQ84diaXn/XYRxfzOYTZX8eUroAWOMnUCYxlPxGzXTAtmuQSiJkL7P7veZTsWsGCOHtCpdAx7dgGb113CD8QheeUoZlH9Ml6jd3fGFteYmuFp7cR6fa3VYVzxp5BFsdEJqSI4VqDvBOpUoLkbpRRKMjosHKtphfi0PAzbkJw6UdKcrqQ/Ca4nGmWk0PIf3LTsJrv44p4ZTPVI8b3lihXMm72QUE28e11yu9SIZRe0hMgmvWlivXEJw3C3YT1N5w+JM3Y5dIWp/YLoiRXVkIzNJQMN3YeWvKEFf/xO1AD2BO3jjU9oBZfKQpxCJ58gPsQrRt6qM3Y6zYuF8s4B+llpwM/ex2xnNwrTbNkp4ARyXyCujX+ixhjiBLtElfGoHPP1jOaIkJhGje9DxaptddfFBDLAdq0/3Q+LHOmwdQcH5+libUy3HnyP7jf51kjjWE3XEJGSchHI2ewEAn9UZRH8h0UNRXutBzUVvKgC6K1lUvqzEreKVxvrYe6zgbZc/DiUvLgIzJBiJgP9rdZYpDQ="
-
- - DOCS_PUSH_KEY_LABEL=0760b951e99c
-
-addons:
- apt:
- sources:
- - sourceline: 'deb https://apt.postgresql.org/pub/repos/apt/ trusty-pgdg main'
- key_url: 'https://www.postgresql.org/media/keys/ACCC4CF8.asc'
-
-branches:
- # Avoid building PR branches.
- only:
- - master
- - ci
- - releases
- - /^v\d+(\.\d+)*$/
-
-matrix:
- fast_finish:
- true
-
- include:
- # Do quick test runs for each supported version of PostgreSQL
- # minus the latest.
- - os: linux
- dist: xenial
- language: python
- python: "3.6"
- env: BUILD=quicktests PGVERSION=9.2
- addons:
- apt: {packages: [postgresql-9.2, postgresql-contrib-9.2]}
-
- - os: linux
- dist: xenial
- language: python
- python: "3.6"
- env: BUILD=quicktests PGVERSION=9.3
- addons:
- apt: {packages: [postgresql-9.3, postgresql-contrib-9.3]}
-
- - os: linux
- dist: xenial
- language: python
- python: "3.6"
- env: BUILD=quicktests PGVERSION=9.4
- addons:
- apt: {packages: [postgresql-9.4, postgresql-contrib-9.4]}
-
- - os: linux
- dist: xenial
- language: python
- python: "3.6"
- env: BUILD=quicktests PGVERSION=9.5
- addons:
- apt: {packages: [postgresql-9.5, postgresql-contrib-9.5]}
-
- - os: linux
- dist: xenial
- language: python
- python: "3.6"
- env: BUILD=quicktests PGVERSION=9.6
- addons:
- apt: {packages: [postgresql-9.6, postgresql-contrib-9.6]}
-
- - os: linux
- dist: xenial
- language: python
- python: "3.6"
- env: BUILD=quicktests PGVERSION=10
- addons:
- apt: {packages: [postgresql-10]}
-
- - os: linux
- dist: xenial
- language: python
- python: "3.6"
- env: BUILD=quicktests PGVERSION=11
- addons:
- apt: {packages: [postgresql-11]}
-
- # Do a full test run on the latest supported version of PostgreSQL
- # on each supported version of Python.
- - os: linux
- dist: xenial
- sudo: required
- language: python
- python: "3.5"
- env: BUILD=tests PGVERSION=12
- addons:
- apt: {packages: [postgresql-12]}
-
- - os: linux
- dist: xenial
- sudo: required
- language: python
- python: "3.6"
- env: BUILD=tests PGVERSION=12
- addons:
- apt: {packages: [postgresql-12]}
-
- - os: linux
- dist: xenial
- sudo: true
- language: python
- python: "3.7"
- env: BUILD=tests PGVERSION=12
- addons:
- apt: {packages: [postgresql-12]}
-
- # Build manylinux wheels. Each wheel will be tested,
- # so there is no need for BUILD=tests here.
- # Also use this job to publish the releases and build
- # the documentation.
- - os: linux
- dist: xenial
- sudo: required
- language: python
- python: "3.6"
- env: BUILD=wheels,docs,release PGVERSION=12
- services: [docker]
- addons:
- apt: {packages: [postgresql-12]}
-
- - os: osx
- env: BUILD=tests,wheels PYTHON_VERSION=3.5.7 PGVERSION=10
-
- - os: osx
- env: BUILD=tests,wheels PYTHON_VERSION=3.6.9 PGVERSION=10
-
- - os: osx
- env: BUILD=tests,wheels PYTHON_VERSION=3.7.4 PGVERSION=10
-
-cache:
- pip
-
-before_install:
- - .ci/travis-before-install.sh
-
-install:
- - .ci/travis-install.sh
-
-script:
- - .ci/travis-tests.sh
- - .ci/travis-build-docs.sh
- - .ci/travis-build-wheels.sh
-
-deploy:
- - provider: script
- script: .ci/travis-release.sh
- on:
- tags: true
- condition: '"${BUILD}" == *release*'
-
- - provider: script
- script: .ci/travis-publish-docs.sh
- on:
- branch: master
- condition: '"${BUILD}" == *docs*'
diff --git a/MANIFEST.in b/MANIFEST.in
index 08be0d4b..3eac0565 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,5 @@
-recursive-include docs *.py *.rst
+recursive-include docs *.py *.rst Makefile *.css
recursive-include examples *.py
recursive-include tests *.py *.pem
-recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.c *.h
+recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h
include LICENSE README.rst Makefile performance.png .flake8
diff --git a/Makefile b/Makefile
index 698f011a..67417a3f 100644
--- a/Makefile
+++ b/Makefile
@@ -12,6 +12,7 @@ clean:
rm -fr dist/ doc/_build/
rm -fr asyncpg/pgproto/*.c asyncpg/pgproto/*.html
rm -fr asyncpg/pgproto/codecs/*.html
+ rm -fr asyncpg/pgproto/*.so
rm -fr asyncpg/protocol/*.c asyncpg/protocol/*.html
rm -fr asyncpg/protocol/*.so build *.egg-info
rm -fr asyncpg/protocol/codecs/*.html
@@ -19,28 +20,26 @@ clean:
compile:
- $(PYTHON) setup.py build_ext --inplace --cython-always
+ env ASYNCPG_BUILD_CYTHON_ALWAYS=1 $(PYTHON) -m pip install -e .
debug:
- ASYNCPG_DEBUG=1 $(PYTHON) setup.py build_ext --inplace
-
+ env ASYNCPG_DEBUG=1 $(PYTHON) -m pip install -e .
test:
- PYTHONASYNCIODEBUG=1 $(PYTHON) setup.py test
- $(PYTHON) setup.py test
- USE_UVLOOP=1 $(PYTHON) setup.py test
+ PYTHONASYNCIODEBUG=1 $(PYTHON) -m unittest -v tests.suite
+ $(PYTHON) -m unittest -v tests.suite
+ USE_UVLOOP=1 $(PYTHON) -m unittest -v tests.suite
testinstalled:
- cd /tmp && $(PYTHON) $(ROOT)/tests/__init__.py
- cd /tmp && USE_UVLOOP=1 $(PYTHON) $(ROOT)/tests/__init__.py
+ cd "$${HOME}" && $(PYTHON) $(ROOT)/tests/__init__.py
quicktest:
- $(PYTHON) setup.py test
+ $(PYTHON) -m unittest -v tests.suite
htmldocs:
- $(PYTHON) setup.py build_ext --inplace
+ $(PYTHON) -m pip install -e .[docs]
$(MAKE) -C docs html
diff --git a/README.rst b/README.rst
index fac5744a..1a37296d 100644
--- a/README.rst
+++ b/README.rst
@@ -1,14 +1,11 @@
asyncpg -- A fast PostgreSQL Database Client Library for Python/asyncio
=======================================================================
-.. image:: https://travis-ci.org/MagicStack/asyncpg.svg?branch=master
- :target: https://travis-ci.org/MagicStack/asyncpg
-
-.. image:: https://ci.appveyor.com/api/projects/status/9rwppnxphgc8bqoj/branch/master?svg=true
- :target: https://ci.appveyor.com/project/magicstack/asyncpg
-
+.. image:: https://github.com/MagicStack/asyncpg/workflows/Tests/badge.svg
+ :target: https://github.com/MagicStack/asyncpg/actions?query=workflow%3ATests+branch%3Amaster
+ :alt: GitHub Actions status
.. image:: https://img.shields.io/pypi/v/asyncpg.svg
- :target: https://pypi.python.org/pypi/asyncpg
+ :target: https://pypi.python.org/pypi/asyncpg
**asyncpg** is a database interface library designed specifically for
PostgreSQL and Python/asyncio. asyncpg is an efficient, clean implementation
@@ -16,8 +13,10 @@ of PostgreSQL server binary protocol for use with Python's ``asyncio``
framework. You can read more about asyncpg in an introductory
`blog post `_.
-asyncpg requires Python 3.5 or later and is supported for PostgreSQL
-versions 9.2 to 12.
+asyncpg requires Python 3.9 or later and is supported for PostgreSQL
+versions 9.5 to 18. Other PostgreSQL versions or other databases
+implementing the PostgreSQL protocol *may* work, but are not being
+actively tested.
Documentation
@@ -30,14 +29,14 @@ The project documentation can be found
Performance
-----------
-In our testing asyncpg is, on average, **3x** faster than psycopg2
-(and its asyncio variant -- aiopg).
+In our testing asyncpg is, on average, **5x** faster than psycopg3.
-.. image:: performance.png
- :target: http://magic.io/blog/asyncpg-1m-rows-from-postgres-to-python/
+.. image:: https://raw.githubusercontent.com/MagicStack/asyncpg/master/performance.png?fddca40ab0
+ :target: https://gistpreview.github.io/?0ed296e93523831ea0918d42dd1258c2
The above results are a geometric mean of benchmarks obtained with PostgreSQL
-`client driver benchmarking toolbench `_.
+`client driver benchmarking toolbench `_
+in June 2023 (click on the chart to see full details).
Features
@@ -60,11 +59,18 @@ This enables asyncpg to have easy-to-use support for:
Installation
------------
-asyncpg is available on PyPI and has no dependencies.
-Use pip to install::
+asyncpg is available on PyPI. When not using GSSAPI/SSPI authentication it
+has no dependencies. Use pip to install::
$ pip install asyncpg
+If you need GSSAPI/SSPI authentication, use::
+
+ $ pip install 'asyncpg[gssauth]'
+
+For more details, please `see the documentation
+`_.
+
Basic Usage
-----------
@@ -77,11 +83,13 @@ Basic Usage
async def run():
conn = await asyncpg.connect(user='user', password='password',
database='database', host='127.0.0.1')
- values = await conn.fetch('''SELECT * FROM mytable''')
+ values = await conn.fetch(
+ 'SELECT * FROM mytable WHERE id = $1',
+ 10,
+ )
await conn.close()
- loop = asyncio.get_event_loop()
- loop.run_until_complete(run())
+ asyncio.run(run())
License
diff --git a/asyncpg/__init__.py b/asyncpg/__init__.py
index 791e7959..e8811a9d 100644
--- a/asyncpg/__init__.py
+++ b/asyncpg/__init__.py
@@ -4,31 +4,21 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
from .connection import connect, Connection # NOQA
from .exceptions import * # NOQA
-from .pool import create_pool # NOQA
+from .pool import create_pool, Pool # NOQA
from .protocol import Record # NOQA
from .types import * # NOQA
-__all__ = ('connect', 'create_pool', 'Record', 'Connection') + \
- exceptions.__all__ # NOQA
+from ._version import __version__ # NOQA
+
+from . import exceptions
-# The rules of changing __version__:
-#
-# In a release revision, __version__ must be set to 'x.y.z',
-# and the release revision tagged with the 'vx.y.z' tag.
-# For example, asyncpg release 0.15.0 should have
-# __version__ set to '0.15.0', and tagged with 'v0.15.0'.
-#
-# In between releases, __version__ must be set to
-# 'x.y+1.0.dev0', so asyncpg revisions between 0.15.0 and
-# 0.16.0 should have __version__ set to '0.16.0.dev0' in
-# the source.
-#
-# Source and wheel distributions built from development
-# snapshots will automatically include the git revision
-# in __version__, for example: '0.16.0.dev0+ge06ad03'
-__version__ = '0.19.0'
+__all__: tuple[str, ...] = (
+ 'connect', 'create_pool', 'Pool', 'Record', 'Connection'
+)
+__all__ += exceptions.__all__ # NOQA
diff --git a/asyncpg/_asyncio_compat.py b/asyncpg/_asyncio_compat.py
new file mode 100644
index 00000000..a211d0a9
--- /dev/null
+++ b/asyncpg/_asyncio_compat.py
@@ -0,0 +1,94 @@
+# Backports from Python/Lib/asyncio for older Pythons
+#
+# Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved
+#
+# SPDX-License-Identifier: PSF-2.0
+
+from __future__ import annotations
+
+import asyncio
+import functools
+import sys
+import typing
+
+if typing.TYPE_CHECKING:
+ from . import compat
+
+if sys.version_info < (3, 11):
+ from async_timeout import timeout as timeout_ctx
+else:
+ from asyncio import timeout as timeout_ctx
+
+_T = typing.TypeVar('_T')
+
+
+async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
+ """Wait for the single Future or coroutine to complete, with timeout.
+
+ Coroutine will be wrapped in Task.
+
+ Returns result of the Future or coroutine. When a timeout occurs,
+ it cancels the task and raises TimeoutError. To avoid the task
+ cancellation, wrap it in shield().
+
+ If the wait is cancelled, the task is also cancelled.
+
+ If the task supresses the cancellation and returns a value instead,
+ that value is returned.
+
+ This function is a coroutine.
+ """
+ # The special case for timeout <= 0 is for the following case:
+ #
+ # async def test_waitfor():
+ # func_started = False
+ #
+ # async def func():
+ # nonlocal func_started
+ # func_started = True
+ #
+ # try:
+ # await asyncio.wait_for(func(), 0)
+ # except asyncio.TimeoutError:
+ # assert not func_started
+ # else:
+ # assert False
+ #
+ # asyncio.run(test_waitfor())
+
+ if timeout is not None and timeout <= 0:
+ fut = asyncio.ensure_future(fut)
+
+ if fut.done():
+ return fut.result()
+
+ await _cancel_and_wait(fut)
+ try:
+ return fut.result()
+ except asyncio.CancelledError as exc:
+ raise TimeoutError from exc
+
+ async with timeout_ctx(timeout):
+ return await fut
+
+
+async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
+ """Cancel the *fut* future or task and wait until it completes."""
+
+ loop = asyncio.get_running_loop()
+ waiter = loop.create_future()
+ cb = functools.partial(_release_waiter, waiter)
+ fut.add_done_callback(cb)
+
+ try:
+ fut.cancel()
+ # We cannot wait on *fut* directly to make
+ # sure _cancel_and_wait itself is reliably cancellable.
+ await waiter
+ finally:
+ fut.remove_done_callback(cb)
+
+
+def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None:
+ if not waiter.done():
+ waiter.set_result(None)
diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py
index 2aecb0fe..95775e11 100644
--- a/asyncpg/_testbase/__init__.py
+++ b/asyncpg/_testbase/__init__.py
@@ -19,6 +19,7 @@
import unittest
+import asyncpg
from asyncpg import cluster as pg_cluster
from asyncpg import connection as pg_connection
from asyncpg import pool as pg_pool
@@ -80,7 +81,7 @@ def wrapper(self, *args, __meth__=meth, **kwargs):
coro = __meth__(self, *args, **kwargs)
timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT)
if timeout:
- coro = asyncio.wait_for(coro, timeout, loop=self.loop)
+ coro = asyncio.wait_for(coro, timeout)
try:
self.loop.run_until_complete(coro)
except asyncio.TimeoutError:
@@ -116,10 +117,22 @@ def setUp(self):
self.__unhandled_exceptions = []
def tearDown(self):
- if self.__unhandled_exceptions:
+ excs = []
+ for exc in self.__unhandled_exceptions:
+ if isinstance(exc, ConnectionResetError):
+ texc = traceback.TracebackException.from_exception(
+ exc, lookup_lines=False)
+ if texc.stack[-1].name == "_call_connection_lost":
+ # On Windows calling socket.shutdown may raise
+ # ConnectionResetError, which happens in the
+ # finally block of _call_connection_lost.
+ continue
+ excs.append(exc)
+
+ if excs:
formatted = []
- for i, context in enumerate(self.__unhandled_exceptions):
+ for i, context in enumerate(excs):
formatted.append(self._format_loop_exception(context, i + 1))
self.fail(
@@ -213,13 +226,6 @@ def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
return cluster
-def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
- initdb_options=None):
- cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
- cluster.start(port='dynamic', server_settings=server_settings)
- return cluster
-
-
def _get_initdb_options(initdb_options=None):
if not initdb_options:
initdb_options = {}
@@ -243,8 +249,12 @@ def _init_default_cluster(initdb_options=None):
_default_cluster = pg_cluster.RunningCluster()
else:
_default_cluster = _init_cluster(
- pg_cluster.TempCluster, cluster_kwargs={},
- initdb_options=_get_initdb_options(initdb_options))
+ pg_cluster.TempCluster,
+ cluster_kwargs={
+ "data_dir_suffix": ".apgtest",
+ },
+ initdb_options=_get_initdb_options(initdb_options),
+ )
return _default_cluster
@@ -261,19 +271,28 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=60.0,
+ connect=None,
setup=None,
init=None,
loop=None,
pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection,
+ record_class=asyncpg.Record,
**connect_kwargs):
return pool_class(
dsn,
- min_size=min_size, max_size=max_size,
- max_queries=max_queries, loop=loop, setup=setup, init=init,
+ min_size=min_size,
+ max_size=max_size,
+ max_queries=max_queries,
+ loop=loop,
+ connect=connect,
+ setup=setup,
+ init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
- **connect_kwargs)
+ record_class=record_class,
+ **connect_kwargs,
+ )
class ClusterTestCase(TestCase):
@@ -327,8 +346,10 @@ def tearDownClass(cls):
@classmethod
def get_connection_spec(cls, kwargs={}):
conn_spec = cls.cluster.get_connection_spec()
+ if kwargs.get('dsn'):
+ conn_spec.pop('host')
conn_spec.update(kwargs)
- if not os.environ.get('PGHOST'):
+ if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
@@ -430,3 +451,93 @@ def tearDown(self):
self.con = None
finally:
super().tearDown()
+
+
+class HotStandbyTestCase(ClusterTestCase):
+
+ @classmethod
+ def setup_cluster(cls):
+ cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
+ cls.start_cluster(
+ cls.master_cluster,
+ server_settings={
+ 'max_wal_senders': 10,
+ 'wal_level': 'hot_standby'
+ }
+ )
+
+ con = None
+
+ try:
+ con = cls.loop.run_until_complete(
+ cls.master_cluster.connect(
+ database='postgres', user='postgres', loop=cls.loop))
+
+ cls.loop.run_until_complete(
+ con.execute('''
+ CREATE ROLE replication WITH LOGIN REPLICATION
+ '''))
+
+ cls.master_cluster.trust_local_replication_by('replication')
+
+ conn_spec = cls.master_cluster.get_connection_spec()
+
+ cls.standby_cluster = cls.new_cluster(
+ pg_cluster.HotStandbyCluster,
+ cluster_kwargs={
+ 'master': conn_spec,
+ 'replication_user': 'replication'
+ }
+ )
+ cls.start_cluster(
+ cls.standby_cluster,
+ server_settings={
+ 'hot_standby': True
+ }
+ )
+
+ finally:
+ if con is not None:
+ cls.loop.run_until_complete(con.close())
+
+ @classmethod
+ def get_cluster_connection_spec(cls, cluster, kwargs={}):
+ conn_spec = cluster.get_connection_spec()
+ if kwargs.get('dsn'):
+ conn_spec.pop('host')
+ conn_spec.update(kwargs)
+ if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
+ if 'database' not in conn_spec:
+ conn_spec['database'] = 'postgres'
+ if 'user' not in conn_spec:
+ conn_spec['user'] = 'postgres'
+ return conn_spec
+
+ @classmethod
+ def get_connection_spec(cls, kwargs={}):
+ primary_spec = cls.get_cluster_connection_spec(
+ cls.master_cluster, kwargs
+ )
+ standby_spec = cls.get_cluster_connection_spec(
+ cls.standby_cluster, kwargs
+ )
+ return {
+ 'host': [primary_spec['host'], standby_spec['host']],
+ 'port': [primary_spec['port'], standby_spec['port']],
+ 'database': primary_spec['database'],
+ 'user': primary_spec['user'],
+ **kwargs
+ }
+
+ @classmethod
+ def connect_primary(cls, **kwargs):
+ conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
+ return pg_connection.connect(**conn_spec, loop=cls.loop)
+
+ @classmethod
+ def connect_standby(cls, **kwargs):
+ conn_spec = cls.get_cluster_connection_spec(
+ cls.standby_cluster,
+ kwargs
+ )
+ return pg_connection.connect(**conn_spec, loop=cls.loop)
diff --git a/asyncpg/_testbase/fuzzer.py b/asyncpg/_testbase/fuzzer.py
index 649e5770..88745646 100644
--- a/asyncpg/_testbase/fuzzer.py
+++ b/asyncpg/_testbase/fuzzer.py
@@ -36,15 +36,13 @@ def __init__(self, *, listening_addr: str='127.0.0.1',
self.listen_task = None
async def _wait(self, work):
- work_task = asyncio.ensure_future(work, loop=self.loop)
- stop_event_task = asyncio.ensure_future(self.stop_event.wait(),
- loop=self.loop)
+ work_task = asyncio.ensure_future(work)
+ stop_event_task = asyncio.ensure_future(self.stop_event.wait())
try:
await asyncio.wait(
[work_task, stop_event_task],
- return_when=asyncio.FIRST_COMPLETED,
- loop=self.loop)
+ return_when=asyncio.FIRST_COMPLETED)
if self.stop_event.is_set():
raise StopServer()
@@ -58,7 +56,8 @@ async def _wait(self, work):
def start(self):
started = threading.Event()
- self.thread = threading.Thread(target=self._start, args=(started,))
+ self.thread = threading.Thread(
+ target=self._start_thread, args=(started,))
self.thread.start()
if not started.wait(timeout=2):
raise RuntimeError('fuzzer proxy failed to start')
@@ -70,13 +69,14 @@ def stop(self):
def _stop(self):
self.stop_event.set()
- def _start(self, started_event):
+ def _start_thread(self, started_event):
self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self.loop)
- self.connectivity = asyncio.Event(loop=self.loop)
+ self.connectivity = asyncio.Event()
self.connectivity.set()
- self.connectivity_loss = asyncio.Event(loop=self.loop)
- self.stop_event = asyncio.Event(loop=self.loop)
+ self.connectivity_loss = asyncio.Event()
+ self.stop_event = asyncio.Event()
if self.listening_port is None:
self.listening_port = cluster.find_available_port()
@@ -92,7 +92,7 @@ def _start(self, started_event):
self.loop.close()
async def _main(self, started_event):
- self.listen_task = asyncio.ensure_future(self.listen(), loop=self.loop)
+ self.listen_task = asyncio.ensure_future(self.listen())
# Notify the main thread that we are ready to go.
started_event.set()
try:
@@ -100,7 +100,7 @@ async def _main(self, started_event):
finally:
for c in list(self.connections):
c.close()
- await asyncio.sleep(0.01, loop=self.loop)
+ await asyncio.sleep(0.01)
if hasattr(self.loop, 'remove_reader'):
self.loop.remove_reader(self.sock.fileno())
self.sock.close()
@@ -145,6 +145,10 @@ def _close_connection(self, connection):
if conn_task is not None:
conn_task.cancel()
+ def close_all_connections(self):
+ for conn in list(self.connections):
+ self.loop.call_soon_threadsafe(self._close_connection, conn)
+
class Connection:
def __init__(self, client_sock, backend_sock, proxy):
@@ -176,17 +180,23 @@ def close(self):
async def handle(self):
self.proxy_to_backend_task = asyncio.ensure_future(
- self.proxy_to_backend(), loop=self.loop)
+ self.proxy_to_backend())
self.proxy_from_backend_task = asyncio.ensure_future(
- self.proxy_from_backend(), loop=self.loop)
+ self.proxy_from_backend())
try:
await asyncio.wait(
[self.proxy_to_backend_task, self.proxy_from_backend_task],
- loop=self.loop, return_when=asyncio.FIRST_COMPLETED)
+ return_when=asyncio.FIRST_COMPLETED)
finally:
+ if self.proxy_to_backend_task is not None:
+ self.proxy_to_backend_task.cancel()
+
+ if self.proxy_from_backend_task is not None:
+ self.proxy_from_backend_task.cancel()
+
# Asyncio fails to properly remove the readers and writers
# when the task doing recv() or send() is cancelled, so
# we must remove the readers and writers manually before
@@ -201,49 +211,47 @@ async def handle(self):
async def _read(self, sock, n):
read_task = asyncio.ensure_future(
- self.loop.sock_recv(sock, n),
- loop=self.loop)
+ self.loop.sock_recv(sock, n))
conn_event_task = asyncio.ensure_future(
- self.connectivity_loss.wait(),
- loop=self.loop)
+ self.connectivity_loss.wait())
try:
await asyncio.wait(
[read_task, conn_event_task],
- return_when=asyncio.FIRST_COMPLETED,
- loop=self.loop)
+ return_when=asyncio.FIRST_COMPLETED)
if self.connectivity_loss.is_set():
return None
else:
return read_task.result()
finally:
- if not read_task.done():
- read_task.cancel()
- if not conn_event_task.done():
- conn_event_task.cancel()
+ if not self.loop.is_closed():
+ if not read_task.done():
+ read_task.cancel()
+ if not conn_event_task.done():
+ conn_event_task.cancel()
async def _write(self, sock, data):
write_task = asyncio.ensure_future(
- self.loop.sock_sendall(sock, data), loop=self.loop)
+ self.loop.sock_sendall(sock, data))
conn_event_task = asyncio.ensure_future(
- self.connectivity_loss.wait(), loop=self.loop)
+ self.connectivity_loss.wait())
try:
await asyncio.wait(
[write_task, conn_event_task],
- return_when=asyncio.FIRST_COMPLETED,
- loop=self.loop)
+ return_when=asyncio.FIRST_COMPLETED)
if self.connectivity_loss.is_set():
return None
else:
return write_task.result()
finally:
- if not write_task.done():
- write_task.cancel()
- if not conn_event_task.done():
- conn_event_task.cancel()
+ if not self.loop.is_closed():
+ if not write_task.done():
+ write_task.cancel()
+ if not conn_event_task.done():
+ conn_event_task.cancel()
async def proxy_to_backend(self):
buf = None
@@ -268,7 +276,8 @@ async def proxy_to_backend(self):
pass
finally:
- self.loop.call_soon(self.close)
+ if not self.loop.is_closed():
+ self.loop.call_soon(self.close)
async def proxy_from_backend(self):
buf = None
@@ -293,4 +302,5 @@ async def proxy_from_backend(self):
pass
finally:
- self.loop.call_soon(self.close)
+ if not self.loop.is_closed():
+ self.loop.call_soon(self.close)
diff --git a/asyncpg/_version.py b/asyncpg/_version.py
new file mode 100644
index 00000000..738da168
--- /dev/null
+++ b/asyncpg/_version.py
@@ -0,0 +1,17 @@
+# This file MUST NOT contain anything but the __version__ assignment.
+#
+# When making a release, change the value of __version__
+# to an appropriate value, and open a pull request against
+# the correct branch (master if making a new feature release).
+# The commit message MUST contain a properly formatted release
+# log, and the commit must be signed.
+#
+# The release automation will: build and test the packages for the
+# supported platforms, publish the packages on PyPI, merge the PR
+# to the target branch, create a Git tag pointing to the commit.
+
+from __future__ import annotations
+
+import typing
+
+__version__: typing.Final = '0.32.0.dev0'
diff --git a/asyncpg/cluster.py b/asyncpg/cluster.py
index 47699351..606c2eae 100644
--- a/asyncpg/cluster.py
+++ b/asyncpg/cluster.py
@@ -6,7 +6,6 @@
import asyncio
-import errno
import os
import os.path
import platform
@@ -14,6 +13,7 @@
import re
import shutil
import socket
+import string
import subprocess
import sys
import tempfile
@@ -36,29 +36,38 @@ def platform_exe(name):
return name
-def find_available_port(port_range=(49152, 65535), max_tries=1000):
- low, high = port_range
-
- port = low
- try_no = 0
-
- while try_no < max_tries:
- try_no += 1
- port = random.randint(low, high)
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- try:
- sock.bind(('127.0.0.1', port))
- except socket.error as e:
- if e.errno == errno.EADDRINUSE:
- continue
- finally:
- sock.close()
-
- break
+def find_available_port():
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ try:
+ sock.bind(('127.0.0.1', 0))
+ return sock.getsockname()[1]
+ except Exception:
+ return None
+ finally:
+ sock.close()
+
+
+def _world_readable_mkdtemp(suffix=None, prefix=None, dir=None):
+ name = "".join(random.choices(string.ascii_lowercase, k=8))
+ if dir is None:
+ dir = tempfile.gettempdir()
+ if prefix is None:
+ prefix = tempfile.gettempprefix()
+ if suffix is None:
+ suffix = ""
+ fn = os.path.join(dir, prefix + name + suffix)
+ os.mkdir(fn, 0o755)
+ return fn
+
+
+def _mkdtemp(suffix=None, prefix=None, dir=None):
+ if _system == 'Windows' and os.environ.get("GITHUB_ACTIONS"):
+ # Due to mitigations introduced in python/cpython#118486
+ # when Python runs in a session created via an SSH connection
+ # tempfile.mkdtemp creates directories that are not accessible.
+ return _world_readable_mkdtemp(suffix, prefix, dir)
else:
- port = None
-
- return port
+ return tempfile.mkdtemp(suffix, prefix, dir)
class ClusterError(Exception):
@@ -69,7 +78,10 @@ class Cluster:
def __init__(self, data_dir, *, pg_config_path=None):
self._data_dir = data_dir
self._pg_config_path = pg_config_path
- self._pg_bin_dir = os.environ.get('PGINSTALLATION')
+ self._pg_bin_dir = (
+ os.environ.get('PGINSTALLATION')
+ or os.environ.get('PGBIN')
+ )
self._pg_ctl = None
self._daemon_pid = None
self._daemon_process = None
@@ -124,6 +136,10 @@ def init(self, **settings):
'cluster in {!r} has already been initialized'.format(
self._data_dir))
+ settings = dict(settings)
+ if 'encoding' not in settings:
+ settings['encoding'] = 'UTF-8'
+
if settings:
settings_args = ['--{}={}'.format(k, v)
for k, v in settings.items()]
@@ -131,9 +147,13 @@ def init(self, **settings):
else:
extra_args = []
+ os.makedirs(self._data_dir, exist_ok=True)
process = subprocess.run(
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
- stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ cwd=self._data_dir,
+ )
output = process.stdout
@@ -164,8 +184,8 @@ def start(self, wait=60, *, server_settings={}, **opts):
sockdir = server_settings.get('unix_socket_directories')
if sockdir is None:
sockdir = server_settings.get('unix_socket_directory')
- if sockdir is None:
- sockdir = '/tmp'
+ if sockdir is None and _system != 'Windows':
+ sockdir = tempfile.gettempdir()
ssl_key = server_settings.get('ssl_key_file')
if ssl_key:
@@ -176,12 +196,13 @@ def start(self, wait=60, *, server_settings={}, **opts):
server_settings = server_settings.copy()
server_settings['ssl_key_file'] = keyfile
- if self._pg_version < (9, 3):
- sockdir_opt = 'unix_socket_directory'
- else:
- sockdir_opt = 'unix_socket_directories'
+ if sockdir is not None:
+ if self._pg_version < (9, 3):
+ sockdir_opt = 'unix_socket_directory'
+ else:
+ sockdir_opt = 'unix_socket_directories'
- server_settings[sockdir_opt] = sockdir
+ server_settings[sockdir_opt] = sockdir
for k, v in server_settings.items():
extra_args.extend(['-c', '{}={}'.format(k, v)])
@@ -193,13 +214,24 @@ def start(self, wait=60, *, server_settings={}, **opts):
# privileges.
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
+ print(
+ 'asyncpg.cluster: Running',
+ ' '.join([
+ self._pg_ctl, 'start', '-D', self._data_dir,
+ '-o', ' '.join(extra_args)
+ ]),
+ file=sys.stderr,
+ )
else:
stdout = subprocess.DEVNULL
process = subprocess.run(
[self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)],
- stdout=stdout, stderr=subprocess.STDOUT)
+ stdout=stdout,
+ stderr=subprocess.STDOUT,
+ cwd=self._data_dir,
+ )
if process.returncode != 0:
if process.stderr:
@@ -218,7 +250,10 @@ def start(self, wait=60, *, server_settings={}, **opts):
self._daemon_process = \
subprocess.Popen(
[self._postgres, '-D', self._data_dir, *extra_args],
- stdout=stdout, stderr=subprocess.STDOUT)
+ stdout=stdout,
+ stderr=subprocess.STDOUT,
+ cwd=self._data_dir,
+ )
self._daemon_pid = self._daemon_process.pid
@@ -232,7 +267,10 @@ def reload(self):
process = subprocess.run(
[self._pg_ctl, 'reload', '-D', self._data_dir],
- stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ cwd=self._data_dir,
+ )
stderr = process.stderr
@@ -245,7 +283,10 @@ def stop(self, wait=60):
process = subprocess.run(
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
'-m', 'fast'],
- stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ cwd=self._data_dir,
+ )
stderr = process.stderr
@@ -521,7 +562,10 @@ def _run_pg_config(self, pg_config_path):
def _find_pg_config(self, pg_config_path):
if pg_config_path is None:
- pg_install = os.environ.get('PGINSTALLATION')
+ pg_install = (
+ os.environ.get('PGINSTALLATION')
+ or os.environ.get('PGBIN')
+ )
if pg_install:
pg_config_path = platform_exe(
os.path.join(pg_install, 'pg_config'))
@@ -580,9 +624,9 @@ class TempCluster(Cluster):
def __init__(self, *,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
- self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
- prefix=data_dir_prefix,
- dir=data_dir_parent)
+ self._data_dir = _mkdtemp(suffix=data_dir_suffix,
+ prefix=data_dir_prefix,
+ dir=data_dir_parent)
super().__init__(self._data_dir, pg_config_path=pg_config_path)
@@ -623,7 +667,7 @@ def init(self, **settings):
'pg_basebackup init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
- if self._pg_version <= (11, 0):
+ if self._pg_version < (12, 0):
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
f.write(textwrap.dedent("""\
standby_mode = 'on'
@@ -642,7 +686,7 @@ def start(self, wait=60, *, server_settings={}, **opts):
if self._pg_version >= (12, 0):
server_settings = server_settings.copy()
server_settings['primary_conninfo'] = (
- 'host={host} port={port} user={user}'.format(
+ '"host={host} port={port} user={user}"'.format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user,
diff --git a/asyncpg/compat.py b/asyncpg/compat.py
index ff4f27b4..57eec650 100644
--- a/asyncpg/compat.py
+++ b/asyncpg/compat.py
@@ -4,60 +4,26 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
-import asyncio
-import functools
-import os
+import enum
import pathlib
import platform
+import typing
import sys
+if typing.TYPE_CHECKING:
+ import asyncio
-PY_36 = sys.version_info >= (3, 6)
-PY_37 = sys.version_info >= (3, 7)
-SYSTEM = platform.uname().system
+SYSTEM: typing.Final = platform.uname().system
-if sys.version_info < (3, 5, 2):
- def aiter_compat(func):
- @functools.wraps(func)
- async def wrapper(self):
- return func(self)
- return wrapper
-else:
- def aiter_compat(func):
- return func
-
-
-if PY_36:
- fspath = os.fspath
-else:
- def fspath(path):
- fsp = getattr(path, '__fspath__', None)
- if fsp is not None and callable(fsp):
- path = fsp()
- if not isinstance(path, (str, bytes)):
- raise TypeError(
- 'expected {}() to return str or bytes, not {}'.format(
- fsp.__qualname__, type(path).__name__
- ))
- return path
- elif isinstance(path, (str, bytes)):
- return path
- else:
- raise TypeError(
- 'expected str, bytes or path-like object, not {}'.format(
- type(path).__name__
- )
- )
-
-
-if SYSTEM == 'Windows':
+if sys.platform == 'win32':
import ctypes.wintypes
- CSIDL_APPDATA = 0x001a
+ CSIDL_APPDATA: typing.Final = 0x001a
- def get_pg_home_directory() -> pathlib.Path:
+ def get_pg_home_directory() -> pathlib.Path | None:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
@@ -69,13 +35,54 @@ def get_pg_home_directory() -> pathlib.Path:
return pathlib.Path(buf.value) / 'postgresql'
else:
- def get_pg_home_directory() -> pathlib.Path:
- return pathlib.Path.home()
+ def get_pg_home_directory() -> pathlib.Path | None:
+ try:
+ return pathlib.Path.home()
+ except (RuntimeError, KeyError):
+ return None
+
+
+async def wait_closed(stream: asyncio.StreamWriter) -> None:
+ # Not all asyncio versions have StreamWriter.wait_closed().
+ if hasattr(stream, 'wait_closed'):
+ try:
+ await stream.wait_closed()
+ except ConnectionResetError:
+ # On Windows wait_closed() sometimes propagates
+ # ConnectionResetError which is totally unnecessary.
+ pass
+
+
+if sys.version_info < (3, 12):
+ def markcoroutinefunction(c): # type: ignore
+ pass
+else:
+ from inspect import markcoroutinefunction # noqa: F401
+
+
+if sys.version_info < (3, 12):
+ from ._asyncio_compat import wait_for as wait_for # noqa: F401
+else:
+ from asyncio import wait_for as wait_for # noqa: F401
+
+if sys.version_info < (3, 11):
+ from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
+else:
+ from asyncio import timeout as timeout # noqa: F401
+
+if sys.version_info < (3, 9):
+ from typing import ( # noqa: F401
+ Awaitable as Awaitable,
+ )
+else:
+ from collections.abc import ( # noqa: F401
+ Awaitable as Awaitable,
+ )
-if PY_37:
- def current_asyncio_task(loop):
- return asyncio.current_task(loop)
+if sys.version_info < (3, 11):
+ class StrEnum(str, enum.Enum):
+ __str__ = str.__str__
+ __repr__ = enum.Enum.__repr__
else:
- def current_asyncio_task(loop):
- return asyncio.Task.current_task(loop)
+ from enum import StrEnum as StrEnum # noqa: F401
diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py
index 26fdec59..07c4fdde 100644
--- a/asyncpg/connect_utils.py
+++ b/asyncpg/connect_utils.py
@@ -4,29 +4,55 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
import asyncio
+import configparser
import collections
+from collections.abc import Callable
+import enum
import functools
import getpass
import os
import pathlib
import platform
+import random
import re
import socket
import ssl as ssl_module
import stat
import struct
-import time
+import sys
import typing
import urllib.parse
import warnings
+import inspect
from . import compat
from . import exceptions
from . import protocol
+class SSLMode(enum.IntEnum):
+ disable = 0
+ allow = 1
+ prefer = 2
+ require = 3
+ verify_ca = 4
+ verify_full = 5
+
+ @classmethod
+ def parse(cls, sslmode):
+ if isinstance(sslmode, cls):
+ return sslmode
+ return getattr(cls, sslmode.replace('-', '_'))
+
+
+class SSLNegotiation(compat.StrEnum):
+ postgres = "postgres"
+ direct = "direct"
+
+
_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
@@ -34,9 +60,12 @@
'password',
'database',
'ssl',
- 'ssl_is_advisory',
- 'connect_timeout',
+ 'sslmode',
+ 'ssl_negotiation',
'server_settings',
+ 'target_session_attrs',
+ 'krbsrvname',
+ 'gsslib',
])
@@ -59,6 +88,9 @@
PGPASSFILE = '.pgpass'
+PG_SERVICEFILE = '.pg_service.conf'
+
+
def _read_password_file(passfile: pathlib.Path) \
-> typing.List[typing.Tuple[str, ...]]:
@@ -140,13 +172,15 @@ def _read_password_from_pgpass(
def _validate_port_spec(hosts, port):
- if isinstance(port, list):
+ if isinstance(port, list) and len(port) > 1:
# If there is a list of ports, its length must
# match that of the host list.
if len(port) != len(hosts):
- raise exceptions.InterfaceError(
+ raise exceptions.ClientConfigurationError(
'could not match {} port numbers to {} hosts'.format(
len(port), len(hosts)))
+ elif isinstance(port, list) and len(port) == 1:
+ port = [port[0] for _ in range(len(hosts))]
else:
port = [port for _ in range(len(hosts))]
@@ -179,11 +213,25 @@ def _parse_hostlist(hostlist, port, *, unquote=False):
port = _validate_port_spec(hostspecs, port)
for i, hostspec in enumerate(hostspecs):
- if not hostspec.startswith('/'):
- addr, _, hostspec_port = hostspec.partition(':')
- else:
+ if hostspec[0] == '/':
+ # Unix socket
addr = hostspec
hostspec_port = ''
+ elif hostspec[0] == '[':
+ # IPv6 address
+ m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
+ if m:
+ addr = m.group(1)
+ hostspec_port = m.group(2)
+ else:
+ raise exceptions.ClientConfigurationError(
+ 'invalid IPv6 address in the connection URI: {!r}'.format(
+ hostspec
+ )
+ )
+ else:
+ # IPv4 address
+ addr, _, hostspec_port = hostspec.partition(':')
if unquote:
addr = urllib.parse.unquote(addr)
@@ -203,18 +251,71 @@ def _parse_hostlist(hostlist, port, *, unquote=False):
return hosts, port
+def _parse_tls_version(tls_version):
+ if tls_version.startswith('SSL'):
+ raise exceptions.ClientConfigurationError(
+ f"Unsupported TLS version: {tls_version}"
+ )
+ try:
+ return ssl_module.TLSVersion[tls_version.replace('.', '_')]
+ except KeyError:
+ raise exceptions.ClientConfigurationError(
+ f"No such TLS version: {tls_version}"
+ )
+
+
+def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
+ try:
+ homedir = pathlib.Path.home()
+ except (RuntimeError, KeyError):
+ return None
+
+ return (homedir / '.postgresql' / filename).resolve()
+
+
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
- connect_timeout, server_settings):
+ service, servicefile,
+ direct_tls, server_settings,
+ target_session_attrs, krbsrvname, gsslib):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
+ sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
+ ssl_min_protocol_version = ssl_max_protocol_version = None
+ sslnegotiation = None
if dsn:
parsed = urllib.parse.urlparse(dsn)
+ query = None
+ if parsed.query:
+ query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
+ for key, val in query.items():
+ if isinstance(val, list):
+ query[key] = val[-1]
+
+ if 'service' in query:
+ val = query.pop('service')
+ if not service and val:
+ service = val
+
+ connection_service_file = servicefile
+
+ if connection_service_file is None:
+ connection_service_file = os.getenv('PGSERVICEFILE')
+
+ if connection_service_file is None:
+ homedir = compat.get_pg_home_directory()
+ if homedir:
+ connection_service_file = homedir / PG_SERVICEFILE
+ else:
+ connection_service_file = None
+ else:
+ connection_service_file = pathlib.Path(connection_service_file)
+
if parsed.scheme not in {'postgresql', 'postgres'}:
- raise ValueError(
+ raise exceptions.ClientConfigurationError(
'invalid DSN: scheme is expected to be either '
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))
@@ -247,11 +348,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if password is None and dsn_password:
password = urllib.parse.unquote(dsn_password)
- if parsed.query:
- query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
- for key, val in query.items():
- if isinstance(val, list):
- query[key] = val[-1]
+ if query:
if 'port' in query:
val = query.pop('port')
@@ -293,12 +390,169 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None:
ssl = val
+ if 'sslcert' in query:
+ sslcert = query.pop('sslcert')
+
+ if 'sslkey' in query:
+ sslkey = query.pop('sslkey')
+
+ if 'sslrootcert' in query:
+ sslrootcert = query.pop('sslrootcert')
+
+ if 'sslnegotiation' in query:
+ sslnegotiation = query.pop('sslnegotiation')
+
+ if 'sslcrl' in query:
+ sslcrl = query.pop('sslcrl')
+
+ if 'sslpassword' in query:
+ sslpassword = query.pop('sslpassword')
+
+ if 'ssl_min_protocol_version' in query:
+ ssl_min_protocol_version = query.pop(
+ 'ssl_min_protocol_version'
+ )
+
+ if 'ssl_max_protocol_version' in query:
+ ssl_max_protocol_version = query.pop(
+ 'ssl_max_protocol_version'
+ )
+
+ if 'target_session_attrs' in query:
+ dsn_target_session_attrs = query.pop(
+ 'target_session_attrs'
+ )
+ if target_session_attrs is None:
+ target_session_attrs = dsn_target_session_attrs
+
+ if 'krbsrvname' in query:
+ val = query.pop('krbsrvname')
+ if krbsrvname is None:
+ krbsrvname = val
+
+ if 'gsslib' in query:
+ val = query.pop('gsslib')
+ if gsslib is None:
+ gsslib = val
+
+ if 'service' in query:
+ val = query.pop('service')
+ if service is None:
+ service = val
+
if query:
if server_settings is None:
server_settings = query
else:
server_settings = {**query, **server_settings}
+ if connection_service_file is not None and service is not None:
+ pg_service = configparser.ConfigParser()
+ pg_service.read(connection_service_file)
+ if service in pg_service.sections():
+ service_params = pg_service[service]
+ if 'port' in service_params:
+ val = service_params.pop('port')
+ if not port and val:
+ port = [int(p) for p in val.split(',')]
+
+ if 'host' in service_params:
+ val = service_params.pop('host')
+ if not host and val:
+ host, port = _parse_hostlist(val, port)
+
+ if 'dbname' in service_params:
+ val = service_params.pop('dbname')
+ if database is None:
+ database = val
+
+ if 'database' in service_params:
+ val = service_params.pop('database')
+ if database is None:
+ database = val
+
+ if 'user' in service_params:
+ val = service_params.pop('user')
+ if user is None:
+ user = val
+
+ if 'password' in service_params:
+ val = service_params.pop('password')
+ if password is None:
+ password = val
+
+ if 'passfile' in service_params:
+ val = service_params.pop('passfile')
+ if passfile is None:
+ passfile = val
+
+ if 'sslmode' in service_params:
+ val = service_params.pop('sslmode')
+ if ssl is None:
+ ssl = val
+
+ if 'sslcert' in service_params:
+ val = service_params.pop('sslcert')
+ if sslcert is None:
+ sslcert = val
+
+ if 'sslkey' in service_params:
+ val = service_params.pop('sslkey')
+ if sslkey is None:
+ sslkey = val
+
+ if 'sslrootcert' in service_params:
+ val = service_params.pop('sslrootcert')
+ if sslrootcert is None:
+ sslrootcert = val
+
+ if 'sslnegotiation' in service_params:
+ val = service_params.pop('sslnegotiation')
+ if sslnegotiation is None:
+ sslnegotiation = val
+
+ if 'sslcrl' in service_params:
+ val = service_params.pop('sslcrl')
+ if sslcrl is None:
+ sslcrl = val
+
+ if 'sslpassword' in service_params:
+ val = service_params.pop('sslpassword')
+ if sslpassword is None:
+ sslpassword = val
+
+ if 'ssl_min_protocol_version' in service_params:
+ val = service_params.pop(
+ 'ssl_min_protocol_version'
+ )
+ if ssl_min_protocol_version is None:
+ ssl_min_protocol_version = val
+
+ if 'ssl_max_protocol_version' in service_params:
+ val = service_params.pop(
+ 'ssl_max_protocol_version'
+ )
+ if ssl_max_protocol_version is None:
+ ssl_max_protocol_version = val
+
+ if 'target_session_attrs' in service_params:
+ dsn_target_session_attrs = service_params.pop(
+ 'target_session_attrs'
+ )
+ if target_session_attrs is None:
+ target_session_attrs = dsn_target_session_attrs
+
+ if 'krbsrvname' in service_params:
+ val = service_params.pop('krbsrvname')
+ if krbsrvname is None:
+ krbsrvname = val
+
+ if 'gsslib' in service_params:
+ val = service_params.pop('gsslib')
+ if gsslib is None:
+ gsslib = val
+ if not service:
+ service = os.environ.get('PGSERVICE')
if not host:
hostspec = os.environ.get('PGHOST')
if hostspec:
@@ -313,7 +567,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
host = ['/run/postgresql', '/var/run/postgresql',
'/tmp', '/private/tmp', 'localhost']
- if not isinstance(host, list):
+ if not isinstance(host, (list, tuple)):
host = [host]
if auth_hosts is None:
@@ -352,11 +606,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
database = user
if user is None:
- raise exceptions.InterfaceError(
+ raise exceptions.ClientConfigurationError(
'could not determine user name to connect with')
if database is None:
- raise exceptions.InterfaceError(
+ raise exceptions.ClientConfigurationError(
'could not determine database name to connect to')
if password is None:
@@ -379,6 +633,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
passfile=passfile)
addrs = []
+ have_tcp_addrs = False
for h, p in zip(host, port):
if h.startswith('/'):
# UNIX socket name
@@ -388,84 +643,230 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
else:
# TCP host/port
addrs.append((h, p))
+ have_tcp_addrs = True
if not addrs:
- raise ValueError(
+ raise exceptions.InternalClientError(
'could not determine the database address to connect to')
if ssl is None:
ssl = os.getenv('PGSSLMODE')
- # ssl_is_advisory is only allowed to come from the sslmode parameter.
- ssl_is_advisory = None
- if isinstance(ssl, str):
- SSLMODES = {
- 'disable': 0,
- 'allow': 1,
- 'prefer': 2,
- 'require': 3,
- 'verify-ca': 4,
- 'verify-full': 5,
- }
+ if ssl is None and have_tcp_addrs:
+ ssl = 'prefer'
+
+ if direct_tls is not None:
+ sslneg = (
+ SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
+ )
+ else:
+ if sslnegotiation is None:
+ sslnegotiation = os.environ.get("PGSSLNEGOTIATION")
+
+ if sslnegotiation is not None:
+ try:
+ sslneg = SSLNegotiation(sslnegotiation)
+ except ValueError:
+ modes = ', '.join(
+ m.name.replace('_', '-')
+ for m in SSLNegotiation
+ )
+ raise exceptions.ClientConfigurationError(
+ f'`sslnegotiation` parameter must be one of: {modes}'
+ ) from None
+ else:
+ sslneg = SSLNegotiation.postgres
+
+ if isinstance(ssl, (str, SSLMode)):
try:
- sslmode = SSLMODES[ssl]
- except KeyError:
- modes = ', '.join(SSLMODES.keys())
- raise exceptions.InterfaceError(
- '`sslmode` parameter must be one of: {}'.format(modes))
-
- # sslmode 'allow' is currently handled as 'prefer' because we're
- # missing the "retry with SSL" behavior for 'allow', but do have the
- # "retry without SSL" behavior for 'prefer'.
- # Not changing 'allow' to 'prefer' here would be effectively the same
- # as changing 'allow' to 'disable'.
- if sslmode == SSLMODES['allow']:
- sslmode = SSLMODES['prefer']
+ sslmode = SSLMode.parse(ssl)
+ except AttributeError:
+ modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
+ raise exceptions.ClientConfigurationError(
+ '`sslmode` parameter must be one of: {}'.format(modes)
+ ) from None
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
- # Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
- if sslmode <= SSLMODES['allow']:
+ if sslmode < SSLMode.allow:
ssl = False
- ssl_is_advisory = sslmode >= SSLMODES['allow']
else:
- ssl = ssl_module.create_default_context()
- ssl.check_hostname = sslmode >= SSLMODES['verify-full']
- ssl.verify_mode = ssl_module.CERT_REQUIRED
- if sslmode <= SSLMODES['require']:
+ ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT)
+ ssl.check_hostname = sslmode >= SSLMode.verify_full
+ if sslmode < SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
- ssl_is_advisory = sslmode <= SSLMODES['prefer']
-
- if ssl:
- for addr in addrs:
- if isinstance(addr, str):
- # UNIX socket
- raise exceptions.InterfaceError(
- '`ssl` parameter can only be enabled for TCP addresses, '
- 'got a UNIX socket path: {!r}'.format(addr))
+ else:
+ if sslrootcert is None:
+ sslrootcert = os.getenv('PGSSLROOTCERT')
+ if sslrootcert:
+ ssl.load_verify_locations(cafile=sslrootcert)
+ ssl.verify_mode = ssl_module.CERT_REQUIRED
+ else:
+ try:
+ sslrootcert = _dot_postgresql_path('root.crt')
+ if sslrootcert is not None:
+ ssl.load_verify_locations(cafile=sslrootcert)
+ else:
+ raise exceptions.ClientConfigurationError(
+ 'cannot determine location of user '
+ 'PostgreSQL configuration directory'
+ )
+ except (
+ exceptions.ClientConfigurationError,
+ FileNotFoundError,
+ NotADirectoryError,
+ ):
+ if sslmode > SSLMode.require:
+ if sslrootcert is None:
+ sslrootcert = '~/.postgresql/root.crt'
+ detail = (
+ 'Could not determine location of user '
+ 'home directory (HOME is either unset, '
+ 'inaccessible, or does not point to a '
+ 'valid directory)'
+ )
+ else:
+ detail = None
+ raise exceptions.ClientConfigurationError(
+ f'root certificate file "{sslrootcert}" does '
+ f'not exist or cannot be accessed',
+ hint='Provide the certificate file directly '
+ f'or make sure "{sslrootcert}" '
+ 'exists and is readable.',
+ detail=detail,
+ )
+ elif sslmode == SSLMode.require:
+ ssl.verify_mode = ssl_module.CERT_NONE
+ else:
+ assert False, 'unreachable'
+ else:
+ ssl.verify_mode = ssl_module.CERT_REQUIRED
+
+ if sslcrl is None:
+ sslcrl = os.getenv('PGSSLCRL')
+ if sslcrl:
+ ssl.load_verify_locations(cafile=sslcrl)
+ ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
+ else:
+ sslcrl = _dot_postgresql_path('root.crl')
+ if sslcrl is not None:
+ try:
+ ssl.load_verify_locations(cafile=sslcrl)
+ except (
+ FileNotFoundError,
+ NotADirectoryError,
+ ):
+ pass
+ else:
+ ssl.verify_flags |= \
+ ssl_module.VERIFY_CRL_CHECK_CHAIN
+
+ if sslkey is None:
+ sslkey = os.getenv('PGSSLKEY')
+ if not sslkey:
+ sslkey = _dot_postgresql_path('postgresql.key')
+ if sslkey is not None and not sslkey.exists():
+ sslkey = None
+ if not sslpassword:
+ sslpassword = ''
+ if sslcert is None:
+ sslcert = os.getenv('PGSSLCERT')
+ if sslcert:
+ ssl.load_cert_chain(
+ sslcert, keyfile=sslkey, password=lambda: sslpassword
+ )
+ else:
+ sslcert = _dot_postgresql_path('postgresql.crt')
+ if sslcert is not None:
+ try:
+ ssl.load_cert_chain(
+ sslcert,
+ keyfile=sslkey,
+ password=lambda: sslpassword
+ )
+ except (FileNotFoundError, NotADirectoryError):
+ pass
+
+ # OpenSSL 1.1.1 keylog file, copied from create_default_context()
+ if hasattr(ssl, 'keylog_filename'):
+ keylogfile = os.environ.get('SSLKEYLOGFILE')
+ if keylogfile and not sys.flags.ignore_environment:
+ ssl.keylog_filename = keylogfile
+
+ if ssl_min_protocol_version is None:
+ ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION')
+ if ssl_min_protocol_version:
+ ssl.minimum_version = _parse_tls_version(
+ ssl_min_protocol_version
+ )
+ else:
+ ssl.minimum_version = _parse_tls_version('TLSv1.2')
+
+ if ssl_max_protocol_version is None:
+ ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION')
+ if ssl_max_protocol_version:
+ ssl.maximum_version = _parse_tls_version(
+ ssl_max_protocol_version
+ )
+
+ elif ssl is True:
+ ssl = ssl_module.create_default_context()
+ sslmode = SSLMode.verify_full
+ else:
+ sslmode = SSLMode.disable
if server_settings is not None and (
not isinstance(server_settings, dict) or
not all(isinstance(k, str) for k in server_settings) or
not all(isinstance(v, str) for v in server_settings.values())):
- raise ValueError(
+ raise exceptions.ClientConfigurationError(
'server_settings is expected to be None or '
'a Dict[str, str]')
+ if target_session_attrs is None:
+ target_session_attrs = os.getenv(
+ "PGTARGETSESSIONATTRS", SessionAttribute.any
+ )
+ try:
+ target_session_attrs = SessionAttribute(target_session_attrs)
+ except ValueError:
+ raise exceptions.ClientConfigurationError(
+ "target_session_attrs is expected to be one of "
+ "{!r}"
+ ", got {!r}".format(
+ SessionAttribute.__members__.values, target_session_attrs
+ )
+ ) from None
+
+ if krbsrvname is None:
+ krbsrvname = os.getenv('PGKRBSRVNAME')
+
+ if gsslib is None:
+ gsslib = os.getenv('PGGSSLIB')
+ if gsslib is None:
+ gsslib = 'sspi' if _system == 'Windows' else 'gssapi'
+ if gsslib not in {'gssapi', 'sspi'}:
+ raise exceptions.ClientConfigurationError(
+ "gsslib parameter must be either 'gssapi' or 'sspi'"
+ ", got {!r}".format(gsslib))
+
params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
- ssl_is_advisory=ssl_is_advisory, connect_timeout=connect_timeout,
- server_settings=server_settings)
+ sslmode=sslmode, ssl_negotiation=sslneg,
+ server_settings=server_settings,
+ target_session_attrs=target_session_attrs,
+ krbsrvname=krbsrvname, gsslib=gsslib)
return addrs, params
def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
- database, timeout, command_timeout,
+ database, command_timeout,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
- ssl, server_settings):
-
+ ssl, direct_tls, server_settings,
+ target_session_attrs, krbsrvname, gsslib,
+ service, servicefile):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
@@ -492,8 +893,11 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
addrs, params = _parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
- database=database, connect_timeout=timeout,
- server_settings=server_settings)
+ direct_tls=direct_tls, database=database,
+ server_settings=server_settings,
+ target_session_attrs=target_session_attrs,
+ krbsrvname=krbsrvname, gsslib=gsslib,
+ service=service, servicefile=servicefile)
config = _ClientConfiguration(
command_timeout=command_timeout,
@@ -504,143 +908,386 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
return addrs, params, config
-async def _connect_addr(*, addr, loop, timeout, params, config,
- connection_class):
+class TLSUpgradeProto(asyncio.Protocol):
+ def __init__(
+ self,
+ loop: asyncio.AbstractEventLoop,
+ host: str,
+ port: int,
+ ssl_context: ssl_module.SSLContext,
+ ssl_is_advisory: bool,
+ ) -> None:
+ self.on_data = _create_future(loop)
+ self.host = host
+ self.port = port
+ self.ssl_context = ssl_context
+ self.ssl_is_advisory = ssl_is_advisory
+
+ def data_received(self, data: bytes) -> None:
+ if data == b'S':
+ self.on_data.set_result(True)
+ elif (self.ssl_is_advisory and
+ self.ssl_context.verify_mode == ssl_module.CERT_NONE and
+ data == b'N'):
+ # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
+ # since the only way to get ssl_is_advisory is from
+ # sslmode=prefer. But be extra sure to disallow insecure
+ # connections when the ssl context asks for real security.
+ self.on_data.set_result(False)
+ else:
+ self.on_data.set_exception(
+ ConnectionError(
+ 'PostgreSQL server at "{host}:{port}" '
+ 'rejected SSL upgrade'.format(
+ host=self.host, port=self.port)))
+
+ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
+ if not self.on_data.done():
+ if exc is None:
+ exc = ConnectionError('unexpected connection_lost() call')
+ self.on_data.set_exception(exc)
+
+
+_ProctolFactoryR = typing.TypeVar(
+ "_ProctolFactoryR", bound=asyncio.protocols.Protocol
+)
+
+
+async def _create_ssl_connection(
+ # TODO: The return type is a specific combination of subclasses of
+ # asyncio.protocols.Protocol that we can't express. For now, having the
+ # return type be dependent on signature of the factory is an improvement
+ protocol_factory: Callable[[], _ProctolFactoryR],
+ host: str,
+ port: int,
+ *,
+ loop: asyncio.AbstractEventLoop,
+ ssl_context: ssl_module.SSLContext,
+ ssl_is_advisory: bool = False,
+) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]:
+
+ tr, pr = await loop.create_connection(
+ lambda: TLSUpgradeProto(loop, host, port,
+ ssl_context, ssl_is_advisory),
+ host, port)
+
+ tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
+
+ try:
+ do_ssl_upgrade = await pr.on_data
+ except (Exception, asyncio.CancelledError):
+ tr.close()
+ raise
+
+ if hasattr(loop, 'start_tls'):
+ if do_ssl_upgrade:
+ try:
+ new_tr = await loop.start_tls(
+ tr, pr, ssl_context, server_hostname=host)
+ assert new_tr is not None
+ except (Exception, asyncio.CancelledError):
+ tr.close()
+ raise
+ else:
+ new_tr = tr
+
+ pg_proto = protocol_factory()
+ pg_proto.is_ssl = do_ssl_upgrade
+ pg_proto.connection_made(new_tr)
+ new_tr.set_protocol(pg_proto)
+
+ return new_tr, pg_proto
+ else:
+ conn_factory = functools.partial(
+ loop.create_connection, protocol_factory)
+
+ if do_ssl_upgrade:
+ conn_factory = functools.partial(
+ conn_factory, ssl=ssl_context, server_hostname=host)
+
+ sock = _get_socket(tr)
+ sock = sock.dup()
+ _set_nodelay(sock)
+ tr.close()
+
+ try:
+ new_tr, pg_proto = await conn_factory(sock=sock)
+ pg_proto.is_ssl = do_ssl_upgrade
+ return new_tr, pg_proto
+ except (Exception, asyncio.CancelledError):
+ sock.close()
+ raise
+
+
+async def _connect_addr(
+ *,
+ addr,
+ loop,
+ params,
+ config,
+ connection_class,
+ record_class
+):
assert loop is not None
- if timeout <= 0:
- raise asyncio.TimeoutError
+ params_input = params
+ if callable(params.password):
+ password = params.password()
+ if inspect.isawaitable(password):
+ password = await password
+
+ params = params._replace(password=password)
+ args = (addr, loop, config, connection_class, record_class, params_input)
+
+ # prepare the params (which attempt has ssl) for the 2 attempts
+ if params.sslmode == SSLMode.allow:
+ params_retry = params
+ params = params._replace(ssl=None)
+ elif params.sslmode == SSLMode.prefer:
+ params_retry = params._replace(ssl=None)
+ else:
+ # skip retry if we don't have to
+ return await __connect_addr(params, False, *args)
+
+ # first attempt
+ try:
+ return await __connect_addr(params, True, *args)
+ except _RetryConnectSignal:
+ pass
+
+ # second attempt
+ return await __connect_addr(params_retry, False, *args)
+
+
+class _RetryConnectSignal(Exception):
+ pass
+
+async def __connect_addr(
+ params,
+ retry,
+ addr,
+ loop,
+ config,
+ connection_class,
+ record_class,
+ params_input,
+):
connected = _create_future(loop)
+
proto_factory = lambda: protocol.Protocol(
- addr, connected, params, loop)
+ addr, connected, params, record_class, loop)
if isinstance(addr, str):
# UNIX socket
- assert not params.ssl
connector = loop.create_unix_connection(proto_factory, addr)
+
+ elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
+ # if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
+ # direct SSL connection
+ connector = loop.create_connection(
+ proto_factory, *addr, ssl=params.ssl
+ )
+
elif params.ssl:
connector = _create_ssl_connection(
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
- ssl_is_advisory=params.ssl_is_advisory)
+ ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
connector = loop.create_connection(proto_factory, *addr)
- before = time.monotonic()
- tr, pr = await asyncio.wait_for(
- connector, timeout=timeout, loop=loop)
- timeout -= time.monotonic() - before
+ tr, pr = await connector
try:
- if timeout <= 0:
- raise asyncio.TimeoutError
- await asyncio.wait_for(connected, loop=loop, timeout=timeout)
- except Exception:
+ await connected
+ except (
+ exceptions.InvalidAuthorizationSpecificationError,
+ exceptions.ConnectionDoesNotExistError, # seen on Windows
+ ):
+ tr.close()
+
+ # retry=True here is a redundant check because we don't want to
+ # accidentally raise the internal _RetryConnectSignal to the user
+ if retry and (
+ params.sslmode == SSLMode.allow and not pr.is_ssl or
+ params.sslmode == SSLMode.prefer and pr.is_ssl
+ ):
+ # Trigger retry when:
+ # 1. First attempt with sslmode=allow, ssl=None failed
+ # 2. First attempt with sslmode=prefer, ssl=ctx failed while the
+ # server claimed to support SSL (returning "S" for SSLRequest)
+ # (likely because pg_hba.conf rejected the connection)
+ raise _RetryConnectSignal()
+
+ else:
+ # but will NOT retry if:
+ # 1. First attempt with sslmode=prefer failed but the server
+ # doesn't support SSL (returning 'N' for SSLRequest), because
+ # we already tried to connect without SSL thru ssl_is_advisory
+ # 2. Second attempt with sslmode=prefer, ssl=None failed
+ # 3. Second attempt with sslmode=allow, ssl=ctx failed
+ # 4. Any other sslmode
+ raise
+
+ except (Exception, asyncio.CancelledError):
tr.close()
raise
- con = connection_class(pr, tr, loop, addr, config, params)
+ con = connection_class(pr, tr, loop, addr, config, params_input)
pr.set_connection(con)
return con
-async def _connect(*, loop, timeout, connection_class, **kwargs):
- if loop is None:
- loop = asyncio.get_event_loop()
+class SessionAttribute(str, enum.Enum):
+ any = 'any'
+ primary = 'primary'
+ standby = 'standby'
+ prefer_standby = 'prefer-standby'
+ read_write = "read-write"
+ read_only = "read-only"
- addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
- last_error = None
- addr = None
- for addr in addrs:
- before = time.monotonic()
- try:
- con = await _connect_addr(
- addr=addr, loop=loop, timeout=timeout,
- params=params, config=config,
- connection_class=connection_class)
- except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
- last_error = ex
+def _accept_in_hot_standby(should_be_in_hot_standby: bool):
+ """
+ If the server didn't report "in_hot_standby" at startup, we must determine
+ the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
+ If the server allows a connection and states it is in recovery it must
+ be a replica/standby server.
+ """
+ async def can_be_used(connection):
+ settings = connection.get_settings()
+ hot_standby_status = getattr(settings, 'in_hot_standby', None)
+ if hot_standby_status is not None:
+ is_in_hot_standby = hot_standby_status == 'on'
else:
- return con
- finally:
- timeout -= time.monotonic() - before
+ is_in_hot_standby = await connection.fetchval(
+ "SELECT pg_catalog.pg_is_in_recovery()"
+ )
+ return is_in_hot_standby == should_be_in_hot_standby
- raise last_error
+ return can_be_used
-async def _negotiate_ssl_connection(host, port, conn_factory, *, loop, ssl,
- server_hostname, ssl_is_advisory=False):
- # Note: ssl_is_advisory only affects behavior when the server does not
- # accept SSLRequests. If the SSLRequest is accepted but either the SSL
- # negotiation fails or the PostgreSQL user isn't permitted to use SSL,
- # there's nothing that would attempt to reconnect with a non-SSL socket.
- reader, writer = await asyncio.open_connection(host, port, loop=loop)
+def _accept_read_only(should_be_read_only: bool):
+ """
+ Verify the server has not set default_transaction_read_only=True
+ """
+ async def can_be_used(connection):
+ settings = connection.get_settings()
+ is_readonly = getattr(settings, 'default_transaction_read_only', 'off')
- tr = writer.transport
- try:
- sock = _get_socket(tr)
- _set_nodelay(sock)
+ if is_readonly == "on":
+ return should_be_read_only
- writer.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
- await writer.drain()
- resp = await reader.readexactly(1)
+ return await _accept_in_hot_standby(should_be_read_only)(connection)
+ return can_be_used
- if resp == b'S':
- conn_factory = functools.partial(
- conn_factory, ssl=ssl, server_hostname=server_hostname)
- elif (ssl_is_advisory and
- ssl.verify_mode == ssl_module.CERT_NONE and
- resp == b'N'):
- # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
- # since the only way to get ssl_is_advisory is from sslmode=prefer
- # (or sslmode=allow). But be extra sure to disallow insecure
- # connections when the ssl context asks for real security.
- pass
- else:
- raise ConnectionError(
- 'PostgreSQL server at "{}:{}" rejected SSL upgrade'.format(
- host, port))
- sock = sock.dup() # Must come before tr.close()
- finally:
- tr.close()
+async def _accept_any(_):
+ return True
+
+
+target_attrs_check = {
+ SessionAttribute.any: _accept_any,
+ SessionAttribute.primary: _accept_in_hot_standby(False),
+ SessionAttribute.standby: _accept_in_hot_standby(True),
+ SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
+ SessionAttribute.read_write: _accept_read_only(False),
+ SessionAttribute.read_only: _accept_read_only(True),
+}
+
+async def _can_use_connection(connection, attr: SessionAttribute):
+ can_use = target_attrs_check[attr]
+ return await can_use(connection)
+
+
+async def _connect(*, loop, connection_class, record_class, **kwargs):
+ if loop is None:
+ loop = asyncio.get_event_loop()
+
+ addrs, params, config = _parse_connect_arguments(**kwargs)
+ target_attr = params.target_session_attrs
+
+ candidates = []
+ chosen_connection = None
+ last_error = None
try:
- return await conn_factory(sock=sock) # Must come after tr.close()
- except Exception:
- sock.close()
- raise
+ for addr in addrs:
+ try:
+ conn = await _connect_addr(
+ addr=addr,
+ loop=loop,
+ params=params,
+ config=config,
+ connection_class=connection_class,
+ record_class=record_class,
+ )
+ candidates.append(conn)
+ if await _can_use_connection(conn, target_attr):
+ chosen_connection = conn
+ break
+ except OSError as ex:
+ last_error = ex
+ else:
+ if target_attr == SessionAttribute.prefer_standby and candidates:
+ chosen_connection = random.choice(candidates)
+ finally:
+
+ async def _close_candidates(conns, chosen):
+ await asyncio.gather(
+ *(c.close() for c in conns if c is not chosen),
+ return_exceptions=True
+ )
+ if candidates:
+ asyncio.create_task(
+ _close_candidates(candidates, chosen_connection))
+ if chosen_connection:
+ return chosen_connection
-async def _create_ssl_connection(protocol_factory, host, port, *,
- loop, ssl_context, ssl_is_advisory=False):
- return await _negotiate_ssl_connection(
- host, port,
- functools.partial(loop.create_connection, protocol_factory),
- loop=loop,
- ssl=ssl_context,
- server_hostname=host,
- ssl_is_advisory=ssl_is_advisory)
+ raise last_error or exceptions.TargetServerAttributeNotMatched(
+ 'None of the hosts match the target attribute requirement '
+ '{!r}'.format(target_attr)
+ )
-async def _open_connection(*, loop, addr, params: _ConnectionParameters):
+async def _cancel(*, loop, addr, params: _ConnectionParameters,
+ backend_pid, backend_secret):
+
+ class CancelProto(asyncio.Protocol):
+
+ def __init__(self):
+ self.on_disconnect = _create_future(loop)
+ self.is_ssl = False
+
+ def connection_lost(self, exc):
+ if not self.on_disconnect.done():
+ self.on_disconnect.set_result(True)
+
if isinstance(addr, str):
- r, w = await asyncio.open_unix_connection(addr, loop=loop)
+ tr, pr = await loop.create_unix_connection(CancelProto, addr)
else:
- if params.ssl:
- r, w = await _negotiate_ssl_connection(
+ if params.ssl and params.sslmode != SSLMode.allow:
+ tr, pr = await _create_ssl_connection(
+ CancelProto,
*addr,
- functools.partial(asyncio.open_connection, loop=loop),
loop=loop,
- ssl=params.ssl,
- server_hostname=addr[0],
- ssl_is_advisory=params.ssl_is_advisory)
+ ssl_context=params.ssl,
+ ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
- r, w = await asyncio.open_connection(*addr, loop=loop)
- _set_nodelay(_get_socket(w.transport))
+ tr, pr = await loop.create_connection(
+ CancelProto, *addr)
+ _set_nodelay(_get_socket(tr))
- return r, w
+ # Pack a CancelRequest message
+ msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)
+
+ try:
+ tr.write(msg)
+ await pr.on_disconnect
+ finally:
+ tr.close()
def _get_socket(transport):
diff --git a/asyncpg/connection.py b/asyncpg/connection.py
index 8e841871..71fb04f8 100644
--- a/asyncpg/connection.py
+++ b/asyncpg/connection.py
@@ -9,12 +9,17 @@
import asyncpg
import collections
import collections.abc
+import contextlib
+import functools
import itertools
-import struct
+import inspect
+import os
import sys
import time
import traceback
+import typing
import warnings
+import weakref
from . import compat
from . import connect_utils
@@ -44,14 +49,15 @@ class Connection(metaclass=ConnectionMeta):
__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_aborted',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
+ '_stmt_cache_enabled',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
- '_log_listeners', '_cancellations', '_source_traceback',
- '__weakref__')
+ '_log_listeners', '_termination_listeners', '_cancellations',
+ '_source_traceback', '_query_loggers', '__weakref__')
def __init__(self, protocol, transport, loop,
- addr: (str, int) or str,
+ addr,
config: connect_utils._ClientConfiguration,
params: connect_utils._ConnectionParameters):
self._protocol = protocol
@@ -71,14 +77,18 @@ def __init__(self, protocol, transport, loop,
self._stmt_cache = _StatementCache(
loop=loop,
max_size=config.statement_cache_size,
- on_remove=self._maybe_gc_stmt,
+ on_remove=functools.partial(
+ _weak_maybe_gc_stmt, weakref.ref(self)),
max_lifetime=config.max_cached_statement_lifetime)
self._stmts_to_close = set()
+ self._stmt_cache_enabled = config.statement_cache_size > 0
self._listeners = {}
self._log_listeners = set()
self._cancellations = set()
+ self._termination_listeners = set()
+ self._query_loggers = set()
settings = self._protocol.get_settings()
ver_string = settings.server_version
@@ -88,7 +98,10 @@ def __init__(self, protocol, transport, loop,
self._server_caps = _detect_server_capabilities(
self._server_version, settings)
- self._intro_query = introspection.INTRO_LOOKUP_TYPES
+ if self._server_version < (14, 0):
+ self._intro_query = introspection.INTRO_LOOKUP_TYPES_13
+ else:
+ self._intro_query = introspection.INTRO_LOOKUP_TYPES
self._reset_query = None
self._proxy = None
@@ -129,17 +142,21 @@ async def add_listener(self, channel, callback):
:param str channel: Channel to listen on.
:param callable callback:
- A callable receiving the following arguments:
+ A callable or a coroutine function receiving the following
+ arguments:
**connection**: a Connection the callback is registered with;
**pid**: PID of the Postgres server that sent the notification;
**channel**: name of the channel the notification was sent to;
**payload**: the payload.
+
+ .. versionchanged:: 0.24.0
+ The ``callback`` argument may be a coroutine function.
"""
self._check_open()
if channel not in self._listeners:
await self.fetch('LISTEN {}'.format(utils._quote_ident(channel)))
self._listeners[channel] = set()
- self._listeners[channel].add(callback)
+ self._listeners[channel].add(_Callback.from_callable(callback))
async def remove_listener(self, channel, callback):
"""Remove a listening callback on the specified channel."""
@@ -147,9 +164,10 @@ async def remove_listener(self, channel, callback):
return
if channel not in self._listeners:
return
- if callback not in self._listeners[channel]:
+ cb = _Callback.from_callable(callback)
+ if cb not in self._listeners[channel]:
return
- self._listeners[channel].remove(callback)
+ self._listeners[channel].remove(cb)
if not self._listeners[channel]:
del self._listeners[channel]
await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel)))
@@ -162,22 +180,74 @@ def add_log_listener(self, callback):
DEBUG, INFO, or LOG.
:param callable callback:
- A callable receiving the following arguments:
+ A callable or a coroutine function receiving the following
+ arguments:
**connection**: a Connection the callback is registered with;
**message**: the `exceptions.PostgresLogMessage` message.
.. versionadded:: 0.12.0
+
+ .. versionchanged:: 0.24.0
+ The ``callback`` argument may be a coroutine function.
"""
if self.is_closed():
raise exceptions.InterfaceError('connection is closed')
- self._log_listeners.add(callback)
+ self._log_listeners.add(_Callback.from_callable(callback))
def remove_log_listener(self, callback):
"""Remove a listening callback for log messages.
.. versionadded:: 0.12.0
"""
- self._log_listeners.discard(callback)
+ self._log_listeners.discard(_Callback.from_callable(callback))
+
+ def add_termination_listener(self, callback):
+ """Add a listener that will be called when the connection is closed.
+
+ :param callable callback:
+ A callable or a coroutine function receiving one argument:
+ **connection**: a Connection the callback is registered with.
+
+ .. versionadded:: 0.21.0
+
+ .. versionchanged:: 0.24.0
+ The ``callback`` argument may be a coroutine function.
+ """
+ self._termination_listeners.add(_Callback.from_callable(callback))
+
+ def remove_termination_listener(self, callback):
+ """Remove a listening callback for connection termination.
+
+ :param callable callback:
+ The callable or coroutine function that was passed to
+ :meth:`Connection.add_termination_listener`.
+
+ .. versionadded:: 0.21.0
+ """
+ self._termination_listeners.discard(_Callback.from_callable(callback))
+
+ def add_query_logger(self, callback):
+ """Add a logger that will be called when queries are executed.
+
+ :param callable callback:
+ A callable or a coroutine function receiving one argument:
+ **record**, a LoggedQuery containing `query`, `args`, `timeout`,
+ `elapsed`, `exception`, `conn_addr`, and `conn_params`.
+
+ .. versionadded:: 0.29.0
+ """
+ self._query_loggers.add(_Callback.from_callable(callback))
+
+ def remove_query_logger(self, callback):
+ """Remove a query logger callback.
+
+ :param callable callback:
+ The callable or coroutine function that was passed to
+ :meth:`Connection.add_query_logger`.
+
+ .. versionadded:: 0.29.0
+ """
+ self._query_loggers.discard(_Callback.from_callable(callback))
def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
@@ -206,7 +276,7 @@ def get_settings(self):
"""
return self._protocol.get_settings()
- def transaction(self, *, isolation='read_committed', readonly=False,
+ def transaction(self, *, isolation=None, readonly=False,
deferrable=False):
"""Create a :class:`~transaction.Transaction` object.
@@ -215,7 +285,9 @@ def transaction(self, *, isolation='read_committed', readonly=False,
:param isolation: Transaction isolation mode, can be one of:
`'serializable'`, `'repeatable_read'`,
- `'read_committed'`.
+ `'read_uncommitted'`, `'read_committed'`. If not
+ specified, the behavior is up to the server and
+ session, which is usually ``read_committed``.
:param readonly: Specifies whether or not this transaction is
read-only.
@@ -239,7 +311,12 @@ def is_in_transaction(self):
"""
return self._protocol.is_in_transaction()
- async def execute(self, query: str, *args, timeout: float=None) -> str:
+ async def execute(
+ self,
+ query: str,
+ *args,
+ timeout: typing.Optional[float]=None,
+ ) -> str:
"""Execute an SQL command (or commands).
This method can execute many SQL commands at once, when no arguments
@@ -270,12 +347,29 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
self._check_open()
if not args:
- return await self._protocol.query(query, timeout)
-
- _, status, _ = await self._execute(query, args, 0, timeout, True)
+ if self._query_loggers:
+ with self._time_and_log(query, args, timeout):
+ result = await self._protocol.query(query, timeout)
+ else:
+ result = await self._protocol.query(query, timeout)
+ return result
+
+ _, status, _ = await self._execute(
+ query,
+ args,
+ 0,
+ timeout,
+ return_status=True,
+ )
return status.decode()
- async def executemany(self, command: str, args, *, timeout: float=None):
+ async def executemany(
+ self,
+ command: str,
+ args,
+ *,
+ timeout: typing.Optional[float]=None,
+ ):
"""Execute an SQL *command* for each sequence of arguments in *args*.
Example:
@@ -291,42 +385,68 @@ async def executemany(self, command: str, args, *, timeout: float=None):
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.
- .. note::
-
- When inserting a large number of rows,
- use :meth:`Connection.copy_records_to_table()` instead,
- it is much more efficient for this purpose.
-
.. versionadded:: 0.7.0
.. versionchanged:: 0.11.0
`timeout` became a keyword-only parameter.
+
+ .. versionchanged:: 0.22.0
+ ``executemany()`` is now an atomic operation, which means that
+ either all executions succeed, or none at all. This is in contrast
+ to prior versions, where the effect of already-processed iterations
+ would remain in place when an error has occurred, unless
+ ``executemany()`` was called in a transaction.
"""
self._check_open()
return await self._executemany(command, args, timeout)
- async def _get_statement(self, query, timeout, *, named: bool=False,
- use_cache: bool=True):
+ async def _get_statement(
+ self,
+ query,
+ timeout,
+ *,
+ named: typing.Union[str, bool, None] = False,
+ use_cache=True,
+ ignore_custom_codec=False,
+ record_class=None
+ ):
+ if record_class is None:
+ record_class = self._protocol.get_record_class()
+ else:
+ _check_record_class(record_class)
+
if use_cache:
- statement = self._stmt_cache.get(query)
+ statement = self._stmt_cache.get(
+ (query, record_class, ignore_custom_codec)
+ )
if statement is not None:
return statement
# Only use the cache when:
# * `statement_cache_size` is greater than 0;
# * query size is less than `max_cacheable_statement_size`.
- use_cache = self._stmt_cache.get_max_size() > 0
- if (use_cache and
- self._config.max_cacheable_statement_size and
- len(query) > self._config.max_cacheable_statement_size):
- use_cache = False
+ use_cache = (
+ self._stmt_cache_enabled
+ and (
+ not self._config.max_cacheable_statement_size
+ or len(query) <= self._config.max_cacheable_statement_size
+ )
+ )
- if use_cache or named:
+ if isinstance(named, str):
+ stmt_name = named
+ elif use_cache or named:
stmt_name = self._get_unique_id('stmt')
else:
stmt_name = ''
- statement = await self._protocol.prepare(stmt_name, query, timeout)
+ statement = await self._protocol.prepare(
+ stmt_name,
+ query,
+ timeout,
+ record_class=record_class,
+ ignore_custom_codec=ignore_custom_codec,
+ )
need_reprepare = False
types_with_missing_codecs = statement._init_types()
tries = 0
@@ -360,12 +480,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False,
# for the statement.
statement._init_codecs()
- if need_reprepare:
- await self._protocol.prepare(
- stmt_name, query, timeout, state=statement)
+ if (
+ need_reprepare
+ or (not statement.name and not self._stmt_cache_enabled)
+ ):
+ # Mark this anonymous prepared statement as "unprepared",
+ # causing it to get re-Parsed in next bind_execute.
+ # We always do this when stmt_cache_size is set to 0 assuming
+ # people are running PgBouncer which is mishandling implicit
+ # transactions.
+ statement.mark_unprepared()
if use_cache:
- self._stmt_cache.put(query, statement)
+ self._stmt_cache.put(
+ (query, record_class, ignore_custom_codec), statement)
# If we've just created a new statement object, check if there
# are any statements for GC.
@@ -375,50 +503,201 @@ async def _get_statement(self, query, timeout, *, named: bool=False,
return statement
async def _introspect_types(self, typeoids, timeout):
- return await self.__execute(
- self._intro_query, (list(typeoids),), 0, timeout)
+ if self._server_caps.jit:
+ try:
+ cfgrow, _ = await self.__execute(
+ """
+ SELECT
+ current_setting('jit') AS cur,
+ set_config('jit', 'off', false) AS new
+ """,
+ (),
+ 0,
+ timeout,
+ ignore_custom_codec=True,
+ )
+ jit_state = cfgrow[0]['cur']
+ except exceptions.UndefinedObjectError:
+ jit_state = 'off'
+ else:
+ jit_state = 'off'
+
+ result = await self.__execute(
+ self._intro_query,
+ (list(typeoids),),
+ 0,
+ timeout,
+ ignore_custom_codec=True,
+ )
+
+ if jit_state != 'off':
+ await self.__execute(
+ """
+ SELECT
+ set_config('jit', $1, false)
+ """,
+ (jit_state,),
+ 0,
+ timeout,
+ ignore_custom_codec=True,
+ )
+
+ return result
+
+ async def _introspect_type(self, typename, schema):
+ if schema == 'pg_catalog' and not typename.endswith("[]"):
+ typeoid = protocol.BUILTIN_TYPE_NAME_MAP.get(typename.lower())
+ if typeoid is not None:
+ return introspection.TypeRecord((typeoid, None, b"b"))
+
+ rows = await self._execute(
+ introspection.TYPE_BY_NAME,
+ [typename, schema],
+ limit=1,
+ timeout=None,
+ ignore_custom_codec=True,
+ )
+
+ if not rows:
+ raise ValueError(
+ 'unknown type: {}.{}'.format(schema, typename))
- def cursor(self, query, *args, prefetch=None, timeout=None):
+ return rows[0]
+
+ def cursor(
+ self,
+ query,
+ *args,
+ prefetch=None,
+ timeout=None,
+ record_class=None
+ ):
"""Return a *cursor factory* for the specified query.
- :param args: Query arguments.
- :param int prefetch: The number of rows the *cursor iterator*
- will prefetch (defaults to ``50``.)
- :param float timeout: Optional timeout in seconds.
+ :param args:
+ Query arguments.
+ :param int prefetch:
+ The number of rows the *cursor iterator*
+ will prefetch (defaults to ``50``.)
+ :param float timeout:
+ Optional timeout in seconds.
+ :param type record_class:
+ If specified, the class to use for records returned by this cursor.
+ Must be a subclass of :class:`~asyncpg.Record`. If not specified,
+ a per-connection *record_class* is used.
+
+ :return:
+ A :class:`~cursor.CursorFactory` object.
- :return: A :class:`~cursor.CursorFactory` object.
+ .. versionchanged:: 0.22.0
+ Added the *record_class* parameter.
"""
self._check_open()
- return cursor.CursorFactory(self, query, None, args,
- prefetch, timeout)
+ return cursor.CursorFactory(
+ self,
+ query,
+ None,
+ args,
+ prefetch,
+ timeout,
+ record_class,
+ )
- async def prepare(self, query, *, timeout=None):
+ async def prepare(
+ self,
+ query,
+ *,
+ name=None,
+ timeout=None,
+ record_class=None,
+ ):
"""Create a *prepared statement* for the specified query.
- :param str query: Text of the query to create a prepared statement for.
- :param float timeout: Optional timeout value in seconds.
+ :param str query:
+ Text of the query to create a prepared statement for.
+ :param str name:
+ Optional name of the returned prepared statement. If not
+ specified, the name is auto-generated.
+ :param float timeout:
+ Optional timeout value in seconds.
+ :param type record_class:
+ If specified, the class to use for records returned by the
+ prepared statement. Must be a subclass of
+ :class:`~asyncpg.Record`. If not specified, a per-connection
+ *record_class* is used.
+
+ :return:
+ A :class:`~prepared_stmt.PreparedStatement` instance.
- :return: A :class:`~prepared_stmt.PreparedStatement` instance.
+ .. versionchanged:: 0.22.0
+ Added the *record_class* parameter.
+
+ .. versionchanged:: 0.25.0
+ Added the *name* parameter.
"""
- return await self._prepare(query, timeout=timeout, use_cache=False)
+ return await self._prepare(
+ query,
+ name=name,
+ timeout=timeout,
+ record_class=record_class,
+ )
- async def _prepare(self, query, *, timeout=None, use_cache: bool=False):
+ async def _prepare(
+ self,
+ query,
+ *,
+ name: typing.Union[str, bool, None] = None,
+ timeout=None,
+ use_cache: bool=False,
+ record_class=None
+ ):
self._check_open()
- stmt = await self._get_statement(query, timeout, named=True,
- use_cache=use_cache)
+ if name is None:
+ name = self._stmt_cache_enabled
+ stmt = await self._get_statement(
+ query,
+ timeout,
+ named=name,
+ use_cache=use_cache,
+ record_class=record_class,
+ )
return prepared_stmt.PreparedStatement(self, query, stmt)
- async def fetch(self, query, *args, timeout=None) -> list:
+ async def fetch(
+ self,
+ query,
+ *args,
+ timeout=None,
+ record_class=None
+ ) -> list:
"""Run a query and return the results as a list of :class:`Record`.
- :param str query: Query text.
- :param args: Query arguments.
- :param float timeout: Optional timeout value in seconds.
+ :param str query:
+ Query text.
+ :param args:
+ Query arguments.
+ :param float timeout:
+ Optional timeout value in seconds.
+ :param type record_class:
+ If specified, the class to use for records returned by this method.
+ Must be a subclass of :class:`~asyncpg.Record`. If not specified,
+ a per-connection *record_class* is used.
- :return list: A list of :class:`Record` instances.
+ :return list:
+ A list of :class:`~asyncpg.Record` instances. If specified, the
+ actual type of list elements would be *record_class*.
+
+ .. versionchanged:: 0.22.0
+ Added the *record_class* parameter.
"""
self._check_open()
- return await self._execute(query, args, 0, timeout)
+ return await self._execute(
+ query,
+ args,
+ 0,
+ timeout,
+ record_class=record_class,
+ )
async def fetchval(self, query, *args, column=0, timeout=None):
"""Run a query and return a value in the first row.
@@ -441,22 +720,89 @@ async def fetchval(self, query, *args, column=0, timeout=None):
return None
return data[0][column]
- async def fetchrow(self, query, *args, timeout=None):
+ async def fetchrow(
+ self,
+ query,
+ *args,
+ timeout=None,
+ record_class=None
+ ):
"""Run a query and return the first row.
- :param str query: Query text
- :param args: Query arguments
- :param float timeout: Optional timeout value in seconds.
-
- :return: The first row as a :class:`Record` instance, or None if
- no records were returned by the query.
+ :param str query:
+ Query text
+ :param args:
+ Query arguments
+ :param float timeout:
+ Optional timeout value in seconds.
+ :param type record_class:
+ If specified, the class to use for the value returned by this
+ method. Must be a subclass of :class:`~asyncpg.Record`.
+ If not specified, a per-connection *record_class* is used.
+
+ :return:
+ The first row as a :class:`~asyncpg.Record` instance, or None if
+ no records were returned by the query. If specified,
+ *record_class* is used as the type for the result value.
+
+ .. versionchanged:: 0.22.0
+ Added the *record_class* parameter.
"""
self._check_open()
- data = await self._execute(query, args, 1, timeout)
+ data = await self._execute(
+ query,
+ args,
+ 1,
+ timeout,
+ record_class=record_class,
+ )
if not data:
return None
return data[0]
+ async def fetchmany(
+ self,
+ query,
+ args,
+ *,
+ timeout: typing.Optional[float]=None,
+ record_class=None,
+ ):
+ """Run a query for each sequence of arguments in *args*
+ and return the results as a list of :class:`Record`.
+
+ :param query:
+ Query to execute.
+ :param args:
+ An iterable containing sequences of arguments for the query.
+ :param float timeout:
+ Optional timeout value in seconds.
+ :param type record_class:
+ If specified, the class to use for records returned by this method.
+ Must be a subclass of :class:`~asyncpg.Record`. If not specified,
+ a per-connection *record_class* is used.
+
+ :return list:
+ A list of :class:`~asyncpg.Record` instances. If specified, the
+ actual type of list elements would be *record_class*.
+
+ Example:
+
+ .. code-block:: pycon
+
+ >>> rows = await con.fetchmany('''
+ ... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a;
+ ... ''', [('x', 1), ('y', 2), ('z', 3)])
+ >>> rows
+ [, , ]
+
+ .. versionadded:: 0.30.0
+ """
+ self._check_open()
+ return await self._executemany(
+ query, args, timeout, return_rows=True, record_class=record_class
+ )
+
async def copy_from_table(self, table_name, *, output,
columns=None, schema_name=None, timeout=None,
format=None, oids=None, delimiter=None,
@@ -500,7 +846,7 @@ async def copy_from_table(self, table_name, *, output,
... output='file.csv', format='csv')
... print(result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
'COPY 100'
.. _`COPY statement documentation`:
@@ -569,7 +915,7 @@ async def copy_from_query(self, query, *args, output,
... output='file.csv', format='csv')
... print(result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
'COPY 10'
.. _`COPY statement documentation`:
@@ -597,7 +943,7 @@ async def copy_to_table(self, table_name, *, source,
delimiter=None, null=None, header=None,
quote=None, escape=None, force_quote=None,
force_not_null=None, force_null=None,
- encoding=None):
+ encoding=None, where=None):
"""Copy data to the specified table.
:param str table_name:
@@ -616,6 +962,15 @@ async def copy_to_table(self, table_name, *, source,
:param str schema_name:
An optional schema name to qualify the table.
+ :param str where:
+ An optional SQL expression used to filter rows when copying.
+
+ .. note::
+
+ Usage of this parameter requires support for the
+ ``COPY FROM ... WHERE`` syntax, introduced in
+ PostgreSQL version 12.
+
:param float timeout:
Optional timeout value in seconds.
@@ -636,13 +991,16 @@ async def copy_to_table(self, table_name, *, source,
... 'mytable', source='datafile.tbl')
... print(result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
'COPY 140000'
.. _`COPY statement documentation`:
https://www.postgresql.org/docs/current/static/sql-copy.html
.. versionadded:: 0.11.0
+
+ .. versionadded:: 0.29.0
+ Added the *where* parameter.
"""
tabname = utils._quote_ident(table_name)
if schema_name:
@@ -654,6 +1012,7 @@ async def copy_to_table(self, table_name, *, source,
else:
cols = ''
+ cond = self._format_copy_where(where)
opts = self._format_copy_opts(
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
null=null, header=header, quote=quote, escape=escape,
@@ -661,14 +1020,14 @@ async def copy_to_table(self, table_name, *, source,
encoding=encoding
)
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
- tab=tabname, cols=cols, opts=opts)
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
+ tab=tabname, cols=cols, opts=opts, cond=cond)
return await self._copy_in(copy_stmt, source, timeout)
async def copy_records_to_table(self, table_name, *, records,
columns=None, schema_name=None,
- timeout=None):
+ timeout=None, where=None):
"""Copy a list of records to the specified table using binary COPY.
:param str table_name:
@@ -676,6 +1035,8 @@ async def copy_records_to_table(self, table_name, *, records,
:param records:
An iterable returning row tuples to copy into the table.
+ :term:`Asynchronous iterables `
+ are also supported.
:param list columns:
An optional list of column names to copy.
@@ -683,6 +1044,16 @@ async def copy_records_to_table(self, table_name, *, records,
:param str schema_name:
An optional schema name to qualify the table.
+ :param str where:
+ An optional SQL expression used to filter rows when copying.
+
+ .. note::
+
+ Usage of this parameter requires support for the
+ ``COPY FROM ... WHERE`` syntax, introduced in
+ PostgreSQL version 12.
+
+
:param float timeout:
Optional timeout value in seconds.
@@ -702,10 +1073,34 @@ async def copy_records_to_table(self, table_name, *, records,
... (2, 'ham', 'spam')])
... print(result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
'COPY 2'
+ Asynchronous record iterables are also supported:
+
+ .. code-block:: pycon
+
+ >>> import asyncpg
+ >>> import asyncio
+ >>> async def run():
+ ... con = await asyncpg.connect(user='postgres')
+ ... async def record_gen(size):
+ ... for i in range(size):
+ ... yield (i,)
+ ... result = await con.copy_records_to_table(
+ ... 'mytable', records=record_gen(100))
+ ... print(result)
+ ...
+ >>> asyncio.run(run())
+ 'COPY 100'
+
.. versionadded:: 0.11.0
+
+ .. versionchanged:: 0.24.0
+ The ``records`` argument may be an asynchronous iterable.
+
+ .. versionadded:: 0.29.0
+ Added the *where* parameter.
"""
tabname = utils._quote_ident(table_name)
if schema_name:
@@ -721,15 +1116,28 @@ async def copy_records_to_table(self, table_name, *, records,
intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format(
tab=tabname, cols=col_list)
- intro_ps = await self._prepare(intro_query, use_cache=True)
+ intro_ps = await self.prepare(intro_query)
+ cond = self._format_copy_where(where)
opts = '(FORMAT binary)'
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
- tab=tabname, cols=cols, opts=opts)
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
+ tab=tabname, cols=cols, opts=opts, cond=cond)
+
+ return await self._protocol.copy_in(
+ copy_stmt, None, None, records, intro_ps._state, timeout)
- return await self._copy_in_records(
- copy_stmt, records, intro_ps._state, timeout)
+ def _format_copy_where(self, where):
+ if where and not self._server_caps.sql_copy_from_where:
+ raise exceptions.UnsupportedServerFeatureError(
+ 'the `where` parameter requires PostgreSQL 12 or later')
+
+ if where:
+ where_clause = 'WHERE ' + where
+ else:
+ where_clause = ''
+
+ return where_clause
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
delimiter=None, null=None, header=None, quote=None,
@@ -762,7 +1170,7 @@ def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
async def _copy_out(self, copy_stmt, output, timeout):
try:
- path = compat.fspath(output)
+ path = os.fspath(output)
except TypeError:
# output is not a path-like object
path = None
@@ -801,7 +1209,7 @@ async def _writer(data):
async def _copy_in(self, copy_stmt, source, timeout):
try:
- path = compat.fspath(source)
+ path = os.fspath(source)
except TypeError:
# source is not a path-like object
path = None
@@ -814,13 +1222,16 @@ async def _copy_in(self, copy_stmt, source, timeout):
if path is not None:
# a path
- f = await run_in_executor(None, open, path, 'wb')
+ f = await run_in_executor(None, open, path, 'rb')
opened_by_us = True
elif hasattr(source, 'read'):
# file-like
f = source
elif isinstance(source, collections.abc.AsyncIterable):
# assuming calling output returns an awaitable.
+ # copy_in() is designed to handle very large amounts of data, and
+ # the source async iterable is allowed to return an arbitrary
+ # amount of data on every iteration.
reader = source
else:
# assuming source is an instance supporting the buffer protocol.
@@ -829,7 +1240,6 @@ async def _copy_in(self, copy_stmt, source, timeout):
if f is not None:
# Copying from a file-like object.
class _Reader:
- @compat.aiter_compat
def __aiter__(self):
return self
@@ -849,10 +1259,6 @@ async def __anext__(self):
if opened_by_us:
await run_in_executor(None, f.close)
- async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
- return await self._protocol.copy_in(
- copy_stmt, None, None, records, intro_stmt, timeout)
-
async def set_type_codec(self, typename, *,
schema='public', encoder, decoder,
format='text'):
@@ -904,6 +1310,9 @@ async def set_type_codec(self, typename, *,
| ``time with | (``microseconds``, |
| time zone`` | ``time zone offset in seconds``) |
+-----------------+---------------------------------------------+
+ | any composite | Composite value elements |
+ | type | |
+ +-----------------+---------------------------------------------+
:param encoder:
Callable accepting a Python object as a single argument and
@@ -942,7 +1351,7 @@ async def set_type_codec(self, typename, *,
... print(result)
... print(datetime.datetime(2002, 1, 1) + result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
relativedelta(years=+2, months=+3, days=+1)
2004-04-02 00:00:00
@@ -957,22 +1366,56 @@ async def set_type_codec(self, typename, *,
.. versionchanged:: 0.13.0
The ``binary`` keyword argument was removed in favor of
``format``.
- """
- self._check_open()
- typeinfo = await self.fetchrow(
- introspection.TYPE_BY_NAME, typename, schema)
- if not typeinfo:
- raise ValueError('unknown type: {}.{}'.format(schema, typename))
+ .. versionchanged:: 0.29.0
+ Custom codecs for composite types are now supported with
+ ``format='tuple'``.
- if not introspection.is_scalar_type(typeinfo):
- raise ValueError(
- 'cannot use custom codec on non-scalar type {}.{}'.format(
- schema, typename))
+ .. note::
+
+ It is recommended to use the ``'binary'`` or ``'tuple'`` *format*
+ whenever possible and if the underlying type supports it. Asyncpg
+ currently does not support text I/O for composite and range types,
+ and some other functionality, such as
+ :meth:`Connection.copy_to_table`, does not support types with text
+ codecs.
+ """
+ self._check_open()
+ settings = self._protocol.get_settings()
+ typeinfo = await self._introspect_type(typename, schema)
+ full_typeinfos = []
+ if introspection.is_scalar_type(typeinfo):
+ kind = 'scalar'
+ elif introspection.is_composite_type(typeinfo):
+ if format != 'tuple':
+ raise exceptions.UnsupportedClientFeatureError(
+ 'only tuple-format codecs can be used on composite types',
+ hint="Use `set_type_codec(..., format='tuple')` and "
+ "pass/interpret data as a Python tuple. See an "
+ "example at https://magicstack.github.io/asyncpg/"
+ "current/usage.html#example-decoding-complex-types",
+ )
+ kind = 'composite'
+ full_typeinfos, _ = await self._introspect_types(
+ (typeinfo['oid'],), 10)
+ else:
+ raise exceptions.InterfaceError(
+ f'cannot use custom codec on type {schema}.{typename}: '
+ f'it is neither a scalar type nor a composite type'
+ )
+ if introspection.is_domain_type(typeinfo):
+ raise exceptions.UnsupportedClientFeatureError(
+ 'custom codecs on domain types are not supported',
+ hint='Set the codec on the base type.',
+ detail=(
+ 'PostgreSQL does not distinguish domains from '
+ 'their base types in query results at the protocol level.'
+ )
+ )
oid = typeinfo['oid']
- self._protocol.get_settings().add_python_codec(
- oid, typename, schema, 'scalar',
+ settings.add_python_codec(
+ oid, typename, schema, full_typeinfos, kind,
encoder, decoder, format)
# Statement cache is no longer valid due to codec changes.
@@ -991,15 +1434,9 @@ async def reset_type_codec(self, typename, *, schema='public'):
.. versionadded:: 0.12.0
"""
- typeinfo = await self.fetchrow(
- introspection.TYPE_BY_NAME, typename, schema)
- if not typeinfo:
- raise ValueError('unknown type: {}.{}'.format(schema, typename))
-
- oid = typeinfo['oid']
-
+ typeinfo = await self._introspect_type(typename, schema)
self._protocol.get_settings().remove_python_codec(
- oid, typename, schema)
+ typeinfo['oid'], typename, schema)
# Statement cache is no longer valid due to codec changes.
self._drop_local_statement_cache()
@@ -1040,13 +1477,7 @@ async def set_builtin_type_codec(self, typename, *,
core data type. Added the *format* keyword argument.
"""
self._check_open()
-
- typeinfo = await self.fetchrow(
- introspection.TYPE_BY_NAME, typename, schema)
- if not typeinfo:
- raise exceptions.InterfaceError(
- 'unknown type: {}.{}'.format(schema, typename))
-
+ typeinfo = await self._introspect_type(typename, schema)
if not introspection.is_scalar_type(typeinfo):
raise exceptions.InterfaceError(
'cannot alias non-scalar type {}.{}'.format(
@@ -1080,7 +1511,7 @@ async def close(self, *, timeout=None):
try:
if not self.is_closed():
await self._protocol.close(timeout)
- except Exception:
+ except (Exception, asyncio.CancelledError):
# If we fail to close gracefully, abort the connection.
self._abort()
raise
@@ -1093,11 +1524,10 @@ def terminate(self):
self._abort()
self._cleanup()
- async def reset(self, *, timeout=None):
+ async def _reset(self):
self._check_open()
self._listeners.clear()
self._log_listeners.clear()
- reset_query = self._get_reset_query()
if self._protocol.is_in_transaction() or self._top_xact is not None:
if self._top_xact is None or not self._top_xact._managed:
@@ -1109,10 +1539,36 @@ async def reset(self, *, timeout=None):
})
self._top_xact = None
- reset_query = 'ROLLBACK;\n' + reset_query
+ await self.execute("ROLLBACK")
+
+ async def reset(self, *, timeout=None):
+ """Reset the connection state.
+
+ Calling this will reset the connection session state to a state
+ resembling that of a newly obtained connection. Namely, an open
+ transaction (if any) is rolled back, open cursors are closed,
+ all `LISTEN `_
+ registrations are removed, all session configuration
+ variables are reset to their default values, and all advisory locks
+ are released.
+
+ Note that the above describes the default query returned by
+ :meth:`Connection.get_reset_query`. If one overloads the method
+ by subclassing ``Connection``, then this method will do whatever
+ the overloaded method returns, except open transactions are always
+ terminated and any callbacks registered by
+ :meth:`Connection.add_listener` or :meth:`Connection.add_log_listener`
+ are removed.
- if reset_query:
- await self.execute(reset_query, timeout=timeout)
+ :param float timeout:
+ A timeout for resetting the connection. If not specified, defaults
+ to no timeout.
+ """
+ async with compat.timeout(timeout):
+ await self._reset()
+ reset_query = self.get_reset_query()
+ if reset_query:
+ await self.execute(reset_query)
def _abort(self):
# Put the connection into the aborted state.
@@ -1121,6 +1577,7 @@ def _abort(self):
self._protocol = None
def _cleanup(self):
+ self._call_termination_listeners()
# Free the resources associated with this connection.
# This must be called when a connection is terminated.
@@ -1132,6 +1589,7 @@ def _cleanup(self):
self._mark_stmts_as_closed()
self._listeners.clear()
self._log_listeners.clear()
+ self._query_loggers.clear()
self._clean_tasks()
def _clean_tasks(self):
@@ -1162,7 +1620,13 @@ def _mark_stmts_as_closed(self):
self._stmts_to_close.clear()
def _maybe_gc_stmt(self, stmt):
- if stmt.refs == 0 and not self._stmt_cache.has(stmt.query):
+ if (
+ stmt.refs == 0
+ and stmt.name
+ and not self._stmt_cache.has(
+ (stmt.query, stmt.record_class, stmt.ignore_custom_codec)
+ )
+ ):
# If low-level `stmt` isn't referenced from any high-level
# `PreparedStatement` object and is not in the `_stmt_cache`:
#
@@ -1186,24 +1650,16 @@ async def _cleanup_stmts(self):
await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT)
async def _cancel(self, waiter):
- r = w = None
-
try:
# Open new connection to the server
- r, w = await connect_utils._open_connection(
- loop=self._loop, addr=self._addr, params=self._params)
-
- # Pack CancelRequest message
- msg = struct.pack('!llll', 16, 80877102,
- self._protocol.backend_pid,
- self._protocol.backend_secret)
-
- w.write(msg)
- await r.read() # Wait until EOF
+ await connect_utils._cancel(
+ loop=self._loop, addr=self._addr, params=self._params,
+ backend_pid=self._protocol.backend_pid,
+ backend_secret=self._protocol.backend_secret)
except ConnectionResetError as ex:
# On some systems Postgres will reset the connection
# after processing the cancellation command.
- if r is None and not waiter.done():
+ if not waiter.done():
waiter.set_exception(ex)
except asyncio.CancelledError:
# There are two scenarios in which the cancellation
@@ -1213,16 +1669,14 @@ async def _cancel(self, waiter):
# the CancelledError, and don't want the loop to warn about
# an unretrieved exception.
pass
- except Exception as ex:
+ except (Exception, asyncio.CancelledError) as ex:
if not waiter.done():
waiter.set_exception(ex)
finally:
self._cancellations.discard(
- compat.current_asyncio_task(self._loop))
+ asyncio.current_task(self._loop))
if not waiter.done():
waiter.set_result(None)
- if w is not None:
- w.close()
def _cancel_current_command(self, waiter):
self._cancellations.add(self._loop.create_task(self._cancel(waiter)))
@@ -1235,18 +1689,23 @@ def _process_log_message(self, fields, last_query):
con_ref = self._unwrap()
for cb in self._log_listeners:
- self._loop.call_soon(
- self._call_log_listener, cb, con_ref, message)
+ if cb.is_async:
+ self._loop.create_task(cb.cb(con_ref, message))
+ else:
+ self._loop.call_soon(cb.cb, con_ref, message)
- def _call_log_listener(self, cb, con_ref, message):
- try:
- cb(con_ref, message)
- except Exception as ex:
- self._loop.call_exception_handler({
- 'message': 'Unhandled exception in asyncpg log message '
- 'listener callback {!r}'.format(cb),
- 'exception': ex
- })
+ def _call_termination_listeners(self):
+ if not self._termination_listeners:
+ return
+
+ con_ref = self._unwrap()
+ for cb in self._termination_listeners:
+ if cb.is_async:
+ self._loop.create_task(cb.cb(con_ref))
+ else:
+ self._loop.call_soon(cb.cb, con_ref)
+
+ self._termination_listeners.clear()
def _process_notification(self, pid, channel, payload):
if channel not in self._listeners:
@@ -1254,18 +1713,10 @@ def _process_notification(self, pid, channel, payload):
con_ref = self._unwrap()
for cb in self._listeners[channel]:
- self._loop.call_soon(
- self._call_listener, cb, con_ref, pid, channel, payload)
-
- def _call_listener(self, cb, con_ref, pid, channel, payload):
- try:
- cb(con_ref, pid, channel, payload)
- except Exception as ex:
- self._loop.call_exception_handler({
- 'message': 'Unhandled exception in asyncpg notification '
- 'listener callback {!r}'.format(cb),
- 'exception': ex
- })
+ if cb.is_async:
+ self._loop.create_task(cb.cb(con_ref, pid, channel, payload))
+ else:
+ self._loop.call_soon(cb.cb, con_ref, pid, channel, payload)
def _unwrap(self):
if self._proxy is None:
@@ -1278,7 +1729,15 @@ def _unwrap(self):
con_ref = self._proxy
return con_ref
- def _get_reset_query(self):
+ def get_reset_query(self):
+ """Return the query sent to server on connection release.
+
+ The query returned by this method is used by :meth:`Connection.reset`,
+ which is, in turn, used by :class:`~asyncpg.pool.Pool` before making
+ the connection available to another acquirer.
+
+ .. versionadded:: 0.30.0
+ """
if self._reset_query is not None:
return self._reset_query
@@ -1290,16 +1749,7 @@ def _get_reset_query(self):
if caps.sql_close_all:
_reset_query.append('CLOSE ALL;')
if caps.notifications and caps.plpgsql:
- _reset_query.append('''
- DO $$
- BEGIN
- PERFORM * FROM pg_listening_channels() LIMIT 1;
- IF FOUND THEN
- UNLISTEN *;
- END IF;
- END;
- $$;
- ''')
+ _reset_query.append('UNLISTEN *;')
if caps.sql_reset:
_reset_query.append('RESET ALL;')
@@ -1401,40 +1851,179 @@ async def reload_schema_state(self):
... await con.execute('LOCK TABLE tbl')
... await change_type(con)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
.. versionadded:: 0.14.0
"""
self._drop_global_type_cache()
self._drop_global_statement_cache()
- async def _execute(self, query, args, limit, timeout, return_status=False):
+ async def _execute(
+ self,
+ query,
+ args,
+ limit,
+ timeout,
+ *,
+ return_status=False,
+ ignore_custom_codec=False,
+ record_class=None
+ ):
with self._stmt_exclusive_section:
result, _ = await self.__execute(
- query, args, limit, timeout, return_status=return_status)
+ query,
+ args,
+ limit,
+ timeout,
+ return_status=return_status,
+ record_class=record_class,
+ ignore_custom_codec=ignore_custom_codec,
+ )
return result
- async def __execute(self, query, args, limit, timeout,
- return_status=False):
+ @contextlib.contextmanager
+ def query_logger(self, callback):
+ """Context manager that adds `callback` to the list of query loggers,
+ and removes it upon exit.
+
+ :param callable callback:
+ A callable or a coroutine function receiving one argument:
+ **record**, a LoggedQuery containing `query`, `args`, `timeout`,
+ `elapsed`, `exception`, `conn_addr`, and `conn_params`.
+
+ Example:
+
+ .. code-block:: pycon
+
+ >>> class QuerySaver:
+ def __init__(self):
+ self.queries = []
+ def __call__(self, record):
+ self.queries.append(record.query)
+ >>> with con.query_logger(QuerySaver()):
+ >>> await con.execute("SELECT 1")
+ >>> print(log.queries)
+ ['SELECT 1']
+
+ .. versionadded:: 0.29.0
+ """
+ self.add_query_logger(callback)
+ yield
+ self.remove_query_logger(callback)
+
+ @contextlib.contextmanager
+ def _time_and_log(self, query, args, timeout):
+ start = time.monotonic()
+ exception = None
+ try:
+ yield
+ except BaseException as ex:
+ exception = ex
+ raise
+ finally:
+ elapsed = time.monotonic() - start
+ record = LoggedQuery(
+ query=query,
+ args=args,
+ timeout=timeout,
+ elapsed=elapsed,
+ exception=exception,
+ conn_addr=self._addr,
+ conn_params=self._params,
+ )
+ for cb in self._query_loggers:
+ if cb.is_async:
+ self._loop.create_task(cb.cb(record))
+ else:
+ self._loop.call_soon(cb.cb, record)
+
+ async def __execute(
+ self,
+ query,
+ args,
+ limit,
+ timeout,
+ *,
+ return_status=False,
+ ignore_custom_codec=False,
+ record_class=None
+ ):
executor = lambda stmt, timeout: self._protocol.bind_execute(
- stmt, args, '', limit, return_status, timeout)
+ state=stmt,
+ args=args,
+ portal_name='',
+ limit=limit,
+ return_extra=return_status,
+ timeout=timeout,
+ )
timeout = self._protocol._get_timeout(timeout)
- return await self._do_execute(query, executor, timeout)
+ if self._query_loggers:
+ with self._time_and_log(query, args, timeout):
+ result, stmt = await self._do_execute(
+ query,
+ executor,
+ timeout,
+ record_class=record_class,
+ ignore_custom_codec=ignore_custom_codec,
+ )
+ else:
+ result, stmt = await self._do_execute(
+ query,
+ executor,
+ timeout,
+ record_class=record_class,
+ ignore_custom_codec=ignore_custom_codec,
+ )
+ return result, stmt
- async def _executemany(self, query, args, timeout):
+ async def _executemany(
+ self,
+ query,
+ args,
+ timeout,
+ return_rows=False,
+ record_class=None,
+ ):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
- stmt, args, '', timeout)
+ state=stmt,
+ args=args,
+ portal_name='',
+ timeout=timeout,
+ return_rows=return_rows,
+ )
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
- result, _ = await self._do_execute(query, executor, timeout)
+ with self._time_and_log(query, args, timeout):
+ result, _ = await self._do_execute(
+ query, executor, timeout, record_class=record_class
+ )
return result
- async def _do_execute(self, query, executor, timeout, retry=True):
+ async def _do_execute(
+ self,
+ query,
+ executor,
+ timeout,
+ retry=True,
+ *,
+ ignore_custom_codec=False,
+ record_class=None
+ ):
if timeout is None:
- stmt = await self._get_statement(query, None)
+ stmt = await self._get_statement(
+ query,
+ None,
+ record_class=record_class,
+ ignore_custom_codec=ignore_custom_codec,
+ )
else:
before = time.monotonic()
- stmt = await self._get_statement(query, timeout)
+ stmt = await self._get_statement(
+ query,
+ timeout,
+ record_class=record_class,
+ ignore_custom_codec=ignore_custom_codec,
+ )
after = time.monotonic()
timeout -= after - before
before = after
@@ -1494,6 +2083,8 @@ async def _do_execute(self, query, executor, timeout, retry=True):
async def connect(dsn=None, *,
host=None, port=None,
user=None, password=None, passfile=None,
+ service=None,
+ servicefile=None,
database=None,
loop=None,
timeout=60,
@@ -1502,8 +2093,13 @@ async def connect(dsn=None, *,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
+ direct_tls=None,
connection_class=Connection,
- server_settings=None):
+ record_class=protocol.Record,
+ server_settings=None,
+ target_session_attrs=None,
+ krbsrvname=None,
+ gsslib=None):
r"""A coroutine to establish a connection to a PostgreSQL server.
The connection parameters may be specified either as a connection
@@ -1511,7 +2107,7 @@ async def connect(dsn=None, *,
If both *dsn* and keyword arguments are specified, the latter
override the corresponding values parsed from the connection URI.
The default values for the majority of arguments can be specified
- using `environment variables `_.
+ using `environment variables `_.
Returns a new :class:`~asyncpg.connection.Connection` object.
@@ -1519,10 +2115,22 @@ async def connect(dsn=None, *,
Connection arguments specified using as a single string in the
`libpq connection URI format`_:
``postgres://user:password@host:port/database?option=value``.
- The following options are recognized by asyncpg: host, port,
- user, database (or dbname), password, passfile, sslmode.
- Unlike libpq, asyncpg will treat unrecognized options
- as `server settings`_ to be used for the connection.
+ The following options are recognized by asyncpg: ``host``,
+ ``port``, ``user``, ``database`` (or ``dbname``), ``password``,
+ ``passfile``, ``sslmode``, ``sslcert``, ``sslkey``, ``sslrootcert``,
+ and ``sslcrl``. Unlike libpq, asyncpg will treat unrecognized
+ options as `server settings`_ to be used for the connection.
+
+ .. note::
+
+ The URI must be *valid*, which means that all components must
+ be properly quoted with :py:func:`urllib.parse.quote_plus`, and
+ any literal IPv6 addresses must be enclosed in square brackets.
+ For example:
+
+ .. code-block:: text
+
+ postgres://dbuser@[fe80::1ff:fe23:4567:890a%25eth0]/dbname
:param host:
Database host address as one of the following:
@@ -1567,7 +2175,7 @@ async def connect(dsn=None, *,
If not specified, the value parsed from the *dsn* argument is used,
or the value of the ``PGDATABASE`` environment variable, or the
- operating system name of the user running the application.
+ computed value of the *user* argument.
:param password:
Password to be used for authentication, if the server requires
@@ -1577,11 +2185,23 @@ async def connect(dsn=None, *,
other users and applications may be able to read it without needing
specific privileges. It is recommended to use *passfile* instead.
+ Password may be either a string, or a callable that returns a string.
+ If a callable is provided, it will be called each time a new connection
+ is established.
+
:param passfile:
The name of the file used to store passwords
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
on Windows).
+ :param service:
+ The name of the postgres connection service stored in the postgres
+ connection service file.
+
+ :param servicefile:
+ The location of the connnection service file used to store
+ connection parameters.
+
:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
@@ -1611,17 +2231,118 @@ async def connect(dsn=None, *,
Pass ``True`` or an `ssl.SSLContext `_ instance to
require an SSL connection. If ``True``, a default SSL context
returned by `ssl.create_default_context() `_
- will be used.
+ will be used. The value can also be one of the following strings:
+
+ - ``'disable'`` - SSL is disabled (equivalent to ``False``)
+ - ``'prefer'`` - try SSL first, fallback to non-SSL connection
+ if SSL connection fails
+ - ``'allow'`` - try without SSL first, then retry with SSL if the first
+ attempt fails.
+ - ``'require'`` - only try an SSL connection. Certificate
+ verification errors are ignored
+ - ``'verify-ca'`` - only try an SSL connection, and verify
+ that the server certificate is issued by a trusted certificate
+ authority (CA)
+ - ``'verify-full'`` - only try an SSL connection, verify
+ that the server certificate is issued by a trusted CA and
+ that the requested server host name matches that in the
+ certificate.
+
+ The default is ``'prefer'``: try an SSL connection and fallback to
+ non-SSL connection if that fails.
+
+ .. note::
+
+ *ssl* is ignored for Unix domain socket communication.
+
+ Example of programmatic SSL context configuration that is equivalent
+ to ``sslmode=verify-full&sslcert=..&sslkey=..&sslrootcert=..``:
+
+ .. code-block:: pycon
+
+ >>> import asyncpg
+ >>> import asyncio
+ >>> import ssl
+ >>> async def main():
+ ... # Load CA bundle for server certificate verification,
+ ... # equivalent to sslrootcert= in DSN.
+ ... sslctx = ssl.create_default_context(
+ ... ssl.Purpose.SERVER_AUTH,
+ ... cafile="path/to/ca_bundle.pem")
+ ... # If True, equivalent to sslmode=verify-full, if False:
+ ... # sslmode=verify-ca.
+ ... sslctx.check_hostname = True
+ ... # Load client certificate and private key for client
+ ... # authentication, equivalent to sslcert= and sslkey= in
+ ... # DSN.
+ ... sslctx.load_cert_chain(
+ ... "path/to/client.cert",
+ ... keyfile="path/to/client.key",
+ ... )
+ ... con = await asyncpg.connect(user='postgres', ssl=sslctx)
+ ... await con.close()
+ >>> asyncio.run(main())
+
+ Example of programmatic SSL context configuration that is equivalent
+ to ``sslmode=require`` (no server certificate or host verification):
+
+ .. code-block:: pycon
+
+ >>> import asyncpg
+ >>> import asyncio
+ >>> import ssl
+ >>> async def main():
+ ... sslctx = ssl.create_default_context(
+ ... ssl.Purpose.SERVER_AUTH)
+ ... sslctx.check_hostname = False
+ ... sslctx.verify_mode = ssl.CERT_NONE
+ ... con = await asyncpg.connect(user='postgres', ssl=sslctx)
+ ... await con.close()
+ >>> asyncio.run(main())
+
+ :param bool direct_tls:
+ Pass ``True`` to skip PostgreSQL STARTTLS mode and perform a direct
+ SSL connection. Must be used alongside ``ssl`` param.
:param dict server_settings:
An optional dict of server runtime parameters. Refer to
PostgreSQL documentation for
- a `list of supported options `_.
+ a `list of supported options `_.
- :param Connection connection_class:
+ :param type connection_class:
Class of the returned connection object. Must be a subclass of
:class:`~asyncpg.connection.Connection`.
+ :param type record_class:
+ If specified, the class to use for records returned by queries on
+ this connection object. Must be a subclass of
+ :class:`~asyncpg.Record`.
+
+ :param SessionAttribute target_session_attrs:
+ If specified, check that the host has the correct attribute.
+ Can be one of:
+
+ - ``"any"`` - the first successfully connected host
+ - ``"primary"`` - the host must NOT be in hot standby mode
+ - ``"standby"`` - the host must be in hot standby mode
+ - ``"read-write"`` - the host must allow writes
+ - ``"read-only"`` - the host most NOT allow writes
+ - ``"prefer-standby"`` - first try to find a standby host, but if
+ none of the listed hosts is a standby server,
+ return any of them.
+
+ If not specified, the value parsed from the *dsn* argument is used,
+ or the value of the ``PGTARGETSESSIONATTRS`` environment variable,
+ or ``"any"`` if neither is specified.
+
+ :param str krbsrvname:
+ Kerberos service name to use when authenticating with GSSAPI. This
+ must match the server configuration. Defaults to 'postgres'.
+
+ :param str gsslib:
+ GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi'
+ or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise.
+
:return: A :class:`~asyncpg.connection.Connection` instance.
Example:
@@ -1635,7 +2356,7 @@ async def connect(dsn=None, *,
... types = await con.fetch('SELECT * FROM pg_type')
... print(types)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
[ '_Callback':
+ if inspect.iscoroutinefunction(cb):
+ is_async = True
+ elif callable(cb):
+ is_async = False
+ else:
+ raise exceptions.InterfaceError(
+ 'expected a callable or an `async def` function,'
+ 'got {!r}'.format(cb)
+ )
+
+ return cls(cb, is_async)
+
+
class _Atomic:
__slots__ = ('_acquired',)
@@ -1861,10 +2659,17 @@ class _ConnectionProxy:
__slots__ = ()
+LoggedQuery = collections.namedtuple(
+ 'LoggedQuery',
+ ['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr',
+ 'conn_params'])
+LoggedQuery.__doc__ = 'Log record of an executed query.'
+
+
ServerCapabilities = collections.namedtuple(
'ServerCapabilities',
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
- 'sql_close_all'])
+ 'sql_close_all', 'sql_copy_from_where', 'jit'])
ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.'
@@ -1876,6 +2681,8 @@ def _detect_server_capabilities(server_version, connection_settings):
plpgsql = False
sql_reset = True
sql_close_all = False
+ jit = False
+ sql_copy_from_where = False
elif hasattr(connection_settings, 'crdb_version'):
# CockroachDB detected.
advisory_locks = False
@@ -1883,6 +2690,8 @@ def _detect_server_capabilities(server_version, connection_settings):
plpgsql = False
sql_reset = False
sql_close_all = False
+ jit = False
+ sql_copy_from_where = False
elif hasattr(connection_settings, 'crate_version'):
# CrateDB detected.
advisory_locks = False
@@ -1890,6 +2699,8 @@ def _detect_server_capabilities(server_version, connection_settings):
plpgsql = False
sql_reset = False
sql_close_all = False
+ jit = False
+ sql_copy_from_where = False
else:
# Standard PostgreSQL server assumed.
advisory_locks = True
@@ -1897,13 +2708,17 @@ def _detect_server_capabilities(server_version, connection_settings):
plpgsql = True
sql_reset = True
sql_close_all = True
+ jit = server_version >= (11, 0)
+ sql_copy_from_where = server_version.major >= 12
return ServerCapabilities(
advisory_locks=advisory_locks,
notifications=notifications,
plpgsql=plpgsql,
sql_reset=sql_reset,
- sql_close_all=sql_close_all
+ sql_close_all=sql_close_all,
+ sql_copy_from_where=sql_copy_from_where,
+ jit=jit,
)
@@ -1928,4 +2743,31 @@ def _extract_stack(limit=10):
return ''.join(traceback.format_list(stack))
+def _check_record_class(record_class):
+ if record_class is protocol.Record:
+ pass
+ elif (
+ isinstance(record_class, type)
+ and issubclass(record_class, protocol.Record)
+ ):
+ if (
+ record_class.__new__ is not protocol.Record.__new__
+ or record_class.__init__ is not protocol.Record.__init__
+ ):
+ raise exceptions.InterfaceError(
+ 'record_class must not redefine __new__ or __init__'
+ )
+ else:
+ raise exceptions.InterfaceError(
+ 'record_class is expected to be a subclass of '
+ 'asyncpg.Record, got {!r}'.format(record_class)
+ )
+
+
+def _weak_maybe_gc_stmt(weak_ref, stmt):
+ self = weak_ref()
+ if self is not None:
+ self._maybe_gc_stmt(stmt)
+
+
_uid = 0
diff --git a/asyncpg/cursor.py b/asyncpg/cursor.py
index 030def0e..b4abeed1 100644
--- a/asyncpg/cursor.py
+++ b/asyncpg/cursor.py
@@ -7,7 +7,6 @@
import collections
-from . import compat
from . import connresource
from . import exceptions
@@ -19,34 +18,60 @@ class CursorFactory(connresource.ConnectionResource):
results of a large query.
"""
- __slots__ = ('_state', '_args', '_prefetch', '_query', '_timeout')
-
- def __init__(self, connection, query, state, args, prefetch, timeout):
+ __slots__ = (
+ '_state',
+ '_args',
+ '_prefetch',
+ '_query',
+ '_timeout',
+ '_record_class',
+ )
+
+ def __init__(
+ self,
+ connection,
+ query,
+ state,
+ args,
+ prefetch,
+ timeout,
+ record_class
+ ):
super().__init__(connection)
self._args = args
self._prefetch = prefetch
self._query = query
self._timeout = timeout
self._state = state
+ self._record_class = record_class
if state is not None:
state.attach()
- @compat.aiter_compat
@connresource.guarded
def __aiter__(self):
prefetch = 50 if self._prefetch is None else self._prefetch
- return CursorIterator(self._connection,
- self._query, self._state,
- self._args, prefetch,
- self._timeout)
+ return CursorIterator(
+ self._connection,
+ self._query,
+ self._state,
+ self._args,
+ self._record_class,
+ prefetch,
+ self._timeout,
+ )
@connresource.guarded
def __await__(self):
if self._prefetch is not None:
raise exceptions.InterfaceError(
'prefetch argument can only be specified for iterable cursor')
- cursor = Cursor(self._connection, self._query,
- self._state, self._args)
+ cursor = Cursor(
+ self._connection,
+ self._query,
+ self._state,
+ self._args,
+ self._record_class,
+ )
return cursor._init(self._timeout).__await__()
def __del__(self):
@@ -57,9 +82,16 @@ def __del__(self):
class BaseCursor(connresource.ConnectionResource):
- __slots__ = ('_state', '_args', '_portal_name', '_exhausted', '_query')
+ __slots__ = (
+ '_state',
+ '_args',
+ '_portal_name',
+ '_exhausted',
+ '_query',
+ '_record_class',
+ )
- def __init__(self, connection, query, state, args):
+ def __init__(self, connection, query, state, args, record_class):
super().__init__(connection)
self._args = args
self._state = state
@@ -68,6 +100,7 @@ def __init__(self, connection, query, state, args):
self._portal_name = None
self._exhausted = False
self._query = query
+ self._record_class = record_class
def _check_ready(self):
if self._state is None:
@@ -125,6 +158,17 @@ async def _exec(self, n, timeout):
self._state, self._portal_name, n, True, timeout)
return buffer
+ async def _close_portal(self, timeout):
+ self._check_ready()
+
+ if not self._portal_name:
+ raise exceptions.InterfaceError(
+ 'cursor does not have an open portal')
+
+ protocol = self._connection._protocol
+ await protocol.close_portal(self._portal_name, timeout)
+ self._portal_name = None
+
def __repr__(self):
attrs = []
if self._exhausted:
@@ -151,8 +195,17 @@ class CursorIterator(BaseCursor):
__slots__ = ('_buffer', '_prefetch', '_timeout')
- def __init__(self, connection, query, state, args, prefetch, timeout):
- super().__init__(connection, query, state, args)
+ def __init__(
+ self,
+ connection,
+ query,
+ state,
+ args,
+ record_class,
+ prefetch,
+ timeout
+ ):
+ super().__init__(connection, query, state, args, record_class)
if prefetch <= 0:
raise exceptions.InterfaceError(
@@ -162,7 +215,6 @@ def __init__(self, connection, query, state, args, prefetch, timeout):
self._prefetch = prefetch
self._timeout = timeout
- @compat.aiter_compat
@connresource.guarded
def __aiter__(self):
return self
@@ -171,10 +223,14 @@ def __aiter__(self):
async def __anext__(self):
if self._state is None:
self._state = await self._connection._get_statement(
- self._query, self._timeout, named=True)
+ self._query,
+ self._timeout,
+ named=True,
+ record_class=self._record_class,
+ )
self._state.attach()
- if not self._portal_name:
+ if not self._portal_name and not self._exhausted:
buffer = await self._bind_exec(self._prefetch, self._timeout)
self._buffer.extend(buffer)
@@ -182,6 +238,9 @@ async def __anext__(self):
buffer = await self._exec(self._prefetch, self._timeout)
self._buffer.extend(buffer)
+ if self._portal_name and self._exhausted:
+ await self._close_portal(self._timeout)
+
if self._buffer:
return self._buffer.popleft()
@@ -196,7 +255,11 @@ class Cursor(BaseCursor):
async def _init(self, timeout):
if self._state is None:
self._state = await self._connection._get_statement(
- self._query, timeout, named=True)
+ self._query,
+ timeout,
+ named=True,
+ record_class=self._record_class,
+ )
self._state.attach()
self._check_ready()
await self._bind(timeout)
diff --git a/asyncpg/exceptions/__init__.py b/asyncpg/exceptions/__init__.py
index 446a71a8..752fd007 100644
--- a/asyncpg/exceptions/__init__.py
+++ b/asyncpg/exceptions/__init__.py
@@ -121,6 +121,10 @@ class StackedDiagnosticsAccessedWithoutActiveHandlerError(DiagnosticsError):
sqlstate = '0Z002'
+class InvalidArgumentForXqueryError(_base.PostgresError):
+ sqlstate = '10608'
+
+
class CaseNotFoundError(_base.PostgresError):
sqlstate = '20000'
@@ -337,6 +341,10 @@ class DuplicateJsonObjectKeyValueError(DataError):
sqlstate = '22030'
+class InvalidArgumentForSQLJsonDatetimeFunctionError(DataError):
+ sqlstate = '22031'
+
+
class InvalidJsonTextError(DataError):
sqlstate = '22032'
@@ -393,6 +401,10 @@ class SQLJsonScalarRequiredError(DataError):
sqlstate = '2203F'
+class SQLJsonItemCannotBeCastToTargetTypeError(DataError):
+ sqlstate = '2203G'
+
+
class IntegrityConstraintViolationError(_base.PostgresError):
sqlstate = '23000'
@@ -477,6 +489,10 @@ class IdleInTransactionSessionTimeoutError(InvalidTransactionStateError):
sqlstate = '25P03'
+class TransactionTimeoutError(InvalidTransactionStateError):
+ sqlstate = '25P04'
+
+
class InvalidSQLStatementNameError(_base.PostgresError):
sqlstate = '26000'
@@ -872,6 +888,10 @@ class DatabaseDroppedError(OperatorInterventionError):
sqlstate = '57P04'
+class IdleSessionTimeoutError(OperatorInterventionError):
+ sqlstate = '57P05'
+
+
class PostgresSystemError(_base.PostgresError):
sqlstate = '58000'
@@ -888,6 +908,10 @@ class DuplicateFileError(PostgresSystemError):
sqlstate = '58P02'
+class FileNameTooLongError(PostgresSystemError):
+ sqlstate = '58P03'
+
+
class SnapshotTooOldError(_base.PostgresError):
sqlstate = '72000'
@@ -1040,7 +1064,7 @@ class IndexCorruptedError(InternalServerError):
sqlstate = 'XX002'
-__all__ = _base.__all__ + (
+__all__ = (
'ActiveSQLTransactionError', 'AdminShutdownError',
'AmbiguousAliasError', 'AmbiguousColumnError',
'AmbiguousFunctionError', 'AmbiguousParameterError',
@@ -1083,11 +1107,11 @@ class IndexCorruptedError(InternalServerError):
'FDWTableNotFoundError', 'FDWTooManyHandlesError',
'FDWUnableToCreateExecutionError', 'FDWUnableToCreateReplyError',
'FDWUnableToEstablishConnectionError', 'FeatureNotSupportedError',
- 'ForeignKeyViolationError', 'FunctionExecutedNoReturnStatementError',
- 'GeneratedAlwaysError', 'GroupingError',
- 'HeldCursorRequiresSameIsolationLevelError',
- 'IdleInTransactionSessionTimeoutError', 'ImplicitZeroBitPadding',
- 'InFailedSQLTransactionError',
+ 'FileNameTooLongError', 'ForeignKeyViolationError',
+ 'FunctionExecutedNoReturnStatementError', 'GeneratedAlwaysError',
+ 'GroupingError', 'HeldCursorRequiresSameIsolationLevelError',
+ 'IdleInTransactionSessionTimeoutError', 'IdleSessionTimeoutError',
+ 'ImplicitZeroBitPadding', 'InFailedSQLTransactionError',
'InappropriateAccessModeForBranchTransactionError',
'InappropriateIsolationLevelForBranchTransactionError',
'IndeterminateCollationError', 'IndeterminateDatatypeError',
@@ -1098,7 +1122,9 @@ class IndexCorruptedError(InternalServerError):
'InvalidArgumentForNthValueFunctionError',
'InvalidArgumentForNtileFunctionError',
'InvalidArgumentForPowerFunctionError',
+ 'InvalidArgumentForSQLJsonDatetimeFunctionError',
'InvalidArgumentForWidthBucketFunctionError',
+ 'InvalidArgumentForXqueryError',
'InvalidAuthorizationSpecificationError',
'InvalidBinaryRepresentationError', 'InvalidCachedStatementError',
'InvalidCatalogNameError', 'InvalidCharacterValueForCastError',
@@ -1154,6 +1180,7 @@ class IndexCorruptedError(InternalServerError):
'ReadingExternalRoutineSQLDataNotPermittedError',
'ReadingSQLDataNotPermittedError', 'ReservedNameError',
'RestrictViolationError', 'SQLJsonArrayNotFoundError',
+ 'SQLJsonItemCannotBeCastToTargetTypeError',
'SQLJsonMemberNotFoundError', 'SQLJsonNumberNotFoundError',
'SQLJsonObjectNotFoundError', 'SQLJsonScalarRequiredError',
'SQLRoutineError', 'SQLStatementNotYetCompleteError',
@@ -1170,9 +1197,9 @@ class IndexCorruptedError(InternalServerError):
'TooManyJsonObjectMembersError', 'TooManyRowsError',
'TransactionIntegrityConstraintViolationError',
'TransactionResolutionUnknownError', 'TransactionRollbackError',
- 'TriggerProtocolViolatedError', 'TriggeredActionError',
- 'TriggeredDataChangeViolationError', 'TrimError',
- 'UndefinedColumnError', 'UndefinedFileError',
+ 'TransactionTimeoutError', 'TriggerProtocolViolatedError',
+ 'TriggeredActionError', 'TriggeredDataChangeViolationError',
+ 'TrimError', 'UndefinedColumnError', 'UndefinedFileError',
'UndefinedFunctionError', 'UndefinedObjectError',
'UndefinedParameterError', 'UndefinedTableError',
'UniqueViolationError', 'UnsafeNewEnumValueUsageError',
@@ -1180,3 +1207,5 @@ class IndexCorruptedError(InternalServerError):
'WindowingError', 'WithCheckOptionViolationError',
'WrongObjectTypeError', 'ZeroLengthCharacterStringError'
)
+
+__all__ += _base.__all__
diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py
index 3e6ef812..00e9699a 100644
--- a/asyncpg/exceptions/_base.py
+++ b/asyncpg/exceptions/_base.py
@@ -12,7 +12,10 @@
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
- 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError')
+ 'ClientConfigurationError',
+ 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
+ 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
+ 'UnsupportedServerFeatureError')
def _is_asyncpg_class(cls):
@@ -142,7 +145,8 @@ def _make_constructor(cls, fields, query=None):
purpose;
* if you have no option of avoiding the use of pgbouncer,
- then you must switch pgbouncer's pool_mode to "session".
+ then you can set statement_cache_size to 0 when creating
+ the asyncpg connection object.
""")
dct['hint'] = hint
@@ -208,11 +212,32 @@ def __init__(self, msg, *, detail=None, hint=None):
InterfaceMessage.__init__(self, detail=detail, hint=hint)
Exception.__init__(self, msg)
+ def with_msg(self, msg):
+ return type(self)(
+ msg,
+ detail=self.detail,
+ hint=self.hint,
+ ).with_traceback(
+ self.__traceback__
+ )
+
+
+class ClientConfigurationError(InterfaceError, ValueError):
+ """An error caused by improper client configuration."""
+
class DataError(InterfaceError, ValueError):
"""An error caused by invalid query input."""
+class UnsupportedClientFeatureError(InterfaceError):
+ """Requested feature is unsupported by asyncpg."""
+
+
+class UnsupportedServerFeatureError(InterfaceError):
+ """Requested feature is unsupported by PostgreSQL server."""
+
+
class InterfaceWarning(InterfaceMessage, UserWarning):
"""A warning caused by an improper use of asyncpg API."""
@@ -229,6 +254,10 @@ class ProtocolError(InternalClientError):
"""Unexpected condition in the handling of PostgreSQL protocol input."""
+class TargetServerAttributeNotMatched(InternalClientError):
+ """Could not find a host that satisfies the target attribute requirement"""
+
+
class OutdatedSchemaCacheError(InternalClientError):
"""A value decoding error caused by a schema change before row fetching."""
diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py
index 201f4341..c3b4e60c 100644
--- a/asyncpg/introspection.py
+++ b/asyncpg/introspection.py
@@ -4,8 +4,16 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
-_TYPEINFO = '''\
+import typing
+from .protocol.protocol import _create_record # type: ignore
+
+if typing.TYPE_CHECKING:
+ from . import protocol
+
+
+_TYPEINFO_13: typing.Final = '''\
(
SELECT
t.oid AS oid,
@@ -37,23 +45,130 @@
ELSE NULL
END) AS basetype,
- t.typreceive::oid != 0 AND t.typsend::oid != 0
- AS has_bin_io,
t.typelem AS elemtype,
elem_t.typdelim AS elemdelim,
range_t.rngsubtype AS range_subtype,
- (CASE WHEN t.typtype = 'r' THEN
+ (CASE WHEN t.typtype = 'c' THEN
(SELECT
- range_elem_t.typreceive::oid != 0 AND
- range_elem_t.typsend::oid != 0
+ array_agg(ia.atttypid ORDER BY ia.attnum)
FROM
- pg_catalog.pg_type AS range_elem_t
+ pg_attribute ia
+ INNER JOIN pg_class c
+ ON (ia.attrelid = c.oid)
WHERE
- range_elem_t.oid = range_t.rngsubtype)
- ELSE
- elem_t.typreceive::oid != 0 AND
- elem_t.typsend::oid != 0
- END) AS elem_has_bin_io,
+ ia.attnum > 0 AND NOT ia.attisdropped
+ AND c.reltype = t.oid)
+
+ ELSE NULL
+ END) AS attrtypoids,
+ (CASE WHEN t.typtype = 'c' THEN
+ (SELECT
+ array_agg(ia.attname::text ORDER BY ia.attnum)
+ FROM
+ pg_attribute ia
+ INNER JOIN pg_class c
+ ON (ia.attrelid = c.oid)
+ WHERE
+ ia.attnum > 0 AND NOT ia.attisdropped
+ AND c.reltype = t.oid)
+
+ ELSE NULL
+ END) AS attrnames
+ FROM
+ pg_catalog.pg_type AS t
+ INNER JOIN pg_catalog.pg_namespace ns ON (
+ ns.oid = t.typnamespace)
+ LEFT JOIN pg_type elem_t ON (
+ t.typlen = -1 AND
+ t.typelem != 0 AND
+ t.typelem = elem_t.oid
+ )
+ LEFT JOIN pg_range range_t ON (
+ t.oid = range_t.rngtypid
+ )
+ )
+'''
+
+
+INTRO_LOOKUP_TYPES_13 = '''\
+WITH RECURSIVE typeinfo_tree(
+ oid, ns, name, kind, basetype, elemtype, elemdelim,
+ range_subtype, attrtypoids, attrnames, depth)
+AS (
+ SELECT
+ ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
+ ti.elemtype, ti.elemdelim, ti.range_subtype,
+ ti.attrtypoids, ti.attrnames, 0
+ FROM
+ {typeinfo} AS ti
+ WHERE
+ ti.oid = any($1::oid[])
+
+ UNION ALL
+
+ SELECT
+ ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
+ ti.elemtype, ti.elemdelim, ti.range_subtype,
+ ti.attrtypoids, ti.attrnames, tt.depth + 1
+ FROM
+ {typeinfo} ti,
+ typeinfo_tree tt
+ WHERE
+ (tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
+ OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
+ OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
+ OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
+)
+
+SELECT DISTINCT
+ *,
+ basetype::regtype::text AS basetype_name,
+ elemtype::regtype::text AS elemtype_name,
+ range_subtype::regtype::text AS range_subtype_name
+FROM
+ typeinfo_tree
+ORDER BY
+ depth DESC
+'''.format(typeinfo=_TYPEINFO_13)
+
+
+_TYPEINFO: typing.Final = '''\
+ (
+ SELECT
+ t.oid AS oid,
+ ns.nspname AS ns,
+ t.typname AS name,
+ t.typtype AS kind,
+ (CASE WHEN t.typtype = 'd' THEN
+ (WITH RECURSIVE typebases(oid, depth) AS (
+ SELECT
+ t2.typbasetype AS oid,
+ 0 AS depth
+ FROM
+ pg_type t2
+ WHERE
+ t2.oid = t.oid
+
+ UNION ALL
+
+ SELECT
+ t2.typbasetype AS oid,
+ tb.depth + 1 AS depth
+ FROM
+ pg_type t2,
+ typebases tb
+ WHERE
+ tb.oid = t2.oid
+ AND t2.typbasetype != 0
+ ) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1)
+
+ ELSE NULL
+ END) AS basetype,
+ t.typelem AS elemtype,
+ elem_t.typdelim AS elemdelim,
+ COALESCE(
+ range_t.rngsubtype,
+ multirange_t.rngsubtype) AS range_subtype,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.atttypid ORDER BY ia.attnum)
@@ -92,18 +207,21 @@
LEFT JOIN pg_range range_t ON (
t.oid = range_t.rngtypid
)
+ LEFT JOIN pg_range multirange_t ON (
+ t.oid = multirange_t.rngmultitypid
+ )
)
'''
INTRO_LOOKUP_TYPES = '''\
WITH RECURSIVE typeinfo_tree(
- oid, ns, name, kind, basetype, has_bin_io, elemtype, elemdelim,
- range_subtype, elem_has_bin_io, attrtypoids, attrnames, depth)
+ oid, ns, name, kind, basetype, elemtype, elemdelim,
+ range_subtype, attrtypoids, attrnames, depth)
AS (
SELECT
- ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
- ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
+ ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
+ ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, 0
FROM
{typeinfo} AS ti
@@ -113,8 +231,8 @@
UNION ALL
SELECT
- ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
- ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
+ ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
+ ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, tt.depth + 1
FROM
{typeinfo} ti,
@@ -123,10 +241,14 @@
(tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
+ OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
)
SELECT DISTINCT
- *
+ *,
+ basetype::regtype::text AS basetype_name,
+ elemtype::regtype::text AS elemtype_name,
+ range_subtype::regtype::text AS range_subtype_name
FROM
typeinfo_tree
ORDER BY
@@ -134,7 +256,7 @@
'''.format(typeinfo=_TYPEINFO)
-TYPE_BY_NAME = '''\
+TYPE_BY_NAME: typing.Final = '''\
SELECT
t.oid,
t.typelem AS elemtype,
@@ -147,12 +269,28 @@
'''
+def TypeRecord(
+ rec: typing.Tuple[int, typing.Optional[int], bytes],
+) -> protocol.Record:
+ assert len(rec) == 3
+ return _create_record( # type: ignore
+ {"oid": 0, "elemtype": 1, "kind": 2}, rec)
+
+
# 'b' for a base type, 'd' for a domain, 'e' for enum.
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
-def is_scalar_type(typeinfo) -> bool:
+def is_scalar_type(typeinfo: protocol.Record) -> bool:
return (
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)
+
+
+def is_domain_type(typeinfo: protocol.Record) -> bool:
+ return typeinfo['kind'] == b'd' # type: ignore[no-any-return]
+
+
+def is_composite_type(typeinfo: protocol.Record) -> bool:
+ return typeinfo['kind'] == b'c' # type: ignore[no-any-return]
diff --git a/asyncpg/pgproto b/asyncpg/pgproto
index 6079e5b2..a29a6f6a 160000
--- a/asyncpg/pgproto
+++ b/asyncpg/pgproto
@@ -1 +1 @@
-Subproject commit 6079e5b2addf7717aabbdcdb7825d6b68c731409
+Subproject commit a29a6f6aaa09013cb33ffadb8dd57e21d671ab55
diff --git a/asyncpg/pool.py b/asyncpg/pool.py
index 64f4071e..5c7ea9ca 100644
--- a/asyncpg/pool.py
+++ b/asyncpg/pool.py
@@ -4,17 +4,22 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
import asyncio
+from collections.abc import Awaitable, Callable
import functools
import inspect
import logging
import time
+from types import TracebackType
+from typing import Any, Optional, Type
import warnings
+from . import compat
from . import connection
-from . import connect_utils
from . import exceptions
+from . import protocol
logger = logging.getLogger(__name__)
@@ -22,7 +27,14 @@
class PoolConnectionProxyMeta(type):
- def __new__(mcls, name, bases, dct, *, wrap=False):
+ def __new__(
+ mcls,
+ name: str,
+ bases: tuple[Type[Any], ...],
+ dct: dict[str, Any],
+ *,
+ wrap: bool = False,
+ ) -> PoolConnectionProxyMeta:
if wrap:
for attrname in dir(connection.Connection):
if attrname.startswith('_') or attrname in dct:
@@ -32,7 +44,8 @@ def __new__(mcls, name, bases, dct, *, wrap=False):
if not inspect.isfunction(meth):
continue
- wrapper = mcls._wrap_connection_method(attrname)
+ iscoroutine = inspect.iscoroutinefunction(meth)
+ wrapper = mcls._wrap_connection_method(attrname, iscoroutine)
wrapper = functools.update_wrapper(wrapper, meth)
dct[attrname] = wrapper
@@ -41,13 +54,11 @@ def __new__(mcls, name, bases, dct, *, wrap=False):
return super().__new__(mcls, name, bases, dct)
- def __init__(cls, name, bases, dct, *, wrap=False):
- # Needed for Python 3.5 to handle `wrap` class keyword argument.
- super().__init__(name, bases, dct)
-
@staticmethod
- def _wrap_connection_method(meth_name):
- def call_con_method(self, *args, **kwargs):
+ def _wrap_connection_method(
+ meth_name: str, iscoroutine: bool
+ ) -> Callable[..., Any]:
+ def call_con_method(self: Any, *args: Any, **kwargs: Any) -> Any:
# This method will be owned by PoolConnectionProxy class.
if self._con is None:
raise exceptions.InterfaceError(
@@ -58,6 +69,9 @@ def call_con_method(self, *args, **kwargs):
meth = getattr(self._con.__class__, meth_name)
return meth(self._con, *args, **kwargs)
+ if iscoroutine:
+ compat.markcoroutinefunction(call_con_method)
+
return call_con_method
@@ -67,17 +81,18 @@ class PoolConnectionProxy(connection._ConnectionProxy,
__slots__ = ('_con', '_holder')
- def __init__(self, holder: 'PoolConnectionHolder',
- con: connection.Connection):
+ def __init__(
+ self, holder: PoolConnectionHolder, con: connection.Connection
+ ) -> None:
self._con = con
self._holder = holder
con._set_proxy(self)
- def __getattr__(self, attr):
+ def __getattr__(self, attr: str) -> Any:
# Proxy all unresolved attributes to the wrapped Connection object.
return getattr(self._con, attr)
- def _detach(self) -> connection.Connection:
+ def _detach(self) -> Optional[connection.Connection]:
if self._con is None:
return
@@ -85,7 +100,7 @@ def _detach(self) -> connection.Connection:
con._set_proxy(None)
return con
- def __repr__(self):
+ def __repr__(self) -> str:
if self._con is None:
return '<{classname} [released] {id:#x}>'.format(
classname=self.__class__.__name__, id=id(self))
@@ -102,21 +117,34 @@ class PoolConnectionHolder:
'_inactive_callback', '_timeout',
'_generation')
- def __init__(self, pool, *, max_queries, setup, max_inactive_time):
+ def __init__(
+ self,
+ pool: "Pool",
+ *,
+ max_queries: float,
+ setup: Optional[Callable[[PoolConnectionProxy], Awaitable[None]]],
+ max_inactive_time: float,
+ ) -> None:
self._pool = pool
- self._con = None
- self._proxy = None
+ self._con: Optional[connection.Connection] = None
+ self._proxy: Optional[PoolConnectionProxy] = None
self._max_queries = max_queries
self._max_inactive_time = max_inactive_time
self._setup = setup
- self._inactive_callback = None
- self._in_use = None # type: asyncio.Future
- self._timeout = None
- self._generation = None
+ self._inactive_callback: Optional[Callable] = None
+ self._in_use: Optional[asyncio.Future] = None
+ self._timeout: Optional[float] = None
+ self._generation: Optional[int] = None
+
+ def is_connected(self) -> bool:
+ return self._con is not None and not self._con.is_closed()
+
+ def is_idle(self) -> bool:
+ return not self._in_use
- async def connect(self):
+ async def connect(self) -> None:
if self._con is not None:
raise exceptions.InternalClientError(
'PoolConnectionHolder.connect() called while another '
@@ -146,7 +174,7 @@ async def acquire(self) -> PoolConnectionProxy:
if self._setup is not None:
try:
await self._setup(proxy)
- except Exception as ex:
+ except (Exception, asyncio.CancelledError) as ex:
# If a user-defined `setup` function fails, we don't
# know if the connection is safe for re-use, hence
# we close it. A new connection will be created
@@ -164,7 +192,7 @@ async def acquire(self) -> PoolConnectionProxy:
return proxy
- async def release(self, timeout):
+ async def release(self, timeout: Optional[float]) -> None:
if self._in_use is None:
raise exceptions.InternalClientError(
'PoolConnectionHolder.release() called on '
@@ -197,14 +225,19 @@ async def release(self, timeout):
# If the connection is in cancellation state,
# wait for the cancellation
started = time.monotonic()
- await asyncio.wait_for(
+ await compat.wait_for(
self._con._protocol._wait_for_cancellation(),
- budget, loop=self._pool._loop)
+ budget)
if budget is not None:
budget -= time.monotonic() - started
- await self._con.reset(timeout=budget)
- except Exception as ex:
+ if self._pool._reset is not None:
+ async with compat.timeout(budget):
+ await self._con._reset()
+ await self._pool._reset(self._con)
+ else:
+ await self._con.reset(timeout=budget)
+ except (Exception, asyncio.CancelledError) as ex:
# If the `reset` call failed, terminate the connection.
# A new one will be created when `acquire` is called
# again.
@@ -222,25 +255,25 @@ async def release(self, timeout):
# Rearm the connection inactivity timer.
self._setup_inactive_callback()
- async def wait_until_released(self):
+ async def wait_until_released(self) -> None:
if self._in_use is None:
return
else:
await self._in_use
- async def close(self):
+ async def close(self) -> None:
if self._con is not None:
# Connection.close() will call _release_on_close() to
# finish holder cleanup.
await self._con.close()
- def terminate(self):
+ def terminate(self) -> None:
if self._con is not None:
# Connection.terminate() will call _release_on_close() to
# finish holder cleanup.
self._con.terminate()
- def _setup_inactive_callback(self):
+ def _setup_inactive_callback(self) -> None:
if self._inactive_callback is not None:
raise exceptions.InternalClientError(
'pool connection inactivity timer already exists')
@@ -249,12 +282,12 @@ def _setup_inactive_callback(self):
self._inactive_callback = self._pool._loop.call_later(
self._max_inactive_time, self._deactivate_inactive_connection)
- def _maybe_cancel_inactive_callback(self):
+ def _maybe_cancel_inactive_callback(self) -> None:
if self._inactive_callback is not None:
self._inactive_callback.cancel()
self._inactive_callback = None
- def _deactivate_inactive_connection(self):
+ def _deactivate_inactive_connection(self) -> None:
if self._in_use is not None:
raise exceptions.InternalClientError(
'attempting to deactivate an acquired connection')
@@ -268,12 +301,12 @@ def _deactivate_inactive_connection(self):
# so terminate() above will not call the below.
self._release_on_close()
- def _release_on_close(self):
+ def _release_on_close(self) -> None:
self._maybe_cancel_inactive_callback()
self._release()
self._con = None
- def _release(self):
+ def _release(self) -> None:
"""Release this connection holder."""
if self._in_use is None:
# The holder is not checked out.
@@ -304,21 +337,26 @@ class Pool:
Pools are created by calling :func:`~asyncpg.pool.create_pool`.
"""
- __slots__ = ('_queue', '_loop', '_minsize', '_maxsize',
- '_init', '_connect_args', '_connect_kwargs',
- '_working_addr', '_working_config', '_working_params',
- '_holders', '_initialized', '_initializing', '_closing',
- '_closed', '_connection_class', '_generation')
+ __slots__ = (
+ '_queue', '_loop', '_minsize', '_maxsize',
+ '_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
+ '_holders', '_initialized', '_initializing', '_closing',
+ '_closed', '_connection_class', '_record_class', '_generation',
+ '_setup', '_max_queries', '_max_inactive_connection_lifetime'
+ )
def __init__(self, *connect_args,
min_size,
max_size,
max_queries,
max_inactive_connection_lifetime,
- setup,
- init,
+ connect=None,
+ setup=None,
+ init=None,
+ reset=None,
loop,
connection_class,
+ record_class,
**connect_kwargs):
if len(connect_args) > 1:
@@ -356,40 +394,41 @@ def __init__(self, *connect_args,
'connection_class is expected to be a subclass of '
'asyncpg.Connection, got {!r}'.format(connection_class))
+ if not issubclass(record_class, protocol.Record):
+ raise TypeError(
+ 'record_class is expected to be a subclass of '
+ 'asyncpg.Record, got {!r}'.format(record_class))
+
self._minsize = min_size
self._maxsize = max_size
self._holders = []
self._initialized = False
self._initializing = False
- self._queue = asyncio.LifoQueue(maxsize=self._maxsize, loop=self._loop)
-
- self._working_addr = None
- self._working_config = None
- self._working_params = None
+ self._queue = None
self._connection_class = connection_class
+ self._record_class = record_class
self._closing = False
self._closed = False
self._generation = 0
- self._init = init
+
+ self._connect = connect if connect is not None else connection.connect
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs
- for _ in range(max_size):
- ch = PoolConnectionHolder(
- self,
- max_queries=max_queries,
- max_inactive_time=max_inactive_connection_lifetime,
- setup=setup)
+ self._setup = setup
+ self._init = init
+ self._reset = reset
- self._holders.append(ch)
- self._queue.put_nowait(ch)
+ self._max_queries = max_queries
+ self._max_inactive_connection_lifetime = \
+ max_inactive_connection_lifetime
async def _async__init__(self):
if self._initialized:
- return
+ return self
if self._initializing:
raise exceptions.InterfaceError(
'pool is being initialized in another task')
@@ -404,15 +443,25 @@ async def _async__init__(self):
self._initialized = True
async def _initialize(self):
+ self._queue = asyncio.LifoQueue(maxsize=self._maxsize)
+ for _ in range(self._maxsize):
+ ch = PoolConnectionHolder(
+ self,
+ max_queries=self._max_queries,
+ max_inactive_time=self._max_inactive_connection_lifetime,
+ setup=self._setup)
+
+ self._holders.append(ch)
+ self._queue.put_nowait(ch)
+
if self._minsize:
# Since we use a LIFO queue, the first items in the queue will be
# the last ones in `self._holders`. We want to pre-connect the
# first few connections in the queue, therefore we want to walk
# `self._holders` in reverse.
- # Connect the first connection holder in the queue so that it
- # can record `_working_addr` and `_working_opts`, which will
- # speed up successive connection attempts.
+ # Connect the first connection holder in the queue so that
+ # any connection issues are visible early.
first_ch = self._holders[-1] # type: PoolConnectionHolder
await first_ch.connect()
@@ -424,7 +473,42 @@ async def _initialize(self):
break
connect_tasks.append(ch.connect())
- await asyncio.gather(*connect_tasks, loop=self._loop)
+ await asyncio.gather(*connect_tasks)
+
+ def is_closing(self):
+ """Return ``True`` if the pool is closing or is closed.
+
+ .. versionadded:: 0.28.0
+ """
+ return self._closed or self._closing
+
+ def get_size(self):
+ """Return the current number of connections in this pool.
+
+ .. versionadded:: 0.25.0
+ """
+ return sum(h.is_connected() for h in self._holders)
+
+ def get_min_size(self):
+ """Return the minimum number of connections in this pool.
+
+ .. versionadded:: 0.25.0
+ """
+ return self._minsize
+
+ def get_max_size(self):
+ """Return the maximum allowed number of connections in this pool.
+
+ .. versionadded:: 0.25.0
+ """
+ return self._maxsize
+
+ def get_idle_size(self):
+ """Return the current number of idle connections in this pool.
+
+ .. versionadded:: 0.25.0
+ """
+ return sum(h.is_connected() and h.is_idle() for h in self._holders)
def set_connect_args(self, dsn=None, **connect_kwargs):
r"""Set the new connection arguments for this pool.
@@ -449,38 +533,32 @@ def set_connect_args(self, dsn=None, **connect_kwargs):
self._connect_args = [dsn]
self._connect_kwargs = connect_kwargs
- self._working_addr = None
- self._working_config = None
- self._working_params = None
async def _get_new_connection(self):
- if self._working_addr is None:
- # First connection attempt on this pool.
- con = await connection.connect(
- *self._connect_args,
- loop=self._loop,
- connection_class=self._connection_class,
- **self._connect_kwargs)
-
- self._working_addr = con._addr
- self._working_config = con._config
- self._working_params = con._params
-
- else:
- # We've connected before and have a resolved address,
- # and parsed options and config.
- con = await connect_utils._connect_addr(
- loop=self._loop,
- addr=self._working_addr,
- timeout=self._working_params.connect_timeout,
- config=self._working_config,
- params=self._working_params,
- connection_class=self._connection_class)
+ con = await self._connect(
+ *self._connect_args,
+ loop=self._loop,
+ connection_class=self._connection_class,
+ record_class=self._record_class,
+ **self._connect_kwargs,
+ )
+ if not isinstance(con, self._connection_class):
+ good = self._connection_class
+ good_n = f'{good.__module__}.{good.__name__}'
+ bad = type(con)
+ if bad.__module__ == "builtins":
+ bad_n = bad.__name__
+ else:
+ bad_n = f'{bad.__module__}.{bad.__name__}'
+ raise exceptions.InterfaceError(
+ "expected pool connect callback to return an instance of "
+ f"'{good_n}', got " f"'{bad_n}'"
+ )
if self._init is not None:
try:
await self._init(con)
- except Exception as ex:
+ except (Exception, asyncio.CancelledError) as ex:
# If a user-defined `init` function fails, we don't
# know if the connection is safe for re-use, hence
# we close it. A new connection will be created
@@ -496,48 +574,72 @@ async def _get_new_connection(self):
return con
- async def execute(self, query: str, *args, timeout: float=None) -> str:
+ async def execute(
+ self,
+ query: str,
+ *args,
+ timeout: Optional[float]=None,
+ ) -> str:
"""Execute an SQL command (or commands).
Pool performs this operation using one of its connections. Other than
that, it behaves identically to
- :meth:`Connection.execute() `.
+ :meth:`Connection.execute() `.
.. versionadded:: 0.10.0
"""
async with self.acquire() as con:
return await con.execute(query, *args, timeout=timeout)
- async def executemany(self, command: str, args, *, timeout: float=None):
+ async def executemany(
+ self,
+ command: str,
+ args,
+ *,
+ timeout: Optional[float]=None,
+ ):
"""Execute an SQL *command* for each sequence of arguments in *args*.
Pool performs this operation using one of its connections. Other than
that, it behaves identically to
- :meth:`Connection.executemany() `.
+ :meth:`Connection.executemany()
+ `.
.. versionadded:: 0.10.0
"""
async with self.acquire() as con:
return await con.executemany(command, args, timeout=timeout)
- async def fetch(self, query, *args, timeout=None) -> list:
+ async def fetch(
+ self,
+ query,
+ *args,
+ timeout=None,
+ record_class=None
+ ) -> list:
"""Run a query and return the results as a list of :class:`Record`.
Pool performs this operation using one of its connections. Other than
that, it behaves identically to
- :meth:`Connection.fetch() `.
+ :meth:`Connection.fetch() `.
.. versionadded:: 0.10.0
"""
async with self.acquire() as con:
- return await con.fetch(query, *args, timeout=timeout)
+ return await con.fetch(
+ query,
+ *args,
+ timeout=timeout,
+ record_class=record_class
+ )
async def fetchval(self, query, *args, column=0, timeout=None):
"""Run a query and return a value in the first row.
Pool performs this operation using one of its connections. Other than
that, it behaves identically to
- :meth:`Connection.fetchval() `.
+ :meth:`Connection.fetchval()
+ `.
.. versionadded:: 0.10.0
"""
@@ -545,17 +647,207 @@ async def fetchval(self, query, *args, column=0, timeout=None):
return await con.fetchval(
query, *args, column=column, timeout=timeout)
- async def fetchrow(self, query, *args, timeout=None):
+ async def fetchrow(self, query, *args, timeout=None, record_class=None):
"""Run a query and return the first row.
Pool performs this operation using one of its connections. Other than
that, it behaves identically to
- :meth:`Connection.fetchrow() `.
+ :meth:`Connection.fetchrow() `.
.. versionadded:: 0.10.0
"""
async with self.acquire() as con:
- return await con.fetchrow(query, *args, timeout=timeout)
+ return await con.fetchrow(
+ query,
+ *args,
+ timeout=timeout,
+ record_class=record_class
+ )
+
+ async def fetchmany(self, query, args, *, timeout=None, record_class=None):
+ """Run a query for each sequence of arguments in *args*
+ and return the results as a list of :class:`Record`.
+
+ Pool performs this operation using one of its connections. Other than
+ that, it behaves identically to
+ :meth:`Connection.fetchmany()
+ `.
+
+ .. versionadded:: 0.30.0
+ """
+ async with self.acquire() as con:
+ return await con.fetchmany(
+ query, args, timeout=timeout, record_class=record_class
+ )
+
+ async def copy_from_table(
+ self,
+ table_name,
+ *,
+ output,
+ columns=None,
+ schema_name=None,
+ timeout=None,
+ format=None,
+ oids=None,
+ delimiter=None,
+ null=None,
+ header=None,
+ quote=None,
+ escape=None,
+ force_quote=None,
+ encoding=None
+ ):
+ """Copy table contents to a file or file-like object.
+
+ Pool performs this operation using one of its connections. Other than
+ that, it behaves identically to
+ :meth:`Connection.copy_from_table()
+ `.
+
+ .. versionadded:: 0.24.0
+ """
+ async with self.acquire() as con:
+ return await con.copy_from_table(
+ table_name,
+ output=output,
+ columns=columns,
+ schema_name=schema_name,
+ timeout=timeout,
+ format=format,
+ oids=oids,
+ delimiter=delimiter,
+ null=null,
+ header=header,
+ quote=quote,
+ escape=escape,
+ force_quote=force_quote,
+ encoding=encoding
+ )
+
+ async def copy_from_query(
+ self,
+ query,
+ *args,
+ output,
+ timeout=None,
+ format=None,
+ oids=None,
+ delimiter=None,
+ null=None,
+ header=None,
+ quote=None,
+ escape=None,
+ force_quote=None,
+ encoding=None
+ ):
+ """Copy the results of a query to a file or file-like object.
+
+ Pool performs this operation using one of its connections. Other than
+ that, it behaves identically to
+ :meth:`Connection.copy_from_query()
+ `.
+
+ .. versionadded:: 0.24.0
+ """
+ async with self.acquire() as con:
+ return await con.copy_from_query(
+ query,
+ *args,
+ output=output,
+ timeout=timeout,
+ format=format,
+ oids=oids,
+ delimiter=delimiter,
+ null=null,
+ header=header,
+ quote=quote,
+ escape=escape,
+ force_quote=force_quote,
+ encoding=encoding
+ )
+
+ async def copy_to_table(
+ self,
+ table_name,
+ *,
+ source,
+ columns=None,
+ schema_name=None,
+ timeout=None,
+ format=None,
+ oids=None,
+ freeze=None,
+ delimiter=None,
+ null=None,
+ header=None,
+ quote=None,
+ escape=None,
+ force_quote=None,
+ force_not_null=None,
+ force_null=None,
+ encoding=None,
+ where=None
+ ):
+ """Copy data to the specified table.
+
+ Pool performs this operation using one of its connections. Other than
+ that, it behaves identically to
+ :meth:`Connection.copy_to_table()
+ `.
+
+ .. versionadded:: 0.24.0
+ """
+ async with self.acquire() as con:
+ return await con.copy_to_table(
+ table_name,
+ source=source,
+ columns=columns,
+ schema_name=schema_name,
+ timeout=timeout,
+ format=format,
+ oids=oids,
+ freeze=freeze,
+ delimiter=delimiter,
+ null=null,
+ header=header,
+ quote=quote,
+ escape=escape,
+ force_quote=force_quote,
+ force_not_null=force_not_null,
+ force_null=force_null,
+ encoding=encoding,
+ where=where
+ )
+
+ async def copy_records_to_table(
+ self,
+ table_name,
+ *,
+ records,
+ columns=None,
+ schema_name=None,
+ timeout=None,
+ where=None
+ ):
+ """Copy a list of records to the specified table using binary COPY.
+
+ Pool performs this operation using one of its connections. Other than
+ that, it behaves identically to
+ :meth:`Connection.copy_records_to_table()
+ `.
+
+ .. versionadded:: 0.24.0
+ """
+ async with self.acquire() as con:
+ return await con.copy_records_to_table(
+ table_name,
+ records=records,
+ columns=columns,
+ schema_name=schema_name,
+ timeout=timeout,
+ where=where
+ )
def acquire(self, *, timeout=None):
"""Acquire a database connection from the pool.
@@ -587,7 +879,7 @@ async def _acquire_impl():
ch = await self._queue.get() # type: PoolConnectionHolder
try:
proxy = await ch.acquire() # type: PoolConnectionProxy
- except Exception:
+ except (Exception, asyncio.CancelledError):
self._queue.put_nowait(ch)
raise
else:
@@ -603,8 +895,8 @@ async def _acquire_impl():
if timeout is None:
return await _acquire_impl()
else:
- return await asyncio.wait_for(
- _acquire_impl(), timeout=timeout, loop=self._loop)
+ return await compat.wait_for(
+ _acquire_impl(), timeout=timeout)
async def release(self, connection, *, timeout=None):
"""Release a database connection back to the pool.
@@ -642,7 +934,7 @@ async def release(self, connection, *, timeout=None):
# Use asyncio.shield() to guarantee that task cancellation
# does not prevent the connection from being returned to the
# pool properly.
- return await asyncio.shield(ch.release(timeout), loop=self._loop)
+ return await asyncio.shield(ch.release(timeout))
async def close(self):
"""Attempt to gracefully close all connections in the pool.
@@ -673,13 +965,13 @@ async def close(self):
release_coros = [
ch.wait_until_released() for ch in self._holders]
- await asyncio.gather(*release_coros, loop=self._loop)
+ await asyncio.gather(*release_coros)
close_coros = [
ch.close() for ch in self._holders]
- await asyncio.gather(*close_coros, loop=self._loop)
+ await asyncio.gather(*close_coros)
- except Exception:
+ except (Exception, asyncio.CancelledError):
self.terminate()
raise
@@ -752,7 +1044,7 @@ class PoolAcquireContext:
__slots__ = ('timeout', 'connection', 'done', 'pool')
- def __init__(self, pool, timeout):
+ def __init__(self, pool: Pool, timeout: Optional[float]) -> None:
self.pool = pool
self.timeout = timeout
self.connection = None
@@ -764,7 +1056,12 @@ async def __aenter__(self):
self.connection = await self.pool._acquire(self.timeout)
return self.connection
- async def __aexit__(self, *exc):
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]] = None,
+ exc_val: Optional[BaseException] = None,
+ exc_tb: Optional[TracebackType] = None,
+ ) -> None:
self.done = True
con = self.connection
self.connection = None
@@ -780,23 +1077,39 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
+ connect=None,
setup=None,
init=None,
+ reset=None,
loop=None,
connection_class=connection.Connection,
+ record_class=protocol.Record,
**connect_kwargs):
r"""Create a connection pool.
Can be used either with an ``async with`` block:
+ .. code-block:: python
+
+ async with asyncpg.create_pool(user='postgres',
+ command_timeout=60) as pool:
+ await pool.fetch('SELECT 1')
+
+ Or to perform multiple operations on a single connection:
+
.. code-block:: python
async with asyncpg.create_pool(user='postgres',
command_timeout=60) as pool:
async with pool.acquire() as con:
+ await con.execute('''
+ CREATE TABLE names (
+ id serial PRIMARY KEY,
+ name VARCHAR (255) NOT NULL)
+ ''')
await con.fetch('SELECT 1')
- Or directly with ``await``:
+ Or directly with ``await`` (not recommended):
.. code-block:: python
@@ -809,12 +1122,12 @@ def create_pool(dsn=None, *,
.. warning::
Prepared statements and cursors returned by
- :meth:`Connection.prepare() ` and
- :meth:`Connection.cursor() ` become
- invalid once the connection is released. Likewise, all notification
- and log listeners are removed, and ``asyncpg`` will issue a warning
- if there are any listener callbacks registered on a connection that
- is being released to the pool.
+ :meth:`Connection.prepare() `
+ and :meth:`Connection.cursor() `
+ become invalid once the connection is released. Likewise, all
+ notification and log listeners are removed, and ``asyncpg`` will
+ issue a warning if there are any listener callbacks registered on a
+ connection that is being released to the pool.
:param str dsn:
Connection arguments specified using as a single string in
@@ -829,6 +1142,11 @@ def create_pool(dsn=None, *,
The class to use for connections. Must be a subclass of
:class:`~asyncpg.connection.Connection`.
+ :param type record_class:
+ If specified, the class to use for records returned by queries on
+ the connections in this pool. Must be a subclass of
+ :class:`~asyncpg.Record`.
+
:param int min_size:
Number of connection the pool will be initialized with.
@@ -843,9 +1161,16 @@ def create_pool(dsn=None, *,
Number of seconds after which inactive connections in the
pool will be closed. Pass ``0`` to disable this mechanism.
+ :param coroutine connect:
+ A coroutine that is called instead of
+ :func:`~asyncpg.connection.connect` whenever the pool needs to make a
+ new connection. Must return an instance of type specified by
+ *connection_class* or :class:`~asyncpg.connection.Connection` if
+ *connection_class* was not specified.
+
:param coroutine setup:
A coroutine to prepare a connection right before it is returned
- from :meth:`Pool.acquire() `. An example use
+ from :meth:`Pool.acquire()`. An example use
case would be to automatically set up notifications listeners for
all connections of a pool.
@@ -857,6 +1182,25 @@ def create_pool(dsn=None, *,
or :meth:`Connection.set_type_codec() <\
asyncpg.connection.Connection.set_type_codec>`.
+ :param coroutine reset:
+ A coroutine to reset a connection before it is returned to the pool by
+ :meth:`Pool.release()`. The function is supposed
+ to reset any changes made to the database session so that the next
+ acquirer gets the connection in a well-defined state.
+
+ The default implementation calls :meth:`Connection.reset() <\
+ asyncpg.connection.Connection.reset>`, which runs the following::
+
+ SELECT pg_advisory_unlock_all();
+ CLOSE ALL;
+ UNLISTEN *;
+ RESET ALL;
+
+ The exact reset query is determined by detected server capabilities,
+ and a custom *reset* implementation can obtain the default query
+ by calling :meth:`Connection.get_reset_query() <\
+ asyncpg.connection.Connection.get_reset_query>`.
+
:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
@@ -875,15 +1219,30 @@ def create_pool(dsn=None, *,
.. versionchanged:: 0.13.0
An :exc:`~asyncpg.exceptions.InterfaceWarning` will be produced
if there are any active listeners (added via
- :meth:`Connection.add_listener() `
+ :meth:`Connection.add_listener()
+ `
or :meth:`Connection.add_log_listener()
- `) present on the connection
- at the moment of its release to the pool.
+ `) present on the
+ connection at the moment of its release to the pool.
+
+ .. versionchanged:: 0.22.0
+ Added the *record_class* parameter.
+
+ .. versionchanged:: 0.30.0
+ Added the *connect* and *reset* parameters.
"""
return Pool(
dsn,
connection_class=connection_class,
- min_size=min_size, max_size=max_size,
- max_queries=max_queries, loop=loop, setup=setup, init=init,
+ record_class=record_class,
+ min_size=min_size,
+ max_size=max_size,
+ max_queries=max_queries,
+ loop=loop,
+ connect=connect,
+ setup=setup,
+ init=init,
+ reset=reset,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
- **connect_kwargs)
+ **connect_kwargs,
+ )
diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py
index 09a0a2ec..0c2d335e 100644
--- a/asyncpg/prepared_stmt.py
+++ b/asyncpg/prepared_stmt.py
@@ -6,6 +6,7 @@
import json
+import typing
from . import connresource
from . import cursor
@@ -24,6 +25,14 @@ def __init__(self, connection, query, state):
state.attach()
self._last_status = None
+ @connresource.guarded
+ def get_name(self) -> str:
+ """Return the name of this prepared statement.
+
+ .. versionadded:: 0.25.0
+ """
+ return self._state.name
+
@connresource.guarded
def get_query(self) -> str:
"""Return the text of the query for this prepared statement.
@@ -103,9 +112,15 @@ def cursor(self, *args, prefetch=None,
:return: A :class:`~cursor.CursorFactory` object.
"""
- return cursor.CursorFactory(self._connection, self._query,
- self._state, args, prefetch,
- timeout)
+ return cursor.CursorFactory(
+ self._connection,
+ self._query,
+ self._state,
+ args,
+ prefetch,
+ timeout,
+ self._state.record_class,
+ )
@connresource.guarded
async def explain(self, *args, analyze=False):
@@ -133,8 +148,8 @@ async def explain(self, *args, analyze=False):
# will discard any output that a SELECT would return, other
# side effects of the statement will happen as usual. If you
# wish to use EXPLAIN ANALYZE on an INSERT, UPDATE, DELETE,
- # CREATE TABLE AS, or EXECUTE statement without letting the
- # command affect your data, use this approach:
+ # MERGE, CREATE TABLE AS, or EXECUTE statement without letting
+ # the command affect your data, use this approach:
# BEGIN;
# EXPLAIN ANALYZE ...;
# ROLLBACK;
@@ -196,11 +211,50 @@ async def fetchrow(self, *args, timeout=None):
return None
return data[0]
- async def __bind_execute(self, args, limit, timeout):
+ @connresource.guarded
+ async def fetchmany(self, args, *, timeout=None):
+ """Execute the statement and return a list of :class:`Record` objects.
+
+ :param args: Query arguments.
+ :param float timeout: Optional timeout value in seconds.
+
+ :return: A list of :class:`Record` instances.
+
+ .. versionadded:: 0.30.0
+ """
+ return await self.__do_execute(
+ lambda protocol: protocol.bind_execute_many(
+ self._state,
+ args,
+ portal_name='',
+ timeout=timeout,
+ return_rows=True,
+ )
+ )
+
+ @connresource.guarded
+ async def executemany(self, args, *, timeout: typing.Optional[float]=None):
+ """Execute the statement for each sequence of arguments in *args*.
+
+ :param args: An iterable containing sequences of arguments.
+ :param float timeout: Optional timeout value in seconds.
+ :return None: This method discards the results of the operations.
+
+ .. versionadded:: 0.22.0
+ """
+ return await self.__do_execute(
+ lambda protocol: protocol.bind_execute_many(
+ self._state,
+ args,
+ portal_name='',
+ timeout=timeout,
+ return_rows=False,
+ ))
+
+ async def __do_execute(self, executor):
protocol = self._connection._protocol
try:
- data, status, _ = await protocol.bind_execute(
- self._state, args, '', limit, True, timeout)
+ return await executor(protocol)
except exceptions.OutdatedSchemaCacheError:
await self._connection.reload_schema_state()
# We can not find all manually created prepared statements, so just
@@ -209,6 +263,11 @@ async def __bind_execute(self, args, limit, timeout):
# invalidate themselves (unfortunately, clearing caches again).
self._state.mark_closed()
raise
+
+ async def __bind_execute(self, args, limit, timeout):
+ data, status, _ = await self.__do_execute(
+ lambda protocol: protocol.bind_execute(
+ self._state, args, '', limit, True, timeout))
self._last_status = status
return data
diff --git a/asyncpg/protocol/__init__.py b/asyncpg/protocol/__init__.py
index e872e2fa..043454db 100644
--- a/asyncpg/protocol/__init__.py
+++ b/asyncpg/protocol/__init__.py
@@ -4,5 +4,9 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+# flake8: NOQA
-from .protocol import Protocol, Record, NO_TIMEOUT # NOQA
+from __future__ import annotations
+
+from .protocol import Protocol, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP
+from .record import Record
diff --git a/asyncpg/protocol/codecs/array.pyx b/asyncpg/protocol/codecs/array.pyx
index 46bdb736..f8f9b8dd 100644
--- a/asyncpg/protocol/codecs/array.pyx
+++ b/asyncpg/protocol/codecs/array.pyx
@@ -209,7 +209,9 @@ cdef _write_textarray_data(ConnectionSettings settings, object obj,
try:
if not apg_strcasecmp_char(elem_str, b'NULL'):
- array_data.write_bytes(b'"NULL"')
+ array_data.write_byte(b'"')
+ array_data.write_cstr(elem_str, 4)
+ array_data.write_byte(b'"')
else:
quoted_elem_len = elem_len
need_quoting = False
@@ -286,16 +288,21 @@ cdef inline array_decode(ConnectionSettings settings, FRBuffer *buf,
Codec elem_codec
if ndims == 0:
- result = cpython.PyList_New(0)
- return result
+ return []
if ndims > ARRAY_MAXDIM:
raise exceptions.ProtocolError(
'number of array dimensions ({}) exceed the maximum expected ({})'.
format(ndims, ARRAY_MAXDIM))
+ elif ndims < 0:
+ raise exceptions.ProtocolError(
+ 'unexpected array dimensions value: {}'.format(ndims))
for i in range(ndims):
dims[i] = hton.unpack_int32(frb_read(buf, 4))
+ if dims[i] < 0:
+ raise exceptions.ProtocolError(
+ 'unexpected array dimension size: {}'.format(dims[i]))
# Ignore the lower bound information
frb_read(buf, 4)
@@ -340,14 +347,18 @@ cdef _nested_array_decode(ConnectionSettings settings,
# An array of current positions at each array level.
int32_t indexes[ARRAY_MAXDIM]
- if PG_DEBUG:
- if ndims <= 0:
- raise exceptions.ProtocolError(
- 'unexpected ndims value: {}'.format(ndims))
-
for i in range(ndims):
array_len *= dims[i]
indexes[i] = 0
+ strides[i] = NULL
+
+ if array_len == 0:
+ # A multidimensional array with a zero-sized dimension?
+ return []
+
+ elif array_len < 0:
+ # Array length overflow
+ raise exceptions.ProtocolError('array length overflow')
for i in range(array_len):
# Decode the element.
@@ -847,19 +858,7 @@ cdef arraytext_decode(ConnectionSettings settings, FRBuffer *buf):
return array_decode(settings, buf, &text_decode_ex, NULL)
-cdef anyarray_decode(ConnectionSettings settings, FRBuffer *buf):
- # Instances of anyarray (or any other polymorphic pseudotype) are
- # never supposed to be returned from actual queries.
- raise exceptions.ProtocolError(
- 'unexpected instance of \'anyarray\' type')
-
-
cdef init_array_codecs():
- register_core_codec(ANYARRAYOID,
- NULL,
- &anyarray_decode,
- PG_FORMAT_BINARY)
-
# oid[] and text[] are registered as core codecs
# to make type introspection query work
#
diff --git a/asyncpg/protocol/codecs/base.pxd b/asyncpg/protocol/codecs/base.pxd
index be1f0a3f..f5492590 100644
--- a/asyncpg/protocol/codecs/base.pxd
+++ b/asyncpg/protocol/codecs/base.pxd
@@ -22,13 +22,26 @@ ctypedef object (*codec_decode_func)(Codec codec,
FRBuffer *buf)
+cdef class CodecMap:
+ cdef:
+ void** binary_codec_map
+ void** text_codec_map
+ dict extra_codecs
+
+ cdef inline void *get_binary_codec_ptr(self, uint32_t idx)
+ cdef inline void set_binary_codec_ptr(self, uint32_t idx, void *ptr)
+ cdef inline void *get_text_codec_ptr(self, uint32_t idx)
+ cdef inline void set_text_codec_ptr(self, uint32_t idx, void *ptr)
+
+
cdef enum CodecType:
- CODEC_UNDEFINED = 0
- CODEC_C = 1
- CODEC_PY = 2
- CODEC_ARRAY = 3
- CODEC_COMPOSITE = 4
- CODEC_RANGE = 5
+ CODEC_UNDEFINED = 0
+ CODEC_C = 1
+ CODEC_PY = 2
+ CODEC_ARRAY = 3
+ CODEC_COMPOSITE = 4
+ CODEC_RANGE = 5
+ CODEC_MULTIRANGE = 6
cdef enum ServerDataFormat:
@@ -56,6 +69,7 @@ cdef class Codec:
encode_func c_encoder
decode_func c_decoder
+ Codec base_codec
object py_encoder
object py_decoder
@@ -78,6 +92,7 @@ cdef class Codec:
CodecType type, ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder, decode_func c_decoder,
+ Codec base_codec,
object py_encoder, object py_decoder,
Codec element_codec, tuple element_type_oids,
object element_names, list element_codecs,
@@ -95,6 +110,9 @@ cdef class Codec:
cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
+ cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
+ object obj)
+
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
@@ -109,6 +127,8 @@ cdef class Codec:
cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf)
+ cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf)
+
cdef decode_composite(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_in_python(self, ConnectionSettings settings, FRBuffer *buf)
@@ -139,6 +159,12 @@ cdef class Codec:
str schema,
Codec element_codec)
+ @staticmethod
+ cdef Codec new_multirange_codec(uint32_t oid,
+ str name,
+ str schema,
+ Codec element_codec)
+
@staticmethod
cdef Codec new_composite_codec(uint32_t oid,
str name,
@@ -157,6 +183,7 @@ cdef class Codec:
object decoder,
encode_func c_encoder,
decode_func c_decoder,
+ Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat)
@@ -166,5 +193,7 @@ cdef class DataCodecConfig:
dict _derived_type_codecs
dict _custom_type_codecs
- cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
- cdef inline Codec get_any_local_codec(self, uint32_t oid)
+ cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
+ bint ignore_custom_codec=*)
+ cdef inline Codec get_custom_codec(self, uint32_t oid,
+ ServerDataFormat format)
diff --git a/asyncpg/protocol/codecs/base.pyx b/asyncpg/protocol/codecs/base.pyx
index 5d3ccc4b..009598a8 100644
--- a/asyncpg/protocol/codecs/base.pyx
+++ b/asyncpg/protocol/codecs/base.pyx
@@ -7,12 +7,37 @@
from collections.abc import Mapping as MappingABC
+import asyncpg
from asyncpg import exceptions
-cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2]
-cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2]
-cdef dict EXTRA_CODECS = {}
+# The class indirection is needed because Cython
+# does not (as of 3.1.0) store global cdef variables
+# in module state.
+@cython.final
+cdef class CodecMap:
+
+ def __cinit__(self):
+ self.extra_codecs = {}
+ self.binary_codec_map = cpython.PyMem_Calloc(
+ (MAXSUPPORTEDOID + 1) * 2, sizeof(void *))
+ self.text_codec_map = cpython.PyMem_Calloc(
+ (MAXSUPPORTEDOID + 1) * 2, sizeof(void *))
+
+ cdef inline void *get_binary_codec_ptr(self, uint32_t idx):
+ return self.binary_codec_map[idx]
+
+ cdef inline void set_binary_codec_ptr(self, uint32_t idx, void *ptr):
+ self.binary_codec_map[idx] = ptr
+
+ cdef inline void *get_text_codec_ptr(self, uint32_t idx):
+ return self.text_codec_map[idx]
+
+ cdef inline void set_text_codec_ptr(self, uint32_t idx, void *ptr):
+ self.text_codec_map[idx] = ptr
+
+
+codec_map = CodecMap()
@cython.final
@@ -22,14 +47,25 @@ cdef class Codec:
self.oid = oid
self.type = CODEC_UNDEFINED
- cdef init(self, str name, str schema, str kind,
- CodecType type, ServerDataFormat format,
- ClientExchangeFormat xformat,
- encode_func c_encoder, decode_func c_decoder,
- object py_encoder, object py_decoder,
- Codec element_codec, tuple element_type_oids,
- object element_names, list element_codecs,
- Py_UCS4 element_delimiter):
+ cdef init(
+ self,
+ str name,
+ str schema,
+ str kind,
+ CodecType type,
+ ServerDataFormat format,
+ ClientExchangeFormat xformat,
+ encode_func c_encoder,
+ decode_func c_decoder,
+ Codec base_codec,
+ object py_encoder,
+ object py_decoder,
+ Codec element_codec,
+ tuple element_type_oids,
+ object element_names,
+ list element_codecs,
+ Py_UCS4 element_delimiter,
+ ):
self.name = name
self.schema = schema
@@ -39,6 +75,7 @@ cdef class Codec:
self.xformat = xformat
self.c_encoder = c_encoder
self.c_decoder = c_decoder
+ self.base_codec = base_codec
self.py_encoder = py_encoder
self.py_decoder = py_decoder
self.element_codec = element_codec
@@ -47,8 +84,14 @@ cdef class Codec:
self.element_delimiter = element_delimiter
self.element_names = element_names
+ if base_codec is not None:
+ if c_encoder != NULL or c_decoder != NULL:
+ raise exceptions.InternalClientError(
+ 'base_codec is mutually exclusive with c_encoder/c_decoder'
+ )
+
if element_names is not None:
- self.record_desc = record.ApgRecordDesc_New(
+ self.record_desc = RecordDescriptor(
element_names, tuple(element_names))
else:
self.record_desc = None
@@ -65,14 +108,21 @@ cdef class Codec:
self.decoder = &self.decode_array_text
elif type == CODEC_RANGE:
if format != PG_FORMAT_BINARY:
- raise NotImplementedError(
+ raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'range types is not supported'.format(schema, name))
self.encoder = &self.encode_range
self.decoder = &self.decode_range
+ elif type == CODEC_MULTIRANGE:
+ if format != PG_FORMAT_BINARY:
+ raise exceptions.UnsupportedClientFeatureError(
+ 'cannot decode type "{}"."{}": text encoding of '
+ 'range types is not supported'.format(schema, name))
+ self.encoder = &self.encode_multirange
+ self.decoder = &self.decode_multirange
elif type == CODEC_COMPOSITE:
if format != PG_FORMAT_BINARY:
- raise NotImplementedError(
+ raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'composite types is not supported'.format(schema, name))
self.encoder = &self.encode_composite
@@ -90,7 +140,7 @@ cdef class Codec:
codec = Codec(self.oid)
codec.init(self.name, self.schema, self.kind,
self.type, self.format, self.xformat,
- self.c_encoder, self.c_decoder,
+ self.c_encoder, self.c_decoder, self.base_codec,
self.py_encoder, self.py_decoder,
self.element_codec,
self.element_type_oids, self.element_names,
@@ -121,6 +171,12 @@ cdef class Codec:
codec_encode_func_ex,
(self.element_codec))
+ cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
+ object obj):
+ multirange_encode(settings, buf, obj, self.element_codec.oid,
+ codec_encode_func_ex,
+ (self.element_codec))
+
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
cdef:
@@ -182,7 +238,10 @@ cdef class Codec:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
- self.c_encoder(settings, buf, data)
+ if self.base_codec is not None:
+ self.base_codec.encode(settings, buf, data)
+ else:
+ self.c_encoder(settings, buf, data)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
@@ -208,6 +267,10 @@ cdef class Codec:
return range_decode(settings, buf, codec_decode_func_ex,
(self.element_codec))
+ cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf):
+ return multirange_decode(settings, buf, codec_decode_func_ex,
+ (self.element_codec))
+
cdef decode_composite(self, ConnectionSettings settings,
FRBuffer *buf):
cdef:
@@ -232,7 +295,7 @@ cdef class Codec:
schema=self.schema,
data_type=self.name,
)
- result = record.ApgRecord_New(self.record_desc, elem_count)
+ result = self.record_desc.make_record(asyncpg.Record, elem_count)
for i in range(elem_count):
elem_typ = self.element_type_oids[i]
received_elem_typ = hton.unpack_int32(frb_read(buf, 4))
@@ -262,7 +325,7 @@ cdef class Codec:
settings, frb_slice_from(&elem_buf, buf, elem_len))
cpython.Py_INCREF(elem)
- record.ApgRecord_SET_ITEM(result, i, elem)
+ recordcapi.ApgRecord_SET_ITEM(result, i, elem)
return result
@@ -277,7 +340,10 @@ cdef class Codec:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
- data = self.c_decoder(settings, buf)
+ if self.base_codec is not None:
+ data = self.base_codec.decode(settings, buf)
+ else:
+ data = self.c_decoder(settings, buf)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
@@ -293,7 +359,11 @@ cdef class Codec:
if self.c_encoder is not NULL or self.py_encoder is not None:
return True
- elif self.type == CODEC_ARRAY or self.type == CODEC_RANGE:
+ elif (
+ self.type == CODEC_ARRAY
+ or self.type == CODEC_RANGE
+ or self.type == CODEC_MULTIRANGE
+ ):
return self.element_codec.has_encoder()
elif self.type == CODEC_COMPOSITE:
@@ -311,7 +381,11 @@ cdef class Codec:
if self.c_decoder is not NULL or self.py_decoder is not None:
return True
- elif self.type == CODEC_ARRAY or self.type == CODEC_RANGE:
+ elif (
+ self.type == CODEC_ARRAY
+ or self.type == CODEC_RANGE
+ or self.type == CODEC_MULTIRANGE
+ ):
return self.element_codec.has_decoder()
elif self.type == CODEC_COMPOSITE:
@@ -341,8 +415,8 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format,
- PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec,
- None, None, None, element_delimiter)
+ PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
+ element_codec, None, None, None, element_delimiter)
return codec
@staticmethod
@@ -353,8 +427,20 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format,
- PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec,
- None, None, None, 0)
+ PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
+ element_codec, None, None, None, 0)
+ return codec
+
+ @staticmethod
+ cdef Codec new_multirange_codec(uint32_t oid,
+ str name,
+ str schema,
+ Codec element_codec):
+ cdef Codec codec
+ codec = Codec(oid)
+ codec.init(name, schema, 'multirange', CODEC_MULTIRANGE,
+ element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None,
+ None, None, element_codec, None, None, None, 0)
return codec
@staticmethod
@@ -369,7 +455,7 @@ cdef class Codec:
codec = Codec(oid)
codec.init(name, schema, 'composite', CODEC_COMPOSITE,
format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
- element_type_oids, element_names, element_codecs, 0)
+ None, element_type_oids, element_names, element_codecs, 0)
return codec
@staticmethod
@@ -381,12 +467,13 @@ cdef class Codec:
object decoder,
encode_func c_encoder,
decode_func c_decoder,
+ Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, kind, CODEC_PY, format, xformat,
- c_encoder, c_decoder, encoder, decoder,
+ c_encoder, c_decoder, base_codec, encoder, decoder,
None, None, None, None, 0)
return codec
@@ -420,7 +507,7 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
cdef class DataCodecConfig:
- def __init__(self, cache_key):
+ def __init__(self):
# Codec instance cache for derived types:
# composites, arrays, ranges, domains and their combinations.
self._derived_type_codecs = {}
@@ -439,14 +526,7 @@ cdef class DataCodecConfig:
for ti in types:
oid = ti['oid']
- if not ti['has_bin_io']:
- format = PG_FORMAT_TEXT
- else:
- format = PG_FORMAT_BINARY
-
- has_text_elements = False
-
- if self.get_codec(oid, format) is not None:
+ if self.get_codec(oid, PG_FORMAT_ANY) is not None:
continue
name = ti['name']
@@ -467,54 +547,50 @@ cdef class DataCodecConfig:
name = name[1:]
name = '{}[]'.format(name)
- if ti['elem_has_bin_io']:
- elem_format = PG_FORMAT_BINARY
- else:
- elem_format = PG_FORMAT_TEXT
-
- elem_codec = self.get_codec(array_element_oid, elem_format)
+ elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY)
if elem_codec is None:
- elem_format = PG_FORMAT_TEXT
elem_codec = self.declare_fallback_codec(
- array_element_oid, name, schema)
+ array_element_oid, ti['elemtype_name'], schema)
elem_delim = ti['elemdelim'][0]
- self._derived_type_codecs[oid, elem_format] = \
+ self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_array_codec(
oid, name, schema, elem_codec, elem_delim)
elif ti['kind'] == b'c':
+ # Composite type
+
if not comp_type_attrs:
raise exceptions.InternalClientError(
- 'type record missing field types for '
- 'composite {}'.format(oid))
-
- # Composite type
+ f'type record missing field types for composite {oid}')
comp_elem_codecs = []
+ has_text_elements = False
for typoid in comp_type_attrs:
- elem_codec = self.get_codec(typoid, PG_FORMAT_BINARY)
- if elem_codec is None:
- elem_codec = self.get_codec(typoid, PG_FORMAT_TEXT)
- has_text_elements = True
+ elem_codec = self.get_codec(typoid, PG_FORMAT_ANY)
if elem_codec is None:
raise exceptions.InternalClientError(
- 'no codec for composite attribute type {}'.format(
- typoid))
+ f'no codec for composite attribute type {typoid}')
+ if elem_codec.format is PG_FORMAT_TEXT:
+ has_text_elements = True
comp_elem_codecs.append(elem_codec)
element_names = collections.OrderedDict()
for i, attrname in enumerate(ti['attrnames']):
element_names[attrname] = i
+ # If at least one element is text-encoded, we must
+ # encode the whole composite as text.
if has_text_elements:
- format = PG_FORMAT_TEXT
+ elem_format = PG_FORMAT_TEXT
+ else:
+ elem_format = PG_FORMAT_BINARY
- self._derived_type_codecs[oid, format] = \
+ self._derived_type_codecs[oid, elem_format] = \
Codec.new_composite_codec(
- oid, name, schema, format, comp_elem_codecs,
+ oid, name, schema, elem_format, comp_elem_codecs,
comp_type_attrs, element_names)
elif ti['kind'] == b'd':
@@ -522,39 +598,45 @@ cdef class DataCodecConfig:
if not base_type:
raise exceptions.InternalClientError(
- 'type record missing base type for domain {}'.format(
- oid))
+ f'type record missing base type for domain {oid}')
- elem_codec = self.get_codec(base_type, format)
+ elem_codec = self.get_codec(base_type, PG_FORMAT_ANY)
if elem_codec is None:
- format = PG_FORMAT_TEXT
elem_codec = self.declare_fallback_codec(
- base_type, name, schema)
+ base_type, ti['basetype_name'], schema)
- self._derived_type_codecs[oid, format] = elem_codec
+ self._derived_type_codecs[oid, elem_codec.format] = elem_codec
elif ti['kind'] == b'r':
# Range type
if not range_subtype_oid:
raise exceptions.InternalClientError(
- 'type record missing base type for range {}'.format(
- oid))
+ f'type record missing base type for range {oid}')
- if ti['elem_has_bin_io']:
- elem_format = PG_FORMAT_BINARY
- else:
- elem_format = PG_FORMAT_TEXT
-
- elem_codec = self.get_codec(range_subtype_oid, elem_format)
+ elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
if elem_codec is None:
- elem_format = PG_FORMAT_TEXT
elem_codec = self.declare_fallback_codec(
- range_subtype_oid, name, schema)
+ range_subtype_oid, ti['range_subtype_name'], schema)
- self._derived_type_codecs[oid, elem_format] = \
+ self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_range_codec(oid, name, schema, elem_codec)
+ elif ti['kind'] == b'm':
+ # Multirange type
+
+ if not range_subtype_oid:
+ raise exceptions.InternalClientError(
+ f'type record missing base type for multirange {oid}')
+
+ elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
+ if elem_codec is None:
+ elem_codec = self.declare_fallback_codec(
+ range_subtype_oid, ti['range_subtype_name'], schema)
+
+ self._derived_type_codecs[oid, elem_codec.format] = \
+ Codec.new_multirange_codec(oid, name, schema, elem_codec)
+
elif ti['kind'] == b'e':
# Enum types are essentially text
self._set_builtin_type_codec(oid, name, schema, 'scalar',
@@ -563,17 +645,21 @@ cdef class DataCodecConfig:
self.declare_fallback_codec(oid, name, schema)
def add_python_codec(self, typeoid, typename, typeschema, typekind,
- encoder, decoder, format, xformat):
+ typeinfos, encoder, decoder, format, xformat):
cdef:
- Codec core_codec
+ Codec core_codec = None
encode_func c_encoder = NULL
decode_func c_decoder = NULL
+ Codec base_codec = None
uint32_t oid = pylong_as_oid(typeoid)
bint codec_set = False
# Clear all previous overrides (this also clears type cache).
self.remove_python_codec(typeoid, typename, typeschema)
+ if typeinfos:
+ self.add_types(typeinfos)
+
if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY)
else:
@@ -581,16 +667,21 @@ cdef class DataCodecConfig:
for fmt in formats:
if xformat == PG_XFORMAT_TUPLE:
- core_codec = get_core_codec(oid, fmt, xformat)
- if core_codec is None:
- continue
- c_encoder = core_codec.c_encoder
- c_decoder = core_codec.c_decoder
+ if typekind == "scalar":
+ core_codec = get_core_codec(oid, fmt, xformat)
+ if core_codec is None:
+ continue
+ c_encoder = core_codec.c_encoder
+ c_decoder = core_codec.c_decoder
+ elif typekind == "composite":
+ base_codec = self.get_codec(oid, fmt)
+ if base_codec is None:
+ continue
self._custom_type_codecs[typeoid, fmt] = \
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
- fmt, xformat)
+ base_codec, fmt, xformat)
codec_set = True
if not codec_set:
@@ -664,19 +755,14 @@ cdef class DataCodecConfig:
def declare_fallback_codec(self, uint32_t oid, str name, str schema):
cdef Codec codec
- codec = self.get_codec(oid, PG_FORMAT_TEXT)
- if codec is not None:
- return codec
-
if oid <= MAXBUILTINOID:
# This is a BKI type, for which asyncpg has no
# defined codec. This should only happen for newly
# added builtin types, for which this version of
# asyncpg is lacking support.
#
- raise NotImplementedError(
- 'unhandled standard data type {!r} (OID {})'.format(
- name, oid))
+ raise exceptions.UnsupportedClientFeatureError(
+ f'unhandled standard data type {name!r} (OID {oid})')
else:
# This is a non-BKI type, and as such, has no
# stable OID, so no possibility of a builtin codec.
@@ -691,36 +777,53 @@ cdef class DataCodecConfig:
return codec
- cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format):
+ cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
+ bint ignore_custom_codec=False):
cdef Codec codec
- codec = self.get_any_local_codec(oid)
- if codec is not None:
- if codec.format != format:
- # The codec for this OID has been overridden by
- # set_{builtin}_type_codec with a different format.
- # We must respect that and not return a core codec.
- return None
- else:
- return codec
-
- codec = get_core_codec(oid, format)
- if codec is not None:
+ if format == PG_FORMAT_ANY:
+ codec = self.get_codec(
+ oid, PG_FORMAT_BINARY, ignore_custom_codec)
+ if codec is None:
+ codec = self.get_codec(
+ oid, PG_FORMAT_TEXT, ignore_custom_codec)
return codec
else:
- try:
- return self._derived_type_codecs[oid, format]
- except KeyError:
- return None
+ if not ignore_custom_codec:
+ codec = self.get_custom_codec(oid, PG_FORMAT_ANY)
+ if codec is not None:
+ if codec.format != format:
+ # The codec for this OID has been overridden by
+ # set_{builtin}_type_codec with a different format.
+ # We must respect that and not return a core codec.
+ return None
+ else:
+ return codec
+
+ codec = get_core_codec(oid, format)
+ if codec is not None:
+ return codec
+ else:
+ try:
+ return self._derived_type_codecs[oid, format]
+ except KeyError:
+ return None
- cdef inline Codec get_any_local_codec(self, uint32_t oid):
+ cdef inline Codec get_custom_codec(
+ self,
+ uint32_t oid,
+ ServerDataFormat format
+ ):
cdef Codec codec
- codec = self._custom_type_codecs.get((oid, PG_FORMAT_BINARY))
- if codec is None:
- return self._custom_type_codecs.get((oid, PG_FORMAT_TEXT))
+ if format == PG_FORMAT_ANY:
+ codec = self.get_custom_codec(oid, PG_FORMAT_BINARY)
+ if codec is None:
+ codec = self.get_custom_codec(oid, PG_FORMAT_TEXT)
else:
- return codec
+ codec = self._custom_type_codecs.get((oid, format))
+
+ return codec
cdef inline Codec get_core_codec(
@@ -732,9 +835,9 @@ cdef inline Codec get_core_codec(
if oid > MAXSUPPORTEDOID:
return None
if format == PG_FORMAT_BINARY:
- ptr = binary_codec_map[oid * xformat]
+ ptr = (codec_map).get_binary_codec_ptr(oid * xformat)
elif format == PG_FORMAT_TEXT:
- ptr = text_codec_map[oid * xformat]
+ ptr = (codec_map).get_text_codec_ptr(oid * xformat)
if ptr is NULL:
return None
@@ -760,7 +863,10 @@ cdef inline Codec get_any_core_codec(
cdef inline int has_core_codec(uint32_t oid):
- return binary_codec_map[oid] != NULL or text_codec_map[oid] != NULL
+ return (
+ (codec_map).get_binary_codec_ptr(oid) != NULL
+ or (codec_map).get_text_codec_ptr(oid) != NULL
+ )
cdef register_core_codec(uint32_t oid,
@@ -784,13 +890,13 @@ cdef register_core_codec(uint32_t oid,
codec = Codec(oid)
codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat,
- encode, decode, None, None, None, None, None, None, 0)
+ encode, decode, None, None, None, None, None, None, None, 0)
cpython.Py_INCREF(codec) # immortalize
if format == PG_FORMAT_BINARY:
- binary_codec_map[oid * xformat] = codec
+ (codec_map).set_binary_codec_ptr(oid * xformat, codec)
elif format == PG_FORMAT_TEXT:
- text_codec_map[oid * xformat] = codec
+ (codec_map).set_text_codec_ptr(oid * xformat, codec)
else:
raise exceptions.InternalClientError(
'invalid data format: {}'.format(format))
@@ -808,9 +914,9 @@ cdef register_extra_codec(str name,
codec = Codec(INVALIDOID)
codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT,
- encode, decode, None, None, None, None, None, None, 0)
- EXTRA_CODECS[name, format] = codec
+ encode, decode, None, None, None, None, None, None, None, 0)
+ (codec_map).extra_codecs[name, format] = codec
cdef inline Codec get_extra_codec(str name, ServerDataFormat format):
- return EXTRA_CODECS.get((name, format))
+ return (codec_map).extra_codecs.get((name, format))
diff --git a/asyncpg/protocol/codecs/pgproto.pyx b/asyncpg/protocol/codecs/pgproto.pyx
index ea9c15ac..51d650d0 100644
--- a/asyncpg/protocol/codecs/pgproto.pyx
+++ b/asyncpg/protocol/codecs/pgproto.pyx
@@ -180,6 +180,10 @@ cdef init_json_codecs():
pgproto.jsonb_encode,
pgproto.jsonb_decode,
PG_FORMAT_BINARY)
+ register_core_codec(JSONPATHOID,
+ pgproto.jsonpath_encode,
+ pgproto.jsonpath_decode,
+ PG_FORMAT_BINARY)
cdef init_int_codecs():
@@ -229,6 +233,17 @@ cdef init_pseudo_codecs():
pgproto.uint4_decode,
PG_FORMAT_BINARY)
+ # 64-bit OID types
+ oid8_types = [
+ XID8OID,
+ ]
+
+ for oid_type in oid8_types:
+ register_core_codec(oid_type,
+ pgproto.uint8_encode,
+ pgproto.uint8_decode,
+ PG_FORMAT_BINARY)
+
# reg* types -- these are really system catalog OIDs, but
# allow the catalog object name as an input. We could just
# decode these as OIDs, but handling them as text seems more
@@ -237,7 +252,7 @@ cdef init_pseudo_codecs():
reg_types = [
REGPROCOID, REGPROCEDUREOID, REGOPEROID, REGOPERATOROID,
REGCLASSOID, REGTYPEOID, REGCONFIGOID, REGDICTIONARYOID,
- REGNAMESPACEOID, REGROLEOID, REFCURSOROID
+ REGNAMESPACEOID, REGROLEOID, REFCURSOROID, REGCOLLATIONOID,
]
for reg_type in reg_types:
@@ -256,8 +271,11 @@ cdef init_pseudo_codecs():
no_io_types = [
ANYOID, TRIGGEROID, EVENT_TRIGGEROID, LANGUAGE_HANDLEROID,
FDW_HANDLEROID, TSM_HANDLEROID, INTERNALOID, OPAQUEOID,
- ANYELEMENTOID, ANYNONARRAYOID, PG_DDL_COMMANDOID,
- INDEX_AM_HANDLEROID,
+ ANYELEMENTOID, ANYNONARRAYOID, ANYCOMPATIBLEOID,
+ ANYCOMPATIBLEARRAYOID, ANYCOMPATIBLENONARRAYOID,
+ ANYCOMPATIBLERANGEOID, ANYCOMPATIBLEMULTIRANGEOID,
+ ANYRANGEOID, ANYMULTIRANGEOID, ANYARRAYOID,
+ PG_DDL_COMMANDOID, INDEX_AM_HANDLEROID, TABLE_AM_HANDLEROID,
]
register_core_codec(ANYENUMOID,
@@ -306,6 +324,26 @@ cdef init_pseudo_codecs():
pgproto.text_decode,
PG_FORMAT_TEXT)
+ # pg_mcv_list is a special type used in pg_statistic_ext_data
+ # system catalog
+ register_core_codec(PG_MCV_LISTOID,
+ pgproto.bytea_encode,
+ pgproto.bytea_decode,
+ PG_FORMAT_BINARY)
+
+ # These two are internal to BRIN index support and are unlikely
+ # to be sent, but since I/O functions for these exist, add decoders
+ # nonetheless.
+ register_core_codec(PG_BRIN_BLOOM_SUMMARYOID,
+ NULL,
+ pgproto.bytea_decode,
+ PG_FORMAT_BINARY)
+
+ register_core_codec(PG_BRIN_MINMAX_MULTI_SUMMARYOID,
+ NULL,
+ pgproto.bytea_decode,
+ PG_FORMAT_BINARY)
+
cdef init_text_codecs():
textoids = [
@@ -337,8 +375,13 @@ cdef init_tid_codecs():
cdef init_txid_codecs():
register_core_codec(TXID_SNAPSHOTOID,
- pgproto.txid_snapshot_encode,
- pgproto.txid_snapshot_decode,
+ pgproto.pg_snapshot_encode,
+ pgproto.pg_snapshot_decode,
+ PG_FORMAT_BINARY)
+
+ register_core_codec(PG_SNAPSHOTOID,
+ pgproto.pg_snapshot_encode,
+ pgproto.pg_snapshot_decode,
PG_FORMAT_BINARY)
@@ -382,12 +425,12 @@ cdef init_numeric_codecs():
cdef init_network_codecs():
register_core_codec(CIDROID,
pgproto.cidr_encode,
- pgproto.net_decode,
+ pgproto.cidr_decode,
PG_FORMAT_BINARY)
register_core_codec(INETOID,
pgproto.inet_encode,
- pgproto.net_decode,
+ pgproto.inet_decode,
PG_FORMAT_BINARY)
register_core_codec(MACADDROID,
diff --git a/asyncpg/protocol/codecs/range.pyx b/asyncpg/protocol/codecs/range.pyx
index 2f598c1b..1038c18d 100644
--- a/asyncpg/protocol/codecs/range.pyx
+++ b/asyncpg/protocol/codecs/range.pyx
@@ -7,6 +7,8 @@
from asyncpg import types as apg_types
+from collections.abc import Sequence as SequenceABC
+
# defined in postgresql/src/include/utils/rangetypes.h
DEF RANGE_EMPTY = 0x01 # range is empty
DEF RANGE_LB_INC = 0x02 # lower bound is inclusive
@@ -139,11 +141,67 @@ cdef range_decode(ConnectionSettings settings, FRBuffer *buf,
empty=(flags & RANGE_EMPTY) != 0)
-cdef init_range_codecs():
- register_core_codec(ANYRANGEOID,
- NULL,
- pgproto.text_decode,
- PG_FORMAT_TEXT)
+cdef multirange_encode(ConnectionSettings settings, WriteBuffer buf,
+ object obj, uint32_t elem_oid,
+ encode_func_ex encoder, const void *encoder_arg):
+ cdef:
+ WriteBuffer elem_data
+ ssize_t elem_data_len
+ ssize_t elem_count
+
+ if not isinstance(obj, SequenceABC):
+ raise TypeError(
+ 'expected a sequence (got type {!r})'.format(type(obj).__name__)
+ )
+
+ elem_data = WriteBuffer.new()
+
+ for elem in obj:
+ range_encode(settings, elem_data, elem, elem_oid, encoder, encoder_arg)
+ elem_count = len(obj)
+ if elem_count > INT32_MAX:
+ raise OverflowError(f'too many elements in multirange value')
+
+ elem_data_len = elem_data.len()
+ if elem_data_len > INT32_MAX - 4:
+ raise OverflowError(
+ f'size of encoded multirange datum exceeds the maximum allowed'
+ f' {INT32_MAX - 4} bytes')
+
+ # Datum length
+ buf.write_int32(4 + elem_data_len)
+ # Number of elements in multirange
+ buf.write_int32(elem_count)
+ buf.write_buffer(elem_data)
+
+
+cdef multirange_decode(ConnectionSettings settings, FRBuffer *buf,
+ decode_func_ex decoder, const void *decoder_arg):
+ cdef:
+ int32_t nelems = hton.unpack_int32(frb_read(buf, 4))
+ FRBuffer elem_buf
+ int32_t elem_len
+ int i
+ list result
+
+ if nelems == 0:
+ return []
+
+ if nelems < 0:
+ raise exceptions.ProtocolError(
+ 'unexpected multirange size value: {}'.format(nelems))
+
+ result = cpython.PyList_New(nelems)
+ for i in range(nelems):
+ elem_len = hton.unpack_int32(frb_read(buf, 4))
+ if elem_len == -1:
+ raise exceptions.ProtocolError(
+ 'unexpected NULL element in multirange value')
+ else:
+ frb_slice_from(&elem_buf, buf, elem_len)
+ elem = range_decode(settings, &elem_buf, decoder, decoder_arg)
+ cpython.Py_INCREF(elem)
+ cpython.PyList_SET_ITEM(result, i, elem)
-init_range_codecs()
+ return result
diff --git a/asyncpg/protocol/codecs/record.pyx b/asyncpg/protocol/codecs/record.pyx
index 5326a8c6..6446f2da 100644
--- a/asyncpg/protocol/codecs/record.pyx
+++ b/asyncpg/protocol/codecs/record.pyx
@@ -51,9 +51,20 @@ cdef anonymous_record_decode(ConnectionSettings settings, FRBuffer *buf):
return result
+cdef anonymous_record_encode(ConnectionSettings settings, WriteBuffer buf, obj):
+ raise exceptions.UnsupportedClientFeatureError(
+ 'input of anonymous composite types is not supported',
+ hint=(
+ 'Consider declaring an explicit composite type and '
+ 'using it to cast the argument.'
+ ),
+ detail='PostgreSQL does not implement anonymous composite type input.'
+ )
+
+
cdef init_record_codecs():
register_core_codec(RECORDOID,
- NULL,
+ anonymous_record_encode,
anonymous_record_decode,
PG_FORMAT_BINARY)
diff --git a/asyncpg/protocol/consts.pxi b/asyncpg/protocol/consts.pxi
index 97cbbf35..e1f8726e 100644
--- a/asyncpg/protocol/consts.pxi
+++ b/asyncpg/protocol/consts.pxi
@@ -8,3 +8,5 @@
DEF _MAXINT32 = 2**31 - 1
DEF _COPY_BUFFER_SIZE = 524288
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
+DEF _EXECUTE_MANY_BUF_NUM = 4
+DEF _EXECUTE_MANY_BUF_SIZE = 32768
diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd
index c96b1fa5..34c7c712 100644
--- a/asyncpg/protocol/coreproto.pxd
+++ b/asyncpg/protocol/coreproto.pxd
@@ -51,16 +51,6 @@ cdef enum AuthenticationMessage:
AUTH_SASL_FINAL = 12
-AUTH_METHOD_NAME = {
- AUTH_REQUIRED_KERBEROS: 'kerberosv5',
- AUTH_REQUIRED_PASSWORD: 'password',
- AUTH_REQUIRED_PASSWORDMD5: 'md5',
- AUTH_REQUIRED_GSS: 'gss',
- AUTH_REQUIRED_SASL: 'scram-sha-256',
- AUTH_REQUIRED_SSPI: 'sspi',
-}
-
-
cdef enum ResultType:
RESULT_OK = 1
RESULT_FAILED = 2
@@ -96,10 +86,13 @@ cdef class CoreProtocol:
object transport
+ object address
# Instance of _ConnectionParameters
object con_params
# Instance of SCRAMAuthentication
SCRAMAuthentication scram
+ # Instance of gssapi.SecurityContext or sspilib.SecurityContext
+ object gss_ctx
readonly int32_t backend_pid
readonly int32_t backend_secret
@@ -114,6 +107,7 @@ cdef class CoreProtocol:
# True - completed, False - suspended
bint result_execute_completed
+ cpdef is_in_transaction(self)
cdef _process__auth(self, char mtype)
cdef _process__prepare(self, char mtype)
cdef _process__bind_execute(self, char mtype)
@@ -144,8 +138,13 @@ cdef class CoreProtocol:
cdef _auth_password_message_md5(self, bytes salt)
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
cdef _auth_password_message_sasl_continue(self, bytes server_response)
+ cdef _auth_gss_init_gssapi(self)
+ cdef _auth_gss_init_sspi(self, bint negotiate)
+ cdef _auth_gss_get_service(self)
+ cdef _auth_gss_step(self, bytes server_response)
cdef _write(self, buf)
+ cdef _writelines(self, list buffers)
cdef _read_server_messages(self)
@@ -155,19 +154,26 @@ cdef class CoreProtocol:
cdef _ensure_connected(self)
+ cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
cdef WriteBuffer _build_bind_message(self, str portal_name,
str stmt_name,
WriteBuffer bind_data)
+ cdef WriteBuffer _build_empty_bind_data(self)
+ cdef WriteBuffer _build_execute_message(self, str portal_name,
+ int32_t limit)
cdef _connect(self)
- cdef _prepare(self, str stmt_name, str query)
+ cdef _prepare_and_describe(self, str stmt_name, str query)
+ cdef _send_parse_message(self, str stmt_name, str query)
cdef _send_bind_message(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
- cdef _bind_execute_many(self, str portal_name, str stmt_name,
- object bind_data)
+ cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
+ object bind_data, bint return_rows)
+ cdef bint _bind_execute_many_more(self, bint first=*)
+ cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,
WriteBuffer bind_data)
cdef _execute(self, str portal_name, int32_t limit)
diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx
index a44bc5ad..da96c412 100644
--- a/asyncpg/protocol/coreproto.pyx
+++ b/asyncpg/protocol/coreproto.pyx
@@ -5,15 +5,26 @@
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
-from hashlib import md5 as hashlib_md5 # for MD5 authentication
+import hashlib
include "scram.pyx"
+AUTH_METHOD_NAME = {
+ AUTH_REQUIRED_KERBEROS: 'kerberosv5',
+ AUTH_REQUIRED_PASSWORD: 'password',
+ AUTH_REQUIRED_PASSWORDMD5: 'md5',
+ AUTH_REQUIRED_GSS: 'gss',
+ AUTH_REQUIRED_SASL: 'scram-sha-256',
+ AUTH_REQUIRED_SSPI: 'sspi',
+}
+
+
cdef class CoreProtocol:
- def __init__(self, con_params):
+ def __init__(self, addr, con_params):
+ self.address = addr
# type of `con_params` is `_ConnectionParameters`
self.buffer = ReadBuffer()
self.user = con_params.user
@@ -26,14 +37,17 @@ cdef class CoreProtocol:
self.encoding = 'utf-8'
# type of `scram` is `SCRAMAuthentcation`
self.scram = None
-
- # executemany support data
- self._execute_iter = None
- self._execute_portal_name = None
- self._execute_stmt_name = None
+ # type of `gss_ctx` is `gssapi.SecurityContext` or
+ # `sspilib.SecurityContext`
+ self.gss_ctx = None
self._reset_result()
+ cpdef is_in_transaction(self):
+ # PQTRANS_INTRANS = idle, within transaction block
+ # PQTRANS_INERROR = idle, within failed transaction
+ return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
+
cdef _read_server_messages(self):
cdef:
char mtype
@@ -118,6 +132,7 @@ cdef class CoreProtocol:
self.result = apg_exc.InternalClientError(
'unknown error in protocol implementation')
+ self._parse_msg_ready_for_query()
self._push_result()
else:
@@ -149,15 +164,28 @@ cdef class CoreProtocol:
cdef _process__auth(self, char mtype):
if mtype == b'R':
# Authentication...
- self._parse_msg_authentication()
- if self.result_type != RESULT_OK:
+ try:
+ self._parse_msg_authentication()
+ except Exception as ex:
+ # Exception in authentication parsing code
+ # is usually either malformed authentication data
+ # or missing support for cryptographic primitives
+ # in the hashlib module.
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InternalClientError(
+ f"unexpected error while performing authentication: {ex}")
+ self.result.__cause__ = ex
self.con_status = CONNECTION_BAD
self._push_result()
+ else:
+ if self.result_type != RESULT_OK:
+ self.con_status = CONNECTION_BAD
+ self._push_result()
- elif self.auth_msg is not None:
- # Server wants us to send auth data, so do that.
- self._write(self.auth_msg)
- self.auth_msg = None
+ elif self.auth_msg is not None:
+ # Server wants us to send auth data, so do that.
+ self._write(self.auth_msg)
+ self.auth_msg = None
elif mtype == b'K':
# BackendKeyData
@@ -187,19 +215,23 @@ cdef class CoreProtocol:
elif mtype == b'T':
# Row description
self.result_row_desc = self.buffer.consume_message()
+ self._push_result()
elif mtype == b'E':
# ErrorResponse
self._parse_msg_error_response(True)
-
- elif mtype == b'Z':
- # ReadyForQuery
- self._parse_msg_ready_for_query()
- self._push_result()
+ # we don't send a sync during the parse/describe sequence
+ # but send a FLUSH instead. If an error happens we need to
+ # send a SYNC explicitly in order to mark the end of the transaction.
+ # this effectively clears the error and we then wait until we get a
+ # ready for new query message
+ self._write(SYNC_MESSAGE)
+ self.state = PROTOCOL_ERROR_CONSUME
elif mtype == b'n':
# NoData
self.buffer.discard_message()
+ self._push_result()
cdef _process__bind_execute(self, char mtype):
if mtype == b'D':
@@ -219,6 +251,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)
+ elif mtype == b'1':
+ # ParseComplete, in case `_bind_execute()` is reparsing
+ self.buffer.discard_message()
+
elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
@@ -251,6 +287,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)
+ elif mtype == b'1':
+ # ParseComplete, in case `_bind_execute_many()` is reparsing
+ self.buffer.discard_message()
+
elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
@@ -258,27 +298,16 @@ cdef class CoreProtocol:
elif mtype == b'Z':
# ReadyForQuery
self._parse_msg_ready_for_query()
- if self.result_type == RESULT_FAILED:
- self._push_result()
- else:
- try:
- buf = next(self._execute_iter)
- except StopIteration:
- self._push_result()
- except Exception as e:
- self.result_type = RESULT_FAILED
- self.result = e
- self._push_result()
- else:
- # Next iteration over the executemany() arg sequence
- self._send_bind_message(
- self._execute_portal_name, self._execute_stmt_name,
- buf, 0)
+ self._push_result()
elif mtype == b'I':
# EmptyQueryResponse
self.buffer.discard_message()
+ elif mtype == b'1':
+ # ParseComplete
+ self.buffer.discard_message()
+
cdef _process__bind(self, char mtype):
if mtype == b'E':
# ErrorResponse
@@ -604,22 +633,35 @@ cdef class CoreProtocol:
'could not verify server signature for '
'SCRAM authentciation: scram-sha-256',
)
+ self.scram = None
- elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
- AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
- AUTH_REQUIRED_SSPI):
- self.result_type = RESULT_FAILED
- self.result = apg_exc.InterfaceError(
- 'unsupported authentication method requested by the '
- 'server: {!r}'.format(AUTH_METHOD_NAME[status]))
+ elif status in (AUTH_REQUIRED_GSS, AUTH_REQUIRED_SSPI):
+ # AUTH_REQUIRED_SSPI is the same as AUTH_REQUIRED_GSS, except that
+ # it uses protocol negotiation with SSPI clients. Both methods use
+ # AUTH_REQUIRED_GSS_CONTINUE for subsequent authentication steps.
+ if self.gss_ctx is not None:
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InterfaceError(
+ 'duplicate GSSAPI/SSPI authentication request')
+ else:
+ if self.con_params.gsslib == 'gssapi':
+ self._auth_gss_init_gssapi()
+ else:
+ self._auth_gss_init_sspi(status == AUTH_REQUIRED_SSPI)
+ self.auth_msg = self._auth_gss_step(None)
+
+ elif status == AUTH_REQUIRED_GSS_CONTINUE:
+ server_response = self.buffer.consume_message()
+ self.auth_msg = self._auth_gss_step(server_response)
else:
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
'unsupported authentication method requested by the '
- 'server: {}'.format(status))
+ 'server: {!r}'.format(AUTH_METHOD_NAME.get(status, status)))
- if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
+ if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
+ AUTH_REQUIRED_GSS_CONTINUE):
self.buffer.discard_message()
cdef _auth_password_message_cleartext(self):
@@ -627,7 +669,7 @@ cdef class CoreProtocol:
WriteBuffer msg
msg = WriteBuffer.new_message(b'p')
- msg.write_bytestring(self.password.encode('ascii'))
+ msg.write_bytestring(self.password.encode(self.encoding))
msg.end_message()
return msg
@@ -639,11 +681,11 @@ cdef class CoreProtocol:
msg = WriteBuffer.new_message(b'p')
# 'md5' + md5(md5(password + username) + salt))
- userpass = ((self.password or '') + (self.user or '')).encode('ascii')
- hash = hashlib_md5(hashlib_md5(userpass).hexdigest().\
- encode('ascii') + salt).hexdigest().encode('ascii')
+ userpass = (self.password or '') + (self.user or '')
+ md5_1 = hashlib.md5(userpass.encode(self.encoding)).hexdigest()
+ md5_2 = hashlib.md5(md5_1.encode('ascii') + salt).hexdigest()
- msg.write_bytestring(b'md5' + hash)
+ msg.write_bytestring(b'md5' + md5_2.encode('ascii'))
msg.end_message()
return msg
@@ -676,6 +718,59 @@ cdef class CoreProtocol:
return msg
+ cdef _auth_gss_init_gssapi(self):
+ try:
+ import gssapi
+ except ModuleNotFoundError:
+ raise apg_exc.InterfaceError(
+ 'gssapi module not found; please install asyncpg[gssauth] to '
+ 'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
+ ) from None
+
+ service_name, host = self._auth_gss_get_service()
+ self.gss_ctx = gssapi.SecurityContext(
+ name=gssapi.Name(
+ f'{service_name}@{host}', gssapi.NameType.hostbased_service),
+ usage='initiate')
+
+ cdef _auth_gss_init_sspi(self, bint negotiate):
+ try:
+ import sspilib
+ except ModuleNotFoundError:
+ raise apg_exc.InterfaceError(
+ 'sspilib module not found; please install asyncpg[gssauth] to '
+ 'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
+ ) from None
+
+ service_name, host = self._auth_gss_get_service()
+ self.gss_ctx = sspilib.ClientSecurityContext(
+ target_name=f'{service_name}/{host}',
+ credential=sspilib.UserCredential(
+ protocol='Negotiate' if negotiate else 'Kerberos'))
+
+ cdef _auth_gss_get_service(self):
+ service_name = self.con_params.krbsrvname or 'postgres'
+ if isinstance(self.address, str):
+ raise apg_exc.InternalClientError(
+ 'GSSAPI/SSPI authentication is only supported for TCP/IP '
+ 'connections')
+
+ return service_name, self.address[0]
+
+ cdef _auth_gss_step(self, bytes server_response):
+ cdef:
+ WriteBuffer msg
+
+ token = self.gss_ctx.step(server_response)
+ if not token:
+ self.gss_ctx = None
+ return None
+ msg = WriteBuffer.new_message(b'p')
+ msg.write_bytes(token)
+ msg.end_message()
+
+ return msg
+
cdef _parse_msg_ready_for_query(self):
cdef char status = self.buffer.read_byte()
@@ -725,6 +820,11 @@ cdef class CoreProtocol:
self.result_execute_completed = False
self._discard_data = False
+ # executemany support data
+ self._execute_iter = None
+ self._execute_portal_name = None
+ self._execute_stmt_name = None
+
cdef _set_state(self, ProtocolState new_state):
if new_state == PROTOCOL_IDLE:
if self.state == PROTOCOL_FAILED:
@@ -775,6 +875,17 @@ cdef class CoreProtocol:
if self.con_status != CONNECTION_OK:
raise apg_exc.InternalClientError('not connected')
+ cdef WriteBuffer _build_parse_message(self, str stmt_name, str query):
+ cdef WriteBuffer buf
+
+ buf = WriteBuffer.new_message(b'P')
+ buf.write_str(stmt_name, self.encoding)
+ buf.write_str(query, self.encoding)
+ buf.write_int16(0)
+
+ buf.end_message()
+ return buf
+
cdef WriteBuffer _build_bind_message(self, str portal_name,
str stmt_name,
WriteBuffer bind_data):
@@ -790,6 +901,25 @@ cdef class CoreProtocol:
buf.end_message()
return buf
+ cdef WriteBuffer _build_empty_bind_data(self):
+ cdef WriteBuffer buf
+ buf = WriteBuffer.new()
+ buf.write_int16(0) # The number of parameter format codes
+ buf.write_int16(0) # The number of parameter values
+ buf.write_int16(0) # The number of result-column format codes
+ return buf
+
+ cdef WriteBuffer _build_execute_message(self, str portal_name,
+ int32_t limit):
+ cdef WriteBuffer buf
+
+ buf = WriteBuffer.new_message(b'E')
+ buf.write_str(portal_name, self.encoding) # name of the portal
+ buf.write_int32(limit) # number of rows to return; 0 - all
+
+ buf.end_message()
+ return buf
+
# API for subclasses
cdef _connect(self):
@@ -832,7 +962,15 @@ cdef class CoreProtocol:
outbuf.write_buffer(buf)
self._write(outbuf)
- cdef _prepare(self, str stmt_name, str query):
+ cdef _send_parse_message(self, str stmt_name, str query):
+ cdef:
+ WriteBuffer msg
+
+ self._ensure_connected()
+ msg = self._build_parse_message(stmt_name, query)
+ self._write(msg)
+
+ cdef _prepare_and_describe(self, str stmt_name, str query):
cdef:
WriteBuffer packet
WriteBuffer buf
@@ -840,12 +978,7 @@ cdef class CoreProtocol:
self._ensure_connected()
self._set_state(PROTOCOL_PREPARE)
- buf = WriteBuffer.new_message(b'P')
- buf.write_str(stmt_name, self.encoding)
- buf.write_str(query, self.encoding)
- buf.write_int16(0)
- buf.end_message()
- packet = buf
+ packet = self._build_parse_message(stmt_name, query)
buf = WriteBuffer.new_message(b'D')
buf.write_byte(b'S')
@@ -853,7 +986,7 @@ cdef class CoreProtocol:
buf.end_message()
packet.write_buffer(buf)
- packet.write_bytes(SYNC_MESSAGE)
+ packet.write_bytes(FLUSH_MESSAGE)
self._write(packet)
@@ -867,10 +1000,7 @@ cdef class CoreProtocol:
buf = self._build_bind_message(portal_name, stmt_name, bind_data)
packet = buf
- buf = WriteBuffer.new_message(b'E')
- buf.write_str(portal_name, self.encoding) # name of the portal
- buf.write_int32(limit) # number of rows to return; 0 - all
- buf.end_message()
+ buf = self._build_execute_message(portal_name, limit)
packet.write_buffer(buf)
packet.write_bytes(SYNC_MESSAGE)
@@ -889,30 +1019,102 @@ cdef class CoreProtocol:
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
- cdef _bind_execute_many(self, str portal_name, str stmt_name,
- object bind_data):
-
- cdef WriteBuffer buf
-
+ cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
+ object bind_data, bint return_rows):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
- self.result = None
- self._discard_data = True
+ self.result = [] if return_rows else None
+ self._discard_data = not return_rows
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name
+ return self._bind_execute_many_more(True)
- try:
- buf = next(bind_data)
- except StopIteration:
- self._push_result()
- except Exception as e:
- self.result_type = RESULT_FAILED
- self.result = e
+ cdef bint _bind_execute_many_more(self, bint first=False):
+ cdef:
+ WriteBuffer packet
+ WriteBuffer buf
+ list buffers = []
+
+ # as we keep sending, the server may return an error early
+ if self.result_type == RESULT_FAILED:
+ self._write(SYNC_MESSAGE)
+ return False
+
+ # collect up to four 32KB buffers to send
+ # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051
+ while len(buffers) < _EXECUTE_MANY_BUF_NUM:
+ packet = WriteBuffer.new()
+
+ # fill one 32KB buffer
+ while packet.len() < _EXECUTE_MANY_BUF_SIZE:
+ try:
+ # grab one item from the input
+ buf = next(self._execute_iter)
+
+ # reached the end of the input
+ except StopIteration:
+ if first:
+ # if we never send anything, simply set the result
+ self._push_result()
+ else:
+ # otherwise, append SYNC and send the buffers
+ packet.write_bytes(SYNC_MESSAGE)
+ buffers.append(memoryview(packet))
+ self._writelines(buffers)
+ return False
+
+ # error in input, give up the buffers and cleanup
+ except Exception as ex:
+ self._bind_execute_many_fail(ex, first)
+ return False
+
+ # all good, write to the buffer
+ first = False
+ packet.write_buffer(
+ self._build_bind_message(
+ self._execute_portal_name,
+ self._execute_stmt_name,
+ buf,
+ )
+ )
+ packet.write_buffer(
+ self._build_execute_message(self._execute_portal_name, 0,
+ )
+ )
+
+ # collected one buffer
+ buffers.append(memoryview(packet))
+
+ # write to the wire, and signal the caller for more to send
+ self._writelines(buffers)
+ return True
+
+ cdef _bind_execute_many_fail(self, object error, bint first=False):
+ cdef WriteBuffer buf
+
+ self.result_type = RESULT_FAILED
+ self.result = error
+ if first:
self._push_result()
+ elif self.is_in_transaction():
+ # we're in an explicit transaction, just SYNC
+ self._write(SYNC_MESSAGE)
else:
- self._send_bind_message(portal_name, stmt_name, buf, 0)
+ # In an implicit transaction, if `ignore_till_sync` is set,
+ # `ROLLBACK` will be ignored and `Sync` will restore the state;
+ # or the transaction will be rolled back with a warning saying
+ # that there was no transaction, but rollback is done anyway,
+ # so we could safely ignore this warning.
+ # GOTCHA: cannot use simple query message here, because it is
+ # ignored if `ignore_till_sync` is set.
+ buf = self._build_parse_message('', 'ROLLBACK')
+ buf.write_buffer(self._build_bind_message(
+ '', '', self._build_empty_bind_data()))
+ buf.write_buffer(self._build_execute_message('', 0))
+ buf.write_bytes(SYNC_MESSAGE)
+ self._write(buf)
cdef _execute(self, str portal_name, int32_t limit):
cdef WriteBuffer buf
@@ -922,10 +1124,7 @@ cdef class CoreProtocol:
self.result = []
- buf = WriteBuffer.new_message(b'E')
- buf.write_str(portal_name, self.encoding) # name of the portal
- buf.write_int32(limit) # number of rows to return; 0 - all
- buf.end_message()
+ buf = self._build_execute_message(portal_name, limit)
buf.write_bytes(SYNC_MESSAGE)
@@ -1008,6 +1207,9 @@ cdef class CoreProtocol:
cdef _write(self, buf):
raise NotImplementedError
+ cdef _writelines(self, list buffers):
+ raise NotImplementedError
+
cdef _decode_row(self, const char* buf, ssize_t buf_len):
pass
@@ -1027,4 +1229,5 @@ cdef class CoreProtocol:
pass
-cdef bytes SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message())
+SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message())
+FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message())
diff --git a/asyncpg/protocol/encodings.pyx b/asyncpg/protocol/encodings.pyx
index dcd692b7..1463dbe4 100644
--- a/asyncpg/protocol/encodings.pyx
+++ b/asyncpg/protocol/encodings.pyx
@@ -10,7 +10,7 @@
https://www.postgresql.org/docs/current/static/multibyte.html#CHARSET-TABLE
'''
-cdef dict ENCODINGS_MAP = {
+ENCODINGS_MAP = {
'abc': 'cp1258',
'alt': 'cp866',
'euc_cn': 'euccn',
diff --git a/asyncpg/protocol/pgtypes.pxi b/asyncpg/protocol/pgtypes.pxi
index 14db69df..86f8e663 100644
--- a/asyncpg/protocol/pgtypes.pxi
+++ b/asyncpg/protocol/pgtypes.pxi
@@ -10,7 +10,7 @@
DEF INVALIDOID = 0
DEF MAXBUILTINOID = 9999
-DEF MAXSUPPORTEDOID = 4096
+DEF MAXSUPPORTEDOID = 5080
DEF BOOLOID = 16
DEF BYTEAOID = 17
@@ -30,6 +30,7 @@ DEF JSONOID = 114
DEF XMLOID = 142
DEF PG_NODE_TREEOID = 194
DEF SMGROID = 210
+DEF TABLE_AM_HANDLEROID = 269
DEF INDEX_AM_HANDLEROID = 325
DEF POINTOID = 600
DEF LSEGOID = 601
@@ -96,17 +97,36 @@ DEF REGDICTIONARYOID = 3769
DEF JSONBOID = 3802
DEF ANYRANGEOID = 3831
DEF EVENT_TRIGGEROID = 3838
+DEF JSONPATHOID = 4072
DEF REGNAMESPACEOID = 4089
DEF REGROLEOID = 4096
+DEF REGCOLLATIONOID = 4191
+DEF ANYMULTIRANGEOID = 4537
+DEF ANYCOMPATIBLEMULTIRANGEOID = 4538
+DEF PG_BRIN_BLOOM_SUMMARYOID = 4600
+DEF PG_BRIN_MINMAX_MULTI_SUMMARYOID = 4601
+DEF PG_MCV_LISTOID = 5017
+DEF PG_SNAPSHOTOID = 5038
+DEF XID8OID = 5069
+DEF ANYCOMPATIBLEOID = 5077
+DEF ANYCOMPATIBLEARRAYOID = 5078
+DEF ANYCOMPATIBLENONARRAYOID = 5079
+DEF ANYCOMPATIBLERANGEOID = 5080
-cdef ARRAY_TYPES = (_TEXTOID, _OIDOID,)
+ARRAY_TYPES = {_TEXTOID, _OIDOID}
BUILTIN_TYPE_OID_MAP = {
ABSTIMEOID: 'abstime',
ACLITEMOID: 'aclitem',
ANYARRAYOID: 'anyarray',
+ ANYCOMPATIBLEARRAYOID: 'anycompatiblearray',
+ ANYCOMPATIBLEMULTIRANGEOID: 'anycompatiblemultirange',
+ ANYCOMPATIBLENONARRAYOID: 'anycompatiblenonarray',
+ ANYCOMPATIBLEOID: 'anycompatible',
+ ANYCOMPATIBLERANGEOID: 'anycompatiblerange',
ANYELEMENTOID: 'anyelement',
ANYENUMOID: 'anyenum',
+ ANYMULTIRANGEOID: 'anymultirange',
ANYNONARRAYOID: 'anynonarray',
ANYOID: 'any',
ANYRANGEOID: 'anyrange',
@@ -135,6 +155,7 @@ BUILTIN_TYPE_OID_MAP = {
INTERVALOID: 'interval',
JSONBOID: 'jsonb',
JSONOID: 'json',
+ JSONPATHOID: 'jsonpath',
LANGUAGE_HANDLEROID: 'language_handler',
LINEOID: 'line',
LSEGOID: 'lseg',
@@ -146,16 +167,21 @@ BUILTIN_TYPE_OID_MAP = {
OIDOID: 'oid',
OPAQUEOID: 'opaque',
PATHOID: 'path',
+ PG_BRIN_BLOOM_SUMMARYOID: 'pg_brin_bloom_summary',
+ PG_BRIN_MINMAX_MULTI_SUMMARYOID: 'pg_brin_minmax_multi_summary',
PG_DDL_COMMANDOID: 'pg_ddl_command',
PG_DEPENDENCIESOID: 'pg_dependencies',
PG_LSNOID: 'pg_lsn',
+ PG_MCV_LISTOID: 'pg_mcv_list',
PG_NDISTINCTOID: 'pg_ndistinct',
PG_NODE_TREEOID: 'pg_node_tree',
+ PG_SNAPSHOTOID: 'pg_snapshot',
POINTOID: 'point',
POLYGONOID: 'polygon',
RECORDOID: 'record',
REFCURSOROID: 'refcursor',
REGCLASSOID: 'regclass',
+ REGCOLLATIONOID: 'regcollation',
REGCONFIGOID: 'regconfig',
REGDICTIONARYOID: 'regdictionary',
REGNAMESPACEOID: 'regnamespace',
@@ -167,6 +193,7 @@ BUILTIN_TYPE_OID_MAP = {
REGTYPEOID: 'regtype',
RELTIMEOID: 'reltime',
SMGROID: 'smgr',
+ TABLE_AM_HANDLEROID: 'table_am_handler',
TEXTOID: 'text',
TIDOID: 'tid',
TIMEOID: 'time',
@@ -184,6 +211,7 @@ BUILTIN_TYPE_OID_MAP = {
VARBITOID: 'varbit',
VARCHAROID: 'varchar',
VOIDOID: 'void',
+ XID8OID: 'xid8',
XIDOID: 'xid',
XMLOID: 'xml',
_OIDOID: 'oid[]',
@@ -216,5 +244,23 @@ BUILTIN_TYPE_NAME_MAP['double precision'] = \
BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamptz']
+BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \
+ BUILTIN_TYPE_NAME_MAP['timestamp']
+
BUILTIN_TYPE_NAME_MAP['time with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timetz']
+
+BUILTIN_TYPE_NAME_MAP['time without timezone'] = \
+ BUILTIN_TYPE_NAME_MAP['time']
+
+BUILTIN_TYPE_NAME_MAP['char'] = \
+ BUILTIN_TYPE_NAME_MAP['bpchar']
+
+BUILTIN_TYPE_NAME_MAP['character'] = \
+ BUILTIN_TYPE_NAME_MAP['bpchar']
+
+BUILTIN_TYPE_NAME_MAP['character varying'] = \
+ BUILTIN_TYPE_NAME_MAP['varchar']
+
+BUILTIN_TYPE_NAME_MAP['bit varying'] = \
+ BUILTIN_TYPE_NAME_MAP['varbit']
diff --git a/asyncpg/protocol/prepared_stmt.pxd b/asyncpg/protocol/prepared_stmt.pxd
index 0d3f8d3b..369db733 100644
--- a/asyncpg/protocol/prepared_stmt.pxd
+++ b/asyncpg/protocol/prepared_stmt.pxd
@@ -10,7 +10,11 @@ cdef class PreparedStatementState:
readonly str name
readonly str query
readonly bint closed
+ readonly bint prepared
readonly int refs
+ readonly type record_class
+ readonly bint ignore_custom_codec
+
list row_desc
list parameters_desc
@@ -26,7 +30,7 @@ cdef class PreparedStatementState:
bint have_text_cols
tuple rows_codecs
- cdef _encode_bind_msg(self, args)
+ cdef _encode_bind_msg(self, args, int seqno = ?)
cpdef _init_codecs(self)
cdef _ensure_rows_decoder(self)
cdef _ensure_args_encoder(self)
diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx
index b69f76be..4145c664 100644
--- a/asyncpg/protocol/prepared_stmt.pyx
+++ b/asyncpg/protocol/prepared_stmt.pyx
@@ -11,7 +11,14 @@ from asyncpg import exceptions
@cython.final
cdef class PreparedStatementState:
- def __cinit__(self, str name, str query, BaseProtocol protocol):
+ def __cinit__(
+ self,
+ str name,
+ str query,
+ BaseProtocol protocol,
+ type record_class,
+ bint ignore_custom_codec
+ ):
self.name = name
self.query = query
self.settings = protocol.settings
@@ -20,7 +27,10 @@ cdef class PreparedStatementState:
self.args_num = self.cols_num = 0
self.cols_desc = None
self.closed = False
+ self.prepared = True
self.refs = 0
+ self.record_class = record_class
+ self.ignore_custom_codec = ignore_custom_codec
def _get_parameters(self):
cdef Codec codec
@@ -92,12 +102,31 @@ cdef class PreparedStatementState:
def mark_closed(self):
self.closed = True
- cdef _encode_bind_msg(self, args):
+ def mark_unprepared(self):
+ if self.name:
+ raise exceptions.InternalClientError(
+ "named prepared statements cannot be marked unprepared")
+ self.prepared = False
+
+ cdef _encode_bind_msg(self, args, int seqno = -1):
cdef:
int idx
WriteBuffer writer
Codec codec
+ if not cpython.PySequence_Check(args):
+ if seqno >= 0:
+ raise exceptions.DataError(
+ f'invalid input in executemany() argument sequence '
+ f'element #{seqno}: expected a sequence, got '
+ f'{type(args).__name__}'
+ )
+ else:
+ # Non executemany() callers do not pass user input directly,
+ # so bad input is a bug.
+ raise exceptions.InternalClientError(
+ f'Bind: expected a sequence, got {type(args).__name__}')
+
if len(args) > 32767:
raise exceptions.InterfaceError(
'the number of query arguments cannot exceed 32767')
@@ -113,7 +142,7 @@ cdef class PreparedStatementState:
# that the user tried to parametrize a statement that does
# not support parameters.
hint += (r' Note that parameters are supported only in'
- r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
+ r' SELECT, INSERT, UPDATE, DELETE, MERGE and VALUES'
r' statements, and will *not* work in statements '
r' like CREATE VIEW or DECLARE CURSOR.')
@@ -129,7 +158,7 @@ cdef class PreparedStatementState:
writer.write_int16(self.args_num)
for idx in range(self.args_num):
codec = (self.args_codecs[idx])
- writer.write_int16(codec.format)
+ writer.write_int16(codec.format)
else:
# All arguments are in binary format
writer.write_int32(0x00010001)
@@ -147,26 +176,41 @@ cdef class PreparedStatementState:
except (AssertionError, exceptions.InternalClientError):
# These are internal errors and should raise as-is.
raise
- except exceptions.InterfaceError:
- # This is already a descriptive error.
- raise
+ except exceptions.InterfaceError as e:
+ # This is already a descriptive error, but annotate
+ # with argument name for clarity.
+ pos = f'${idx + 1}'
+ if seqno >= 0:
+ pos = (
+ f'{pos} in element #{seqno} of'
+ f' executemany() sequence'
+ )
+ raise e.with_msg(
+ f'query argument {pos}: {e.args[0]}'
+ ) from None
except Exception as e:
# Everything else is assumed to be an encoding error
# due to invalid input.
+ pos = f'${idx + 1}'
+ if seqno >= 0:
+ pos = (
+ f'{pos} in element #{seqno} of'
+ f' executemany() sequence'
+ )
value_repr = repr(arg)
if len(value_repr) > 40:
value_repr = value_repr[:40] + '...'
raise exceptions.DataError(
- 'invalid input for query argument'
- ' ${n}: {v} ({msg})'.format(
- n=idx + 1, v=value_repr, msg=e)) from e
+ f'invalid input for query argument'
+ f' {pos}: {value_repr} ({e})'
+ ) from e
if self.have_text_cols:
writer.write_int16(self.cols_num)
for idx in range(self.cols_num):
codec = (self.rows_codecs[idx])
- writer.write_int16(codec.format)
+ writer.write_int16(codec.format)
else:
# All columns are in binary format
writer.write_int32(0x00010001)
@@ -186,7 +230,7 @@ cdef class PreparedStatementState:
return
if self.cols_num == 0:
- self.cols_desc = record.ApgRecordDesc_New({}, ())
+ self.cols_desc = RecordDescriptor({}, ())
return
cols_mapping = collections.OrderedDict()
@@ -198,7 +242,8 @@ cdef class PreparedStatementState:
cols_mapping[col_name] = i
cols_names.append(col_name)
oid = row[3]
- codec = self.settings.get_data_codec(oid)
+ codec = self.settings.get_data_codec(
+ oid, ignore_custom_codec=self.ignore_custom_codec)
if codec is None or not codec.has_decoder():
raise exceptions.InternalClientError(
'no decoder for OID {}'.format(oid))
@@ -207,7 +252,7 @@ cdef class PreparedStatementState:
codecs.append(codec)
- self.cols_desc = record.ApgRecordDesc_New(
+ self.cols_desc = RecordDescriptor(
cols_mapping, tuple(cols_names))
self.rows_codecs = tuple(codecs)
@@ -223,7 +268,8 @@ cdef class PreparedStatementState:
for i from 0 <= i < self.args_num:
p_oid = self.parameters_desc[i]
- codec = self.settings.get_data_codec(p_oid)
+ codec = self.settings.get_data_codec(
+ p_oid, ignore_custom_codec=self.ignore_custom_codec)
if codec is None or not codec.has_encoder():
raise exceptions.InternalClientError(
'no encoder for OID {}'.format(p_oid))
@@ -264,7 +310,7 @@ cdef class PreparedStatementState:
'different from what was described ({})'.format(
fnum, self.cols_num))
- dec_row = record.ApgRecord_New(self.cols_desc, fnum)
+ dec_row = self.cols_desc.make_record(self.record_class, fnum)
for i in range(fnum):
flen = hton.unpack_int32(frb_read(&rbuf, 4))
@@ -287,7 +333,7 @@ cdef class PreparedStatementState:
frb_set_len(&rbuf, bl - flen)
cpython.Py_INCREF(val)
- record.ApgRecord_SET_ITEM(dec_row, i, val)
+ recordcapi.ApgRecord_SET_ITEM(dec_row, i, val)
if frb_get_len(&rbuf) != 0:
raise BufferError('unexpected trailing {} bytes in buffer'.format(
diff --git a/asyncpg/protocol/protocol.pxd b/asyncpg/protocol/protocol.pxd
index 14a7ecc6..cd221fbb 100644
--- a/asyncpg/protocol/protocol.pxd
+++ b/asyncpg/protocol/protocol.pxd
@@ -31,7 +31,6 @@ cdef class BaseProtocol(CoreProtocol):
cdef:
object loop
- object address
ConnectionSettings settings
object cancel_sent_waiter
object cancel_waiter
@@ -39,9 +38,8 @@ cdef class BaseProtocol(CoreProtocol):
bint return_extra
object create_future
object timeout_handle
- object timeout_callback
- object completed_callback
object conref
+ type record_class
bint is_reading
str last_query
@@ -51,6 +49,8 @@ cdef class BaseProtocol(CoreProtocol):
readonly uint64_t queries_count
+ bint _is_ssl
+
PreparedStatementState statement
cdef get_connection(self)
diff --git a/asyncpg/protocol/protocol.pyi b/asyncpg/protocol/protocol.pyi
new file mode 100644
index 00000000..34db6440
--- /dev/null
+++ b/asyncpg/protocol/protocol.pyi
@@ -0,0 +1,282 @@
+import asyncio
+import asyncio.protocols
+import hmac
+from codecs import CodecInfo
+from collections.abc import Callable, Iterable, Sequence
+from hashlib import md5, sha256
+from typing import (
+ Any,
+ ClassVar,
+ Final,
+ Generic,
+ Literal,
+ NewType,
+ TypeVar,
+ final,
+ overload,
+)
+from typing_extensions import TypeAlias
+
+import asyncpg.pgproto.pgproto
+
+from ..connect_utils import _ConnectionParameters
+from ..pgproto.pgproto import WriteBuffer
+from ..types import Attribute, Type
+from .record import Record
+
+_Record = TypeVar('_Record', bound=Record)
+_OtherRecord = TypeVar('_OtherRecord', bound=Record)
+_PreparedStatementState = TypeVar(
+ '_PreparedStatementState', bound=PreparedStatementState[Any]
+)
+
+_NoTimeoutType = NewType('_NoTimeoutType', object)
+_TimeoutType: TypeAlias = float | None | _NoTimeoutType
+
+BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]]
+BUILTIN_TYPE_OID_MAP: Final[dict[int, str]]
+NO_TIMEOUT: Final[_NoTimeoutType]
+
+hashlib_md5 = md5
+
+@final
+class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext):
+ __pyx_vtable__: Any
+ def __init__(self, conn_key: object) -> None: ...
+ def add_python_codec(
+ self,
+ typeoid: int,
+ typename: str,
+ typeschema: str,
+ typeinfos: Iterable[object],
+ typekind: str,
+ encoder: Callable[[Any], Any],
+ decoder: Callable[[Any], Any],
+ format: object,
+ ) -> Any: ...
+ def clear_type_cache(self) -> None: ...
+ def get_data_codec(
+ self, oid: int, format: object = ..., ignore_custom_codec: bool = ...
+ ) -> Any: ...
+ def get_text_codec(self) -> CodecInfo: ...
+ def register_data_types(self, types: Iterable[object]) -> None: ...
+ def remove_python_codec(
+ self, typeoid: int, typename: str, typeschema: str
+ ) -> None: ...
+ def set_builtin_type_codec(
+ self,
+ typeoid: int,
+ typename: str,
+ typeschema: str,
+ typekind: str,
+ alias_to: str,
+ format: object = ...,
+ ) -> Any: ...
+ def __getattr__(self, name: str) -> Any: ...
+ def __reduce__(self) -> Any: ...
+
+@final
+class PreparedStatementState(Generic[_Record]):
+ closed: bool
+ prepared: bool
+ name: str
+ query: str
+ refs: int
+ record_class: type[_Record]
+ ignore_custom_codec: bool
+ __pyx_vtable__: Any
+ def __init__(
+ self,
+ name: str,
+ query: str,
+ protocol: BaseProtocol[Any],
+ record_class: type[_Record],
+ ignore_custom_codec: bool,
+ ) -> None: ...
+ def _get_parameters(self) -> tuple[Type, ...]: ...
+ def _get_attributes(self) -> tuple[Attribute, ...]: ...
+ def _init_types(self) -> set[int]: ...
+ def _init_codecs(self) -> None: ...
+ def attach(self) -> None: ...
+ def detach(self) -> None: ...
+ def mark_closed(self) -> None: ...
+ def mark_unprepared(self) -> None: ...
+ def __reduce__(self) -> Any: ...
+
+class CoreProtocol:
+ backend_pid: Any
+ backend_secret: Any
+ __pyx_vtable__: Any
+ def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ...
+ def is_in_transaction(self) -> bool: ...
+ def __reduce__(self) -> Any: ...
+
+class BaseProtocol(CoreProtocol, Generic[_Record]):
+ queries_count: Any
+ is_ssl: bool
+ __pyx_vtable__: Any
+ def __init__(
+ self,
+ addr: object,
+ connected_fut: object,
+ con_params: _ConnectionParameters,
+ record_class: type[_Record],
+ loop: object,
+ ) -> None: ...
+ def set_connection(self, connection: object) -> None: ...
+ def get_server_pid(self, *args: object, **kwargs: object) -> int: ...
+ def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ...
+ def get_record_class(self) -> type[_Record]: ...
+ def abort(self) -> None: ...
+ async def bind(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Sequence[object],
+ portal_name: str,
+ timeout: _TimeoutType,
+ ) -> Any: ...
+ @overload
+ async def bind_execute(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Sequence[object],
+ portal_name: str,
+ limit: int,
+ return_extra: Literal[False],
+ timeout: _TimeoutType,
+ ) -> list[_OtherRecord]: ...
+ @overload
+ async def bind_execute(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Sequence[object],
+ portal_name: str,
+ limit: int,
+ return_extra: Literal[True],
+ timeout: _TimeoutType,
+ ) -> tuple[list[_OtherRecord], bytes, bool]: ...
+ @overload
+ async def bind_execute(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Sequence[object],
+ portal_name: str,
+ limit: int,
+ return_extra: bool,
+ timeout: _TimeoutType,
+ ) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ...
+ async def bind_execute_many(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Iterable[Sequence[object]],
+ portal_name: str,
+ timeout: _TimeoutType,
+ ) -> None: ...
+ async def close(self, timeout: _TimeoutType) -> None: ...
+ def _get_timeout(self, timeout: _TimeoutType) -> float | None: ...
+ def _is_cancelling(self) -> bool: ...
+ async def _wait_for_cancellation(self) -> None: ...
+ async def close_statement(
+ self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType
+ ) -> Any: ...
+ async def copy_in(self, *args: object, **kwargs: object) -> str: ...
+ async def copy_out(self, *args: object, **kwargs: object) -> str: ...
+ async def execute(self, *args: object, **kwargs: object) -> Any: ...
+ def is_closed(self, *args: object, **kwargs: object) -> Any: ...
+ def is_connected(self, *args: object, **kwargs: object) -> Any: ...
+ def data_received(self, data: object) -> None: ...
+ def connection_made(self, transport: object) -> None: ...
+ def connection_lost(self, exc: Exception | None) -> None: ...
+ def pause_writing(self, *args: object, **kwargs: object) -> Any: ...
+ @overload
+ async def prepare(
+ self,
+ stmt_name: str,
+ query: str,
+ timeout: float | None = ...,
+ *,
+ state: _PreparedStatementState,
+ ignore_custom_codec: bool = ...,
+ record_class: None,
+ ) -> _PreparedStatementState: ...
+ @overload
+ async def prepare(
+ self,
+ stmt_name: str,
+ query: str,
+ timeout: float | None = ...,
+ *,
+ state: None = ...,
+ ignore_custom_codec: bool = ...,
+ record_class: type[_OtherRecord],
+ ) -> PreparedStatementState[_OtherRecord]: ...
+ async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ...
+ async def query(self, *args: object, **kwargs: object) -> str: ...
+ def resume_writing(self, *args: object, **kwargs: object) -> Any: ...
+ def __reduce__(self) -> Any: ...
+
+@final
+class Codec:
+ __pyx_vtable__: Any
+ def __reduce__(self) -> Any: ...
+
+class DataCodecConfig:
+ __pyx_vtable__: Any
+ def __init__(self) -> None: ...
+ def add_python_codec(
+ self,
+ typeoid: int,
+ typename: str,
+ typeschema: str,
+ typekind: str,
+ typeinfos: Iterable[object],
+ encoder: Callable[[ConnectionSettings, WriteBuffer, object], object],
+ decoder: Callable[..., object],
+ format: object,
+ xformat: object,
+ ) -> Any: ...
+ def add_types(self, types: Iterable[object]) -> Any: ...
+ def clear_type_cache(self) -> None: ...
+ def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ...
+ def remove_python_codec(
+ self, typeoid: int, typename: str, typeschema: str
+ ) -> Any: ...
+ def set_builtin_type_codec(
+ self,
+ typeoid: int,
+ typename: str,
+ typeschema: str,
+ typekind: str,
+ alias_to: str,
+ format: object = ...,
+ ) -> Any: ...
+ def __reduce__(self) -> Any: ...
+
+class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ...
+
+class Timer:
+ def __init__(self, budget: float | None) -> None: ...
+ def __enter__(self) -> None: ...
+ def __exit__(self, et: object, e: object, tb: object) -> None: ...
+ def get_remaining_budget(self) -> float: ...
+ def has_budget_greater_than(self, amount: float) -> bool: ...
+
+@final
+class SCRAMAuthentication:
+ AUTHENTICATION_METHODS: ClassVar[list[str]]
+ DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int]
+ DIGEST = sha256
+ REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]]
+ REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]]
+ SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]]
+ authentication_method: bytes
+ authorization_message: bytes | None
+ client_channel_binding: bytes
+ client_first_message_bare: bytes | None
+ client_nonce: bytes | None
+ client_proof: bytes | None
+ password_salt: bytes | None
+ password_iterations: int
+ server_first_message: bytes | None
+ server_key: hmac.HMAC | None
+ server_nonce: bytes | None
diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx
index ac653bd0..acce4e9f 100644
--- a/asyncpg/protocol/protocol.pyx
+++ b/asyncpg/protocol/protocol.pyx
@@ -13,7 +13,7 @@ cimport cpython
import asyncio
import builtins
import codecs
-import collections
+import collections.abc
import socket
import time
import weakref
@@ -34,11 +34,11 @@ from asyncpg.pgproto.pgproto cimport (
from asyncpg.pgproto cimport pgproto
from asyncpg.protocol cimport cpythonx
-from asyncpg.protocol cimport record
+from asyncpg.protocol cimport recordcapi
from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \
int32_t, uint32_t, int64_t, uint64_t, \
- UINT32_MAX
+ INT32_MAX, UINT32_MAX
from asyncpg.exceptions import _base as apg_exc_base
from asyncpg import compat
@@ -46,6 +46,7 @@ from asyncpg import types as apg_types
from asyncpg import exceptions as apg_exc
from asyncpg.pgproto cimport hton
+from asyncpg.protocol.record import Record, RecordDescriptor
include "consts.pxi"
@@ -73,9 +74,9 @@ NO_TIMEOUT = object()
cdef class BaseProtocol(CoreProtocol):
- def __init__(self, addr, connected_fut, con_params, loop):
+ def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
# type of `con_params` is `_ConnectionParameters`
- CoreProtocol.__init__(self, con_params)
+ CoreProtocol.__init__(self, addr, con_params)
self.loop = loop
self.transport = None
@@ -83,8 +84,8 @@ cdef class BaseProtocol(CoreProtocol):
self.cancel_waiter = None
self.cancel_sent_waiter = None
- self.address = addr
- self.settings = ConnectionSettings((self.address, con_params.database))
+ self.settings = ConnectionSettings((addr, con_params.database))
+ self.record_class = record_class
self.statement = None
self.return_extra = False
@@ -93,15 +94,15 @@ cdef class BaseProtocol(CoreProtocol):
self.closing = False
self.is_reading = True
- self.writing_allowed = asyncio.Event(loop=self.loop)
+ self.writing_allowed = asyncio.Event()
self.writing_allowed.set()
self.timeout_handle = None
- self.timeout_callback = self._on_timeout
- self.completed_callback = self._on_waiter_completed
self.queries_count = 0
+ self._is_ssl = False
+
try:
self.create_future = loop.create_future
except AttributeError:
@@ -122,10 +123,8 @@ cdef class BaseProtocol(CoreProtocol):
def get_settings(self):
return self.settings
- def is_in_transaction(self):
- # PQTRANS_INTRANS = idle, within transaction block
- # PQTRANS_INERROR = idle, within failed transaction
- return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
+ def get_record_class(self):
+ return self.record_class
cdef inline resume_reading(self):
if not self.is_reading:
@@ -137,9 +136,11 @@ cdef class BaseProtocol(CoreProtocol):
self.is_reading = False
self.transport.pause_reading()
- @cython.iterable_coroutine
async def prepare(self, stmt_name, query, timeout,
- PreparedStatementState state=None):
+ *,
+ PreparedStatementState state=None,
+ ignore_custom_codec=False,
+ record_class):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
@@ -151,10 +152,11 @@ cdef class BaseProtocol(CoreProtocol):
waiter = self._new_waiter(timeout)
try:
- self._prepare(stmt_name, query) # network op
+ self._prepare_and_describe(stmt_name, query) # network op
self.last_query = query
if state is None:
- state = PreparedStatementState(stmt_name, query, self)
+ state = PreparedStatementState(
+ stmt_name, query, self, record_class, ignore_custom_codec)
self.statement = state
except Exception as ex:
waiter.set_exception(ex)
@@ -162,11 +164,15 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
- async def bind_execute(self, PreparedStatementState state, args,
- str portal_name, int limit, return_extra,
- timeout):
-
+ async def bind_execute(
+ self,
+ state: PreparedStatementState,
+ args,
+ portal_name: str,
+ limit: int,
+ return_extra: bool,
+ timeout,
+ ):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
@@ -179,6 +185,9 @@ cdef class BaseProtocol(CoreProtocol):
waiter = self._new_waiter(timeout)
try:
+ if not state.prepared:
+ self._send_parse_message(state.name, state.query)
+
self._bind_execute(
portal_name,
state.name,
@@ -195,10 +204,14 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
- async def bind_execute_many(self, PreparedStatementState state, args,
- str portal_name, timeout):
-
+ async def bind_execute_many(
+ self,
+ state: PreparedStatementState,
+ args,
+ portal_name: str,
+ timeout,
+ return_rows: bool,
+ ):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
@@ -207,31 +220,51 @@ cdef class BaseProtocol(CoreProtocol):
self._check_state()
timeout = self._get_timeout_impl(timeout)
+ timer = Timer(timeout)
# Make sure the argument sequence is encoded lazily with
# this generator expression to keep the memory pressure under
# control.
- data_gen = (state._encode_bind_msg(b) for b in args)
+ data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args))
arg_bufs = iter(data_gen)
waiter = self._new_waiter(timeout)
try:
- self._bind_execute_many(
+ if not state.prepared:
+ self._send_parse_message(state.name, state.query)
+
+ more = self._bind_execute_many(
portal_name,
state.name,
- arg_bufs) # network op
+ arg_bufs,
+ return_rows) # network op
self.last_query = state.query
self.statement = state
self.return_extra = False
self.queries_count += 1
+
+ while more:
+ with timer:
+ await compat.wait_for(
+ self.writing_allowed.wait(),
+ timeout=timer.get_remaining_budget())
+ # On Windows the above event somehow won't allow context
+ # switch, so forcing one with sleep(0) here
+ await asyncio.sleep(0)
+ if not timer.has_budget_greater_than(0):
+ raise asyncio.TimeoutError
+ more = self._bind_execute_many_more() # network op
+
+ except asyncio.TimeoutError as e:
+ self._bind_execute_many_fail(e) # network op
+
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter
- @cython.iterable_coroutine
async def bind(self, PreparedStatementState state, args,
str portal_name, timeout):
@@ -260,7 +293,6 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
async def execute(self, PreparedStatementState state,
str portal_name, int limit, return_extra,
timeout):
@@ -290,7 +322,28 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
+ async def close_portal(self, str portal_name, timeout):
+
+ if self.cancel_waiter is not None:
+ await self.cancel_waiter
+ if self.cancel_sent_waiter is not None:
+ await self.cancel_sent_waiter
+ self.cancel_sent_waiter = None
+
+ self._check_state()
+ timeout = self._get_timeout_impl(timeout)
+
+ waiter = self._new_waiter(timeout)
+ try:
+ self._close(
+ portal_name,
+ True) # network op
+ except Exception as ex:
+ waiter.set_exception(ex)
+ self._coreproto_error()
+ finally:
+ return await waiter
+
async def query(self, query, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
@@ -315,7 +368,6 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
async def copy_out(self, copy_stmt, sink, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
@@ -346,11 +398,10 @@ cdef class BaseProtocol(CoreProtocol):
if buffer:
try:
with timer:
- await asyncio.wait_for(
+ await compat.wait_for(
sink(buffer),
- timeout=timer.get_remaining_budget(),
- loop=self.loop)
- except Exception as ex:
+ timeout=timer.get_remaining_budget())
+ except (Exception, asyncio.CancelledError) as ex:
# Abort the COPY operation on any error in
# output sink.
self._request_cancel()
@@ -370,7 +421,6 @@ cdef class BaseProtocol(CoreProtocol):
return status_msg
- @cython.iterable_coroutine
async def copy_in(self, copy_stmt, reader, data,
records, PreparedStatementState record_stmt, timeout):
cdef:
@@ -417,23 +467,44 @@ cdef class BaseProtocol(CoreProtocol):
'no binary format encoder for '
'type {} (OID {})'.format(codec.name, codec.oid))
- for row in records:
- # Tuple header
- wbuf.write_int16(num_cols)
- # Tuple data
- for i in range(num_cols):
- item = row[i]
- if item is None:
- wbuf.write_int32(-1)
- else:
- codec = cpython.PyTuple_GET_ITEM(codecs, i)
- codec.encode(settings, wbuf, item)
-
- if wbuf.len() >= _COPY_BUFFER_SIZE:
- with timer:
- await self.writing_allowed.wait()
- self._write_copy_data_msg(wbuf)
- wbuf = WriteBuffer.new()
+ if isinstance(records, collections.abc.AsyncIterable):
+ async for row in records:
+ # Tuple header
+ wbuf.write_int16(num_cols)
+ # Tuple data
+ for i in range(num_cols):
+ item = row[i]
+ if item is None:
+ wbuf.write_int32(-1)
+ else:
+ codec = cpython.PyTuple_GET_ITEM(
+ codecs, i)
+ codec.encode(settings, wbuf, item)
+
+ if wbuf.len() >= _COPY_BUFFER_SIZE:
+ with timer:
+ await self.writing_allowed.wait()
+ self._write_copy_data_msg(wbuf)
+ wbuf = WriteBuffer.new()
+ else:
+ for row in records:
+ # Tuple header
+ wbuf.write_int16(num_cols)
+ # Tuple data
+ for i in range(num_cols):
+ item = row[i]
+ if item is None:
+ wbuf.write_int32(-1)
+ else:
+ codec = cpython.PyTuple_GET_ITEM(
+ codecs, i)
+ codec.encode(settings, wbuf, item)
+
+ if wbuf.len() >= _COPY_BUFFER_SIZE:
+ with timer:
+ await self.writing_allowed.wait()
+ self._write_copy_data_msg(wbuf)
+ wbuf = WriteBuffer.new()
# End of binary copy.
wbuf.write_int16(-1)
@@ -454,10 +525,9 @@ cdef class BaseProtocol(CoreProtocol):
with timer:
await self.writing_allowed.wait()
with timer:
- chunk = await asyncio.wait_for(
+ chunk = await compat.wait_for(
iterator.__anext__(),
- timeout=timer.get_remaining_budget(),
- loop=self.loop)
+ timeout=timer.get_remaining_budget())
self._write_copy_data_msg(chunk)
except builtins.StopAsyncIteration:
pass
@@ -476,7 +546,7 @@ cdef class BaseProtocol(CoreProtocol):
else:
raise apg_exc.InternalClientError('TimoutError was not raised')
- except Exception as e:
+ except (Exception, asyncio.CancelledError) as e:
self._write_copy_fail_msg(str(e))
self._request_cancel()
# Make asyncio shut up about unretrieved QueryCanceledError
@@ -489,7 +559,6 @@ cdef class BaseProtocol(CoreProtocol):
return status_msg
- @cython.iterable_coroutine
async def close_statement(self, PreparedStatementState state, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
@@ -528,8 +597,8 @@ cdef class BaseProtocol(CoreProtocol):
self._handle_waiter_on_connection_lost(None)
self._terminate()
self.transport.abort()
+ self.transport = None
- @cython.iterable_coroutine
async def close(self, timeout):
if self.closing:
return
@@ -568,7 +637,7 @@ cdef class BaseProtocol(CoreProtocol):
pass
finally:
self.waiter = None
- self.transport.abort()
+ self.transport.abort()
def _request_cancel(self):
self.cancel_waiter = self.create_future()
@@ -588,6 +657,13 @@ cdef class BaseProtocol(CoreProtocol):
})
self.abort()
+ if self.state == PROTOCOL_PREPARE:
+ # we need to send a SYNC to server if we cancel during the PREPARE phase
+ # because the PREPARE sequence does not send a SYNC itself.
+ # we cannot send this extra SYNC if we are not in PREPARE phase,
+ # because then we would issue two SYNCs and we would get two ReadyForQuery
+ # replies, which our current state machine implementation cannot handle
+ self._write(SYNC_MESSAGE)
self._set_state(PROTOCOL_CANCELLED)
def _on_timeout(self, fut):
@@ -599,12 +675,12 @@ cdef class BaseProtocol(CoreProtocol):
self.waiter.set_exception(asyncio.TimeoutError())
def _on_waiter_completed(self, fut):
+ if self.timeout_handle:
+ self.timeout_handle.cancel()
+ self.timeout_handle = None
if fut is not self.waiter or self.cancel_waiter is not None:
return
if fut.cancelled():
- if self.timeout_handle:
- self.timeout_handle.cancel()
- self.timeout_handle = None
self._request_cancel()
def _create_future_fallback(self):
@@ -665,7 +741,6 @@ cdef class BaseProtocol(CoreProtocol):
self.cancel_sent_waiter is not None
)
- @cython.iterable_coroutine
async def _wait_for_cancellation(self):
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
@@ -691,8 +766,8 @@ cdef class BaseProtocol(CoreProtocol):
self.waiter = self.create_future()
if timeout is not None:
self.timeout_handle = self.loop.call_later(
- timeout, self.timeout_callback, self.waiter)
- self.waiter.add_done_callback(self.completed_callback)
+ timeout, self._on_timeout, self.waiter)
+ self.waiter.add_done_callback(self._on_waiter_completed)
return self.waiter
cdef _on_result__connect(self, object waiter):
@@ -880,6 +955,9 @@ cdef class BaseProtocol(CoreProtocol):
cdef _write(self, buf):
self.transport.write(memoryview(buf))
+ cdef _writelines(self, list buffers):
+ self.transport.writelines(buffers)
+
# asyncio callbacks:
def data_received(self, data):
@@ -915,6 +993,14 @@ cdef class BaseProtocol(CoreProtocol):
def resume_writing(self):
self.writing_allowed.set()
+ @property
+ def is_ssl(self):
+ return self._is_ssl
+
+ @is_ssl.setter
+ def is_ssl(self, value):
+ self._is_ssl = value
+
class Timer:
def __init__(self, budget):
@@ -932,6 +1018,13 @@ class Timer:
def get_remaining_budget(self):
return self._budget
+ def has_budget_greater_than(self, amount):
+ if self._budget is None:
+ # Unlimited budget.
+ return True
+ else:
+ return self._budget > amount
+
class Protocol(BaseProtocol, asyncio.Protocol):
pass
@@ -945,17 +1038,14 @@ def _create_record(object mapping, tuple elems):
int32_t i
if mapping is None:
- desc = record.ApgRecordDesc_New({}, ())
+ desc = RecordDescriptor({}, ())
else:
- desc = record.ApgRecordDesc_New(
+ desc = RecordDescriptor(
mapping, tuple(mapping) if mapping else ())
- rec = record.ApgRecord_New(desc, len(elems))
+ rec = desc.make_record(Record, len(elems))
for i in range(len(elems)):
elem = elems[i]
cpython.Py_INCREF(elem)
- record.ApgRecord_SET_ITEM(rec, i, elem)
+ recordcapi.ApgRecord_SET_ITEM(rec, i, elem)
return rec
-
-
-Record =