diff --git a/aepsych/benchmark/pathos_benchmark.py b/aepsych/benchmark/pathos_benchmark.py index f5e4a6643..ab8ae90ae 100644 --- a/aepsych/benchmark/pathos_benchmark.py +++ b/aepsych/benchmark/pathos_benchmark.py @@ -25,7 +25,7 @@ ctx._force_start_method("spawn") # fixes problems with CUDA and fork -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() class PathosBenchmark(Benchmark): diff --git a/aepsych/database/db.py b/aepsych/database/db.py index c9c9cc65d..9bf16bd05 100644 --- a/aepsych/database/db.py +++ b/aepsych/database/db.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import datetime +import io import json import logging import os @@ -19,7 +20,7 @@ from aepsych.config import Config from aepsych.strategy import Strategy from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm.session import close_all_sessions logger = logging.getLogger() @@ -45,33 +46,22 @@ def __init__(self, db_path: Optional[str] = None, update: bool = True) -> None: else: logger.info(f"No DB found at {db_path}, creating a new DB!") - self._engine = self.get_engine() + self._full_db_path = Path(self._db_dir) + self._full_db_path.mkdir(parents=True, exist_ok=True) + self._full_db_path = self._full_db_path.joinpath(self._db_name) - if update and self.is_update_required(): - self.perform_updates() - - def get_engine(self) -> sessionmaker: - """Get the engine for the database. - - Returns: - sessionmaker: The sessionmaker object for the database. - """ - if not hasattr(self, "_engine") or self._engine is None: - self._full_db_path = Path(self._db_dir) - self._full_db_path.mkdir(parents=True, exist_ok=True) - self._full_db_path = self._full_db_path.joinpath(self._db_name) - - self._engine = create_engine(f"sqlite:///{self._full_db_path.as_posix()}") + self._engine = create_engine(f"sqlite:///{self._full_db_path.as_posix()}") - # create the table metadata and tables - tables.Base.metadata.create_all(self._engine) + # create the table metadata and tables + tables.Base.metadata.create_all(self._engine) - # create an ongoing session to be used. Provides a conduit - # to the db so the instantiated objects work properly. - Session = sessionmaker(bind=self.get_engine()) - self._session = Session() + # Create a session to be start and closed on each use + self.session = scoped_session( + sessionmaker(bind=self._engine, expire_on_commit=False) + ) - return self._engine + if update and self.is_update_required(): + self.perform_updates() def delete_db(self) -> None: """Delete the database.""" @@ -106,21 +96,6 @@ def perform_updates(self) -> None: tables.DbParamTable.update(self._engine) tables.DbOutcomeTable.update(self._engine) - @contextmanager - def session_scope(self): - """Provide a transactional scope around a series of operations.""" - Session = sessionmaker(bind=self.get_engine()) - session = Session() - try: - yield session - session.commit() - except Exception as err: - logger.error(f"db session use failed: {err}") - session.rollback() - raise - finally: - session.close() - # @retry(stop_max_attempt_number=8, wait_exponential_multiplier=1.8) def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]: """Execute an arbitrary query written in sql. @@ -132,7 +107,7 @@ def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]: Returns: List[Any]: The results of the query. """ - with self.session_scope() as session: + with self.session() as session: return session.execute(query, vals).all() def get_master_records(self) -> List[tables.DBMasterTable]: @@ -141,7 +116,8 @@ def get_master_records(self) -> List[tables.DBMasterTable]: Returns: List[tables.DBMasterTable]: The list of master records. """ - records = self._session.query(tables.DBMasterTable).all() + with self.session() as session: + records = session.query(tables.DBMasterTable).all() return records def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]: @@ -153,11 +129,12 @@ def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]: Returns: tables.DBMasterTable or None: The master record or None if it doesn't exist. """ - records = ( - self._session.query(tables.DBMasterTable) - .filter(tables.DBMasterTable.unique_id == master_id) - .all() - ) + with self.session() as session: + records = ( + session.query(tables.DBMasterTable) + .filter(tables.DBMasterTable.unique_id == master_id) + .all() + ) if 0 < len(records): return records[0] @@ -259,11 +236,7 @@ def get_params_for(self, master_id: int) -> List[List[tables.DbParamTable]]: raw_record = self.get_raw_for(master_id) if raw_record is not None: - return [ - rec.children_param - for rec in self.get_raw_for(master_id) - if rec is not None - ] + return [raw.children_param for raw in raw_record] return [] @@ -282,14 +255,19 @@ def get_outcomes_for(self, master_id: int) -> List[List[tables.DbParamTable]]: raw_record = self.get_raw_for(master_id) if raw_record is not None: - return [ - rec.children_outcome - for rec in self.get_raw_for(master_id) - if rec is not None - ] + return [raw.children_outcome for raw in raw_record] return [] + @staticmethod + def _add_commit(session, obj): + # Helps guarantee duplicated objects across session can still be written + merged = session.merge(obj) + session.add(merged) + session.commit() + session.refresh(merged) + return merged + def record_setup( self, description: str = None, @@ -312,34 +290,36 @@ def record_setup( Returns: str: The experiment id. """ - self.get_engine() - - master_table = tables.DBMasterTable() - master_table.experiment_description = description - master_table.experiment_name = name - master_table.experiment_id = exp_id if exp_id is not None else str(uuid.uuid4()) - master_table.participant_id = ( - par_id if par_id is not None else str(uuid.uuid4()) - ) - master_table.extra_metadata = extra_metadata - self._session.add(master_table) + with self.session() as session: + master_table = tables.DBMasterTable() + master_table.experiment_description = description + master_table.experiment_name = name + master_table.experiment_id = ( + exp_id if exp_id is not None else str(uuid.uuid4()) + ) + master_table.participant_id = ( + par_id if par_id is not None else str(uuid.uuid4()) + ) + master_table.extra_metadata = extra_metadata + + master_table = self._add_commit(session, master_table) - logger.debug(f"record_setup = [{master_table}]") + logger.debug(f"record_setup = [{master_table}]") - record = tables.DbReplayTable() - record.message_type = "setup" - record.message_contents = request + record = tables.DbReplayTable() + record.message_type = "setup" + record.message_contents = request - if request is not None and "extra_info" in request: - record.extra_info = request["extra_info"] + if request is not None and "extra_info" in request: + record.extra_info = request["extra_info"] - record.timestamp = datetime.datetime.now() - record.parent = master_table - logger.debug(f"record_setup = [{record}]") + record.timestamp = datetime.datetime.now() + record.parent = master_table + logger.debug(f"replay_record_setup = [{record}]") - self._session.add(record) - self._session.commit() + self._add_commit(session, record) + master_table # return the master table if it has a link to the list of child rows # tis needs to be passed into all future calls to link properly return master_table @@ -354,19 +334,19 @@ def record_message( type (str): The type of the message. request (Dict[str, Any]): The request. """ - # create a linked setup table - record = tables.DbReplayTable() - record.message_type = type - record.message_contents = request + with self.session() as session: + # create a linked setup table + record = tables.DbReplayTable() + record.message_type = type + record.message_contents = request - if "extra_info" in request: - record.extra_info = request["extra_info"] + if "extra_info" in request: + record.extra_info = request["extra_info"] - record.timestamp = datetime.datetime.now() - record.parent = master_table + record.timestamp = datetime.datetime.now() + record.parent = master_table - self._session.add(record) - self._session.commit() + self._add_commit(session, record) def record_raw( self, @@ -386,19 +366,19 @@ def record_raw( Returns: tables.DbRawTable: The raw entry. """ - raw_entry = tables.DbRawTable() - raw_entry.model_data = model_data + with self.session() as session: + raw_entry = tables.DbRawTable() + raw_entry.model_data = model_data - if timestamp is None: - raw_entry.timestamp = datetime.datetime.now() - else: - raw_entry.timestamp = timestamp - raw_entry.parent = master_table + if timestamp is None: + raw_entry.timestamp = datetime.datetime.now() + else: + raw_entry.timestamp = timestamp + raw_entry.parent = master_table - raw_entry.extra_data = json.dumps(extra_data) + raw_entry.extra_data = json.dumps(extra_data) - self._session.add(raw_entry) - self._session.commit() + raw_entry = self._add_commit(session, raw_entry) return raw_entry @@ -412,14 +392,14 @@ def record_param( param_name (str): The parameter name. param_value (str): The parameter value. """ - param_entry = tables.DbParamTable() - param_entry.param_name = param_name - param_entry.param_value = param_value + with self.session() as session: + param_entry = tables.DbParamTable() + param_entry.param_name = param_name + param_entry.param_value = param_value - param_entry.parent = raw_table + param_entry.parent = raw_table - self._session.add(param_entry) - self._session.commit() + self._add_commit(session, param_entry) def record_outcome( self, raw_table: tables.DbRawTable, outcome_name: str, outcome_value: float @@ -431,29 +411,31 @@ def record_outcome( outcome_name (str): The outcome name. outcome_value (float): The outcome value. """ - outcome_entry = tables.DbOutcomeTable() - outcome_entry.outcome_name = outcome_name - outcome_entry.outcome_value = outcome_value + with self.session() as session: + outcome_entry = tables.DbOutcomeTable() + outcome_entry.outcome_name = outcome_name + outcome_entry.outcome_value = outcome_value - outcome_entry.parent = raw_table + outcome_entry.parent = raw_table - self._session.add(outcome_entry) - self._session.commit() + self._add_commit(session, outcome_entry) - def record_strat(self, master_table: tables.DBMasterTable, strat: Strategy) -> None: + def record_strat( + self, master_table: tables.DBMasterTable, strat: io.BytesIO + ) -> None: """Record a strategy in the database. Args: master_table (tables.DBMasterTable): The master table. - strat (Strategy): The strategy. + strat (BytesIO): The strategy in buffer form. """ - strat_entry = tables.DbStratTable() - strat_entry.strat = strat - strat_entry.timestamp = datetime.datetime.now() - strat_entry.parent = master_table + with self.session() as session: + strat_entry = tables.DbStratTable() + strat_entry.strat = strat + strat_entry.timestamp = datetime.datetime.now() + strat_entry.parent = master_table - self._session.add(strat_entry) - self._session.commit() + self._add_commit(session, strat_entry) def record_config(self, master_table: tables.DBMasterTable, config: Config) -> None: """Record a config in the database. @@ -462,13 +444,13 @@ def record_config(self, master_table: tables.DBMasterTable, config: Config) -> N master_table (tables.DBMasterTable): The master table. config (Config): The config. """ - config_entry = tables.DbConfigTable() - config_entry.config = config - config_entry.timestamp = datetime.datetime.now() - config_entry.parent = master_table + with self.session() as session: + config_entry = tables.DbConfigTable() + config_entry.config = config + config_entry.timestamp = datetime.datetime.now() + config_entry.parent = master_table - self._session.add(config_entry) - self._session.commit() + self._add_commit(session, config_entry) def summarize_experiments(self) -> pd.DataFrame: """Provides a summary of the experiments contained in the database as a pandas dataframe. diff --git a/aepsych/database/tables.py b/aepsych/database/tables.py index ca0087516..d38038b77 100644 --- a/aepsych/database/tables.py +++ b/aepsych/database/tables.py @@ -49,10 +49,18 @@ class DBMasterTable(Base): extra_metadata = Column(String(4096)) # JSON-formatted metadata - children_replay = relationship("DbReplayTable", back_populates="parent") - children_strat = relationship("DbStratTable", back_populates="parent") - children_config = relationship("DbConfigTable", back_populates="parent") - children_raw = relationship("DbRawTable", back_populates="parent") + children_replay = relationship( + "DbReplayTable", lazy="selectin", join_depth=1, back_populates="parent" + ) + children_strat = relationship( + "DbStratTable", lazy="selectin", join_depth=1, back_populates="parent" + ) + children_config = relationship( + "DbConfigTable", lazy="selectin", join_depth=1, back_populates="parent" + ) + children_raw = relationship( + "DbRawTable", lazy="selectin", join_depth=1, back_populates="parent" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DBMasterTable": @@ -185,7 +193,9 @@ class DbReplayTable(Base): extra_info = Column(PickleType(pickler=pickle)) master_table_id = Column(Integer, ForeignKey("master.unique_id")) - parent = relationship("DBMasterTable", back_populates="children_replay") + parent = relationship( + "DBMasterTable", lazy="joined", join_depth=1, back_populates="children_replay" + ) __mapper_args__ = {} @@ -297,7 +307,9 @@ class DbStratTable(Base): strat = Column(PickleType(pickler=pickle)) master_table_id = Column(Integer, ForeignKey("master.unique_id")) - parent = relationship("DBMasterTable", back_populates="children_strat") + parent = relationship( + "DBMasterTable", lazy="joined", join_depth=1, back_populates="children_strat" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbStratTable": @@ -356,7 +368,9 @@ class DbConfigTable(Base): config = Column(PickleType(pickler=pickle)) master_table_id = Column(Integer, ForeignKey("master.unique_id")) - parent = relationship("DBMasterTable", back_populates="children_config") + parent = relationship( + "DBMasterTable", lazy="joined", join_depth=1, back_populates="children_config" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbConfigTable": @@ -420,9 +434,15 @@ class DbRawTable(Base): extra_data = Column(PickleType(pickler=pickle)) master_table_id = Column(Integer, ForeignKey("master.unique_id")) - parent = relationship("DBMasterTable", back_populates="children_raw") - children_param = relationship("DbParamTable", back_populates="parent") - children_outcome = relationship("DbOutcomeTable", back_populates="parent") + parent = relationship( + "DBMasterTable", lazy="joined", join_depth=1, back_populates="children_raw" + ) + children_param = relationship( + "DbParamTable", lazy="joined", join_depth=1, back_populates="parent" + ) + children_outcome = relationship( + "DbOutcomeTable", lazy="joined", join_depth=1, back_populates="parent" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbRawTable": @@ -527,6 +547,9 @@ def update(db: Any, engine: Engine) -> None: param_value=float(param_value), ) + # Refresh the raw + db_raw_record = db.get_raw_for(master_table.unique_id)[-1] + if isinstance(outcomes, Iterable) and type(outcomes) != str: for j, outcome_value in enumerate(outcomes): if ( @@ -551,23 +574,25 @@ def update(db: Any, engine: Engine) -> None: outcome_value=float(outcomes), ) else: # Raws are already in, so we just need to update it - for master_table in db.get_master_records(): - unique_id = master_table.unique_id - raws = db.get_raw_for(unique_id) - tells = [ - message - for message in db.get_replay_for(unique_id) - if message.message_type == "tell" - ] - - if len(raws) == len(tells): - for raw, tell in zip(raws, tells): - if tell.extra_info is not None and len(tell.extra_info) > 0: - raw.extra_data = tell.extra_info - else: - logger.warning( - f"Tried to update raw table for experiment unique ID {unique_id}, but the number of tells and raws were not the same." - ) + with db.session() as session: + for master_table in db.get_master_records(): + unique_id = master_table.unique_id + raws = db.get_raw_for(unique_id) + tells = [ + message + for message in db.get_replay_for(unique_id) + if message.message_type == "tell" + ] + + if len(raws) == len(tells): + for raw, tell in zip(raws, tells): + if tell.extra_info is not None and len(tell.extra_info) > 0: + raw.extra_data = tell.extra_info + db._add_commit(session, raw) + else: + logger.warning( + f"Tried to update raw table for experiment unique ID {unique_id}, but the number of tells and raws were not the same." + ) @staticmethod def requires_update(engine: Engine) -> bool: @@ -654,7 +679,9 @@ class DbParamTable(Base): param_value = Column(String(50)) iteration_id = Column(Integer, ForeignKey("raw_data.unique_id")) - parent = relationship("DbRawTable", back_populates="children_param") + parent = relationship( + "DbRawTable", lazy="immediate", join_depth=1, back_populates="children_param" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbParamTable": @@ -720,7 +747,9 @@ class DbOutcomeTable(Base): outcome_value = Column(Float) iteration_id = Column(Integer, ForeignKey("raw_data.unique_id")) - parent = relationship("DbRawTable", back_populates="children_outcome") + parent = relationship( + "DbRawTable", lazy="immediate", join_depth=1, back_populates="children_outcome" + ) @classmethod def from_sqlite(cls, row: Dict[str, Any]) -> "DbOutcomeTable": diff --git a/aepsych/server/__init__.py b/aepsych/server/__init__.py index ae3552278..7783362c9 100644 --- a/aepsych/server/__init__.py +++ b/aepsych/server/__init__.py @@ -5,6 +5,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from .server import AEPsychServer +from .server import AEPsychBackgroundServer, AEPsychServer -__all__ = ["AEPsychServer"] +__all__ = ["AEPsychServer", "AEPsychBackgroundServer"] diff --git a/aepsych/server/message_handlers/handle_ask.py b/aepsych/server/message_handlers/handle_ask.py index 7f1a21389..ae2ec4e80 100644 --- a/aepsych/server/message_handlers/handle_ask.py +++ b/aepsych/server/message_handlers/handle_ask.py @@ -10,7 +10,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_ask(server, request): @@ -18,7 +18,7 @@ def handle_ask(server, request): "config" -- dictionary with config (keys are strings, values are floats) "is_finished" -- bool, true if the strat is finished """ - logger.debug("got ask message!") + logger.info("got ask message!") if server._pregen_asks: params = server._pregen_asks.pop() else: diff --git a/aepsych/server/message_handlers/handle_can_model.py b/aepsych/server/message_handlers/handle_can_model.py index 32fa2fb18..94e9de6d4 100644 --- a/aepsych/server/message_handlers/handle_can_model.py +++ b/aepsych/server/message_handlers/handle_can_model.py @@ -9,13 +9,13 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_can_model(server, request): # Check if the strategy has finished initialization; i.e., # if it has a model and data to fit (strat.can_fit) - logger.debug("got can_model message!") + logger.info("got can_model message!") if not server.is_performing_replay: server.db.record_message( master_table=server._db_master_record, type="can_model", request=request diff --git a/aepsych/server/message_handlers/handle_exit.py b/aepsych/server/message_handlers/handle_exit.py index b654558b3..c73f6ad13 100644 --- a/aepsych/server/message_handlers/handle_exit.py +++ b/aepsych/server/message_handlers/handle_exit.py @@ -9,7 +9,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_exit(server, request): diff --git a/aepsych/server/message_handlers/handle_get_config.py b/aepsych/server/message_handlers/handle_get_config.py index 1a347dbae..0f186860a 100644 --- a/aepsych/server/message_handlers/handle_get_config.py +++ b/aepsych/server/message_handlers/handle_get_config.py @@ -8,7 +8,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_get_config(server, request): diff --git a/aepsych/server/message_handlers/handle_info.py b/aepsych/server/message_handlers/handle_info.py index 910aac720..99251ac82 100644 --- a/aepsych/server/message_handlers/handle_info.py +++ b/aepsych/server/message_handlers/handle_info.py @@ -10,7 +10,7 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_info(server, request: Dict[str, Any]) -> Dict[str, Any]: @@ -22,7 +22,7 @@ def handle_info(server, request: Dict[str, Any]) -> Dict[str, Any]: Returns: Dict[str, Any]: Returns dictionary containing the current state of the experiment """ - logger.debug("got info message!") + logger.info("got info message!") ret_val = info(server) diff --git a/aepsych/server/message_handlers/handle_params.py b/aepsych/server/message_handlers/handle_params.py index 525ae2daf..ad4b7181d 100644 --- a/aepsych/server/message_handlers/handle_params.py +++ b/aepsych/server/message_handlers/handle_params.py @@ -9,11 +9,11 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_params(server, request): - logger.debug("got parameters message!") + logger.info("got parameters message!") if not server.is_performing_replay: server.db.record_message( master_table=server._db_master_record, type="parameters", request=request diff --git a/aepsych/server/message_handlers/handle_query.py b/aepsych/server/message_handlers/handle_query.py index 2ba9f4d83..65263e0b4 100644 --- a/aepsych/server/message_handlers/handle_query.py +++ b/aepsych/server/message_handlers/handle_query.py @@ -11,11 +11,11 @@ import numpy as np import torch -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_query(server, request): - logger.debug("got query message!") + logger.info("got query message!") if not server.is_performing_replay: server.db.record_message( master_table=server._db_master_record, type="query", request=request diff --git a/aepsych/server/message_handlers/handle_resume.py b/aepsych/server/message_handlers/handle_resume.py index 4da5a14ca..e7196648d 100644 --- a/aepsych/server/message_handlers/handle_resume.py +++ b/aepsych/server/message_handlers/handle_resume.py @@ -9,11 +9,11 @@ import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def handle_resume(server, request): - logger.debug("got resume message!") + logger.info("got resume message!") strat_id = int(request["message"]["strat_id"]) server.strat_id = strat_id if not server.is_performing_replay: diff --git a/aepsych/server/message_handlers/handle_setup.py b/aepsych/server/message_handlers/handle_setup.py index c7a7df2d7..e8bc8c991 100644 --- a/aepsych/server/message_handlers/handle_setup.py +++ b/aepsych/server/message_handlers/handle_setup.py @@ -14,7 +14,7 @@ from aepsych.strategy import SequentialStrategy from aepsych.version import __version__ -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def _configure(server, config): @@ -68,7 +68,7 @@ def configure(server, config=None, **config_args): def handle_setup(server, request): - logger.debug("got setup message!") + logger.info("got setup message!") ### make a temporary config object to derive parameters because server handles config after table if ( "config_str" in request["message"].keys() diff --git a/aepsych/server/message_handlers/handle_tell.py b/aepsych/server/message_handlers/handle_tell.py index 3b22e33fe..1903e41e6 100644 --- a/aepsych/server/message_handlers/handle_tell.py +++ b/aepsych/server/message_handlers/handle_tell.py @@ -15,13 +15,13 @@ import pandas as pd import torch -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() DEFAULT_DESC = "default description" DEFAULT_NAME = "default name" def handle_tell(server, request): - logger.debug("got tell message!") + logger.info("got tell message!") if not server.is_performing_replay: server.db.record_message( diff --git a/aepsych/server/replay.py b/aepsych/server/replay.py index d338900cc..9fda596fe 100644 --- a/aepsych/server/replay.py +++ b/aepsych/server/replay.py @@ -12,7 +12,7 @@ import pandas as pd from aepsych.server.message_handlers.handle_tell import flatten_tell_record -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def replay(server, uuid_to_replay, skip_computations=False): diff --git a/aepsych/server/server.py b/aepsych/server/server.py index d0a16ba92..8ff711a48 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -1,28 +1,29 @@ #!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. and its affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - import argparse +import asyncio +import concurrent import io +import json import logging import os -import sys -import threading import traceback import warnings -from typing import Dict, Union +from typing import Any, Dict, List, Optional, Union -import aepsych.database.db as db -import aepsych.utils_logging as utils_logging import dill import numpy as np +import pandas as pd import torch -from aepsych import version +from aepsych import utils_logging, version +from aepsych.config import Config +from aepsych.database import db +from aepsych.database.tables import DBMasterTable from aepsych.server.message_handlers import MESSAGE_MAP -from aepsych.server.message_handlers.handle_ask import ask from aepsych.server.message_handlers.handle_setup import configure from aepsych.server.replay import ( get_dataframe_from_replay, @@ -30,12 +31,10 @@ get_strats_from_replay, replay, ) -from aepsych.server.sockets import BAD_REQUEST, DummySocket, PySocket -from aepsych.utils import promote_0d +from aepsych.strategy import SequentialStrategy, Strategy +from multiprocess import Process -logger = utils_logging.getLogger(logging.INFO) -DEFAULT_DESC = "default description" -DEFAULT_NAME = "default name" +logger = utils_logging.getLogger() def get_next_filename(folder, fname, ext): @@ -44,191 +43,83 @@ def get_next_filename(folder, fname, ext): return f"{folder}/{fname}_{n + 1}.{ext}" -class AEPsychServer(object): - def __init__(self, socket=None, database_path=None): - """Server for doing black box optimization using gaussian processes. - Keyword Arguments: - socket -- socket object that implements `send` and `receive` for json - messages (default: DummySocket()). - TODO actually make an abstract interface to subclass from here - """ - if socket is None: - self.socket = DummySocket() - else: - self.socket = socket - self.db = None +class AEPsychServer: + def __init__( + self, + host: str = "0.0.0.0", + port: int = 5555, + database_path: str = "./databases/default.db", + ): + self.host = host + self.port = port + self.clients_connected = 0 + self.db: db.Database = db.Database(database_path) self.is_performing_replay = False self.exit_server_loop = False self._db_raw_record = None - self.db: db.Database = db.Database(database_path) self.skip_computations = False self.strat_names = None self.extensions = None + self._strats: List[SequentialStrategy] = [] + self._parnames: List[List[str]] = [] + self._configs: List[Config] = [] + self._master_records: List[DBMasterTable] = [] + self.strat_id = -1 + self.outcome_names: List[str] = [] if self.db.is_update_required(): self.db.perform_updates() - self._strats = [] - self._parnames = [] - self._configs = [] - self._master_records = [] - self.strat_id = -1 - self._pregen_asks = [] - self.enable_pregen = False - self.outcome_names = [] - - self.debug = False - self.receive_thread = threading.Thread( - target=self._receive_send, args=(self.exit_server_loop,), daemon=True - ) - - self.queue = [] - - def cleanup(self): - """Close the socket and terminate connection to the server. - - Returns: - None - """ - self.socket.close() - - def _receive_send(self, is_exiting: bool) -> None: - """Receive messages from the client. - - Args: - is_exiting (bool): True to terminate reception of new messages from the client, False otherwise. - - Returns: - None - """ - while True: - request = self.socket.receive(is_exiting) - if request != BAD_REQUEST: - self.queue.append(request) - if self.exit_server_loop: - break - logger.info("Terminated input thread") - - def _handle_queue(self) -> None: - """Handles the queue of messages received by the server. - - Returns: - None - """ - if self.queue: - request = self.queue.pop(0) - try: - result = self.handle_request(request) - except Exception as e: - error_message = f"Request '{request}' raised error '{e}'!" - result = f"server_error, {error_message}" - logger.error(f"{error_message}! Full traceback follows:") - logger.error(traceback.format_exc()) - self.socket.send(result) - else: - if self.can_pregen_ask and (len(self._pregen_asks) == 0): - self._pregen_asks.append(ask(self)) - - def serve(self) -> None: - """Run the server. Note that all configuration outside of socket type and port - happens via messages from the client. The server simply forwards messages from - the client to its `setup`, `ask` and `tell` methods, and responds with either - acknowledgment or other response as needed. To understand the server API, see - the docs on the methods in this class. - - Returns: - None - - Raises: - RuntimeError: if a request from a client has no request type - RuntimeError: if a request from a client has no known request type - TODO make things a little more robust to bad messages from client; this - requires resetting the req/rep queue status. - - """ - logger.info("Server up, waiting for connections!") - logger.info("Ctrl-C to quit!") - # yeah we're not sanitizing input at all - - # Start the method to accept a client connection - self.socket.accept_client() - self.receive_thread.start() - while True: - self._handle_queue() - if self.exit_server_loop: - break - # Close the socket and terminate with code 0 - self.cleanup() - sys.exit(0) - - def _unpack_strat_buffer(self, strat_buffer): - if isinstance(strat_buffer, io.BytesIO): - strat = torch.load(strat_buffer, pickle_module=dill) - strat_buffer.seek(0) - elif isinstance(strat_buffer, bytes): - warnings.warn( - "Strat buffer is not in bytes format!" - + " This is a deprecated format, loading using dill.loads.", - DeprecationWarning, - ) - strat = dill.loads(strat_buffer) - else: - raise RuntimeError("Trying to load strat in unknown format!") - return strat - - ### Properties that are set on a per-strat basis + #### Properties #### @property - def strat(self): + def strat(self) -> Optional[SequentialStrategy]: if self.strat_id == -1: return None else: return self._strats[self.strat_id] @strat.setter - def strat(self, s): + def strat(self, s: SequentialStrategy): self._strats.append(s) @property - def config(self): + def config(self) -> Optional[Config]: if self.strat_id == -1: return None else: return self._configs[self.strat_id] @config.setter - def config(self, s): + def config(self, s: Config): self._configs.append(s) @property - def parnames(self): + def parnames(self) -> List[str]: if self.strat_id == -1: return [] else: return self._parnames[self.strat_id] @parnames.setter - def parnames(self, s): + def parnames(self, s: List[str]): self._parnames.append(s) @property - def _db_master_record(self): + def _db_master_record(self) -> Optional[DBMasterTable]: if self.strat_id == -1: return None else: return self._master_records[self.strat_id] @_db_master_record.setter - def _db_master_record(self, s): + def _db_master_record(self, s: DBMasterTable): self._master_records.append(s) @property - def n_strats(self): + def n_strats(self) -> int: return len(self._strats) - @property - def can_pregen_ask(self): - return self.strat is not None and self.enable_pregen - + #### Methods to handle parameter configs #### def _tensor_to_config(self, next_x): stim_per_trial = self.strat.stimuli_per_trial dim = self.strat.dim @@ -280,8 +171,11 @@ def _config_to_tensor(self, config): return x - def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]): + def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]) -> Dict[int, Any]: # Given a dictionary of fixed parameters, turn the parameters names into indices + if self.strat is None: + raise ValueError("No strategy is set, cannot convert fixed parameters.") + dummy = np.zeros(len(self.parnames)).astype("O") for key, value in fixed.items(): idx = self.parnames.index(key) @@ -297,14 +191,208 @@ def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]): return fixed_features - def __getstate__(self): - # nuke the socket since it's not pickleble - state = self.__dict__.copy() - del state["socket"] - del state["db"] - return state + #### Methods to handle replay #### + def replay(self, uuid_to_replay: int, skip_computations: bool = False) -> None: + """Replay an experiment with a specific unique ID. This will leave the + server state at the end of the replay. + + Args: + uuid_to_replay (int): Unique ID of the experiment to replay. This is + the primary key of the experiment's master table. + skip_computations (bool): If True, skip computations during the replay. + Defaults to False. + """ + return replay(self, uuid_to_replay, skip_computations) + + def get_strats_from_replay( + self, uuid_of_replay: Optional[int] = None, force_replay: bool = False + ) -> List[Strategy]: + """Replay an experiment then return the strategies from the replay. + + Args: + uuid_to_replay (int, optional): Unique ID of the experiment to + replay. If not set, the last experiment in the database will be + used. + force_replay (bool): If True, force a replay. Defaults to False. + + Returns: + List[Union[SequentialStrategy, Strategy]]: List of strategies from + the replay. + """ + return get_strats_from_replay(self, uuid_of_replay, force_replay) + + def get_strat_from_replay( + self, uuid_of_replay: Optional[int] = None, strat_id: int = -1 + ) -> Strategy: + """Replay an experiment then return a strategy from the replay. + + Args: + uuid_to_replay (int, optional): Unique ID of the experiment to + replay. If not set, the last experiment in the database will be + used. + strat_id (int): ID of the strategy to return. Defaults to -1, which + returns the last strategy. + + Returns: + Strategy: The strategy from the replay. + """ + return get_strat_from_replay(self, uuid_of_replay, strat_id) + + def get_dataframe_from_replay( + self, uuid_of_replay: Optional[int] = None, force_replay: bool = False + ) -> pd.DataFrame: + """Replay an experiment then return the dataframe from the replay. + + Args: + uuid_to_replay (int, optional): Unique ID of the experiment to + replay. If not set, the last experiment in the database will be + used. + force_replay (bool): If True, force a replay. Defaults to False. + + Returns: + pd.DataFrame: Dataframe from the replay. + """ + return get_dataframe_from_replay(self, uuid_of_replay, force_replay) - def write_strats(self, termination_type): + def _unpack_strat_buffer(self, strat_buffer): + # Unpacks a strategy buffer from the database. + if isinstance(strat_buffer, io.BytesIO): + strat = torch.load(strat_buffer, pickle_module=dill) + strat_buffer.seek(0) + elif isinstance(strat_buffer, bytes): + warnings.warn( + "Strat buffer is not in bytes format!" + + " This is a deprecated format, loading using dill.loads.", + DeprecationWarning, + ) + strat = dill.loads(strat_buffer) + else: + raise RuntimeError("Trying to load strat in unknown format!") + return strat + + #### Method to handle async server #### + def start_blocking(self) -> None: + """Starts the server in a blocking state in the main thread. Used by the + command line interface to start the server for a client in another + process or machine.""" + asyncio.run(self.serve()) + + async def serve(self) -> None: + """Serves the server on the set IP and port. This creates a coroutine + for asyncio to handle requests asyncronously. + """ + self.server = await asyncio.start_server( + self.handle_client, self.host, self.port + ) + self.loop = asyncio.get_running_loop() + pool = concurrent.futures.ThreadPoolExecutor() + self.loop.set_default_executor(pool) + + async with self.server: + logging.info(f"Serving on {self.host}:{self.port}") + try: + await self.server.serve_forever() + except asyncio.CancelledError: + raise + except KeyboardInterrupt: + exception_type = "CTRL+C" + dump_type = "dump" + self.write_strats(exception_type) + self.generate_debug_info(exception_type, dump_type) + except RuntimeError as e: + exception_type = "RuntimeError" + dump_type = "crashdump" + self.write_strats(exception_type) + self.generate_debug_info(exception_type, dump_type) + raise RuntimeError(e) + + async def handle_client(self, reader, writer): + """Coroutine for handling a client connection. This will read messages + from the connected client and dispatch a task to handle the request on + another thread such that its blocking state does not block the server. + This coroutine will end if the client closes the connection. + + Args: + reader: asyncio.StreamReader: The stream reader for the client. + writer: asyncio.StreamWriter: The stream writer for the client. + """ + addr = writer.get_extra_info("peername") + logger.info(f"Connected to {addr}") + self.clients_connected += 1 + + try: + while True: + if self.exit_server_loop: + self.server.close() + break + rcv = await reader.read(1024 * 512) + try: + message = json.loads(rcv) + except UnicodeDecodeError as e: + logger.error(f"Malformed message: {rcv}") + logger.error(traceback.format_exc()) + result = {"error": str(e)} + return_msg = json.dumps(self._simplify_arrays(result)).encode() + writer.write(return_msg) + continue + + future = self.loop.run_in_executor(None, self.handle_request, message) + try: + result = await future + except Exception as e: + logger.error(f"Error handling message: {message}") + logger.error(traceback.format_exc()) + # Some exceptions turned into string are meaningless, so we use repr + result = {"error": e.__repr__()} + if isinstance(result, dict): + return_msg = json.dumps(self._simplify_arrays(result)).encode() + writer.write(return_msg) + else: + writer.write(str(result).encode()) + + await writer.drain() + except asyncio.CancelledError: + pass + finally: + logger.info(f"Connection closed for {addr}") + writer.close() + await writer.wait_closed() + self.clients_connected -= 1 + + def handle_request(self, message: Dict[str, Any]) -> Union[Dict[str, Any], str]: + """Given a message, dispatch the correct handler and return the result. + + Args: + message (Dict[str, Any]): The message to handle. + + Returns: + Union[Dict[str, Any], str]: The result of handling the message. + """ + type_ = message["type"] + result = MESSAGE_MAP[type_](self, message) + return result + + def _simplify_arrays(self, message): + # Simplify arrays for encoding and sending a message to the client + return { + k: ( + v.tolist() + if type(v) == np.ndarray + else self._simplify_arrays(v) + if type(v) is dict + else v + ) + for k, v in message.items() + } + + #### Methods to handle exiting #### + def write_strats(self, termination_type: str) -> None: + """Pickle the stats and records them into the database. + + Args: + termination_type (str): The type of termination. This only affects + the log message. + """ if self._db_master_record is not None and self.strat is not None: logger.info(f"Dumping strats to DB due to {termination_type}.") for strat in self._strats: @@ -313,77 +401,86 @@ def write_strats(self, termination_type): buffer.seek(0) self.db.record_strat(master_table=self._db_master_record, strat=buffer) - def generate_debug_info(self, exception_type, dumptype): + def generate_debug_info(self, exception_type: str, dumptype: str) -> None: + """Generate a debug info file for the server. This will pickle the server + and save it to a file. + + Args: + exception_type (str): The type of exception that caused the server + to terminate. This only affects the log message. + dump_type (str): The type of dump. This only affects the log file. + """ fname = get_next_filename(".", dumptype, "pkl") logger.exception(f"Got {exception_type}, exiting! Server dump in {fname}") dill.dump(self, open(fname, "wb")) - def handle_request(self, request): - if "type" not in request.keys(): - raise RuntimeError(f"Request {request} contains no request type!") - else: - type = request["type"] - if type in MESSAGE_MAP.keys(): - logger.info(f"Received msg [{type}]") - ret_val = MESSAGE_MAP[type](self, request) - return ret_val - - else: - exception_message = ( - f"unknown type: {type}. Allowed types [{MESSAGE_MAP.keys()}]" - ) + def __getstate__(self): + # Called when the server is pickled, we can't pickle the DB. + state = self.__dict__.copy() + del state["db"] + return state - raise RuntimeError(exception_message) - def replay(self, uuid_to_replay, skip_computations=False): - return replay(self, uuid_to_replay, skip_computations) +class AEPsychBackgroundServer(AEPsychServer): + """A class to handle the server in a background thread. Unlike the normal + AEPsychServer, this does not create the db right away until the server is + started. When starting this server, it'll be sent to another process, a db + will be initialized, then the server will be served. This server should then + be interacted with by the main thread via a client.""" + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 5555, + database_path: str = "./databases/default.db", + ): + self.host = host + self.port = port + self.database_path = database_path + self.clients_connected = 0 + self.is_performing_replay = False + self.exit_server_loop = False + self._db_raw_record = None + self.skip_computations = False + self.background_process = None + self.strat_names = None + self.extensions = None + self._strats = [] + self._parnames = [] + self._configs = [] + self._master_records = [] + self.strat_id = -1 + self.outcome_names = [] - def get_strats_from_replay(self, uuid_of_replay=None, force_replay=False): - return get_strats_from_replay(self, uuid_of_replay, force_replay) + def _start_server(self) -> None: + self.db: db.Database = db.Database(self.database_path) + if self.db.is_update_required(): + self.db.perform_updates() - def get_strat_from_replay(self, uuid_of_replay=None, strat_id=-1): - return get_strat_from_replay(self, uuid_of_replay, strat_id) + super().start_blocking() - def get_dataframe_from_replay(self, uuid_of_replay=None, force_replay=False): - return get_dataframe_from_replay(self, uuid_of_replay, force_replay) + def start(self): + """Starts the server in a background thread. Used by the client to start + the server for a client in another process or machine.""" + self.background_process = Process(target=self._start_server, daemon=True) + self.background_process.start() + def stop(self): + """Stops the server and closes the background process.""" + self.exit_server_loop = True + self.background_process.terminate() + self.background_process.join() + self.background_process.close() + self.background_process = None -#! THIS IS WHAT START THE SERVER -def startServerAndRun( - server_class, socket=None, database_path=None, config_path=None, id_of_replay=None -): - server = server_class(socket=socket, database_path=database_path) - try: - if config_path is not None: - with open(config_path) as f: - config_str = f.read() - configure(server, config_str=config_str) - - if socket is not None: - if id_of_replay is not None: - server.replay(id_of_replay, skip_computations=True) - server.serve() - else: - if config_path is not None: - logger.info( - "You have passed in a config path but this is a replay. If there's a config in the database it will be used instead of the passed in config path." - ) - server.replay(id_of_replay) - except KeyboardInterrupt: - exception_type = "CTRL+C" - dump_type = "dump" - server.write_strats(exception_type) - server.generate_debug_info(exception_type, dump_type) - except RuntimeError as e: - exception_type = "RuntimeError" - dump_type = "crashdump" - server.write_strats(exception_type) - server.generate_debug_info(exception_type, dump_type) - raise RuntimeError(e) + def __getstate__(self): + # Override parent's __getstate__ to not worry about the db + state = self.__dict__.copy() + return state def parse_argument(): - parser = argparse.ArgumentParser(description="AEPsych Server!") + parser = argparse.ArgumentParser(description="AEPsych Server") parser.add_argument( "--port", metavar="N", type=int, default=5555, help="port to serve on" ) @@ -415,72 +512,54 @@ def parse_argument(): "--db", type=str, help="The database to use if not the default (./databases/default.db).", - default=None, + default="./databases/default.db", ) parser.add_argument( - "-r", "--replay", type=str, help="Unique id of the experiment to replay." - ) - - parser.add_argument( - "-m", "--resume", action="store_true", help="Resume server after replay." + "-r", + "--resume", + type=str, + help="Unique id of the experiment to replay and resume the server from.", ) args = parser.parse_args() return args -def start_server(server_class, args): - logger.info("Starting the AEPsychServer") +def main(): + logger = utils_logging.getLogger() + logger.info("Starting AEPsychServer") logger.info(f"AEPsych Version: {version.__version__}") - try: - if "db" in args and args.db is not None: - database_path = args.db - if "replay" in args and args.replay is not None: - logger.info(f"Attempting to replay {args.replay}") - if args.resume is True: - sock = PySocket(port=args.port) - logger.info(f"Will resume {args.replay}") - else: - sock = None - startServerAndRun( - server_class, - socket=sock, - database_path=database_path, - uuid_of_replay=args.replay, - config_path=args.stratconfig, - ) - else: - logger.info(f"Setting the database path {database_path}") - sock = PySocket(port=args.port) - startServerAndRun( - server_class, - database_path=database_path, - socket=sock, - config_path=args.stratconfig, - ) - else: - sock = PySocket(port=args.port) - startServerAndRun(server_class, socket=sock, config_path=args.stratconfig) - - except (KeyboardInterrupt, SystemExit): - logger.exception("Got Ctrl+C, exiting!") - sys.exit() - except RuntimeError as e: - fname = get_next_filename(".", "dump", "pkl") - logger.exception(f"CRASHING!! dump in {fname}") - raise RuntimeError(e) - -def main(server_class=AEPsychServer): args = parse_argument() if args.logs: # overide logger path log_path = args.logs - logger = utils_logging.getLogger(logging.DEBUG, log_path) - logger.info(f"Saving logs to path: {log_path}") - start_server(server_class, args) + logger = utils_logging.getLogger(log_path) + logger.info(f"Saving logs to path: {log_path}") + + server = AEPsychServer( + host=args.ip, + port=args.port, + database_path=args.db, + ) + + if args.stratconfig is not None and args.resume is not None: + raise ValueError( + "Cannot configure the server with a config file and a resume from a replay at the same time." + ) + + elif args.stratconfig is not None: + configure(server, config_str=args.stratconfig) + + elif args.resume is not None: + if args.db is None: + raise ValueError("Cannot resume from a replay if no database is given.") + server.replay(args.resume, skip_computations=True) + + # Starts the server in a blocking state + server.start_blocking() if __name__ == "__main__": - main(AEPsychServer) + main() diff --git a/aepsych/server/sockets.py b/aepsych/server/sockets.py index 12b3e640d..8aaf79257 100644 --- a/aepsych/server/sockets.py +++ b/aepsych/server/sockets.py @@ -14,7 +14,7 @@ import aepsych.utils_logging as utils_logging import numpy as np -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() BAD_REQUEST = "bad request" diff --git a/aepsych/server/utils.py b/aepsych/server/utils.py index 27e6db063..027ddd30f 100644 --- a/aepsych/server/utils.py +++ b/aepsych/server/utils.py @@ -13,7 +13,7 @@ import aepsych.database.db as db import aepsych.utils_logging as utils_logging -logger = utils_logging.getLogger(logging.INFO) +logger = utils_logging.getLogger() def get_next_filename(folder, fname, ext): diff --git a/aepsych/strategy/strategy.py b/aepsych/strategy/strategy.py index 097a2f44c..182b55e04 100644 --- a/aepsych/strategy/strategy.py +++ b/aepsych/strategy/strategy.py @@ -8,6 +8,7 @@ from __future__ import annotations import warnings +from copy import deepcopy from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import numpy as np @@ -56,6 +57,7 @@ def __init__( name: str = "", run_indefinitely: bool = False, transforms: ChainedInputTransform = ChainedInputTransform(**{}), + copy_model: bool = False, ) -> None: """Initialize the strategy object. @@ -90,6 +92,9 @@ def __init__( should be defined in raw parameter space for initialization. However, if the lb/ub attribute are access from an initialized Strategy object, it will be returned in transformed space. + copy_model (bool): Whether to do any model-related methods on a + copy or the original. Used for multi-client strategies. Defaults + to False. """ self.is_finished = False @@ -160,6 +165,7 @@ def __init__( self.min_total_outcome_occurrences = min_total_outcome_occurrences self.max_asks = max_asks or generator.max_asks self.keep_most_recent = keep_most_recent + self.copy_model = copy_model self.transforms = transforms if self.transforms is not None: @@ -267,7 +273,8 @@ def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor: self.model.to(self.generator_device) # type: ignore self._count = self._count + num_points - points = self.generator.gen(num_points, self.model, **kwargs) + model = deepcopy(self.model) if self.copy_model else self.model + points = self.generator.gen(num_points, model, **kwargs) if original_device is not None: self.model.to(original_device) # type: ignore @@ -295,9 +302,9 @@ def get_max( self.model is not None ), "model is None! Cannot get the max without a model!" self.model.to(self.model_device) - + model = deepcopy(self.model) if self.copy_model else self.model val, arg = get_max( - self.model, + model, self.bounds, locked_dims=constraints, probability_space=probability_space, @@ -324,9 +331,9 @@ def get_min( self.model is not None ), "model is None! Cannot get the min without a model!" self.model.to(self.model_device) - + model = deepcopy(self.model) if self.copy_model else self.model val, arg = get_min( - self.model, + model, self.bounds, locked_dims=constraints, probability_space=probability_space, @@ -358,9 +365,9 @@ def inv_query( self.model is not None ), "model is None! Cannot get the inv_query without a model!" self.model.to(self.model_device) - + model = deepcopy(self.model) if self.copy_model else self.model val, arg = inv_query( - model=self.model, + model=model, y=y, bounds=self.bounds, locked_dims=constraints, @@ -385,7 +392,8 @@ def predict( """ assert self.model is not None, "model is None! Cannot predict without a model!" self.model.to(self.model_device) - return self.model.predict(x=x, probability_space=probability_space) + model = deepcopy(self.model) if self.copy_model else self.model + return model.predict(x=x, probability_space=probability_space) @ensure_model_is_fresh def sample(self, x: torch.Tensor, num_samples: int = 1000) -> torch.Tensor: @@ -400,7 +408,8 @@ def sample(self, x: torch.Tensor, num_samples: int = 1000) -> torch.Tensor: """ assert self.model is not None, "model is None! Cannot sample without a model!" self.model.to(self.model_device) - return self.model.sample(x, num_samples=num_samples) + model = deepcopy(self.model) if self.copy_model else self.model + return model.sample(x, num_samples=num_samples) def finish(self) -> None: """Finish the strategy.""" @@ -442,7 +451,8 @@ def finished(self) -> bool: assert ( self.model is not None ), "model is None! Cannot predict without a model!" - fmean, _ = self.model.predict(self.eval_grid, probability_space=True) + model = deepcopy(self.model) if self.copy_model else self.model + fmean, _ = model.predict(self.eval_grid, probability_space=True) meets_post_range = bool( ((fmean.max() - fmean.min()) >= self.min_post_range).item() ) @@ -504,9 +514,10 @@ def fit(self) -> None: """Fit the model.""" if self.can_fit: self.model.to(self.model_device) # type: ignore + model = deepcopy(self.model) if self.copy_model else self.model if self.keep_most_recent is not None: try: - self.model.fit( # type: ignore + model.fit( # type: ignore self.x[-self.keep_most_recent :], # type: ignore self.y[-self.keep_most_recent :], # type: ignore ) @@ -516,11 +527,12 @@ def fit(self) -> None: ) else: try: - self.model.fit(self.x, self.y) # type: ignore + model.fit(self.x, self.y) # type: ignore except ModelFittingError: logger.warning( "Failed to fit model! Predictions may not be accurate!" ) + self.model = model else: warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning) @@ -528,9 +540,10 @@ def update(self) -> None: """Update the model.""" if self.can_fit: self.model.to(self.model_device) # type: ignore + model = deepcopy(self.model) if self.copy_model else self.model if self.keep_most_recent is not None: try: - self.model.update( # type: ignore + model.update( # type: ignore self.x[-self.keep_most_recent :], # type: ignore self.y[-self.keep_most_recent :], # type: ignore ) @@ -540,11 +553,13 @@ def update(self) -> None: ) else: try: - self.model.update(self.x, self.y) # type: ignore + model.update(self.x, self.y) # type: ignore except ModelFittingError: logger.warning( "Failed to fit model! Predictions may not be accurate!" ) + + self.model = model else: warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning) diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index ce1b54a12..eabe73ca7 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -477,6 +477,10 @@ def eval(self): f"{self._base_obj.__class__.__name__} has no attribute 'eval'" ) + def __reduce__(self): + # Helps pickle work (not dill) + return (ParameterTransformedGenerator, (self._base_obj, self.transforms)) + @classmethod def get_config_options( cls, @@ -725,6 +729,10 @@ def eval(self): f"{self._base_obj.__class__.__name__} has no attribute 'eval'" ) + def __reduce__(self): + # Helps pickle work (not dill) + return (ParameterTransformedModel, (self._base_obj, self.transforms)) + @classmethod def get_config_options( cls, diff --git a/aepsych/utils_logging.py b/aepsych/utils_logging.py index 9a5aef693..9c0eb37b5 100644 --- a/aepsych/utils_logging.py +++ b/aepsych/utils_logging.py @@ -35,7 +35,7 @@ def format(self, record): return formatter.format(record) -def getLogger(level=logging.INFO, log_path: str = "logs") -> logging.Logger: +def getLogger(log_path: str = "logs") -> logging.Logger: """Get a logger with the specified level and log path. Args: @@ -53,7 +53,7 @@ def getLogger(level=logging.INFO, log_path: str = "logs") -> logging.Logger: "formatters": {"standard": {"()": ColorFormatter}}, "handlers": { "default": { - "level": level, + "level": logging.INFO, "class": "logging.StreamHandler", "formatter": "standard", }, @@ -65,7 +65,11 @@ def getLogger(level=logging.INFO, log_path: str = "logs") -> logging.Logger: }, }, "loggers": { - "": {"handlers": ["default", "file"], "level": level, "propagate": False}, + "": { + "handlers": ["default", "file"], + "level": logging.DEBUG, + "propagate": False, + }, }, } diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 494e5d72d..96425b481 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -34,8 +34,6 @@ def setUp(self): ) def tearDown(self): - self.s.cleanup() - # cleanup the db if self.s.db is not None: self.s.db.delete_db() diff --git a/tests/models/test_pairwise_probit.py b/tests/models/test_pairwise_probit.py index 9b2b6c138..60e22adef 100644 --- a/tests/models/test_pairwise_probit.py +++ b/tests/models/test_pairwise_probit.py @@ -495,14 +495,12 @@ def test_hyperparam_consistency(self): class PairwiseProbitModelServerTest(unittest.TestCase): def setUp(self): # setup logger - server.logger = utils_logging.getLogger(logging.DEBUG, "logs") + server.logger = utils_logging.getLogger("logs") # random datebase path name without dashes database_path = "./{}.db".format(str(uuid.uuid4().hex)) self.s = server.AEPsychServer(database_path=database_path) def tearDown(self): - self.s.cleanup() - # cleanup the db if self.s.db is not None: self.s.db.delete_db() diff --git a/tests/server/message_handlers/test_ask_handlers.py b/tests/server/message_handlers/test_ask_handlers.py index 9d3fa5c6b..30d773f90 100644 --- a/tests/server/message_handlers/test_ask_handlers.py +++ b/tests/server/message_handlers/test_ask_handlers.py @@ -8,7 +8,7 @@ import unittest -from ..test_server import BaseServerTestCase +from ..test_server import AsyncServerTestBase dummy_config = """ [common] @@ -69,7 +69,7 @@ """ -class AskHandlerTestCase(BaseServerTestCase): +class AskHandlerTestCase(AsyncServerTestBase): def test_handle_ask(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_can_model.py b/tests/server/message_handlers/test_can_model.py index 01b3c4b8a..6f5b090e0 100644 --- a/tests/server/message_handlers/test_can_model.py +++ b/tests/server/message_handlers/test_can_model.py @@ -7,10 +7,10 @@ import unittest -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class StratCanModelTestCase(BaseServerTestCase): +class StratCanModelTestCase(AsyncServerTestBase): def test_strat_can_model(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_handle_exit.py b/tests/server/message_handlers/test_handle_exit.py index 5b548bc25..bd8ff89b7 100644 --- a/tests/server/message_handlers/test_handle_exit.py +++ b/tests/server/message_handlers/test_handle_exit.py @@ -5,24 +5,30 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import asyncio import unittest -from unittest.mock import MagicMock -from ..test_server import BaseServerTestCase +from ..test_server import AsyncServerTestBase, dummy_config -class HandleExitTestCase(BaseServerTestCase): - def test_handle_exit(self): +class HandleExitTestCase(AsyncServerTestBase): + async def test_handle_exit(self): + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + + await self.mock_client(setup_request) + request = {} request["type"] = "exit" - self.s.socket.accept_client = MagicMock() - self.s.socket.receive = MagicMock(return_value=request) - self.s.dump = MagicMock() + await self.mock_client(request) - with self.assertRaises(SystemExit) as cm: - self.s.serve() + with self.assertRaises(ConnectionRefusedError): + await asyncio.open_connection(self.s.host, self.s.port) - self.assertEqual(cm.exception.code, 0) + self.assertTrue(self.s.exit_server_loop) if __name__ == "__main__": diff --git a/tests/server/message_handlers/test_handle_finish_strategy.py b/tests/server/message_handlers/test_handle_finish_strategy.py index 9efffdb20..729421bdb 100644 --- a/tests/server/message_handlers/test_handle_finish_strategy.py +++ b/tests/server/message_handlers/test_handle_finish_strategy.py @@ -7,10 +7,10 @@ import unittest -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class ResumeTestCase(BaseServerTestCase): +class ResumeTestCase(AsyncServerTestBase): def test_handle_finish_strategy(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_handle_get_config.py b/tests/server/message_handlers/test_handle_get_config.py index d79c0697f..b173d22e5 100644 --- a/tests/server/message_handlers/test_handle_get_config.py +++ b/tests/server/message_handlers/test_handle_get_config.py @@ -9,10 +9,10 @@ from aepsych.config import Config -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class HandleExitTestCase(BaseServerTestCase): +class HandleExitTestCase(AsyncServerTestBase): def test_get_config(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_query_handlers.py b/tests/server/message_handlers/test_query_handlers.py index 2f8eaff2a..3ff9f618e 100644 --- a/tests/server/message_handlers/test_query_handlers.py +++ b/tests/server/message_handlers/test_query_handlers.py @@ -7,12 +7,12 @@ import unittest -from ..test_server import BaseServerTestCase +from ..test_server import AsyncServerTestBase # Smoke test to make sure nothing breaks. This should really be combined with # the individual query tests -class QueryHandlerTestCase(BaseServerTestCase): +class QueryHandlerTestCase(AsyncServerTestBase): def test_strat_query(self): # Annoying and complex model and output shapes config_str = """ diff --git a/tests/server/message_handlers/test_tell_handlers.py b/tests/server/message_handlers/test_tell_handlers.py index 4128b4ed6..7f68e84f5 100644 --- a/tests/server/message_handlers/test_tell_handlers.py +++ b/tests/server/message_handlers/test_tell_handlers.py @@ -9,10 +9,10 @@ import unittest from unittest.mock import MagicMock -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class MessageHandlerTellTests(BaseServerTestCase): +class MessageHandlerTellTests(AsyncServerTestBase): def test_tell(self): setup_request = { "type": "setup", diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 57d59da50..4345f4ef7 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -5,18 +5,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import asyncio import json import logging -import select import time import unittest import uuid from pathlib import Path -from unittest.mock import MagicMock +from typing import Any, Dict import aepsych.server as server import aepsych.utils_logging as utils_logging -from aepsych.server.sockets import BAD_REQUEST dummy_config = """ [common] @@ -46,6 +45,7 @@ generator = OptimizeAcqfGenerator model = GPClassificationModel min_total_outcome_occurrences = 0 +copy_model = True [OptimizeAcqfGenerator] acqf = MCPosteriorVariance @@ -77,28 +77,52 @@ """ -class BaseServerTestCase(unittest.TestCase): - # so that this can be overridden for tests that require specific databases. +class AsyncServerTestBase(unittest.IsolatedAsyncioTestCase): @property def database_path(self): return "./{}_test_server.db".format(str(uuid.uuid4().hex)) - def setUp(self): + async def asyncSetUp(self): + self.ip = "127.0.0.1" + self.port = 5555 + # setup logger - server.logger = utils_logging.getLogger(logging.DEBUG, "logs") - # random port - socket = server.sockets.PySocket(port=0) + self.logger = utils_logging.getLogger("unittests") + # random datebase path name without dashes database_path = self.database_path - self.s = server.AEPsychServer(socket=socket, database_path=database_path) + self.s = server.AEPsychServer( + database_path=database_path, host=self.ip, port=self.port + ) self.db_name = database_path.split("/")[1] self.db_path = database_path - def tearDown(self): - self.s.cleanup() + try: + self.server_task = asyncio.create_task(self.s.serve()) + except OSError: + # Try 0.0.0.0 after waiting + time.sleep(5) + self.ip = "0.0.0.0" + self.s = server.AEPsychServer( + database_path=database_path, host=self.ip, port=self.port + ) + self.server_task = asyncio.create_task(self.s.serve()) + await asyncio.sleep(0.1) + + self.reader, self.writer = await asyncio.open_connection(self.ip, self.port) + + async def asyncTearDown(self): + # Stops the client + self.writer.close() - # sleep to ensure db is closed - time.sleep(0.2) + # Stops the server + self.server_task.cancel() + try: + await self.server_task + except asyncio.CancelledError: + pass + + await asyncio.sleep(0.2) # cleanup the db if self.s.db is not None: @@ -107,46 +131,18 @@ def tearDown(self): except PermissionError as e: print("Failed to deleted database: ", e) - def dummy_create_setup(self, server, request=None): - request = request or {"test": "test request"} - server._db_master_record = server.db.record_setup( - description="default description", name="default name", request=request - ) + async def mock_client(self, request: Dict[str, Any]) -> Any: + self.writer.write(json.dumps(request).encode()) + await self.writer.drain() + response = await self.reader.read(1024 * 512) + return response.decode() -class ServerTestCase(BaseServerTestCase): - def test_final_strat_serialization(self): - setup_request = { - "type": "setup", - "version": "0.01", - "message": {"config_str": dummy_config}, - } - ask_request = {"type": "ask", "message": ""} - tell_request = { - "type": "tell", - "message": {"config": {"x": [0.5]}, "outcome": 1}, - } - self.s.handle_request(setup_request) - while not self.s.strat.finished: - self.s.handle_request(ask_request) - self.s.handle_request(tell_request) - unique_id = self.s.db.get_master_records()[-1].unique_id - stored_strat = self.s.get_strat_from_replay(unique_id) - # just some spot checks that the strat's the same - # same data. We do this twice to make sure buffers are - # in a good state and we can load twice without crashing - for _ in range(2): - stored_strat = self.s.get_strat_from_replay(unique_id) - self.assertTrue((stored_strat.x == self.s.strat.x).all()) - self.assertTrue((stored_strat.y == self.s.strat.y).all()) - # same lengthscale and outputscale - self.assertEqual( - stored_strat.model.covar_module.lengthscale, - self.s.strat.model.covar_module.lengthscale, - ) +class AsyncServerTestCase(AsyncServerTestBase): + """Server functions are all async""" - def test_pandadf_dump_single(self): + async def test_pandadf_dump_single(self): setup_request = { "type": "setup", "version": "0.01", @@ -158,20 +154,22 @@ def test_pandadf_dump_single(self): "message": {"config": {"x": [0.5]}, "outcome": 1}, "extra_info": {}, } - self.s.handle_request(setup_request) + + await self.mock_client(setup_request) + expected_x = [0, 1, 2, 3] expected_z = list(reversed(expected_x)) expected_y = [x % 2 for x in expected_x] i = 0 while not self.s.strat.finished: - self.s.handle_request(ask_request) + await self.mock_client(ask_request) tell_request["message"]["config"]["x"] = [expected_x[i]] tell_request["message"]["config"]["z"] = [expected_z[i]] tell_request["message"]["outcome"] = expected_y[i] tell_request["extra_info"]["e1"] = 1 tell_request["extra_info"]["e2"] = 2 i = i + 1 - self.s.handle_request(tell_request) + await self.mock_client(tell_request) unique_id = self.s.db.get_master_records()[-1].unique_id out_df = self.s.get_dataframe_from_replay(unique_id) @@ -183,7 +181,38 @@ def test_pandadf_dump_single(self): self.assertTrue("post_mean" in out_df.columns) self.assertTrue("post_var" in out_df.columns) - def test_pandadf_dump_multistrat(self): + async def test_final_strat_serialization(self): + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + ask_request = {"type": "ask", "message": ""} + tell_request = { + "type": "tell", + "message": {"config": {"x": [0.5]}, "outcome": 1}, + } + await self.mock_client(setup_request) + while not self.s.strat.finished: + await self.mock_client(ask_request) + await self.mock_client(tell_request) + + unique_id = self.s.db.get_master_records()[-1].unique_id + stored_strat = self.s.get_strat_from_replay(unique_id) + # just some spot checks that the strat's the same + # same data. We do this twice to make sure buffers are + # in a good state and we can load twice without crashing + for _ in range(2): + stored_strat = self.s.get_strat_from_replay(unique_id) + self.assertTrue((stored_strat.x == self.s.strat.x).all()) + self.assertTrue((stored_strat.y == self.s.strat.y).all()) + # same lengthscale and outputscale + self.assertEqual( + stored_strat.model.covar_module.lengthscale, + self.s.strat.model.covar_module.lengthscale, + ) + + async def test_pandadf_dump_multistrat(self): setup_request = { "type": "setup", "version": "0.01", @@ -199,16 +228,16 @@ def test_pandadf_dump_multistrat(self): expected_z = list(reversed(expected_x)) expected_y = [x % 2 for x in expected_x] i = 0 - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - self.s.handle_request(ask_request) + await self.mock_client(ask_request) tell_request["message"]["config"]["x"] = [expected_x[i]] tell_request["message"]["config"]["z"] = [expected_z[i]] tell_request["message"]["outcome"] = expected_y[i] tell_request["extra_info"]["e1"] = 1 tell_request["extra_info"]["e2"] = 2 i = i + 1 - self.s.handle_request(tell_request) + await self.mock_client(tell_request) unique_id = self.s.db.get_master_records()[-1].unique_id out_df = self.s.get_dataframe_from_replay(unique_id) @@ -221,7 +250,7 @@ def test_pandadf_dump_multistrat(self): self.assertTrue("post_mean" in out_df.columns) self.assertTrue("post_var" in out_df.columns) - def test_pandadf_dump_flat(self): + async def test_pandadf_dump_flat(self): """ This test handles the case where the config values are flat scalars and not lists @@ -237,20 +266,20 @@ def test_pandadf_dump_flat(self): "message": {"config": {"x": [0.5]}, "outcome": 1}, "extra_info": {}, } - self.s.handle_request(setup_request) + await self.mock_client(setup_request) expected_x = [0, 1, 2, 3] expected_z = list(reversed(expected_x)) expected_y = [x % 2 for x in expected_x] i = 0 while not self.s.strat.finished: - self.s.handle_request(ask_request) + await self.mock_client(ask_request) tell_request["message"]["config"]["x"] = expected_x[i] tell_request["message"]["config"]["z"] = expected_z[i] tell_request["message"]["outcome"] = expected_y[i] tell_request["extra_info"]["e1"] = 1 tell_request["extra_info"]["e2"] = 2 i = i + 1 - self.s.handle_request(tell_request) + await self.mock_client(tell_request) unique_id = self.s.db.get_master_records()[-1].unique_id out_df = self.s.get_dataframe_from_replay(unique_id) @@ -262,52 +291,7 @@ def test_pandadf_dump_flat(self): self.assertTrue("post_mean" in out_df.columns) self.assertTrue("post_var" in out_df.columns) - def test_receive(self): - """test_receive - verifies the receive is working when server receives unexpected messages""" - - message1 = b"\x16\x03\x01\x00\xaf\x01\x00\x00\xab\x03\x03\xa9\x80\xcc" # invalid message - message2 = b"\xec\xec\x14M\xfb\xbd\xac\xe7jF\xbe\xf9\x9bM\x92\x15b\xb5" # invalid message - message3 = {"message": {"target": "test request"}} # valid message - message_list = [message1, message2, json.dumps(message3)] - - self.s.socket.conn = MagicMock() - - for i, message in enumerate(message_list): - select.select = MagicMock(return_value=[[self.s.socket.conn], [], []]) - self.s.socket.conn.recv = MagicMock(return_value=message) - if i != 2: - self.assertEqual(self.s.socket.receive(False), BAD_REQUEST) - else: - self.assertEqual(self.s.socket.receive(False), message3) - - def test_error_handling(self): - # double brace escapes, single brace to substitute, so we end up with 3 braces - request = f"{{{BAD_REQUEST}}}" - - expected_error = f"server_error, Request '{request}' raised error ''str' object has no attribute 'keys''!" - - self.s.socket.accept_client = MagicMock() - - self.s.socket.receive = MagicMock(return_value=request) - self.s.socket.send = MagicMock() - self.s.exit_server_loop = True - with self.assertRaises(SystemExit): - self.s.serve() - self.s.socket.send.assert_called_once_with(expected_error) - - def test_queue(self): - """Test to see that the queue is being handled correctly""" - - self.s.socket.accept_client = MagicMock() - ask_request = {"type": "ask", "message": ""} - self.s.socket.receive = MagicMock(return_value=ask_request) - self.s.socket.send = MagicMock() - self.s.exit_server_loop = True - with self.assertRaises(SystemExit): - self.s.serve() - assert len(self.s.queue) == 0 - - def test_replay(self): + async def test_replay(self): exp_config = """ [common] lb = [0] @@ -341,15 +325,14 @@ def test_replay(self): } exit_request = {"message": "", "type": "exit"} - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - self.s.handle_request(ask_request) - self.s.handle_request(tell_request) + await self.mock_client(ask_request) + await self.mock_client(tell_request) - self.s.handle_request(exit_request) + await self.mock_client(exit_request) - socket = server.sockets.PySocket(port=0) - serv = server.AEPsychServer(socket=socket, database_path=self.db_path) + serv = server.AEPsychServer(database_path=self.db_path) exp_ids = [rec.unique_id for rec in serv.db.get_master_records()] serv.replay(exp_ids[-1], skip_computations=True) @@ -359,7 +342,7 @@ def test_replay(self): self.assertTrue(strat.finished) self.assertTrue(strat.x.shape[0] == 4) - def test_string_parameter(self): + async def test_string_parameter(self): string_config = """ [common] parnames = [x, y, z] @@ -405,16 +388,17 @@ def test_string_parameter(self): "type": "tell", "message": {"config": {"x": [0.5], "y": ["blue"], "z": [50]}, "outcome": 1}, } - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - response = self.s.handle_request(ask_request) + response = await self.mock_client(ask_request) + response = json.loads(response) self.assertTrue(response["config"]["y"][0] == "blue") - self.s.handle_request(tell_request) + await self.mock_client(tell_request) self.assertTrue(len(self.s.strat.lb) == 2) self.assertTrue(len(self.s.strat.ub) == 2) - def test_metadata(self): + async def test_metadata(self): setup_request = { "type": "setup", "version": "0.01", @@ -425,10 +409,10 @@ def test_metadata(self): "type": "tell", "message": {"config": {"x": [0.5]}, "outcome": 1}, } - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - self.s.handle_request(ask_request) - self.s.handle_request(tell_request) + await self.mock_client(ask_request) + await self.mock_client(tell_request) master_record = self.s.db.get_master_records()[-1] extra_metadata = json.loads(master_record.extra_metadata) @@ -443,7 +427,7 @@ def test_metadata(self): self.assertTrue(extra_metadata["extra"] == "data that is arbitrary") self.assertTrue("experiment_id" not in extra_metadata) - def test_extension_server(self): + async def test_extension_server(self): extension_path = Path(__file__).parent.parent.parent extension_path = extension_path / "extensions_example" / "new_objects.py" @@ -470,8 +454,8 @@ def test_extension_server(self): "message": {"config_str": config_str}, } - with self.assertLogs(level=logging.INFO) as logs: - self.s.handle_request(setup_request) + with self.assertLogs() as logs: + await self.mock_client(setup_request) outputs = ";".join(logs.output) self.assertTrue(str(extension_path) in outputs) @@ -481,6 +465,184 @@ def test_extension_server(self): self.assertTrue(one == 1) self.assertTrue(strat.generator._base_obj.__class__.__name__ == "OnesGenerator") + async def test_receive(self): + """test_receive - verifies the receive is working when server receives unexpected messages""" + + message1 = b"\x16\x03\x01\x00\xaf\x01\x00\x00\xab\x03\x03\xa9\x80\xcc" # invalid message + message2 = b"\xec\xec\x14M\xfb\xbd\xac\xe7jF\xbe\xf9\x9bM\x92\x15b\xb5" # invalid message + message3 = {"message": {"target": "test request"}} # valid message + message_list = [message1, message2, message3] + + for i, message in enumerate(message_list): + if isinstance(message, dict): + send = json.dumps(message).encode() + else: + send = message + self.writer.write(send) + await self.writer.drain() + + response = await self.reader.read(1024 * 512) + response = response.decode() + response = json.loads(response) + if i != 2: + self.assertTrue("error" in response) # Very generic error for malformed + else: + self.assertTrue("KeyError" in response["error"]) # Specific error + + async def test_multi_client(self): + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + ask_request = {"type": "ask", "message": ""} + tell_request = { + "type": "tell", + "message": {"config": {"x": [0.5]}, "outcome": 1}, + "extra_info": {}, + } + + await self.mock_client(setup_request) + + # Create second client + reader2, writer2 = await asyncio.open_connection(self.ip, self.port) + + async def _mock_client2(request: Dict[str, Any]) -> Any: + writer2.write(json.dumps(request).encode()) + await writer2.drain() + + response = await reader2.read(1024 * 512) + return response.decode() + + for _ in range(2): # 2 loops should do it as we have 2 clients + tasks = [ + asyncio.create_task(self.mock_client(ask_request)), + asyncio.create_task(_mock_client2(ask_request)), + ] + await asyncio.gather(*tasks) + + tasks = [ + asyncio.create_task(self.mock_client(tell_request)), + asyncio.create_task(_mock_client2(tell_request)), + ] + await asyncio.gather(*tasks) + + self.assertTrue(self.s.strat.finished) + self.assertTrue(self.s.strat.x.numel() == 4) + self.assertTrue(self.s.clients_connected == 2) + + +class BackgroundServerTestCase(unittest.IsolatedAsyncioTestCase): + @property + def database_path(self): + return "./{}_test_server.db".format(str(uuid.uuid4().hex)) + + async def asyncSetUp(self): + self.ip = "127.0.0.1" + self.port = 5555 + + # setup logger + self.logger = utils_logging.getLogger("unittests") + + # random datebase path name without dashes + database_path = self.database_path + self.s = server.AEPsychBackgroundServer( + database_path=database_path, host=self.ip, port=self.port + ) + self.db_name = database_path.split("/")[1] + self.db_path = database_path + + # Writer will be made in tests + self.writer = None + + async def asyncTearDown(self): + # Stops the client + if self.writer is not None: + self.writer.close() + + time.sleep(0.1) + + # cleanup the db + db_path = Path(self.db_path) + try: + print(db_path) + db_path.unlink() + except PermissionError as e: + print("Failed to deleted database: ", e) + + async def test_background_server(self): + self.assertIsNone(self.s.background_process) + self.s.start() + self.assertTrue(self.s.background_process.is_alive()) + + # Make a client + try_again = True + attempts = 0 + while try_again: + try_again = False + attempts += 1 + try: + reader, self.writer = await asyncio.open_connection(self.ip, self.port) + except ConnectionRefusedError: + if attempts > 10: + raise ConnectionRefusedError + try_again = True + time.sleep(1) + + async def _mock_client(request: Dict[str, Any]) -> Any: + self.writer.write(json.dumps(request).encode()) + await self.writer.drain() + + response = await reader.read(1024 * 512) + return response.decode() + + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + ask_request = {"type": "ask", "message": ""} + tell_request = { + "type": "tell", + "message": {"config": {"x": [0.5]}, "outcome": 1}, + "extra_info": {}, + } + + await _mock_client(setup_request) + + expected_x = [0, 1, 2, 3] + expected_z = list(reversed(expected_x)) + expected_y = [x % 2 for x in expected_x] + i = 0 + while True: + response = await _mock_client(ask_request) + response = json.loads(response) + tell_request["message"]["config"]["x"] = [expected_x[i]] + tell_request["message"]["config"]["z"] = [expected_z[i]] + tell_request["message"]["outcome"] = expected_y[i] + tell_request["extra_info"]["e1"] = 1 + tell_request["extra_info"]["e2"] = 2 + i = i + 1 + await _mock_client(tell_request) + + if response["is_finished"]: + break + + self.s.stop() + self.assertIsNone(self.s.background_process) + + # Create a synchronous server to check db contents + s = server.AEPsychServer(database_path=self.db_path) + unique_id = s.db.get_master_records()[-1].unique_id + out_df = s.get_dataframe_from_replay(unique_id) + self.assertTrue((out_df.x == expected_x).all()) + self.assertTrue((out_df.z == expected_z).all()) + self.assertTrue((out_df.response == expected_y).all()) + self.assertTrue((out_df.e1 == [1] * 4).all()) + self.assertTrue((out_df.e2 == [2] * 4).all()) + self.assertTrue("post_mean" in out_df.columns) + self.assertTrue("post_var" in out_df.columns) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_datafetcher.py b/tests/test_datafetcher.py index ffe3f1980..af655ef36 100644 --- a/tests/test_datafetcher.py +++ b/tests/test_datafetcher.py @@ -96,10 +96,7 @@ def pre_seed_config( def setUp(self): # setup logger - server.logger = utils_logging.getLogger(logging.DEBUG, "logs") - - # random port - socket = server.sockets.PySocket(port=0) + server.logger = utils_logging.getLogger("logs") database_path = Path(__file__).parent / "test_databases" / "1000_outcome.db" @@ -109,7 +106,7 @@ def setUp(self): time.sleep(0.1) self.assertTrue(dst_db_path.is_file()) - self.s = server.AEPsychServer(socket=socket, database_path=dst_db_path) + self.s = server.AEPsychServer(database_path=dst_db_path) setup_message = { "type": "setup", @@ -125,8 +122,6 @@ def setUp(self): def tearDown(self): time.sleep(0.1) - - self.s.cleanup() self.s.db.delete_db() def test_create_from_config(self): diff --git a/tests/test_db.py b/tests/test_db.py index ec44dca58..e5f2cf2d3 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -31,11 +31,6 @@ def tearDown(self): time.sleep(0.1) self._database.delete_db() - def test_db_create(self): - engine = self._database.get_engine() - self.assertIsNotNone(engine) - self.assertIsNotNone(self._database._engine) - def test_record_setup_basic(self): master_table = self._database.record_setup( description="test description", @@ -115,7 +110,8 @@ def test_update_db(self): name="test name", request={"test": "this is a test request"}, ) - test_database._session.rollback() + with test_database.session() as session: + session.rollback() test_database.perform_updates() # retry adding rows @@ -174,38 +170,32 @@ def test_update_db_with_raw_data_tables(self): outcome_dict_expected[i]["outcome_1"] = outcomes[i - 1][1] # Check that the number of entries in each table is correct - n_iterations = ( - test_database.get_engine() - .execute("SELECT COUNT(*) FROM raw_data") - .fetchone()[0] - ) + n_iterations = test_database.session.execute( + "SELECT COUNT(*) FROM raw_data" + ).fetchone()[0] self.assertEqual(n_iterations, 7) - n_params = ( - test_database.get_engine() - .execute("SELECT COUNT(*) FROM param_data") - .fetchone()[0] - ) + n_params = test_database.session.execute( + "SELECT COUNT(*) FROM param_data" + ).fetchone()[0] self.assertEqual(n_params, 28) - n_outcomes = ( - test_database.get_engine() - .execute("SELECT COUNT(*) FROM outcome_data") - .fetchone()[0] - ) + n_outcomes = test_database.session.execute( + "SELECT COUNT(*) FROM outcome_data" + ).fetchone()[0] self.assertEqual(n_outcomes, 14) # Check that the data is correct - param_data = ( - test_database.get_engine().execute("SELECT * FROM param_data").fetchall() - ) + param_data = test_database.session.execute( + "SELECT * FROM param_data" + ).fetchall() param_dict = {x: {} for x in range(1, 8)} for param in param_data: param_dict[param.iteration_id][param.param_name] = float(param.param_value) self.assertEqual(param_dict, param_dict_expected) - outcome_data = ( - test_database.get_engine().execute("SELECT * FROM outcome_data").fetchall() - ) + outcome_data = test_database.session.execute( + "SELECT * FROM outcome_data" + ).fetchall() outcome_dict = {x: {} for x in range(1, 8)} for outcome in outcome_data: outcome_dict[outcome.iteration_id][outcome.outcome_name] = ( @@ -215,13 +205,9 @@ def test_update_db_with_raw_data_tables(self): self.assertEqual(outcome_dict, outcome_dict_expected) # Check if we have the extra_data column - pragma = ( - test_database.get_engine() - .execute( - "SELECT * FROM pragma_table_info('raw_data') WHERE name='extra_data'" - ) - .fetchall() - ) + pragma = test_database.session.execute( + "SELECT * FROM pragma_table_info('raw_data') WHERE name='extra_data'" + ).fetchall() self.assertTrue(len(pragma) == 1) # Make sure that update is no longer required @@ -245,10 +231,6 @@ def test_update_db_with_raw_extra_data(self): # open the new db test_database = db.Database(db_path=dst_db_path.as_posix(), update=False) - replay_tells = [ - row for row in test_database.get_replay_for(1) if row.message_type == "tell" - ] - # Make sure that update is required self.assertTrue(test_database.is_update_required())