|
106 | 106 | from typing import Collection
|
107 | 107 |
|
108 | 108 | import psycopg
|
| 109 | +from psycopg import AsyncCursor as pg_async_cursor |
109 | 110 | from psycopg import Cursor as pg_cursor # pylint: disable=no-name-in-module
|
110 | 111 | from psycopg.sql import Composed # pylint: disable=no-name-in-module
|
111 | 112 |
|
@@ -151,9 +152,36 @@ def _instrument(self, **kwargs):
|
151 | 152 | commenter_options=commenter_options,
|
152 | 153 | )
|
153 | 154 |
|
| 155 | + dbapi.wrap_connect( |
| 156 | + __name__, |
| 157 | + psycopg.Connection, |
| 158 | + "connect", |
| 159 | + self._DATABASE_SYSTEM, |
| 160 | + self._CONNECTION_ATTRIBUTES, |
| 161 | + version=__version__, |
| 162 | + tracer_provider=tracer_provider, |
| 163 | + db_api_integration_factory=DatabaseApiIntegration, |
| 164 | + enable_commenter=enable_sqlcommenter, |
| 165 | + commenter_options=commenter_options, |
| 166 | + ) |
| 167 | + dbapi.wrap_connect( |
| 168 | + __name__, |
| 169 | + psycopg.AsyncConnection, |
| 170 | + "connect", |
| 171 | + self._DATABASE_SYSTEM, |
| 172 | + self._CONNECTION_ATTRIBUTES, |
| 173 | + version=__version__, |
| 174 | + tracer_provider=tracer_provider, |
| 175 | + db_api_integration_factory=DatabaseApiAsyncIntegration, |
| 176 | + enable_commenter=enable_sqlcommenter, |
| 177 | + commenter_options=commenter_options, |
| 178 | + ) |
| 179 | + |
154 | 180 | def _uninstrument(self, **kwargs):
|
155 | 181 | """ "Disable Psycopg instrumentation"""
|
156 | 182 | dbapi.unwrap_connect(psycopg, "connect")
|
| 183 | + dbapi.unwrap_connect(psycopg.Connection, "connect") |
| 184 | + dbapi.unwrap_connect(psycopg.AsyncConnection, "connect") |
157 | 185 |
|
158 | 186 | # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
|
159 | 187 | @staticmethod
|
@@ -204,6 +232,26 @@ def wrapped_connection(
|
204 | 232 | return connection
|
205 | 233 |
|
206 | 234 |
|
| 235 | +class DatabaseApiAsyncIntegration(dbapi.DatabaseApiIntegration): |
| 236 | + async def wrapped_connection( |
| 237 | + self, |
| 238 | + connect_method: typing.Callable[..., typing.Any], |
| 239 | + args: typing.Tuple[typing.Any, typing.Any], |
| 240 | + kwargs: typing.Dict[typing.Any, typing.Any], |
| 241 | + ): |
| 242 | + """Add object proxy to connection object.""" |
| 243 | + base_cursor_factory = kwargs.pop("cursor_factory", None) |
| 244 | + new_factory_kwargs = {"db_api": self} |
| 245 | + if base_cursor_factory: |
| 246 | + new_factory_kwargs["base_factory"] = base_cursor_factory |
| 247 | + kwargs["cursor_factory"] = _new_cursor_async_factory( |
| 248 | + **new_factory_kwargs |
| 249 | + ) |
| 250 | + connection = await connect_method(*args, **kwargs) |
| 251 | + self.get_connection_attributes(connection) |
| 252 | + return connection |
| 253 | + |
| 254 | + |
207 | 255 | class CursorTracer(dbapi.CursorTracer):
|
208 | 256 | def get_operation_name(self, cursor, args):
|
209 | 257 | if not args:
|
@@ -259,3 +307,36 @@ def callproc(self, *args, **kwargs):
|
259 | 307 | )
|
260 | 308 |
|
261 | 309 | return TracedCursorFactory
|
| 310 | + |
| 311 | + |
| 312 | +def _new_cursor_async_factory( |
| 313 | + db_api=None, base_factory=None, tracer_provider=None |
| 314 | +): |
| 315 | + if not db_api: |
| 316 | + db_api = DatabaseApiAsyncIntegration( |
| 317 | + __name__, |
| 318 | + Psycopg3Instrumentor._DATABASE_SYSTEM, |
| 319 | + connection_attributes=Psycopg3Instrumentor._CONNECTION_ATTRIBUTES, |
| 320 | + version=__version__, |
| 321 | + tracer_provider=tracer_provider, |
| 322 | + ) |
| 323 | + base_factory = base_factory or pg_async_cursor |
| 324 | + _cursor_tracer = CursorTracer(db_api) |
| 325 | + |
| 326 | + class TracedCursorAsyncFactory(base_factory): |
| 327 | + async def execute(self, *args, **kwargs): |
| 328 | + return await _cursor_tracer.traced_execution( |
| 329 | + self, super().execute, *args, **kwargs |
| 330 | + ) |
| 331 | + |
| 332 | + async def executemany(self, *args, **kwargs): |
| 333 | + return await _cursor_tracer.traced_execution( |
| 334 | + self, super().executemany, *args, **kwargs |
| 335 | + ) |
| 336 | + |
| 337 | + async def callproc(self, *args, **kwargs): |
| 338 | + return await _cursor_tracer.traced_execution( |
| 339 | + self, super().callproc, *args, **kwargs |
| 340 | + ) |
| 341 | + |
| 342 | + return TracedCursorAsyncFactory |
0 commit comments