扫码登录,获取cookies

This commit is contained in:
2026-03-09 16:10:29 +08:00
parent 754e720ba7
commit 8229208165
7775 changed files with 1150053 additions and 208 deletions

View File

@@ -0,0 +1 @@
/_scm_version.py

View File

@@ -0,0 +1,69 @@
"""
aiomysql: A pure-Python MySQL client library for asyncio.
Copyright (c) 2010, 2013-2014 PyMySQL contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from pymysql.converters import escape_dict, escape_sequence, escape_string
from pymysql.err import (Warning, Error, InterfaceError, DataError,
DatabaseError, OperationalError, IntegrityError,
InternalError,
NotSupportedError, ProgrammingError, MySQLError)
from .connection import Connection, connect
from .cursors import Cursor, SSCursor, DictCursor, SSDictCursor
from .pool import create_pool, Pool
from ._version import version
__version__ = version
__all__ = [
# Errors
'Error',
'DataError',
'DatabaseError',
'IntegrityError',
'InterfaceError',
'InternalError',
'MySQLError',
'NotSupportedError',
'OperationalError',
'ProgrammingError',
'Warning',
'escape_dict',
'escape_sequence',
'escape_string',
'Connection',
'Pool',
'connect',
'create_pool',
'Cursor',
'SSCursor',
'DictCursor',
'SSDictCursor'
]
(Connection, Pool, connect, create_pool, Cursor, SSCursor, DictCursor,
SSDictCursor) # pyflakes

View File

@@ -0,0 +1,3 @@
# This stub file is necessary because `_scm_version.py`
# autogenerated on build and absent on mypy checks time
version: str

View File

@@ -0,0 +1,4 @@
try:
from ._scm_version import version
except ImportError:
version = "unknown"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,706 @@
import re
import json
import warnings
import contextlib
from pymysql.err import (
Warning, Error, InterfaceError, DataError,
DatabaseError, OperationalError, IntegrityError, InternalError,
NotSupportedError, ProgrammingError)
from .log import logger
from .connection import FIELD_TYPE
# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18
#: Regular expression for :meth:`Cursor.executemany`.
#: executemany only supports simple bulk insert.
#: You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
re.IGNORECASE | re.DOTALL)
class Cursor:
"""Cursor is used to interact with the database."""
#: Max statement size which :meth:`executemany` generates.
#:
#: Max size of allowed statement is max_allowed_packet -
# packet_header_size.
#: Default value of max_allowed_packet is 1048576.
max_stmt_length = 1024000
def __init__(self, connection, echo=False):
"""Do not create an instance of a Cursor yourself. Call
connections.Connection.cursor().
"""
self._connection = connection
self._loop = self._connection.loop
self._description = None
self._rownumber = 0
self._rowcount = -1
self._arraysize = 1
self._executed = None
self._result = None
self._rows = None
self._lastrowid = None
self._echo = echo
@property
def connection(self):
"""This read-only attribute return a reference to the Connection
object on which the cursor was created."""
return self._connection
@property
def description(self):
"""This read-only attribute is a sequence of 7-item sequences.
Each of these sequences is a collections.namedtuple containing
information describing one result column:
0. name: the name of the column returned.
1. type_code: the type of the column.
2. display_size: the actual length of the column in bytes.
3. internal_size: the size in bytes of the column associated to
this column on the server.
4. precision: total number of significant digits in columns of
type NUMERIC. None for other types.
5. scale: count of decimal digits in the fractional part in
columns of type NUMERIC. None for other types.
6. null_ok: always None as not easy to retrieve from the libpq.
This attribute will be None for operations that do not
return rows or if the cursor has not had an operation invoked
via the execute() method yet.
"""
return self._description
@property
def rowcount(self):
"""Returns the number of rows that has been produced of affected.
This read-only attribute specifies the number of rows that the
last :meth:`execute` produced (for Data Query Language
statements like SELECT) or affected (for Data Manipulation
Language statements like UPDATE or INSERT).
The attribute is -1 in case no .execute() has been performed
on the cursor or the row count of the last operation if it
can't be determined by the interface.
"""
return self._rowcount
@property
def rownumber(self):
"""Row index.
This read-only attribute provides the current 0-based index of the
cursor in the result set or ``None`` if the index cannot be
determined.
"""
return self._rownumber
@property
def arraysize(self):
"""How many rows will be returned by fetchmany() call.
This read/write attribute specifies the number of rows to
fetch at a time with fetchmany(). It defaults to
1 meaning to fetch a single row at a time.
"""
return self._arraysize
@arraysize.setter
def arraysize(self, val):
"""How many rows will be returned by fetchmany() call.
This read/write attribute specifies the number of rows to
fetch at a time with fetchmany(). It defaults to
1 meaning to fetch a single row at a time.
"""
self._arraysize = val
@property
def lastrowid(self):
"""This read-only property returns the value generated for an
AUTO_INCREMENT column by the previous INSERT or UPDATE statement
or None when there is no such value available. For example,
if you perform an INSERT into a table that contains an AUTO_INCREMENT
column, lastrowid returns the AUTO_INCREMENT value for the new row.
"""
return self._lastrowid
@property
def echo(self):
"""Return echo mode status."""
return self._echo
@property
def closed(self):
"""The readonly property that returns ``True`` if connections was
detached from current cursor
"""
return True if not self._connection else False
async def close(self):
"""Closing a cursor just exhausts all remaining data."""
conn = self._connection
if conn is None:
return
try:
while (await self.nextset()):
pass
finally:
self._connection = None
def _get_db(self):
if not self._connection:
raise ProgrammingError("Cursor closed")
return self._connection
def _check_executed(self):
if not self._executed:
raise ProgrammingError("execute() first")
def _conv_row(self, row):
return row
def setinputsizes(self, *args):
"""Does nothing, required by DB API."""
def setoutputsizes(self, *args):
"""Does nothing, required by DB API."""
async def nextset(self):
"""Get the next query set"""
conn = self._get_db()
current_result = self._result
if current_result is None or current_result is not conn._result:
return
if not current_result.has_next:
return
self._result = None
self._clear_result()
await conn.next_result()
await self._do_get_result()
return True
def _escape_args(self, args, conn):
if isinstance(args, (tuple, list)):
return tuple(conn.escape(arg) for arg in args)
elif isinstance(args, dict):
return {key: conn.escape(val) for (key, val) in args.items()}
else:
# If it's not a dictionary let's try escaping it anyways.
# Worst case it will throw a Value error
return conn.escape(args)
def mogrify(self, query, args=None):
""" Returns the exact string that is sent to the database by calling
the execute() method. This method follows the extension to the DB
API 2.0 followed by Psycopg.
:param query: ``str`` sql statement
:param args: ``tuple`` or ``list`` of arguments for sql query
"""
conn = self._get_db()
if args is not None:
query = query % self._escape_args(args, conn)
return query
async def execute(self, query, args=None):
"""Executes the given operation
Executes the given operation substituting any markers with
the given parameters.
For example, getting all rows where id is 5:
cursor.execute("SELECT * FROM t1 WHERE id = %s", (5,))
:param query: ``str`` sql statement
:param args: ``tuple`` or ``list`` of arguments for sql query
:returns: ``int``, number of rows that has been produced of affected
"""
conn = self._get_db()
while (await self.nextset()):
pass
if args is not None:
query = query % self._escape_args(args, conn)
await self._query(query)
self._executed = query
if self._echo:
logger.info(query)
logger.info("%r", args)
return self._rowcount
async def executemany(self, query, args):
"""Execute the given operation multiple times
The executemany() method will execute the operation iterating
over the list of parameters in seq_params.
Example: Inserting 3 new employees and their phone number
data = [
('Jane','555-001'),
('Joe', '555-001'),
('John', '555-003')
]
stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')"
await cursor.executemany(stmt, data)
INSERT or REPLACE statements are optimized by batching the data,
that is using the MySQL multiple rows syntax.
:param query: `str`, sql statement
:param args: ``tuple`` or ``list`` of arguments for sql query
"""
if not args:
return
if self._echo:
logger.info("CALL %s", query)
logger.info("%r", args)
m = RE_INSERT_VALUES.match(query)
if m:
q_prefix = m.group(1) % ()
q_values = m.group(2).rstrip()
q_postfix = m.group(3) or ''
assert q_values[0] == '(' and q_values[-1] == ')'
return (await self._do_execute_many(
q_prefix, q_values, q_postfix, args, self.max_stmt_length,
self._get_db().encoding))
else:
rows = 0
for arg in args:
await self.execute(query, arg)
rows += self._rowcount
self._rowcount = rows
return self._rowcount
async def _do_execute_many(self, prefix, values, postfix, args,
max_stmt_length, encoding):
conn = self._get_db()
escape = self._escape_args
if isinstance(prefix, str):
prefix = prefix.encode(encoding)
if isinstance(postfix, str):
postfix = postfix.encode(encoding)
sql = bytearray(prefix)
args = iter(args)
v = values % escape(next(args), conn)
if isinstance(v, str):
v = v.encode(encoding, 'surrogateescape')
sql += v
rows = 0
for arg in args:
v = values % escape(arg, conn)
if isinstance(v, str):
v = v.encode(encoding, 'surrogateescape')
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
r = await self.execute(sql + postfix)
rows += r
sql = bytearray(prefix)
else:
sql += b','
sql += v
r = await self.execute(sql + postfix)
rows += r
self._rowcount = rows
return rows
async def callproc(self, procname, args=()):
"""Execute stored procedure procname with args
Compatibility warning: PEP-249 specifies that any modified
parameters must be returned. This is currently impossible
as they are only available by storing them in a server
variable and then retrieved by a query. Since stored
procedures return zero or more result sets, there is no
reliable way to get at OUT or INOUT parameters via callproc.
The server variables are named @_procname_n, where procname
is the parameter above and n is the position of the parameter
(from zero). Once all result sets generated by the procedure
have been fetched, you can issue a SELECT @_procname_0, ...
query using .execute() to get any OUT or INOUT values.
Compatibility warning: The act of calling a stored procedure
itself creates an empty result set. This appears after any
result sets generated by the procedure. This is non-standard
behavior with respect to the DB-API. Be sure to use nextset()
to advance through all result sets; otherwise you may get
disconnected.
:param procname: ``str``, name of procedure to execute on server
:param args: `sequence of parameters to use with procedure
:returns: the original args.
"""
conn = self._get_db()
if self._echo:
logger.info("CALL %s", procname)
logger.info("%r", args)
for index, arg in enumerate(args):
q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
await self._query(q)
await self.nextset()
_args = ','.join('@_%s_%d' % (procname, i) for i in range(len(args)))
q = f"CALL {procname}({_args})"
await self._query(q)
self._executed = q
return args
def fetchone(self):
"""Fetch the next row """
self._check_executed()
fut = self._loop.create_future()
if self._rows is None or self._rownumber >= len(self._rows):
fut.set_result(None)
return fut
result = self._rows[self._rownumber]
self._rownumber += 1
fut = self._loop.create_future()
fut.set_result(result)
return fut
def fetchmany(self, size=None):
"""Returns the next set of rows of a query result, returning a
list of tuples. When no more rows are available, it returns an
empty list.
The number of rows returned can be specified using the size argument,
which defaults to one
:param size: ``int`` number of rows to return
:returns: ``list`` of fetched rows
"""
self._check_executed()
fut = self._loop.create_future()
if self._rows is None:
fut.set_result([])
return fut
end = self._rownumber + (size or self._arraysize)
result = self._rows[self._rownumber:end]
self._rownumber = min(end, len(self._rows))
fut.set_result(result)
return fut
def fetchall(self):
"""Returns all rows of a query result set
:returns: ``list`` of fetched rows
"""
self._check_executed()
fut = self._loop.create_future()
if self._rows is None:
fut.set_result([])
return fut
if self._rownumber:
result = self._rows[self._rownumber:]
else:
result = self._rows
self._rownumber = len(self._rows)
fut.set_result(result)
return fut
def scroll(self, value, mode='relative'):
"""Scroll the cursor in the result set to a new position according
to mode.
If mode is relative (default), value is taken as offset to the
current position in the result set, if set to absolute, value
states an absolute target position. An IndexError should be raised in
case a scroll operation would leave the result set. In this case,
the cursor position is left undefined (ideal would be to
not move the cursor at all).
:param int value: move cursor to next position according to mode.
:param str mode: scroll mode, possible modes: `relative` and `absolute`
"""
self._check_executed()
if mode == 'relative':
r = self._rownumber + value
elif mode == 'absolute':
r = value
else:
raise ProgrammingError("unknown scroll mode %s" % mode)
if not (0 <= r < len(self._rows)):
raise IndexError("out of range")
self._rownumber = r
fut = self._loop.create_future()
fut.set_result(None)
return fut
async def _query(self, q):
conn = self._get_db()
self._last_executed = q
self._clear_result()
await conn.query(q)
await self._do_get_result()
def _clear_result(self):
self._rownumber = 0
self._result = None
self._rowcount = 0
self._description = None
self._lastrowid = None
self._rows = None
async def _do_get_result(self):
conn = self._get_db()
self._rownumber = 0
self._result = result = conn._result
self._rowcount = result.affected_rows
self._description = result.description
self._lastrowid = result.insert_id
self._rows = result.rows
if result.warning_count > 0:
await self._show_warnings(conn)
async def _show_warnings(self, conn):
if self._result and self._result.has_next:
return
ws = await conn.show_warnings()
if ws is None:
return
for w in ws:
msg = w[-1]
warnings.warn(str(msg), Warning, 4)
Warning = Warning
Error = Error
InterfaceError = InterfaceError
DatabaseError = DatabaseError
DataError = DataError
OperationalError = OperationalError
IntegrityError = IntegrityError
InternalError = InternalError
ProgrammingError = ProgrammingError
NotSupportedError = NotSupportedError
def __aiter__(self):
return self
async def __anext__(self):
ret = await self.fetchone()
if ret is not None:
return ret
else:
raise StopAsyncIteration # noqa
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
return
class _DeserializationCursorMixin:
async def _do_get_result(self):
await super()._do_get_result()
if self._rows:
self._rows = [self._deserialization_row(r) for r in self._rows]
def _deserialization_row(self, row):
if row is None:
return None
if isinstance(row, dict):
dict_flag = True
else:
row = list(row)
dict_flag = False
for index, (name, field_type, *n) in enumerate(self._description):
if field_type == FIELD_TYPE.JSON:
point = name if dict_flag else index
with contextlib.suppress(ValueError, TypeError):
row[point] = json.loads(row[point])
if dict_flag:
return row
else:
return tuple(row)
def _conv_row(self, row):
if row is None:
return None
row = super()._conv_row(row)
return self._deserialization_row(row)
class DeserializationCursor(_DeserializationCursorMixin, Cursor):
"""A cursor automatic deserialization of json type fields"""
class _DictCursorMixin:
# You can override this to use OrderedDict or other dict-like types.
dict_type = dict
async def _do_get_result(self):
await super()._do_get_result()
fields = []
if self._description:
for f in self._result.fields:
name = f.name
if name in fields:
name = f.table_name + '.' + name
fields.append(name)
self._fields = fields
if fields and self._rows:
self._rows = [self._conv_row(r) for r in self._rows]
def _conv_row(self, row):
if row is None:
return None
row = super()._conv_row(row)
return self.dict_type(zip(self._fields, row))
class DictCursor(_DictCursorMixin, Cursor):
"""A cursor which returns results as a dictionary"""
class SSCursor(Cursor):
"""Unbuffered Cursor, mainly useful for queries that return a lot of
data, or for connections to remote servers over a slow network.
Instead of copying every row of data into a buffer, this will fetch
rows as needed. The upside of this, is the client uses much less memory,
and rows are returned much faster when traveling over a slow network,
or if the result set is very big.
There are limitations, though. The MySQL protocol doesn't support
returning the total number of rows, so the only way to tell how many rows
there are is to iterate over every row returned. Also, it currently isn't
possible to scroll backwards, as only the current row is held in memory.
"""
async def close(self):
conn = self._connection
if conn is None:
return
if self._result is not None and self._result is conn._result:
await self._result._finish_unbuffered_query()
try:
while (await self.nextset()):
pass
finally:
self._connection = None
async def _query(self, q):
conn = self._get_db()
self._last_executed = q
await conn.query(q, unbuffered=True)
await self._do_get_result()
return self._rowcount
async def _read_next(self):
"""Read next row """
row = await self._result._read_rowdata_packet_unbuffered()
row = self._conv_row(row)
return row
async def fetchone(self):
""" Fetch next row """
self._check_executed()
row = await self._read_next()
if row is None:
return
self._rownumber += 1
return row
async def fetchall(self):
"""Fetch all, as per MySQLdb. Pretty useless for large queries, as
it is buffered.
"""
rows = []
while True:
row = await self.fetchone()
if row is None:
break
rows.append(row)
return rows
async def fetchmany(self, size=None):
"""Returns the next set of rows of a query result, returning a
list of tuples. When no more rows are available, it returns an
empty list.
The number of rows returned can be specified using the size argument,
which defaults to one
:param size: ``int`` number of rows to return
:returns: ``list`` of fetched rows
"""
self._check_executed()
if size is None:
size = self._arraysize
rows = []
for i in range(size):
row = await self._read_next()
if row is None:
break
rows.append(row)
self._rownumber += 1
return rows
async def scroll(self, value, mode='relative'):
"""Scroll the cursor in the result set to a new position
according to mode . Same as :meth:`Cursor.scroll`, but move cursor
on server side one by one row. If you want to move 20 rows forward
scroll will make 20 queries to move cursor. Currently only forward
scrolling is supported.
:param int value: move cursor to next position according to mode.
:param str mode: scroll mode, possible modes: `relative` and `absolute`
"""
self._check_executed()
if mode == 'relative':
if value < 0:
raise NotSupportedError("Backwards scrolling not supported "
"by this cursor")
for _ in range(value):
await self._read_next()
self._rownumber += value
elif mode == 'absolute':
if value < self._rownumber:
raise NotSupportedError(
"Backwards scrolling not supported by this cursor")
end = value - self._rownumber
for _ in range(end):
await self._read_next()
self._rownumber = value
else:
raise ProgrammingError("unknown scroll mode %s" % mode)
class SSDictCursor(_DictCursorMixin, SSCursor):
"""An unbuffered cursor, which returns results as a dictionary """

View File

@@ -0,0 +1,5 @@
"""Logging configuration."""
import logging
# Name the logger after the package.
logger = logging.getLogger(__package__)

View File

@@ -0,0 +1,270 @@
# based on aiopg pool
# https://github.com/aio-libs/aiopg/blob/master/aiopg/pool.py
import asyncio
import collections
import warnings
from .connection import connect
from .utils import (_PoolContextManager, _PoolConnectionContextManager,
_PoolAcquireContextManager)
def create_pool(minsize=1, maxsize=10, echo=False, pool_recycle=-1,
loop=None, **kwargs):
coro = _create_pool(minsize=minsize, maxsize=maxsize, echo=echo,
pool_recycle=pool_recycle, loop=loop, **kwargs)
return _PoolContextManager(coro)
async def _create_pool(minsize=1, maxsize=10, echo=False, pool_recycle=-1,
loop=None, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
pool = Pool(minsize=minsize, maxsize=maxsize, echo=echo,
pool_recycle=pool_recycle, loop=loop, **kwargs)
if minsize > 0:
async with pool._cond:
await pool._fill_free_pool(False)
return pool
class Pool(asyncio.AbstractServer):
"""Connection pool"""
def __init__(self, minsize, maxsize, echo, pool_recycle, loop, **kwargs):
if minsize < 0:
raise ValueError("minsize should be zero or greater")
if maxsize < minsize and maxsize != 0:
raise ValueError("maxsize should be not less than minsize")
self._minsize = minsize
self._loop = loop
self._conn_kwargs = kwargs
self._acquiring = 0
self._free = collections.deque(maxlen=maxsize or None)
self._cond = asyncio.Condition()
self._used = set()
self._terminated = set()
self._closing = False
self._closed = False
self._echo = echo
self._recycle = pool_recycle
@property
def echo(self):
return self._echo
@property
def minsize(self):
return self._minsize
@property
def maxsize(self):
return self._free.maxlen
@property
def size(self):
return self.freesize + len(self._used) + self._acquiring
@property
def freesize(self):
return len(self._free)
async def clear(self):
"""Close all free connections in pool."""
async with self._cond:
while self._free:
conn = self._free.popleft()
await conn.ensure_closed()
self._cond.notify()
@property
def closed(self):
"""
The readonly property that returns ``True`` if connections is closed.
"""
return self._closed
def close(self):
"""Close pool.
Mark all pool connections to be closed on getting back to pool.
Closed pool doesn't allow to acquire new connections.
"""
if self._closed:
return
self._closing = True
def terminate(self):
"""Terminate pool.
Close pool with instantly closing all acquired connections also.
"""
self.close()
for conn in list(self._used):
conn.close()
self._terminated.add(conn)
self._used.clear()
async def wait_closed(self):
"""Wait for closing all pool's connections."""
if self._closed:
return
if not self._closing:
raise RuntimeError(".wait_closed() should be called "
"after .close()")
while self._free:
conn = self._free.popleft()
conn.close()
async with self._cond:
while self.size > self.freesize:
await self._cond.wait()
self._closed = True
def acquire(self):
"""Acquire free connection from the pool."""
coro = self._acquire()
return _PoolAcquireContextManager(coro, self)
async def _acquire(self):
if self._closing:
raise RuntimeError("Cannot acquire connection after closing pool")
async with self._cond:
while True:
await self._fill_free_pool(True)
if self._free:
conn = self._free.popleft()
assert not conn.closed, conn
assert conn not in self._used, (conn, self._used)
self._used.add(conn)
return conn
else:
await self._cond.wait()
async def _fill_free_pool(self, override_min):
# iterate over free connections and remove timed out ones
free_size = len(self._free)
n = 0
while n < free_size:
conn = self._free[-1]
if conn._reader.at_eof() or conn._reader.exception():
self._free.pop()
conn.close()
# On MySQL 8.0 a timed out connection sends an error packet before
# closing the connection, preventing us from relying on at_eof().
# This relies on our custom StreamReader, as eof_received is not
# present in asyncio.StreamReader.
elif conn._reader.eof_received:
self._free.pop()
conn.close()
elif (self._recycle > -1 and
self._loop.time() - conn.last_usage > self._recycle):
self._free.pop()
conn.close()
else:
self._free.rotate()
n += 1
while self.size < self.minsize:
self._acquiring += 1
try:
conn = await connect(echo=self._echo, loop=self._loop,
**self._conn_kwargs)
# raise exception if pool is closing
self._free.append(conn)
self._cond.notify()
finally:
self._acquiring -= 1
if self._free:
return
if override_min and (not self.maxsize or self.size < self.maxsize):
self._acquiring += 1
try:
conn = await connect(echo=self._echo, loop=self._loop,
**self._conn_kwargs)
# raise exception if pool is closing
self._free.append(conn)
self._cond.notify()
finally:
self._acquiring -= 1
async def _wakeup(self):
async with self._cond:
self._cond.notify()
def release(self, conn):
"""Release free connection back to the connection pool.
This is **NOT** a coroutine.
"""
fut = self._loop.create_future()
fut.set_result(None)
if conn in self._terminated:
assert conn.closed, conn
self._terminated.remove(conn)
return fut
assert conn in self._used, (conn, self._used)
self._used.remove(conn)
if not conn.closed:
in_trans = conn.get_transaction_status()
if in_trans:
conn.close()
return fut
if self._closing:
conn.close()
else:
self._free.append(conn)
fut = self._loop.create_task(self._wakeup())
return fut
def __enter__(self):
raise RuntimeError(
'"yield from" should be used as context manager expression')
def __exit__(self, *args):
# This must exist because __enter__ exists, even though that
# always raises; that's how the with-statement works.
pass # pragma: nocover
def __iter__(self):
# This is not a coroutine. It is meant to enable the idiom:
#
# with (yield from pool) as conn:
# <block>
#
# as an alternative to:
#
# conn = yield from pool.acquire()
# try:
# <block>
# finally:
# conn.release()
conn = yield from self.acquire()
return _PoolConnectionContextManager(self, conn)
def __await__(self):
msg = "with await pool as conn deprecated, use" \
"async with pool.acquire() as conn instead"
warnings.warn(msg, DeprecationWarning, stacklevel=2)
conn = yield from self.acquire()
return _PoolConnectionContextManager(self, conn)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.close()
await self.wait_closed()

View File

@@ -0,0 +1,14 @@
"""Optional support for sqlalchemy.sql dynamic query generation."""
from .connection import SAConnection
from .engine import create_engine, Engine
from .exc import (Error, ArgumentError, InvalidRequestError,
NoSuchColumnError, ResourceClosedError)
__all__ = ('create_engine', 'SAConnection', 'Error',
'ArgumentError', 'InvalidRequestError', 'NoSuchColumnError',
'ResourceClosedError', 'Engine')
(SAConnection, Error, ArgumentError, InvalidRequestError,
NoSuchColumnError, ResourceClosedError, create_engine, Engine)

View File

@@ -0,0 +1,425 @@
# ported from:
# https://github.com/aio-libs/aiopg/blob/master/aiopg/sa/connection.py
import weakref
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.dml import UpdateBase
from sqlalchemy.sql.ddl import DDLElement
from . import exc
from .result import create_result_proxy
from .transaction import (RootTransaction, Transaction,
NestedTransaction, TwoPhaseTransaction)
from ..utils import _TransactionContextManager, _SAConnectionContextManager
def noop(k):
return k
class SAConnection:
def __init__(self, connection, engine, compiled_cache=None):
self._connection = connection
self._transaction = None
self._savepoint_seq = 0
self._weak_results = weakref.WeakSet()
self._engine = engine
self._dialect = engine.dialect
self._compiled_cache = compiled_cache
def execute(self, query, *multiparams, **params):
"""Executes a SQL query with optional parameters.
query - a SQL query string or any sqlalchemy expression.
*multiparams/**params - represent bound parameter values to be
used in the execution. Typically, the format is a dictionary
passed to *multiparams:
await conn.execute(
table.insert(),
{"id":1, "value":"v1"},
)
...or individual key/values interpreted by **params::
await conn.execute(
table.insert(), id=1, value="v1"
)
In the case that a plain SQL string is passed, a tuple or
individual values in *multiparams may be passed::
await conn.execute(
"INSERT INTO table (id, value) VALUES (%d, %s)",
(1, "v1")
)
await conn.execute(
"INSERT INTO table (id, value) VALUES (%s, %s)",
1, "v1"
)
Returns ResultProxy instance with results of SQL query
execution.
"""
coro = self._execute(query, *multiparams, **params)
return _SAConnectionContextManager(coro)
def _base_params(self, query, dp, compiled, is_update):
"""
handle params
"""
if dp and isinstance(dp, (list, tuple)):
if is_update:
dp = {c.key: pval for c, pval in zip(query.table.c, dp)}
else:
raise exc.ArgumentError(
"Don't mix sqlalchemy SELECT "
"clause with positional "
"parameters"
)
compiled_params = compiled.construct_params(dp)
processors = compiled._bind_processors
params = [{
key: processors.get(key, noop)(compiled_params[key])
for key in compiled_params
}]
post_processed_params = self._dialect.execute_sequence_format(params)
return post_processed_params[0]
async def _executemany(self, query, dps, cursor):
"""
executemany
"""
result_map = None
if isinstance(query, str):
await cursor.executemany(query, dps)
elif isinstance(query, DDLElement):
raise exc.ArgumentError(
"Don't mix sqlalchemy DDL clause "
"and execution with parameters"
)
elif isinstance(query, ClauseElement):
compiled = query.compile(dialect=self._dialect)
params = []
is_update = isinstance(query, UpdateBase)
for dp in dps:
params.append(
self._base_params(
query,
dp,
compiled,
is_update,
)
)
await cursor.executemany(str(compiled), params)
result_map = compiled._result_columns
else:
raise exc.ArgumentError(
"sql statement should be str or "
"SQLAlchemy data "
"selection/modification clause"
)
ret = await create_result_proxy(
self,
cursor,
self._dialect,
result_map
)
self._weak_results.add(ret)
return ret
async def _execute(self, query, *multiparams, **params):
cursor = await self._connection.cursor()
dp = _distill_params(multiparams, params)
if len(dp) > 1:
return await self._executemany(query, dp, cursor)
elif dp:
dp = dp[0]
result_map = None
if isinstance(query, str):
await cursor.execute(query, dp or None)
elif isinstance(query, ClauseElement):
if self._compiled_cache is not None:
key = query
compiled = self._compiled_cache.get(key)
if not compiled:
compiled = query.compile(dialect=self._dialect)
if dp and dp.keys() == compiled.params.keys() \
or not (dp or compiled.params):
# we only want queries with bound params in cache
self._compiled_cache[key] = compiled
else:
compiled = query.compile(dialect=self._dialect)
if not isinstance(query, DDLElement):
post_processed_params = self._base_params(
query,
dp,
compiled,
isinstance(query, UpdateBase)
)
result_map = compiled._result_columns
else:
if dp:
raise exc.ArgumentError("Don't mix sqlalchemy DDL clause "
"and execution with parameters")
post_processed_params = compiled.construct_params()
result_map = None
await cursor.execute(str(compiled), post_processed_params)
else:
raise exc.ArgumentError("sql statement should be str or "
"SQLAlchemy data "
"selection/modification clause")
ret = await create_result_proxy(
self, cursor, self._dialect, result_map
)
self._weak_results.add(ret)
return ret
async def scalar(self, query, *multiparams, **params):
"""Executes a SQL query and returns a scalar value."""
res = await self.execute(query, *multiparams, **params)
return (await res.scalar())
@property
def closed(self):
"""The readonly property that returns True if connections is closed."""
return self._connection is None or self._connection.closed
@property
def connection(self):
return self._connection
def begin(self):
"""Begin a transaction and return a transaction handle.
The returned object is an instance of Transaction. This
object represents the "scope" of the transaction, which
completes when either the .rollback or .commit method is
called.
Nested calls to .begin on the same SAConnection instance will
return new Transaction objects that represent an emulated
transaction within the scope of the enclosing transaction,
that is::
trans = await conn.begin() # outermost transaction
trans2 = await conn.begin() # "nested"
await trans2.commit() # does nothing
await trans.commit() # actually commits
Calls to .commit only have an effect when invoked via the
outermost Transaction object, though the .rollback method of
any of the Transaction objects will roll back the transaction.
See also:
.begin_nested - use a SAVEPOINT
.begin_twophase - use a two phase/XA transaction
"""
coro = self._begin()
return _TransactionContextManager(coro)
async def _begin(self):
if self._transaction is None:
self._transaction = RootTransaction(self)
await self._begin_impl()
return self._transaction
else:
return Transaction(self, self._transaction)
async def _begin_impl(self):
cur = await self._connection.cursor()
try:
await cur.execute('BEGIN')
finally:
await cur.close()
async def _commit_impl(self):
cur = await self._connection.cursor()
try:
await cur.execute('COMMIT')
finally:
await cur.close()
self._transaction = None
async def _rollback_impl(self):
cur = await self._connection.cursor()
try:
await cur.execute('ROLLBACK')
finally:
await cur.close()
self._transaction = None
async def begin_nested(self):
"""Begin a nested transaction and return a transaction handle.
The returned object is an instance of :class:`.NestedTransaction`.
Nested transactions require SAVEPOINT support in the
underlying database. Any transaction in the hierarchy may
.commit() and .rollback(), however the outermost transaction
still controls the overall .commit() or .rollback() of the
transaction of a whole.
"""
if self._transaction is None:
self._transaction = RootTransaction(self)
await self._begin_impl()
else:
self._transaction = NestedTransaction(self, self._transaction)
self._transaction._savepoint = await self._savepoint_impl()
return self._transaction
async def _savepoint_impl(self, name=None):
self._savepoint_seq += 1
name = 'aiomysql_sa_savepoint_%s' % self._savepoint_seq
cur = await self._connection.cursor()
try:
await cur.execute('SAVEPOINT ' + name)
return name
finally:
await cur.close()
async def _rollback_to_savepoint_impl(self, name, parent):
cur = await self._connection.cursor()
try:
await cur.execute('ROLLBACK TO SAVEPOINT ' + name)
finally:
await cur.close()
self._transaction = parent
async def _release_savepoint_impl(self, name, parent):
cur = await self._connection.cursor()
try:
await cur.execute('RELEASE SAVEPOINT ' + name)
finally:
await cur.close()
self._transaction = parent
async def begin_twophase(self, xid=None):
"""Begin a two-phase or XA transaction and return a transaction
handle.
The returned object is an instance of
TwoPhaseTransaction, which in addition to the
methods provided by Transaction, also provides a
TwoPhaseTransaction.prepare() method.
xid - the two phase transaction id. If not supplied, a
random id will be generated.
"""
if self._transaction is not None:
raise exc.InvalidRequestError(
"Cannot start a two phase transaction when a transaction "
"is already in progress.")
if xid is None:
xid = self._dialect.create_xid()
self._transaction = TwoPhaseTransaction(self, xid)
await self.execute("XA START %s", xid)
return self._transaction
async def _prepare_twophase_impl(self, xid):
await self.execute("XA END '%s'" % xid)
await self.execute("XA PREPARE '%s'" % xid)
async def recover_twophase(self):
"""Return a list of prepared twophase transaction ids."""
result = await self.execute("XA RECOVER;")
return [row[0] for row in result]
async def rollback_prepared(self, xid, *, is_prepared=True):
"""Rollback prepared twophase transaction."""
if not is_prepared:
await self.execute("XA END '%s'" % xid)
await self.execute("XA ROLLBACK '%s'" % xid)
async def commit_prepared(self, xid, *, is_prepared=True):
"""Commit prepared twophase transaction."""
if not is_prepared:
await self.execute("XA END '%s'" % xid)
await self.execute("XA COMMIT '%s'" % xid)
@property
def in_transaction(self):
"""Return True if a transaction is in progress."""
return self._transaction is not None and self._transaction.is_active
async def close(self):
"""Close this SAConnection.
This results in a release of the underlying database
resources, that is, the underlying connection referenced
internally. The underlying connection is typically restored
back to the connection-holding Pool referenced by the Engine
that produced this SAConnection. Any transactional state
present on the underlying connection is also unconditionally
released via calling Transaction.rollback() method.
After .close() is called, the SAConnection is permanently in a
closed state, and will allow no further operations.
"""
if self._connection is None:
return
if self._transaction is not None:
await self._transaction.rollback()
self._transaction = None
# don't close underlying connection, it can be reused by pool
# conn.close()
self._engine.release(self)
self._connection = None
self._engine = None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
def _distill_params(multiparams, params):
"""Given arguments from the calling form *multiparams, **params,
return a list of bind parameter structures, usually a list of
dictionaries.
In the case of 'raw' execution which accepts positional parameters,
it may be a list of tuples or lists.
"""
if not multiparams:
if params:
return [params]
else:
return []
elif len(multiparams) == 1:
zero = multiparams[0]
if isinstance(zero, (list, tuple)):
if not zero or hasattr(zero[0], '__iter__') and \
not hasattr(zero[0], 'strip'):
# execute(stmt, [{}, {}, {}, ...])
# execute(stmt, [(), (), (), ...])
return zero
else:
# execute(stmt, ("value", "value"))
return [zero]
elif hasattr(zero, 'keys'):
# execute(stmt, {"key":"value"})
return [zero]
else:
# execute(stmt, "value")
return [[zero]]
else:
if (hasattr(multiparams[0], '__iter__') and
not hasattr(multiparams[0], 'strip')):
return multiparams
else:
return [multiparams]

View File

@@ -0,0 +1,235 @@
# ported from:
# https://github.com/aio-libs/aiopg/blob/master/aiopg/sa/engine.py
import asyncio
import aiomysql
from .connection import SAConnection
from .exc import InvalidRequestError, ArgumentError
from ..utils import _PoolContextManager, _PoolAcquireContextManager
from ..cursors import (
Cursor, DeserializationCursor, DictCursor, SSCursor, SSDictCursor)
try:
from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql
from sqlalchemy.dialects.mysql.mysqldb import MySQLCompiler_mysqldb
except ImportError: # pragma: no cover
raise ImportError('aiomysql.sa requires sqlalchemy')
class MySQLCompiler_pymysql(MySQLCompiler_mysqldb):
def construct_params(self, params=None, _group_number=None, _check=True):
pd = super().construct_params(params, _group_number, _check)
for column in self.prefetch:
pd[column.key] = self._exec_default(column.default)
return pd
def _exec_default(self, default):
if default.is_callable:
return default.arg(self.dialect)
else:
return default.arg
_dialect = MySQLDialect_pymysql(paramstyle='pyformat')
_dialect.statement_compiler = MySQLCompiler_pymysql
_dialect.default_paramstyle = 'pyformat'
def create_engine(minsize=1, maxsize=10, loop=None,
dialect=_dialect, pool_recycle=-1, compiled_cache=None,
**kwargs):
"""A coroutine for Engine creation.
Returns Engine instance with embedded connection pool.
The pool has *minsize* opened connections to MySQL server.
"""
deprecated_cursor_classes = [
DeserializationCursor, DictCursor, SSCursor, SSDictCursor,
]
cursorclass = kwargs.get('cursorclass', Cursor)
if not issubclass(cursorclass, Cursor) or any(
issubclass(cursorclass, cursor_class)
for cursor_class in deprecated_cursor_classes
):
raise ArgumentError('SQLAlchemy engine does not support '
'this cursor class')
coro = _create_engine(minsize=minsize, maxsize=maxsize, loop=loop,
dialect=dialect, pool_recycle=pool_recycle,
compiled_cache=compiled_cache, **kwargs)
return _EngineContextManager(coro)
async def _create_engine(minsize=1, maxsize=10, loop=None,
dialect=_dialect, pool_recycle=-1,
compiled_cache=None, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
pool = await aiomysql.create_pool(minsize=minsize, maxsize=maxsize,
loop=loop,
pool_recycle=pool_recycle, **kwargs)
conn = await pool.acquire()
try:
return Engine(dialect, pool, compiled_cache=compiled_cache, **kwargs)
finally:
pool.release(conn)
class Engine:
"""Connects a aiomysql.Pool and
sqlalchemy.engine.interfaces.Dialect together to provide a
source of database connectivity and behavior.
An Engine object is instantiated publicly using the
create_engine coroutine.
"""
def __init__(self, dialect, pool, compiled_cache=None, **kwargs):
self._dialect = dialect
self._pool = pool
self._compiled_cache = compiled_cache
self._conn_kw = kwargs
@property
def dialect(self):
"""An dialect for engine."""
return self._dialect
@property
def name(self):
"""A name of the dialect."""
return self._dialect.name
@property
def driver(self):
"""A driver of the dialect."""
return self._dialect.driver
@property
def minsize(self):
return self._pool.minsize
@property
def maxsize(self):
return self._pool.maxsize
@property
def size(self):
return self._pool.size
@property
def freesize(self):
return self._pool.freesize
def close(self):
"""Close engine.
Mark all engine connections to be closed on getting back to pool.
Closed engine doesn't allow to acquire new connections.
"""
self._pool.close()
def terminate(self):
"""Terminate engine.
Terminate engine pool with instantly closing all acquired
connections also.
"""
self._pool.terminate()
async def wait_closed(self):
"""Wait for closing all engine's connections."""
await self._pool.wait_closed()
def acquire(self):
"""Get a connection from pool."""
coro = self._acquire()
return _EngineAcquireContextManager(coro, self)
async def _acquire(self):
raw = await self._pool.acquire()
conn = SAConnection(raw, self, compiled_cache=self._compiled_cache)
return conn
def release(self, conn):
"""Revert back connection to pool."""
if conn.in_transaction:
raise InvalidRequestError("Cannot release a connection with "
"not finished transaction")
raw = conn.connection
return self._pool.release(raw)
def __enter__(self):
raise RuntimeError(
'"yield from" should be used as context manager expression')
def __exit__(self, *args):
# This must exist because __enter__ exists, even though that
# always raises; that's how the with-statement works.
pass # pragma: nocover
def __iter__(self):
# This is not a coroutine. It is meant to enable the idiom:
#
# with (yield from engine) as conn:
# <block>
#
# as an alternative to:
#
# conn = yield from engine.acquire()
# try:
# <block>
# finally:
# engine.release(conn)
conn = yield from self.acquire()
return _ConnectionContextManager(self, conn)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.close()
await self.wait_closed()
_EngineContextManager = _PoolContextManager
_EngineAcquireContextManager = _PoolAcquireContextManager
class _ConnectionContextManager:
"""Context manager.
This enables the following idiom for acquiring and releasing a
connection around a block:
with (yield from engine) as conn:
cur = yield from conn.cursor()
while failing loudly when accidentally using:
with engine:
<block>
"""
__slots__ = ('_engine', '_conn')
def __init__(self, engine, conn):
self._engine = engine
self._conn = conn
def __enter__(self):
assert self._conn is not None
return self._conn
def __exit__(self, *args):
try:
self._engine.release(self._conn)
finally:
self._engine = None
self._conn = None

View File

@@ -0,0 +1,28 @@
# ported from: https://github.com/aio-libs/aiopg/blob/master/aiopg/sa/exc.py
class Error(Exception):
"""Generic error class."""
class ArgumentError(Error):
"""Raised when an invalid or conflicting function argument is supplied.
This error generally corresponds to construction time state errors.
"""
class InvalidRequestError(ArgumentError):
"""aiomysql.sa was asked to do something it can't do.
This error generally corresponds to runtime state errors.
"""
class NoSuchColumnError(KeyError, InvalidRequestError):
"""A nonexistent column is requested from a ``RowProxy``."""
class ResourceClosedError(InvalidRequestError):
"""An operation was requested from a connection, cursor, or other
object that's in a closed state."""

View File

@@ -0,0 +1,458 @@
# ported from:
# https://github.com/aio-libs/aiopg/blob/master/aiopg/sa/result.py
import weakref
from collections.abc import Mapping, Sequence
from sqlalchemy.sql import expression, sqltypes
from . import exc
async def create_result_proxy(connection, cursor, dialect, result_map):
result_proxy = ResultProxy(connection, cursor, dialect, result_map)
await result_proxy._prepare()
return result_proxy
class RowProxy(Mapping):
__slots__ = ('_result_proxy', '_row', '_processors', '_keymap')
def __init__(self, result_proxy, row, processors, keymap):
"""RowProxy objects are constructed by ResultProxy objects."""
self._result_proxy = result_proxy
self._row = row
self._processors = processors
self._keymap = keymap
def __iter__(self):
return iter(self._result_proxy.keys)
def __len__(self):
return len(self._row)
def __getitem__(self, key):
try:
processor, obj, index = self._keymap[key]
except KeyError:
processor, obj, index = self._result_proxy._key_fallback(key)
# Do we need slicing at all? RowProxy now is Mapping not Sequence
# except TypeError:
# if isinstance(key, slice):
# l = []
# for processor, value in zip(self._processors[key],
# self._row[key]):
# if processor is None:
# l.append(value)
# else:
# l.append(processor(value))
# return tuple(l)
# else:
# raise
if index is None:
raise exc.InvalidRequestError(
"Ambiguous column name '%s' in result set! "
"try 'use_labels' option on select statement." % key)
if processor is not None:
return processor(self._row[index])
else:
return self._row[index]
def __getattr__(self, name):
try:
return self[name]
except KeyError as e:
raise AttributeError(e.args[0])
def __contains__(self, key):
return self._result_proxy._has_key(self._row, key)
__hash__ = None
def __eq__(self, other):
if isinstance(other, RowProxy):
return self.as_tuple() == other.as_tuple()
elif isinstance(other, Sequence):
return self.as_tuple() == other
else:
return NotImplemented
def __ne__(self, other):
return not self == other
def as_tuple(self):
return tuple(self[k] for k in self)
def __repr__(self):
return repr(self.as_tuple())
class ResultMetaData:
"""Handle cursor.description, applying additional info from an execution
context."""
def __init__(self, result_proxy, metadata):
self._processors = processors = []
result_map = {}
if result_proxy._result_map:
result_map = {elem[0]: elem[3] for elem in
result_proxy._result_map}
# We do not strictly need to store the processor in the key mapping,
# though it is faster in the Python version (probably because of the
# saved attribute lookup self._processors)
self._keymap = keymap = {}
self.keys = []
dialect = result_proxy.dialect
# `dbapi_type_map` property removed in SQLAlchemy 1.2+.
# Usage of `getattr` only needed for backward compatibility with
# older versions of SQLAlchemy.
typemap = getattr(dialect, 'dbapi_type_map', {})
assert dialect.case_sensitive, \
"Doesn't support case insensitive database connection"
# high precedence key values.
primary_keymap = {}
assert not dialect.description_encoding, \
"psycopg in py3k should not use this"
for i, rec in enumerate(metadata):
colname = rec[0]
coltype = rec[1]
# PostgreSQL doesn't require this.
# if dialect.requires_name_normalize:
# colname = dialect.normalize_name(colname)
name, obj, type_ = (
colname,
None,
result_map.get(
colname,
typemap.get(coltype, sqltypes.NULLTYPE))
)
processor = type_._cached_result_processor(dialect, coltype)
processors.append(processor)
rec = (processor, obj, i)
# indexes as keys. This is only needed for the Python version of
# RowProxy (the C version uses a faster path for integer indexes).
primary_keymap[i] = rec
# populate primary keymap, looking for conflicts.
if primary_keymap.setdefault(name, rec) is not rec:
# place a record that doesn't have the "index" - this
# is interpreted later as an AmbiguousColumnError,
# but only when actually accessed. Columns
# colliding by name is not a problem if those names
# aren't used; integer access is always
# unambiguous.
primary_keymap[name] = rec = (None, obj, None)
self.keys.append(colname)
if obj:
for o in obj:
keymap[o] = rec
# technically we should be doing this but we
# are saving on callcounts by not doing so.
# if keymap.setdefault(o, rec) is not rec:
# keymap[o] = (None, obj, None)
# overwrite keymap values with those of the
# high precedence keymap.
keymap.update(primary_keymap)
def _key_fallback(self, key, raiseerr=True):
map = self._keymap
result = None
if isinstance(key, str):
result = map.get(key)
# fallback for targeting a ColumnElement to a textual expression
# this is a rare use case which only occurs when matching text()
# or colummn('name') constructs to ColumnElements, or after a
# pickle/unpickle roundtrip
elif isinstance(key, expression.ColumnElement):
if (key._label and key._label in map):
result = map[key._label]
elif (hasattr(key, 'name') and key.name in map):
# match is only on name.
result = map[key.name]
# search extra hard to make sure this
# isn't a column/label name overlap.
# this check isn't currently available if the row
# was unpickled.
if (result is not None and
result[1] is not None):
for obj in result[1]:
if key._compare_name_for_result(obj):
break
else:
result = None
if result is None:
if raiseerr:
raise exc.NoSuchColumnError(
"Could not locate column in row for column '%s'" %
expression._string_or_unprintable(key))
else:
return None
else:
map[key] = result
return result
def _has_key(self, row, key):
if key in self._keymap:
return True
else:
return self._key_fallback(key, False) is not None
class ResultProxy:
"""Wraps a DB-API cursor object to provide easier access to row columns.
Individual columns may be accessed by their integer position,
case-insensitive column name, or by sqlalchemy schema.Column
object. e.g.:
row = fetchone()
col1 = row[0] # access via integer position
col2 = row['col2'] # access via name
col3 = row[mytable.c.mycol] # access via Column object.
ResultProxy also handles post-processing of result column
data using sqlalchemy TypeEngine objects, which are referenced from
the originating SQL statement that produced this result set.
"""
def __init__(self, connection, cursor, dialect, result_map):
self._dialect = dialect
self._closed = False
self._cursor = cursor
self._connection = connection
self._rowcount = cursor.rowcount
self._lastrowid = cursor.lastrowid
self._result_map = result_map
async def _prepare(self):
loop = self._connection.connection.loop
cursor = self._cursor
if cursor.description is not None:
self._metadata = ResultMetaData(self, cursor.description)
def callback(wr):
loop.create_task(cursor.close())
self._weak = weakref.ref(self, callback)
else:
self._metadata = None
await self.close()
self._weak = None
@property
def dialect(self):
"""SQLAlchemy dialect."""
return self._dialect
@property
def cursor(self):
return self._cursor
def keys(self):
"""Return the current set of string keys for rows."""
if self._metadata:
return tuple(self._metadata.keys)
else:
return ()
@property
def rowcount(self):
"""Return the 'rowcount' for this result.
The 'rowcount' reports the number of rows *matched*
by the WHERE criterion of an UPDATE or DELETE statement.
.. note::
Notes regarding .rowcount:
* This attribute returns the number of rows *matched*,
which is not necessarily the same as the number of rows
that were actually *modified* - an UPDATE statement, for example,
may have no net change on a given row if the SET values
given are the same as those present in the row already.
Such a row would be matched but not modified.
* .rowcount is *only* useful in conjunction
with an UPDATE or DELETE statement. Contrary to what the Python
DBAPI says, it does *not* return the
number of rows available from the results of a SELECT statement
as DBAPIs cannot support this functionality when rows are
unbuffered.
* Statements that use RETURNING may not return a correct
rowcount.
"""
return self._rowcount
@property
def lastrowid(self):
"""Returns the 'lastrowid' accessor on the DBAPI cursor.
This is a DBAPI specific method and is only functional
for those backends which support it, for statements
where it is appropriate.
"""
return self._lastrowid
@property
def returns_rows(self):
"""True if this ResultProxy returns rows.
I.e. if it is legal to call the methods .fetchone(),
.fetchmany() and .fetchall()`.
"""
return self._metadata is not None
@property
def closed(self):
return self._closed
async def close(self):
"""Close this ResultProxy.
Closes the underlying DBAPI cursor corresponding to the execution.
Note that any data cached within this ResultProxy is still available.
For some types of results, this may include buffered rows.
If this ResultProxy was generated from an implicit execution,
the underlying Connection will also be closed (returns the
underlying DBAPI connection to the connection pool.)
This method is called automatically when:
* all result rows are exhausted using the fetchXXX() methods.
* cursor.description is None.
"""
if not self._closed:
self._closed = True
await self._cursor.close()
# allow consistent errors
self._cursor = None
self._weak = None
# def __iter__(self):
# while True:
# row = yield from self.fetchone()
# if row is None:
# raise StopIteration
# else:
# yield row
def _non_result(self):
if self._metadata is None:
raise exc.ResourceClosedError(
"This result object does not return rows. "
"It has been closed automatically.")
else:
raise exc.ResourceClosedError("This result object is closed.")
def _process_rows(self, rows):
process_row = RowProxy
metadata = self._metadata
keymap = metadata._keymap
processors = metadata._processors
return [process_row(metadata, row, processors, keymap)
for row in rows]
async def fetchall(self):
"""Fetch all rows, just like DB-API cursor.fetchall()."""
try:
rows = await self._cursor.fetchall()
except AttributeError:
self._non_result()
else:
ret = self._process_rows(rows)
await self.close()
return ret
async def fetchone(self):
"""Fetch one row, just like DB-API cursor.fetchone().
If a row is present, the cursor remains open after this is called.
Else the cursor is automatically closed and None is returned.
"""
try:
row = await self._cursor.fetchone()
except AttributeError:
self._non_result()
else:
if row is not None:
return self._process_rows([row])[0]
else:
await self.close()
return None
async def fetchmany(self, size=None):
"""Fetch many rows, just like DB-API
cursor.fetchmany(size=cursor.arraysize).
If rows are present, the cursor remains open after this is called.
Else the cursor is automatically closed and an empty list is returned.
"""
try:
if size is None:
rows = await self._cursor.fetchmany()
else:
rows = await self._cursor.fetchmany(size)
except AttributeError:
self._non_result()
else:
ret = self._process_rows(rows)
if len(ret) == 0:
await self.close()
return ret
async def first(self):
"""Fetch the first row and then close the result set unconditionally.
Returns None if no row is present.
"""
if self._metadata is None:
self._non_result()
try:
return (await self.fetchone())
finally:
await self.close()
async def scalar(self):
"""Fetch the first column of the first row, and close the result set.
Returns None if no row is present.
"""
row = await self.first()
if row is not None:
return row[0]
else:
return None
def __aiter__(self):
return self
async def __anext__(self):
data = await self.fetchone()
if data is not None:
return data
else:
raise StopAsyncIteration # noqa

View File

@@ -0,0 +1,169 @@
# ported from:
# https://github.com/aio-libs/aiopg/blob/master/aiopg/sa/transaction.py
from . import exc
class Transaction:
"""Represent a database transaction in progress.
The Transaction object is procured by
calling the SAConnection.begin() method of
SAConnection:
with (yield from engine) as conn:
trans = yield from conn.begin()
try:
yield from conn.execute("insert into x (a, b) values (1, 2)")
except Exception:
yield from trans.rollback()
else:
yield from trans.commit()
The object provides .rollback() and .commit()
methods in order to control transaction boundaries.
See also: SAConnection.begin(), SAConnection.begin_twophase(),
SAConnection.begin_nested().
"""
def __init__(self, connection, parent):
self._connection = connection
self._parent = parent or self
self._is_active = True
@property
def is_active(self):
"""Return ``True`` if a transaction is active."""
return self._is_active
@property
def connection(self):
"""Return transaction's connection (SAConnection instance)."""
return self._connection
async def close(self):
"""Close this transaction.
If this transaction is the base transaction in a begin/commit
nesting, the transaction will rollback(). Otherwise, the
method returns.
This is used to cancel a Transaction without affecting the scope of
an enclosing transaction.
"""
if not self._parent._is_active:
return
if self._parent is self:
await self.rollback()
else:
self._is_active = False
async def rollback(self):
"""Roll back this transaction."""
if not self._parent._is_active:
return
await self._do_rollback()
self._is_active = False
async def _do_rollback(self):
await self._parent.rollback()
async def commit(self):
"""Commit this transaction."""
if not self._parent._is_active:
raise exc.InvalidRequestError("This transaction is inactive")
await self._do_commit()
self._is_active = False
async def _do_commit(self):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
await self.rollback()
else:
if self._is_active:
await self.commit()
class RootTransaction(Transaction):
def __init__(self, connection):
super().__init__(connection, None)
async def _do_rollback(self):
await self._connection._rollback_impl()
async def _do_commit(self):
await self._connection._commit_impl()
class NestedTransaction(Transaction):
"""Represent a 'nested', or SAVEPOINT transaction.
A new NestedTransaction object may be procured
using the SAConnection.begin_nested() method.
The interface is the same as that of Transaction class.
"""
_savepoint = None
def __init__(self, connection, parent):
super().__init__(connection, parent)
async def _do_rollback(self):
assert self._savepoint is not None, "Broken transaction logic"
if self._is_active:
await self._connection._rollback_to_savepoint_impl(
self._savepoint, self._parent)
async def _do_commit(self):
assert self._savepoint is not None, "Broken transaction logic"
if self._is_active:
await self._connection._release_savepoint_impl(
self._savepoint, self._parent)
class TwoPhaseTransaction(Transaction):
"""Represent a two-phase transaction.
A new TwoPhaseTransaction object may be procured
using the SAConnection.begin_twophase() method.
The interface is the same as that of Transaction class
with the addition of the .prepare() method.
"""
def __init__(self, connection, xid):
super().__init__(connection, None)
self._is_prepared = False
self._xid = xid
@property
def xid(self):
"""Returns twophase transaction id."""
return self._xid
async def prepare(self):
"""Prepare this TwoPhaseTransaction.
After a PREPARE, the transaction can be committed.
"""
if not self._parent.is_active:
raise exc.InvalidRequestError("This transaction is inactive")
await self._connection._prepare_twophase_impl(self._xid)
self._is_prepared = True
async def _do_rollback(self):
await self._connection.rollback_prepared(
self._xid, is_prepared=self._is_prepared)
async def _do_commit(self):
await self._connection.commit_prepared(
self._xid, is_prepared=self._is_prepared)

