Verified Commit 458bf98e authored by Benjamin "Ziirish" SANS's avatar Benjamin "Ziirish" SANS
Browse files

limit the number of records to treat by each task in order to limit CPU usage (#310)

parent ae96896f
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@
.. moduleauthor:: Ziirish <hi+burpui@ziirish.me>

"""
# Agent do not need "real" HTTP errors
# Agent does not need "real" HTTP errors
try:
    from werkzeug.exceptions import HTTPException
    WERKZEUG = True
@@ -29,3 +29,8 @@ class BUIserverException(HTTPException):

    def __str__(self):
        return self.description


class TooManyRecordsException(Exception):
    """Raised when there are too many records to treat."""
    pass
+37 −7
Original line number Diff line number Diff line
@@ -135,17 +135,38 @@ class SessionManager(object):
        # don't need to store it since it is not managed anyway
        return True

    def delete_session(self):
    def delete_session(self, commit=True):
        """Remove the session"""
        self.delete_session_by_id(getattr(session, 'sid', None))
        self.delete_session_by_id(getattr(session, 'sid', None), commit)

    def delete_session_by_id(self, id):
    def delete_session_by_id(self, id, commit=True):
        """Remove a session by id"""
        if self.session_managed():
            from .ext.sql import db
            from .models import Session
            try:
                Session.query.filter_by(uuid=id).delete()
                if commit:
                    db.session.commit()
            except:
                if commit:
                    db.session.rollback()

    def bulk_session_delete_by_id(self, bucket):
        """Remove all sessions matching the bucket IDs"""
        if self.session_managed():
            from .ext.sql import db
            from .models import Session
            try:
                Session.query.filter(Session.uuid.in_(bucket)).delete(synchronize_session=False)
                db.session.commit()
            except:
                db.session.rollback()

    def commit(self):
        if self.session_managed():
            from .ext.sql import db
            try:
                db.session.commit()
            except:
                db.session.rollback()
@@ -169,17 +190,26 @@ class SessionManager(object):
            return getattr(session, 'sid', str(uuid.uuid4()))
        return None

    def get_expired_sessions(self):
    def get_expired_sessions(self, maxret=-1, count=False):
        """Return all expired sessions"""
        if self.session_managed():
            from .models import Session
            inactive = self.app.config['SESSION_INACTIVE']
            if inactive and inactive.days > 0:
                limit = datetime.datetime.utcnow() - inactive
                return Session.query.filter(
                query = Session.query.filter(
                    Session.timestamp <= limit
                ).all()
        return []
                )
                if count:
                    return query.count()
                if maxret < 0:
                    return query.all()
                else:
                    return query.limit(maxret)
        return [] if not count else 0

    def get_expired_sessions_count(self):
        return self.get_expired_sessions(count=True)

    def get_user_sessions(self, user):
        """Return all sessions of a given user"""
+43 −17
Original line number Diff line number Diff line
@@ -17,20 +17,18 @@ from celery.schedules import crontab
from celery.utils.log import get_task_logger
from time import gmtime, strftime, sleep

# Try to load modules from our current env first
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), '..'))

from burpui._compat import to_unicode  # noqa
from burpui.config import config  # noqa
from burpui.ext.tasks import celery  # noqa
from burpui.ext.cache import cache  # noqa
from burpui.sessions import session_manager  # noqa
from burpui.engines.server import BUIServer  # noqa
from burpui.exceptions import BUIserverException  # noqa
from burpui.exceptions import BUIserverException, TooManyRecordsException  # noqa
from burpui.api.client import ClientTreeAll  # noqa
from burpui.utils import NOTIF_ERROR

try:
    from .ext.ws import socketio
    from burpui.ext.ws import socketio
    WS_AVAILABLE = True
except ImportError:
    WS_AVAILABLE = False
@@ -187,20 +185,36 @@ def get_all_clients_reports(self):
        release_lock(self.name)


@celery.task(ignore_result=True)
@celery.task(ignore_result=True, max_retries=5,
             autoretry_for=(TooManyRecordsException,), retry_backoff=4)
def cleanup_expired_sessions():
    bucket = []

    def expires(sess):
        ret = session_manager.invalidate_session_by_id(sess.uuid)
        if ret:
            session_manager.delete_session_by_id(sess.uuid)
            bucket.append(sess.uuid)
        return ret
    list(map(expires, session_manager.get_expired_sessions()))

    # remove expired sessions, limit to 10000 per batch
    list(map(expires, session_manager.get_expired_sessions(10000)))
    session_manager.bulk_session_delete_by_id(bucket)

    # if we still have more than 10000 expired session, schedule a new cleanup soon
    # unless we already ran 5 successive cleanups (in which case we will just wait
    # for the next schedule to trigger the task).
    if session_manager.get_expired_sessions_count() >= 10000:
        # we raise an exception so celery knows it has to restart this tasks
        # anytime soon.
        raise TooManyRecordsException

@celery.task(ignore_result=True)

@celery.task(ignore_result=True, max_retries=3,
             autoretry_for=(TooManyRecordsException,), retry_backoff=True)
def cleanup_restore():
    tasks = db.session.query(Task).filter(Task.task == 'perform_restore').filter(datetime.utcnow() > Task.expire).all()
    for rec in tasks:
    bucket = []
    query = Task.query.filter(Task.task == 'perform_restore', Task.expire <= datetime.utcnow())
    for rec in query.limit(100):
        logger.info('Task expired: {}'.format(rec))
        task = perform_restore.AsyncResult(rec.uuid)
        try:
@@ -223,12 +237,15 @@ def cleanup_restore():
                    if os.path.isfile(path):
                        os.unlink(path)
        finally:
            bucket.append(rec.uuid)
            task.revoke()
    try:
                db.session.delete(rec)
        Task.query.filter(Task.uuid.in_(bucket)).delete(synchronize_session=False)
        db.session.commit()
    except:
        db.session.rollback()
            task.revoke()
    if query.count() > 100:
        raise TooManyRecordsException


@celery.task(bind=True)
@@ -236,7 +253,8 @@ def perform_restore(self, client, backup,
                    files, strip, fmt, passwd, server=None, user=None,
                    admin=False, room=None, expire=timedelta(minutes=60)):
    ret = None
    lock_name = '{}-{}'.format(self.name, server)
    # we can have only one restore per server-client at the time
    lock_name = '{}-{}-{}'.format(self.name, server, client)

    # TODO: maybe do something with old_lock someday
    wait_for(lock_name, self.request.id)
@@ -300,6 +318,14 @@ def perform_restore(self, client, backup,
    return ret


@celery.task(bind=True)
def delete_client(self, client, keepconf, delcert, revoke, template, delete, server):
    parser = bui.client.get_parser(agent=server)
    self.update_state(state='STARTED', meta={'step': 'doing'})

    ret = parser.remove_client(client, keepconf, delcert, revoke, template, delete)


@celery.task(bind=True)
def load_all_tree(self, client, backup, server=None, user=None):
    key = 'load_all_tree-{}-{}-{}'.format(client, backup, server)