Source code for django_mysql.test.utils

from __future__ import annotations

import uuid
from functools import wraps
from types import TracebackType
from typing import Any

from django.db import connections
from django.db.utils import DEFAULT_DB_ALIAS


[docs] class override_mysql_variables: """ Based on Django's override_settings, but for connection settings. Give a connection alias in using and variable=value pairs to save on that connection. Keeps the old values MySQL-side using session variables of the form @overridden_X. Acts as either a decorator, or a context manager. If it's a decorator it takes a function and returns a wrapped function. If it's a contextmanager it's used with the ``with`` statement. In either event entering/exiting are called before and after, respectively, the function/block is executed. """ def __init__(self, using: str = DEFAULT_DB_ALIAS, **options: str | int) -> None: self.db = using self.options = options self.prefix: str = uuid.uuid1().hex.replace("-", "")[:16] def __enter__(self) -> None: self.enable() def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_traceback: TracebackType | None, ) -> None: self.disable() def __call__(self, test_func: Any) -> Any: from unittest import TestCase if isinstance(test_func, type): if not issubclass(test_func, TestCase): raise TypeError( "{} only works with TestCase classes.".format( self.__class__.__name__ ) ) self.wrap_class(test_func) return test_func else: @wraps(test_func) def inner(*args: Any, **kwargs: Any) -> Any: with self: return test_func(*args, **kwargs) return inner def wrap_class(self, klass: type[Any]) -> None: kwargs: dict[str, Any] = {"using": self.db, **self.options} for name in dir(klass): if not name.startswith("test_"): continue method = getattr(klass, name) # Reconstruct self over and over on each method wrapped = self.__class__(**kwargs)(method) setattr(klass, name, wrapped) def enable(self) -> None: with connections[self.db].cursor() as cursor: for key, value in self.options.items(): cursor.execute( """SET @overridden_{prefix}_{name} = @@{name}, @@{name} = %s """.format( prefix=self.prefix, name=key ), (value,), ) def disable(self) -> None: with connections[self.db].cursor() as cursor: for key in self.options: cursor.execute( """SET @@{name} = @overridden_{prefix}_{name}, @overridden_{prefix}_{name} = NULL """.format( name=key, prefix=self.prefix ) )