View File

@@ -0,0 +1,187 @@
from collections.abc import Coroutine
import struct
def _pack_int24(n):
return struct.pack("<I", n)[:3]
def _lenenc_int(i):
if i < 0:
raise ValueError(
"Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i
)
elif i < 0xFB:
return bytes([i])
elif i < (1 << 16):
return b"\xfc" + struct.pack("<H", i)
elif i < (1 << 24):
return b"\xfd" + struct.pack("<I", i)[:3]
elif i < (1 << 64):
return b"\xfe" + struct.pack("<Q", i)
else:
raise ValueError(
"Encoding %x is larger than %x - no representation in LengthEncodedInteger"
% (i, (1 << 64))
)
class _ContextManager(Coroutine):
__slots__ = ('_coro', '_obj')
def __init__(self, coro):
self._coro = coro
self._obj = None
def send(self, value):
return self._coro.send(value)
def throw(self, typ, val=None, tb=None):
if val is None:
return self._coro.throw(typ)
elif tb is None:
return self._coro.throw(typ, val)
else:
return self._coro.throw(typ, val, tb)
def close(self):
return self._coro.close()
@property
def gi_frame(self):
return self._coro.gi_frame
@property
def gi_running(self):
return self._coro.gi_running
@property
def gi_code(self):
return self._coro.gi_code
def __next__(self):
return self.send(None)
def __iter__(self):
return self._coro.__await__()
def __await__(self):
return self._coro.__await__()
async def __aenter__(self):
self._obj = await self._coro
return self._obj
async def __aexit__(self, exc_type, exc, tb):
await self._obj.close()
self._obj = None
class _ConnectionContextManager(_ContextManager):
async def __aexit__(self, exc_type, exc, tb):
if exc_type is not None:
self._obj.close()
else:
await self._obj.ensure_closed()
self._obj = None
class _PoolContextManager(_ContextManager):
async def __aexit__(self, exc_type, exc, tb):
self._obj.close()
await self._obj.wait_closed()
self._obj = None
class _SAConnectionContextManager(_ContextManager):
def __aiter__(self):
return self
async def __anext__(self):
if self._obj is None:
self._obj = await self._coro
try:
return await self._obj.__anext__()
except StopAsyncIteration:
await self._obj.close()
self._obj = None
raise
class _TransactionContextManager(_ContextManager):
async def __aexit__(self, exc_type, exc, tb):
if exc_type:
await self._obj.rollback()
else:
if self._obj.is_active:
await self._obj.commit()
self._obj = None
class _PoolAcquireContextManager(_ContextManager):
__slots__ = ('_coro', '_conn', '_pool')
def __init__(self, coro, pool):
self._coro = coro
self._conn = None
self._pool = pool
async def __aenter__(self):
self._conn = await self._coro
return self._conn
async def __aexit__(self, exc_type, exc, tb):
try:
await self._pool.release(self._conn)
finally:
self._pool = None
self._conn = None
class _PoolConnectionContextManager:
"""Context manager.
This enables the following idiom for acquiring and releasing a
connection around a block:
with (yield from pool) as conn:
cur = yield from conn.cursor()
while failing loudly when accidentally using:
with pool:
<block>
"""
__slots__ = ('_pool', '_conn')
def __init__(self, pool, conn):
self._pool = pool
self._conn = conn
def __enter__(self):
assert self._conn
return self._conn
def __exit__(self, exc_type, exc_val, exc_tb):
try:
self._pool.release(self._conn)
finally:
self._pool = None
self._conn = None
async def __aenter__(self):
assert not self._conn
self._conn = await self._pool.acquire()
return self._conn
async def __aexit__(self, exc_type, exc_val, exc_tb):
try:
await self._pool.release(self._conn)
finally:
self._pool = None
self._conn = None