# encoding=utf-8 import base64 import calendar import contextlib import datetime import hmac import json import os import random import subprocess from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.orm.session import make_transient from sqlalchemy.orm.util import object_state from sqlalchemy.sql.expression import insert, select, delete, exists from sqlalchemy.sql.functions import func from sqlalchemy.sql.schema import Column, ForeignKey from sqlalchemy.sql.sqltypes import String, LargeBinary, Float, Boolean, Integer, \ DateTime from sqlalchemy.sql.type_api import TypeDecorator from terroroftinytown.client import VERSION from terroroftinytown.client.alphabet import str_to_int, int_to_str from terroroftinytown.tracker.errors import NoItemAvailable, FullClaim, UpdateClient, \ InvalidClaim, NoResourcesAvailable from terroroftinytown.tracker.stats import Stats # These overrides for major api changes MIN_VERSION_OVERRIDE = 55 # for terroroftinytown.client MIN_CLIENT_VERSION_OVERRIDE = 7 # for terrofoftinytown-client-grab/pipeline.py DEADMAN_MAX_ERROR_REPORTS = 4000 DEADMAN_MAX_RESULTS = 40000000 Base = declarative_base() Session = sessionmaker() @contextlib.contextmanager def new_session(): session = Session() try: yield session session.commit() except: session.rollback() raise finally: session.close() class JsonType(TypeDecorator): impl = String def process_bind_param(self, value, engine): return json.dumps(value) def process_result_value(self, value, engine): if value: return json.loads(value) else: return None class GlobalSetting(Base): __tablename__ = 'global_settings' key = Column(String, primary_key=True) value = Column(JsonType) AUTO_DELETE_ERROR_REPORTS = 'auto_delete_error_reports' @classmethod def set_value(cls, key, value): with new_session() as session: setting = session.query(GlobalSetting).filter_by(key=key).first() if setting: setting.value = value else: setting = GlobalSetting(key=key, value=value) session.add(setting) @classmethod def get_value(cls, key): with new_session() as session: setting = session.query(GlobalSetting).filter_by(key=key).first() if setting: return setting.value class User(Base): '''User accounts that manager the tracker.''' __tablename__ = 'users' username = Column(String, primary_key=True) salt = Column(LargeBinary, nullable=False) hash = Column(LargeBinary, nullable=False) def set_password(self, password): self.salt = new_salt() self.hash = make_hash(password, self.salt) def check_password(self, password): test_hash = make_hash(password, self.salt) return compare_digest(self.hash, test_hash) def get_token(self): return make_hash(self.username, self.salt) def check_token(self, test_token): token = self.get_token() return compare_digest(token, test_token) @classmethod def no_users_exist(cls): with new_session() as session: user = session.query(User).first() return user is None @classmethod def is_user_exists(cls, username): with new_session() as session: user = session.query(User).filter_by(username=username).first() return user is not None @classmethod def all_usernames(cls): with new_session() as session: users = session.query(User.username) return list([user.username for user in users]) @classmethod def save_new_user(cls, username, password): with new_session() as session: user = User(username=username) user.set_password(password) session.add(user) @classmethod def check_account(cls, username, password): with new_session() as session: user = session.query(User).filter_by(username=username).first() if user: return user.check_password(password) @classmethod def update_password(cls, username, password): with new_session() as session: user = session.query(User).filter_by(username=username).first() user.set_password(password) @classmethod def delete_user(cls, username): with new_session() as session: session.query(User).filter_by(username=username).delete() @classmethod def get_user_token(cls, username): with new_session() as session: return session.query(User).filter_by(username=username)\ .first().get_token() @classmethod def check_account_session(cls, username, token): with new_session() as session: user = session.query(User).filter_by(username=username).first() if not user: return return user.check_token(token) class Project(Base): '''Project settings.''' __tablename__ = 'projects' name = Column(String, primary_key=True) min_version = Column(Integer, default=VERSION, nullable=False) min_client_version = Column(Integer, default=MIN_CLIENT_VERSION_OVERRIDE, nullable=False) alphabet = Column(String, default='0123456789abcdefghijklmnopqrstuvwxyz' 'ABCDEFGHIJKLMNOPQRSTUVWXYZ', nullable=False) url_template = Column(String, default='http://example.com/{shortcode}', nullable=False) request_delay = Column(Float, default=0.5, nullable=False) redirect_codes = Column(JsonType, default=[301, 302, 303, 307], nullable=False) no_redirect_codes = Column(JsonType, default=[404], nullable=False) unavailable_codes = Column(JsonType, default=[200]) banned_codes = Column(JsonType, default=[403, 420, 429]) body_regex = Column(String) location_anti_regex = Column(String) method = Column(String, default='head', nullable=False) enabled = Column(Boolean, default=False) autoqueue = Column(Boolean, default=False) num_count_per_item = Column(Integer, default=50, nullable=False) max_num_items = Column(Integer, default=100, nullable=False) lower_sequence_num = Column(Integer, default=0, nullable=False) autorelease_time = Column(Integer, default=60 * 30) def to_dict(self, with_shortcode=False): ans = {x.key:x.value for x in object_state(self).attrs} if with_shortcode: ans['lower_shortcode'] = self.lower_shortcode() return ans def lower_shortcode(self): return int_to_str(self.lower_sequence_num, self.alphabet) @classmethod def all_project_names(cls): with new_session() as session: projects = session.query(Project.name) return list([project.name for project in projects]) @classmethod def all_project_infos(cls): with new_session() as session: projects = session.query(Project) return list([project.to_dict(with_shortcode=True) for project in projects]) @classmethod def new_project(cls, name): with new_session() as session: project = Project(name=name) session.add(project) @classmethod def get_plain(cls, name): with new_session() as session: project = session.query(Project).filter_by(name=name).first() make_transient(project) return project @classmethod @contextlib.contextmanager def get_session_object(cls, name): with new_session() as session: project = session.query(Project).filter_by(name=name).first() yield project @classmethod def delete_project(cls, name): # FIXME: need to cascade the deletes with new_session() as session: session.query(Project).filter_by(name=name).delete() class Item(Base): __tablename__ = 'items' id = Column(Integer, primary_key=True) project_id = Column(Integer, ForeignKey('projects.name'), nullable=False) project = relationship('Project') lower_sequence_num = Column(Integer, nullable=False) upper_sequence_num = Column(Integer, nullable=False) datetime_claimed = Column(DateTime) tamper_key = Column(String) username = Column(String) ip_address = Column(String) def to_dict(self, with_shortcode=False): ans = {x.key:x.value for x in object_state(self).attrs} ans.update({ 'project': self.project.to_dict(), 'datetime_claimed': calendar.timegm(self.datetime_claimed.utctimetuple()) if self.datetime_claimed else None, }) if with_shortcode: ans['lower_shortcode'] = int_to_str(self.lower_sequence_num, self.project.alphabet) ans['upper_shortcode'] = int_to_str(self.upper_sequence_num, self.project.alphabet) return ans @classmethod def get_items(cls, project_id): with new_session() as session: rows = session.query(Item).filter_by(project_id=project_id).order_by(Item.datetime_claimed) return list([item.to_dict(with_shortcode=True) for item in rows]) @classmethod def add_items(cls, project_id, sequence_list): with new_session() as session: query = insert(Item) query_args = [] for lower_num, upper_num in sequence_list: query_args.append({ 'project_id': project_id, 'lower_sequence_num': lower_num, 'upper_sequence_num': upper_num, }) session.execute(query, query_args) @classmethod def delete(cls, item_id): with new_session() as session: session.query(Item).filter_by(id=item_id).delete() @classmethod def release(cls, item_id): with new_session() as session: item = session.query(Item).filter_by(id=item_id).first() item.datetime_claimed = None item.ip_address = None item.username = None @classmethod def release_all(cls, project_id=None, old_date=None): with new_session() as session: query = session.query(Item) if project_id: query = query.filter_by(project_id=project_id) if old_date: query = query.filter(Item.datetime_claimed <= old_date) query.update({ 'datetime_claimed': None, 'ip_address': None, 'username': None, }) @classmethod def release_old(cls, project_id=None, autoqueue_only=False): with new_session() as session: # we could probably write this in one query # but it would be non-portable across SQL dialects projects = session.query(Project) \ .filter(Project.autorelease_time > 0) if project_id: projects = projects.filter_by(name=project_id) if autoqueue_only: projects = projects.filter_by(autoqueue=True) for project in projects: min_time = datetime.datetime.utcnow() - datetime.timedelta(seconds=project.autorelease_time) query = session.query(Item) \ .filter(Item.datetime_claimed <= min_time, Item.project == project) query.update({ 'datetime_claimed': None, 'ip_address': None, 'username': None, }) @classmethod def delete_all(cls, project_id): with new_session() as session: session.query(Item).filter_by(project_id=project_id).delete() class BlockedUser(Base): '''Blocked IP addresses or usernames.''' __tablename__ = 'blocked_users' username = Column(String, primary_key=True) note = Column(String) @classmethod def block_username(cls, username, note=None): with new_session() as session: session.add(BlockedUser(username=username, note=note)) @classmethod def unblock_username(cls, username): with new_session() as session: session.query(BlockedUser).filter_by(username=username).delete() @classmethod def is_username_blocked(cls, *username): with new_session() as session: query = select([BlockedUser.username])\ .where(BlockedUser.username.in_(username)) result = session.execute(query).first() if result: return True @classmethod def all_blocked_usernames(cls): with new_session() as session: names = session.query(BlockedUser.username) return list([row[0] for row in names]) class Result(Base): '''Unshortend URL.''' __tablename__ = 'results' id = Column(Integer, primary_key=True) project_id = Column(Integer, ForeignKey('projects.name'), nullable=False, index=True) project = relationship('Project') shortcode = Column(String, nullable=False) url = Column(String, nullable=False) encoding = Column(String, nullable=False) datetime = Column(DateTime) @classmethod def has_results(cls): with new_session() as session: result = session.query(Result.id).first() return bool(result) @classmethod def get_count(cls): with new_session() as session: return (session.query(func.max(Result.id)).scalar() or 0) \ - (session.query(func.min(Result.id)).scalar() or 0) @classmethod def get_results(cls, offset_id=0, limit=1000, project_id=None): with new_session() as session: if int(offset_id) == 0: offset_id = session.query(func.max(Result.id)).scalar() or 0 rows = session.query( Result.id, Result.project_id, Result.shortcode, Result.url, Result.encoding, Result.datetime ) \ .filter(Result.id <= int(offset_id)) if project_id is not None and project_id != 'None': rows = rows.filter(Result.project_id == project_id) alphabet = Project.get_plain(project_id).alphabet else: alphabet = None rows = rows.order_by(Result.id.desc()).limit(int(limit)) for row in rows: ans = { 'id': row[0], 'project_id': row[1], 'shortcode': row[2], 'url': row[3], 'encoding': row[4], 'datetime': row[5] } if alphabet: ans['seq_num'] = str_to_int(row[2], alphabet) yield ans class ErrorReport(Base): '''Error report.''' __tablename__ = 'error_reports' id = Column(Integer, primary_key=True) item_id = Column(Integer, ForeignKey('items.id'), nullable=False) item = relationship('Item') message = Column(String, nullable=False) datetime = Column(DateTime, nullable=False, default=datetime.datetime.utcnow) def to_dict(self): ans = {x.key:x.value for x in object_state(self).attrs} ans.update({ 'project': self.item.project_id if self.item else None, }) return ans @classmethod def get_count(cls): with new_session() as session: min_id = session.query(func.min(ErrorReport.id)).scalar() or 0 max_id = session.query(func.max(ErrorReport.id)).scalar() or 0 return max_id - min_id @classmethod def all_reports(cls, limit=100, offset_id=None, project_id=None): with new_session() as session: reports = session.query(ErrorReport) if offset_id: reports = reports.filter(ErrorReport.id > offset_id) if project_id is not None and project_id != 'None': reports = reports.join(Item).filter(Item.project_id == project_id) reports = reports.limit(limit) return list(report.to_dict() for report in reports) @classmethod def delete_all(cls): with new_session() as session: session.query(ErrorReport.id).delete() @classmethod def delete_one(cls, report_id): with new_session() as session: query = delete(ErrorReport).where(ErrorReport.id == report_id) session.execute(query) @classmethod def delete_orphaned(cls): with new_session() as session: subquery = select([ErrorReport.id])\ .where(ErrorReport.item_id == Item.id)\ .limit(1) query = delete(ErrorReport).where(~exists(subquery)) session.execute(query) class Budget(object): '''Budget calculator to help manage available items. Warning: This class assumes the application is single instance. ''' projects = {} @classmethod def calculate_budgets(cls): cls.projects = {} with new_session() as session: query = session.query( Project.name, Project.max_num_items, Project.min_client_version, Project.min_version, Project.max_num_items ).filter_by(enabled=True) for row in query: (name, max_num_items, min_client_version, min_version, max_num_items) = row cls.projects[name] = { 'max_num_items': max_num_items, 'min_client_version': min_client_version, 'min_version': min_version, 'items': 0, 'claims': 0, 'ip_addresses': set(), } query = session.query(Item.project_id, Item.ip_address) for row in query: project_id, ip_address = row if project_id not in cls.projects: continue project_info = cls.projects[project_id] project_info['items'] += 1 if ip_address: project_info['ip_addresses'].add(ip_address) project_info['claims'] += 1 @classmethod def get_available_project(cls, ip_address, version, client_version): project_names = list(cls.projects.keys()) random.shuffle(project_names) for project_id in project_names: project_info = cls.projects[project_id] if ip_address not in project_info['ip_addresses'] and \ version >= project_info['min_version'] and \ client_version >= project_info['min_client_version'] and \ project_info['claims'] <= project_info['items'] and \ project_info['claims'] < project_info['max_num_items']: return (project_id, project_info['claims'], project_info['items'], project_info['max_num_items']) @classmethod def is_client_outdated(cls, version, client_version): if not cls.projects: return max_version = max(project['min_version'] for project in cls.projects.values()) max_client_version = max(project['min_client_version'] for project in cls.projects.values()) if version < max_version or client_version < max_client_version: return max_version, max_client_version @classmethod def is_claims_full(cls, ip_address): return cls.projects and all(ip_address in project['ip_addresses'] for project in cls.projects.values()) @classmethod def check_out(cls, project_id, ip_address, new_item=False): assert project_id assert ip_address project_info = cls.projects[project_id] project_info['claims'] += 1 if new_item: project_info['items'] += 1 project_info['ip_addresses'].add(ip_address) @classmethod def check_in(cls, project_id, ip_address): assert project_id assert ip_address if project_id not in cls.projects: # Project was recently disabled but the job hasn't come back # yet. Should be safe to ignore. return project_info = cls.projects[project_id] project_info['claims'] -= 1 project_info['items'] -= 1 project_info['ip_addresses'].remove(ip_address) def make_hash(plaintext, salt): key = salt msg = plaintext.encode('ascii') # Yes, I know MD5 is bad but it was the silent default at the time return hmac.new(key, msg, digestmod='MD5').digest() def new_salt(): return os.urandom(16) def new_tamper_key(): return base64.b16encode(os.urandom(16)).decode('ascii') def deadman_checks(): if ErrorReport.get_count() > DEADMAN_MAX_ERROR_REPORTS: return '