Skip to content
56 changes: 56 additions & 0 deletions mava/core_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,59 @@ def launch(
are primarily for debugging
name : name of the system
"""


class SystemBuilder(abc.ABC):
"""Abstract system builder."""

@abc.abstractmethod
def data_server(self) -> List[Any]:
"""Data server to store and serve transition data from and to system.

Returns:
System data server
"""

@abc.abstractmethod
def parameter_server(self) -> Any:
"""Parameter server to store and serve system network parameters.

Returns:
System parameter server
"""

@abc.abstractmethod
def executor(
self, executor_id: str, data_server_client: Any, parameter_server_client: Any
) -> Any:
"""Executor, a collection of agents in an environment to gather experience.

Args:
executor_id : id to identify the executor process for logging purposes
data_server_client : data server client for pushing transition data
parameter_server_client : parameter server client for pulling parameters
Returns:
System executor
"""

@abc.abstractmethod
def trainer(
self, trainer_id: str, data_server_client: Any, parameter_server_client: Any
) -> Any:
"""Trainer, a system process for updating agent specific network parameters.

Args:
trainer_id : id to identify the trainer process for logging purposes
data_server_client : data server client for pulling transition data
parameter_server_client : parameter server client for pushing parameters
Returns:
System trainer
"""

@abc.abstractmethod
def build(self) -> None:
"""Construct program nodes."""

@abc.abstractmethod
def launch(self) -> None:
"""Run the graph program."""
25 changes: 22 additions & 3 deletions mava/core_jax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

"""Tests for core Mava interfaces for Jax systems."""

from typing import Any
from typing import Any, List

import pytest

from mava.core_jax import BaseSystem
from mava.core_jax import BaseSystem, SystemBuilder


def test_exception_for_incomplete_child_class() -> None:
def test_exception_for_incomplete_child_system_class() -> None:
"""Test if error is thrown for missing abstract class overwrites."""
with pytest.raises(TypeError):

Expand All @@ -41,3 +41,22 @@ def configure(self, **kwargs: Any) -> None:
pass

TestIncompleteDummySystem() # type: ignore


def test_exception_for_incomplete_child_builder_class() -> None:
"""Test if error is thrown for missing abstract class overwrites."""
with pytest.raises(TypeError):

class TestIncompleteDummySystemBuilder(SystemBuilder):
def data_server(self) -> List[Any]:
pass

def executor(
self,
executor_id: str,
data_server_client: Any,
parameter_server_client: Any,
) -> Any:
pass

TestIncompleteDummySystemBuilder() # type: ignore