Source code for my_application.repositories.CompetitionRepository


# Copyright 2020 Nedeljko Radulovic, Dihia Boulegane, Albert Bifet
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, DateTime, ForeignKey, func
from sqlalchemy.orm import relationship
from sqlalchemy import UniqueConstraint
from werkzeug.security import generate_password_hash, check_password_hash
from sqlalchemy.types import Integer, String, Boolean
from datetime import datetime
from sqlalchemy import and_

_BASE = declarative_base()


[docs]class Competition(_BASE): __tablename__ = 'competition' competition_id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(64)) datastream_id = Column(Integer, ForeignKey('datastream.datastream_id')) initial_batch_size = Column(Integer) initial_training_time = Column(Integer) batch_size = Column(Integer) time_interval = Column(Integer) start_date = Column(DateTime) end_date = Column(DateTime) predictions_time_interval = Column(Integer) target_class = Column(String(32)) file_path = Column(String(255)) description = Column(String(255)) code = Column(String(10)) __table_args__ = (UniqueConstraint('name'), ) datastream = relationship("Datastream", back_populates="competitions") def __init__(self, competition_id, name, datastream_id, initial_batch_size, initial_training_time, batch_size, time_interval, start_date, end_date, target_class, file_path, predictions_time_interval, description, code): """ Construct a class for Competition table. :param competition_id: :param name: :param datastream_id: :param initial_batch_size: :param initial_training_time: :param batch_size: :param time_interval: :param start_date: :param end_date: :param target_class: :param file_path: :param predictions_time_interval: :param description: :param code: Competition code """ self.competition_id = competition_id self.name = name self.datastream_id = datastream_id self.initial_batch_size = initial_batch_size self.initial_training_time = initial_training_time self.batch_size = batch_size self.time_interval = time_interval self.start_date = start_date self.end_date = end_date self.target_class = target_class self.file_path = file_path self.predictions_time_interval = predictions_time_interval self.description = description self.code = code
[docs] def serialize(self): # print (self.target_class) return { 'competition_id': self.competition_id, 'name': self.name, 'datastream_id': self.datastream_id, 'initial_batch_size': self.initial_batch_size, 'initial_training_time': self.initial_training_time, 'batch_size': self.batch_size, 'time_interval': self.time_interval, 'start_date': self.start_date, 'end_date': self.end_date, 'target_class': self.target_class, 'predictions_time_interval': self.predictions_time_interval, 'description': self.description, 'code': self.code }
[docs]class Datastream(_BASE): __tablename__ = 'datastream' datastream_id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(64)) file_path = Column(String(255)) description = Column(String(255)) competitions = relationship("Competition", back_populates='datastream', lazy='dynamic') __table_args__ = (UniqueConstraint('name'),) def __init__(self, datastream_id, name, description, file_path): """ Construct a class for Datastream table. :param datastream_id: :param name: :param description: :param file_path: """ self.datastream_id = datastream_id self.name = name self.description = description self.file_path = file_path
[docs] def serialize(self): return {'datastream_id': self.datastream_id, 'name': self.name, 'description': self.description}
[docs]class User(_BASE): __tablename__ = 'USERS' user_id = Column(Integer, primary_key=True, autoincrement=True) first_name = Column(String(32)) last_name = Column(String(32)) email = Column(String(32)) password_hash = Column(String(256)) role = Column(String(32)) confirmed = Column(Boolean, nullable=False, default=False) confirmed_on = Column(DateTime, nullable=True) __table_args__ = (UniqueConstraint('email'),) def __init__(self, user_id, first_name, last_name, email, password, role, confirmed, confirmed_on): """ Construct the class for USERS table. :param user_id: :param first_name: :param last_name: :param email: :param password: :param role: :param confirmed: :param confirmed_on: """ self.user_id = user_id self.first_name = first_name self.last_name = last_name self.email = email self.role = role self.set_password(password) self.confirmed = confirmed self.confirmed_on = confirmed_on
[docs] def set_password(self, password): self.password_hash = generate_password_hash(password)
[docs] def check_password(self, password): return check_password_hash(self.password_hash, password)
[docs] def serialize(self): return { 'user_id': self.user_id, 'firstName': self.first_name, 'lastName': self.last_name, 'email': self.email, 'role': self.role }
[docs]class Subscription(_BASE): __tablename__ = 'subscriptions' subscription_id = Column('id', Integer, primary_key=True, autoincrement=True) competition_id = Column(Integer, ForeignKey('competition.competition_id')) user_id = Column(Integer, ForeignKey('USERS.user_id')) time_create = Column(DateTime, nullable=False, default=func.now()) competition = relationship("Competition", backref="memberships") user = relationship("User", backref="memberships") __table_args__ = (UniqueConstraint('competition_id', 'user_id'),) def __init__(self, subscription_id, competition_id, user_id): """ Construct the class for Subscriptions table. :param subscription_id: :param competition_id: :param user_id: """ self.subscription_id = subscription_id self.competition_id = competition_id self.user_id = user_id
[docs]class BaseRepository(): """ Repository base class. Implements the methods to write or delete rows in the table. -------------------------------------------------------------------- insert_one(): inserts one row in the table insert_many(): insets multiple rows in the table delete_one(): deletes one row from the table """ instance = None def __init__(self, host, dbname): self.instance = None self.engine = create_engine(host + dbname) self.sessionmaker = sessionmaker() self.sessionmaker.configure(bind=self.engine) self.Base = _BASE self.Base.metadata.create_all(self.engine) if not self.instance: self.session = self.sessionmaker()
[docs] def insert_one(self, row): try: self.session.add(row) self.session.commit() except Exception as e: print(e) self.session.rollback()
[docs] def insert_many(self, rows): try: for row in rows: self.session.add(row) self.session.commit() except Exception as e: print(e) self.session.rollback()
[docs] def delete_one(self, row): try: self.session.delete(row) self.session.commit() except Exception as e: print(e) self.session.rollback()
[docs] def cleanup(self): self.session.close() self.engine.dispose()
[docs]class CompetitionRepository(BaseRepository): """ Competition repository class. Implements the methods to retrieve competitions by different condititons. --------------------------------------------------------------------------- get_competition_by_id(): Retrieve competition by its ID set_competition_code(): Update the competition code get_all_competitions(): Retrieve all competition get_competition_by_code(): Retrieve the competition by code get_competitions_by_user(): Retrieve competition for a given user """ def __init__(self, host, dbname): BaseRepository.__init__(self, host, dbname)
[docs] def get_competition_by_id(self, competition_id): results = self.session.query(Competition).filter_by(competition_id=competition_id) try: return results[0] except Exception as e: pass
[docs] def set_competition_code(self, competition_id, code): self.session.query(Competition).filter_by(competition_id=competition_id).update({"code": code}) self.session.commit()
[docs] def get_all_competitions(self, status=None, page=None, step=None): now = datetime.now() results = None if status == 'all': try: results = self.session.query(Competition) except Exception: self.session.rollback() elif status == 'active': try: results = self.session.query(Competition).filter(and_(Competition.end_date > now, Competition.start_date < now)) except Exception: self.session.rollback() elif status == 'coming': try: results = self.session.query(Competition).filter(Competition.start_date >= now) except Exception: self.session.rollback() elif status == 'finished': try: results = self.session.query(Competition).filter(Competition.end_date <= now) except Exception: self.session.rollback() else: raise ValueError('Unknown type ' + status) if results is not None: copy = results if step: results = results.limit(step) if page: results = results.offset((page - 1) * step) data = [] for r in results: row = {'name': r.name, 'id': r.competition_id, 'description': r.description, 'start_date': r.start_date.strftime("%Y-%m-%d %H:%M"), 'end_date': r.end_date.strftime("%Y-%m-%d %H:%M")} data.append(row) return {'data': data, 'total': copy.count()} else: return {'data': [], 'total': 0}
[docs] def get_competition_by_code(self, code): results = self.session.query(Competition).filter_by(code=code) try: return results[0] except Exception as e: return None
[docs] def get_competitions_by_user(self, user_id, status, page, step): sub_query = None results = None try: sub_query = self.session.query(Subscription).filter_by(user_id=user_id).subquery() except Exception: self.session.rollback() now = datetime.now() if status == 'all': try: results = self.session.query(Competition).join(sub_query, sub_query.c.competition_id == Competition.competition_id) except Exception: self.session.rollback() elif status == 'active': try: results = self.session.query(Competition).filter(and_(Competition.end_date > now, Competition.start_date < now)).join(sub_query, sub_query.c.competition_id == Competition.competition_id) except Exception: self.session.rollback() elif status == 'coming': try: results = self.session.query(Competition).filter(Competition.start_date >= now).join(sub_query, sub_query.c.competition_id == Competition.competition_id) except Exception: self.session.rollback() elif status == 'finished': try: results = self.session.query(Competition).filter(Competition.end_date <= now).join(sub_query, sub_query.c.competition_id == Competition.competition_id) except Exception: self.session.rollback() else: raise ValueError('Unknown type ' + status) if results is not None: copy = results if step: results = results.limit(step) if page: results = results.offset((page - 1) * step) data = [] for r in results: row = {'name': r.name, 'id': r.competition_id, 'description': r.description, 'start_date': r.start_date.strftime("%Y-%m-%d %H:%M"), 'end_date': r.end_date.strftime("%Y-%m-%d %H:%M")} data.append(row) return {'data': data, 'total': copy.count()} else: return {'data': [], 'total': 0}
[docs]class DatastreamRepository(BaseRepository): """ Competition repository class. Implements the methods to retrieve competitions by different condititons. --------------------------------------------------------------------------- get_competition_by_id(): Retrieve competition by its ID set_competition_code(): Update the competition code get_all_competitions(): Retrieve all competition get_competition_by_code(): Retrieve the competition by code get_competitions_by_user(): Retrieve competition for a given user """ def __init__(self, host, dbname): BaseRepository.__init__(self, host, dbname)
[docs] def get_datastream_by_id(self, datastream_id): results = None try: results = self.session.query(Datastream).filter_by(datastream_id=datastream_id).first() except Exception: self.session.rollback() if not results: return None else: return results
[docs] def get_all_datastreams(self, page=None, step=None): results = None try: results = self.session.query(Datastream) except Exception: self.session.rollback() if results is not None: copy = results if step: results = results.limit(step) if page: results = results.offset((page - 1) * step) data = [r.serialize() for r in results] return {'data': data, 'total': copy.count()} else: return {'data': [], 'total': 0}
[docs]class UserRepository(BaseRepository): """ User repository class. Implements the methods to query the USERS table. --------------------------------------------------------------------------- get_user_by_id(): Retrieve user information based on his ID get_user_by_email(): Retrieve user information based on his email get_all_users(): List all users delete_many(): Delete several users at the same time confirm_user(): Manually confirm user's registration """ def __init__(self, host, dbname): BaseRepository.__init__(self, host, dbname)
[docs] def get_user_by_id(self, id): results = None try: results = self.session.query(User).filter_by(user_id=id).first() except Exception: self.session.rollback() return results
[docs] def get_user_by_email(self, email): results = None try: results = self.session.query(User).filter_by(email=email).first() except Exception: self.session.rollback() return results
[docs] def get_all_users(self): results = None try: results = self.session.query(User) except Exception: self.session.rollback() return results
[docs] def delete_many(self, users): self.session.query(User).delete().where(User.email.in_(users)) self.session.flush() self.session.commit()
[docs] def confirm_user(self, user): user.confirmed = True user.confirmed_on = datetime.now() self.session.commit()
[docs]class SubscriptionRepository(BaseRepository): """ Subscription repository class. Implements the methods to retrieve subscriptions by different condititons. --------------------------------------------------------------------------- get_competition_subscribers(): Get users that subscribed to the competition check_subscription(): Check if a given user is subscribed to a given competition get_subscription(): Retrieve the subscription for a user """ def __init__(self, host, dbname): BaseRepository.__init__(self, host, dbname)
[docs] def get_competition_subscribers(self, competition_id): users = None try: users = self.session.query(Subscription).filter_by(competition_id=competition_id) except Exception: self.session.rollback() return users
[docs] def check_subscription(self, competition_id, user_id): subscribed = None try: s = self.session.query(Subscription).filter_by(competition_id=competition_id, user_id=user_id) subscribed = False if len(list(s)) == 0 else True except Exception: self.session.rollback() return subscribed
[docs] def get_subscription(self, competition_id, user_id): s = None try: s = self.session.query(Subscription).filter_by(competition_id=competition_id, user_id=user_id).first() except Exception: self.session.rollback() return s