Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ requirements:
- pydicom
- tensorflow # [not win]
- tensorflow-probability # [not win]
- tf-keras # [not win]

test:
source_files:
Expand Down
37 changes: 14 additions & 23 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ jobs:
with:
fetch-depth: 0

- uses: astral-sh/setup-uv@v7

- name: Build SDist
run: pipx run build --sdist
run: uv build --sdist

- uses: actions/upload-artifact@v4
with:
Expand All @@ -40,13 +42,15 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-13, macos-14]
os: [ubuntu-latest, windows-latest, macos-15-intel, macos-14]
fail-fast: true
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0

- uses: astral-sh/setup-uv@v7

# Use intel mpi on windows
- uses: mpi4py/setup-mpi@v1
if: ${{ contains(matrix.os, 'windows') }}
Expand Down Expand Up @@ -83,13 +87,8 @@ jobs:
cmake --build build --parallel
cmake --install build

- uses: actions/setup-python@v5

- name: Install cibuildwheel
run: python -m pip install cibuildwheel

- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
run: uvx cibuildwheel --output-dir wheelhouse

- uses: actions/upload-artifact@v4
with:
Expand All @@ -104,30 +103,23 @@ jobs:
shell: bash -leo pipefail {0}
strategy:
matrix:
os: [ ubuntu-latest, macos-13, macos-latest ]
python-version: [ '3.9', '3.10', '3.11', '3.12' ]
os: [ ubuntu-latest, macos-15-intel, macos-latest ]
python-version: [ '3.10', '3.11', '3.12' ]
fail-fast: false
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0

- uses: astral-sh/setup-uv@v7

# Can't figure out a way to get the package version from setuptools_scm inside the conda build
# We need to install setuptools_scm, call it as a module, and store the version in an environment variable
- name: Run setuptools_scm to get package version and store in environment variable BRAINIAK_VERSION (Linux\Mac)
if: ${{ !contains(matrix.os, 'windows') }}
# We need to call it as a module and store the version in an environment variable
- name: Run setuptools_scm to get package version and store in environment variable BRAINIAK_VERSION
run: |
python -m pip install setuptools_scm
export BRAINIAK_VERSION=$(python -m setuptools_scm)
BRAINIAK_VERSION=$(uvx --from setuptools-scm python -m setuptools_scm)
echo "BRAINIAK_VERSION=${BRAINIAK_VERSION}" >> "$GITHUB_ENV"

- name: Run setuptools_scm to get package version and store in environment variable BRAINIAK_VERSION (Windows)
if: ${{ contains(matrix.os, 'windows') }}
run: |
python -m pip install setuptools_scm
set BRAINIAK_VERSION=$(python -m setuptools_scm)
echo "BRAINIAK_VERSION=${BRAINIAK_VERSION}" >> "$GITHUB_ENV"

- name: Setup micromamba and boa
uses: mamba-org/setup-micromamba@v1
with:
Expand Down Expand Up @@ -213,4 +205,3 @@ jobs:
echo "Uploading $file"
anaconda upload "$file"
done

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ before-all = [

[tool.cibuildwheel.linux.environment]
PATH = "/usr/lib64/mpich/bin:$PATH"
LD_LIBRARY_PATH = "/usr/lib64/mpich/lib:$LD_LIBRARY_PATH"

[tool.coverage.run]
source = ["brainiak"]
Expand Down
10 changes: 6 additions & 4 deletions src/brainiak/factoranalysis/htfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,17 @@ def _map_update_posterior(self):
from_sym_2_tri(posterior_cov)

# widths
prior_width_mean_var = prior_widths_mean_var[k].item()
prior_width = prior_widths[k].item()
common = 1.0 /\
(prior_widths_mean_var[k] + self.global_widths_var_scaled)
(prior_width_mean_var + self.global_widths_var_scaled)
observation_mean = np.mean(next_widths)
tmp = common * self.global_widths_var_scaled
self.global_posterior_[self.map_offset[1].item() + k] = \
prior_widths_mean_var[k] * common * observation_mean +\
tmp * prior_widths[k]
prior_width_mean_var * common * observation_mean +\
tmp * prior_width
self.global_posterior_[self.map_offset[3].item() + k] = \
prior_widths_mean_var[k] * tmp
prior_width_mean_var * tmp

return self

Expand Down
75 changes: 45 additions & 30 deletions tests/pytest_mpiexec_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,44 @@ def consolidate_reports(nodeid, reports, style=ReportStyle.first_failure):
return reports


def _as_text(output):
if isinstance(output, bytes):
return output.decode("utf8", "replace")
return output


def _collect_reports(reportlog_dir, n):
reports = {}
for rank in range(n):
reportlog_file = os.path.join(reportlog_dir, f"reportlog-{rank}.jsonl")
if os.path.exists(reportlog_file):
with open(reportlog_file) as f:
for line in f:
report = json.loads(line)
if report["$report_type"] != "TestReport":
continue
report["_mpi_rank"] = rank
nodeid = report["nodeid"]
reports.setdefault(nodeid, []).append(report)

for nodeid, report_list in reports.items():
reports[nodeid] = consolidate_reports(
nodeid, report_list, REPORT_STYLE)

return reports


def _replay_reports(item, reports):
for report in chain(*reports.values()):
if report["$report_type"] == "TestReport":
# reconstruct and redisplay the report
r = item.config.hook.pytest_report_from_serializable(
config=item.config, data=report
)
item.config.hook.pytest_runtest_logreport(
config=item.config, report=r)


def mpi_runtest(item):
"""Replacement for runtest

Expand Down Expand Up @@ -261,50 +299,27 @@ def mpi_runtest(item):
timeout=timeout,
)
except subprocess.TimeoutExpired as e:
reports = _collect_reports(reportlog_dir, n)
if reports:
_replay_reports(item, reports)
if e.stdout:
item.add_report_section(
"mpiexec pytest", "stdout",
e.stdout.decode("utf8", "replace")
_as_text(e.stdout)
)
if e.stderr:
item.add_report_section(
"mpiexec pytest", "stderr",
e.stderr.decode("utf8", "replace")
_as_text(e.stderr)
)
pytest.fail(
f"mpi test did not complete in {timeout} seconds",
pytrace=False,
)

# Collect logs from all ranks
reports = {}
for rank in range(n):
reportlog_file = os.path.join(reportlog_dir,
f"reportlog-{rank}.jsonl")
if os.path.exists(reportlog_file):
with open(reportlog_file) as f:
for line in f:
report = json.loads(line)
if report["$report_type"] != "TestReport":
continue
report["_mpi_rank"] = rank
nodeid = report["nodeid"]
reports.setdefault(nodeid, []).append(report)

for nodeid, report_list in reports.items():
# consolidate reports according to config
reports[nodeid] = consolidate_reports(
nodeid, report_list, REPORT_STYLE)

# collect report items for the test
for report in chain(*reports.values()):
if report["$report_type"] == "TestReport":
# reconstruct and redisplay the report
r = item.config.hook.pytest_report_from_serializable(
config=item.config, data=report
)
item.config.hook.pytest_runtest_logreport(
config=item.config, report=r)
reports = _collect_reports(reportlog_dir, n)
_replay_reports(item, reports)

if p.returncode or not reports:
if p.stdout:
Expand Down
Loading