diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 938704b1..424506ff 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -63,6 +63,7 @@ requirements: - pydicom - tensorflow # [not win] - tensorflow-probability # [not win] + - tf-keras # [not win] test: source_files: diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 714f1911..f59149fa 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -27,8 +27,10 @@ jobs: with: fetch-depth: 0 + - uses: astral-sh/setup-uv@v7 + - name: Build SDist - run: pipx run build --sdist + run: uv build --sdist - uses: actions/upload-artifact@v4 with: @@ -40,13 +42,15 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-13, macos-14] + os: [ubuntu-latest, windows-latest, macos-15-intel, macos-14] fail-fast: true steps: - uses: actions/checkout@v4 with: fetch-depth: 0 + - uses: astral-sh/setup-uv@v7 + # Use intel mpi on windows - uses: mpi4py/setup-mpi@v1 if: ${{ contains(matrix.os, 'windows') }} @@ -83,13 +87,8 @@ jobs: cmake --build build --parallel cmake --install build - - uses: actions/setup-python@v5 - - - name: Install cibuildwheel - run: python -m pip install cibuildwheel - - name: Build wheels - run: python -m cibuildwheel --output-dir wheelhouse + run: uvx cibuildwheel --output-dir wheelhouse - uses: actions/upload-artifact@v4 with: @@ -104,30 +103,23 @@ jobs: shell: bash -leo pipefail {0} strategy: matrix: - os: [ ubuntu-latest, macos-13, macos-latest ] - python-version: [ '3.9', '3.10', '3.11', '3.12' ] + os: [ ubuntu-latest, macos-15-intel, macos-latest ] + python-version: [ '3.10', '3.11', '3.12' ] fail-fast: false steps: - uses: actions/checkout@v4 with: fetch-depth: 0 + - uses: astral-sh/setup-uv@v7 + # Can't figure out a way to get the package version from setuptools_scm inside the conda build - # We need to install setuptools_scm, call it as a module, and store the version in an environment variable - - name: Run setuptools_scm to get package version and store in environment variable BRAINIAK_VERSION (Linux\Mac) - if: ${{ !contains(matrix.os, 'windows') }} + # We need to call it as a module and store the version in an environment variable + - name: Run setuptools_scm to get package version and store in environment variable BRAINIAK_VERSION run: | - python -m pip install setuptools_scm - export BRAINIAK_VERSION=$(python -m setuptools_scm) + BRAINIAK_VERSION=$(uvx --from setuptools-scm python -m setuptools_scm) echo "BRAINIAK_VERSION=${BRAINIAK_VERSION}" >> "$GITHUB_ENV" - - name: Run setuptools_scm to get package version and store in environment variable BRAINIAK_VERSION (Windows) - if: ${{ contains(matrix.os, 'windows') }} - run: | - python -m pip install setuptools_scm - set BRAINIAK_VERSION=$(python -m setuptools_scm) - echo "BRAINIAK_VERSION=${BRAINIAK_VERSION}" >> "$GITHUB_ENV" - - name: Setup micromamba and boa uses: mamba-org/setup-micromamba@v1 with: @@ -213,4 +205,3 @@ jobs: echo "Uploading $file" anaconda upload "$file" done - diff --git a/pyproject.toml b/pyproject.toml index c32bcbab..9cfd856f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,7 @@ before-all = [ [tool.cibuildwheel.linux.environment] PATH = "/usr/lib64/mpich/bin:$PATH" +LD_LIBRARY_PATH = "/usr/lib64/mpich/lib:$LD_LIBRARY_PATH" [tool.coverage.run] source = ["brainiak"] diff --git a/src/brainiak/factoranalysis/htfa.py b/src/brainiak/factoranalysis/htfa.py index a177cc48..c4706791 100644 --- a/src/brainiak/factoranalysis/htfa.py +++ b/src/brainiak/factoranalysis/htfa.py @@ -328,15 +328,17 @@ def _map_update_posterior(self): from_sym_2_tri(posterior_cov) # widths + prior_width_mean_var = prior_widths_mean_var[k].item() + prior_width = prior_widths[k].item() common = 1.0 /\ - (prior_widths_mean_var[k] + self.global_widths_var_scaled) + (prior_width_mean_var + self.global_widths_var_scaled) observation_mean = np.mean(next_widths) tmp = common * self.global_widths_var_scaled self.global_posterior_[self.map_offset[1].item() + k] = \ - prior_widths_mean_var[k] * common * observation_mean +\ - tmp * prior_widths[k] + prior_width_mean_var * common * observation_mean +\ + tmp * prior_width self.global_posterior_[self.map_offset[3].item() + k] = \ - prior_widths_mean_var[k] * tmp + prior_width_mean_var * tmp return self diff --git a/tests/pytest_mpiexec_plugin.py b/tests/pytest_mpiexec_plugin.py index 6ee27541..b82f6043 100644 --- a/tests/pytest_mpiexec_plugin.py +++ b/tests/pytest_mpiexec_plugin.py @@ -219,6 +219,44 @@ def consolidate_reports(nodeid, reports, style=ReportStyle.first_failure): return reports +def _as_text(output): + if isinstance(output, bytes): + return output.decode("utf8", "replace") + return output + + +def _collect_reports(reportlog_dir, n): + reports = {} + for rank in range(n): + reportlog_file = os.path.join(reportlog_dir, f"reportlog-{rank}.jsonl") + if os.path.exists(reportlog_file): + with open(reportlog_file) as f: + for line in f: + report = json.loads(line) + if report["$report_type"] != "TestReport": + continue + report["_mpi_rank"] = rank + nodeid = report["nodeid"] + reports.setdefault(nodeid, []).append(report) + + for nodeid, report_list in reports.items(): + reports[nodeid] = consolidate_reports( + nodeid, report_list, REPORT_STYLE) + + return reports + + +def _replay_reports(item, reports): + for report in chain(*reports.values()): + if report["$report_type"] == "TestReport": + # reconstruct and redisplay the report + r = item.config.hook.pytest_report_from_serializable( + config=item.config, data=report + ) + item.config.hook.pytest_runtest_logreport( + config=item.config, report=r) + + def mpi_runtest(item): """Replacement for runtest @@ -261,15 +299,18 @@ def mpi_runtest(item): timeout=timeout, ) except subprocess.TimeoutExpired as e: + reports = _collect_reports(reportlog_dir, n) + if reports: + _replay_reports(item, reports) if e.stdout: item.add_report_section( "mpiexec pytest", "stdout", - e.stdout.decode("utf8", "replace") + _as_text(e.stdout) ) if e.stderr: item.add_report_section( "mpiexec pytest", "stderr", - e.stderr.decode("utf8", "replace") + _as_text(e.stderr) ) pytest.fail( f"mpi test did not complete in {timeout} seconds", @@ -277,34 +318,8 @@ def mpi_runtest(item): ) # Collect logs from all ranks - reports = {} - for rank in range(n): - reportlog_file = os.path.join(reportlog_dir, - f"reportlog-{rank}.jsonl") - if os.path.exists(reportlog_file): - with open(reportlog_file) as f: - for line in f: - report = json.loads(line) - if report["$report_type"] != "TestReport": - continue - report["_mpi_rank"] = rank - nodeid = report["nodeid"] - reports.setdefault(nodeid, []).append(report) - - for nodeid, report_list in reports.items(): - # consolidate reports according to config - reports[nodeid] = consolidate_reports( - nodeid, report_list, REPORT_STYLE) - - # collect report items for the test - for report in chain(*reports.values()): - if report["$report_type"] == "TestReport": - # reconstruct and redisplay the report - r = item.config.hook.pytest_report_from_serializable( - config=item.config, data=report - ) - item.config.hook.pytest_runtest_logreport( - config=item.config, report=r) + reports = _collect_reports(reportlog_dir, n) + _replay_reports(item, reports) if p.returncode or not reports: if p.stdout: