import os
import subprocess
import time
from collections import defaultdict
from queue import Empty, Queue
from threading import Lock, Thread
from weakref import WeakKeyDictionary

import django
from django.db import DEFAULT_DB_ALIAS
from django.db import connection as default_connection
from django.db import connections

class WeightedAverageRate:
    Adapted from percona-toolkit - provides a weighted average counter to keep
    at a certain rate of activity (row iterations etc.).

    def __init__(self, target_t, weight=0.75):
        target_t - Target time for t in update()
        weight - Weight of previous n/t values
        self.target_t = target_t
        self.avg_n = 0.0
        self.avg_t = 0.0
        self.weight = weight

    def update(self, n, t):
        Update weighted average rate.  Param n is generic; it's how many of
        whatever the caller is doing (rows, checksums, etc.).  Param s is how
        long this n took, in seconds (hi-res or not).

            n - Number of operations (rows, etc.)
            t - Amount of time in seconds that n took

            n adjusted to meet target_t based on weighted decaying avg rate
        if self.avg_n and self.avg_t:
            self.avg_n = (self.avg_n * self.weight) + n
            self.avg_t = (self.avg_t * self.weight) + t
            self.avg_n = n
            self.avg_t = t

        new_n = int(self.avg_rate * self.target_t)
        return new_n

    def avg_rate(self):
            return self.avg_n / self.avg_t
        except ZeroDivisionError:
            # Assume a small amount of time, not 0
            return self.avg_n / 0.001

class StopWatch:
    Context manager for timing a block

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, *args, **kwargs):
        self.end_time = time.time()
        self.total_time = self.end_time - self.start_time

def format_duration(total_seconds):
    hours = total_seconds // 3600
    minutes = (total_seconds % 3600) // 60
    seconds = total_seconds % 60

    out = []
    if hours > 0:
        out.extend([str(hours), "h"])
    if hours or minutes:
        out.extend([str(minutes), "m"])
    out.extend([str(seconds), "s"])
    return "".join(out)

if django.VERSION >= (3, 0):

    def connection_is_mariadb(connection):
        return connection.vendor == "mysql" and connection.mysql_is_mariadb


    _is_mariadb_cache = WeakKeyDictionary()

[docs] def connection_is_mariadb(connection): if connection.vendor != "mysql": return False if connection is default_connection: connection = connections[DEFAULT_DB_ALIAS] try: return _is_mariadb_cache[connection] except KeyError: with connection.temporary_connection(): server_info = connection.connection.get_server_info() is_mariadb = "MariaDB" in server_info _is_mariadb_cache[connection] = is_mariadb return is_mariadb
def settings_to_cmd_args(settings_dict): """ Copied from django 1.8 MySQL backend DatabaseClient - where the runshell commandline creation has been extracted and made callable like so. """ args = ["mysql"] db = settings_dict["OPTIONS"].get("db", settings_dict["NAME"]) user = settings_dict["OPTIONS"].get("user", settings_dict["USER"]) passwd = settings_dict["OPTIONS"].get("passwd", settings_dict["PASSWORD"]) host = settings_dict["OPTIONS"].get("host", settings_dict["HOST"]) port = settings_dict["OPTIONS"].get("port", settings_dict["PORT"]) cert = settings_dict["OPTIONS"].get("ssl", {}).get("ca") defaults_file = settings_dict["OPTIONS"].get("read_default_file") # Seems to be no good way to set sql_mode with CLI. if defaults_file: args += ["--defaults-file=%s" % defaults_file] if user: args += ["--user=%s" % user] if passwd: args += ["--password=%s" % passwd] if host: if "/" in host: args += ["--socket=%s" % host] else: args += ["--host=%s" % host] if port: args += ["--port=%s" % port] if cert: args += ["--ssl-ca=%s" % cert] if db: args += [db] return args programs_memo = {} def have_program(program_name): global programs_memo if program_name not in programs_memo: status =["which", program_name], stdout=subprocess.PIPE) programs_memo[program_name] = status == 0 return programs_memo[program_name]
[docs]def pt_fingerprint(query): """ Takes a query (in a string) and returns its 'fingerprint' """ if not have_program("pt-fingerprint"): # pragma: no cover raise OSError("pt-fingerprint doesn't appear to be installed") thread = PTFingerprintThread.get_thread() thread.in_queue.put(query) return thread.out_queue.get()
class PTFingerprintThread(Thread): """ Class for a singleton background thread to pass queries to pt-fingerprint and get their fingerprints back. This is done because the process launch time is relatively expensive and it's useful to be able to fingerprinting queries quickly. The get_thread() class method returns the singleton thread - either instantiating it or returning the existing one. The thread launches pt-fingerprint with subprocess and then takes queries from an input queue, passes them the subprocess and returns the fingerprint to an output queue. If it receives no queries in PROCESS_LIFETIME seconds, it closes the subprocess and itself - so you don't have processes hanging around. """ the_thread = None life_lock = Lock() PROCESS_LIFETIME = 60.0 # seconds @classmethod def get_thread(cls): with cls.life_lock: if cls.the_thread is None: in_queue = Queue() out_queue = Queue() thread = cls(in_queue, out_queue) thread.daemon = True thread.in_queue = in_queue thread.out_queue = out_queue thread.start() cls.the_thread = thread return cls.the_thread def __init__(self, in_queue, out_queue, **kwargs): self.in_queue = in_queue self.out_queue = out_queue super().__init__(**kwargs) def run(self): # pty is unix/linux only import pty # noqa global fingerprint_thread master, slave = pty.openpty() proc = subprocess.Popen( ["pt-fingerprint"], stdin=subprocess.PIPE, stdout=slave, close_fds=True ) stdin = proc.stdin stdout = os.fdopen(master) while True: try: query = self.in_queue.get(timeout=self.PROCESS_LIFETIME) except Empty: self.life_lock.acquire() # We timed out, but there was something put into the queue # since if ( self.__class__.the_thread is self and self.in_queue.qsize() ): # pragma: no cover self.life_lock.release() break # Die break stdin.write(query.encode("utf-8")) if not query.endswith(";"): stdin.write(b";") stdin.write(b"\n") stdin.flush() fingerprint = stdout.readline() self.out_queue.put(fingerprint.strip()) stdin.close() self.__class__.the_thread = None self.life_lock.release() def collapse_spaces(string): bits = string.replace("\n", " ").split(" ") return " ".join(filter(None, bits)) def index_name(model, *field_names, **kwargs): """ Returns the name of the index existing on field_names, or raises KeyError if no such index exists. """ if not len(field_names): raise ValueError("At least one field name required") using = kwargs.pop("using", DEFAULT_DB_ALIAS) if len(kwargs): raise ValueError("The only supported keyword argument is 'using'") existing_fields = { field for field in model._meta.fields} fields = [existing_fields[name] for name in field_names if name in existing_fields] if len(fields) != len(field_names): unfound_names = set(field_names) - { for field in fields} raise ValueError("Fields do not exist: " + ",".join(unfound_names)) column_names = tuple(field.column for field in fields) list_sql = get_list_sql(column_names) with connections[using].cursor() as cursor: cursor.execute( """SELECT `INDEX_NAME`, `SEQ_IN_INDEX`, `COLUMN_NAME` FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s AND COLUMN_NAME IN {list_sql} ORDER BY `INDEX_NAME`, `SEQ_IN_INDEX` ASC """.format( list_sql=list_sql ), (model._meta.db_table,) + column_names, ) indexes = defaultdict(list) for index_name, _, column_name in cursor.fetchall(): indexes[index_name].append(column_name) indexes_by_columns = {tuple(v): k for k, v in indexes.items()} try: return indexes_by_columns[column_names] except KeyError: raise KeyError("There is no index on (" + ",".join(field_names) + ")") def get_list_sql(sequence): return "({})".format(",".join("%s" for x in sequence)) def mysql_connections(): conn_names = [DEFAULT_DB_ALIAS] + list(set(connections) - {DEFAULT_DB_ALIAS}) for alias in conn_names: connection = connections[alias] if connection.vendor != "mysql": continue yield alias, connection