Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions extension/core_functions/aggregate/holistic/mad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,21 @@ struct MedianAbsoluteDeviationOperation : QuantileOperation {
}

template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask,
AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result,
idx_t ridx, const STATE *gstate) {
static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition,
const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result,
idx_t ridx) {
auto &state = *reinterpret_cast<STATE *>(l_state);
auto gstate = reinterpret_cast<const STATE *>(g_state);

D_ASSERT(partition.inputs);
const auto &inputs = *partition.inputs;
D_ASSERT(inputs.ColumnCount() == 1);
auto &data = state.GetOrCreateWindowCursor(inputs, partition.all_valid);
const auto &fmask = partition.filter_mask;

auto rdata = FlatVector::GetData<RESULT_TYPE>(result);

QuantileIncluded included(fmask, dmask);
QuantileIncluded<INPUT_TYPE> included(fmask, data);
const auto n = FrameSize(included, frames);

if (!n) {
Expand Down Expand Up @@ -263,7 +272,7 @@ AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const Logical
fun.bind = BindMAD;
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
#ifndef DUCKDB_SMALLER_BINARY
fun.window = AggregateFunction::UnaryWindow<STATE, INPUT_TYPE, TARGET_TYPE, OP>;
fun.window = OP::template Window<STATE, INPUT_TYPE, TARGET_TYPE>;
fun.window_init = OP::template WindowInit<STATE, INPUT_TYPE>;
#endif
return fun;
Expand Down
95 changes: 76 additions & 19 deletions extension/core_functions/aggregate/holistic/mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "duckdb/common/uhugeint.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/common/operator/comparison_operators.hpp"
#include "duckdb/common/types/column/column_data_collection.hpp"
#include "core_functions/aggregate/distributive_functions.hpp"
#include "core_functions/aggregate/holistic_functions.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
Expand Down Expand Up @@ -100,6 +101,17 @@ struct ModeState {
bool valid = false;
size_t count = 0;

//! The collection being read
const ColumnDataCollection *inputs;
//! The state used for reading the collection on this thread
ColumnDataScanState scan;
//! The data chunk paged into into
DataChunk page;
//! The data pointer
const KEY_TYPE *data = nullptr;
//! The validity mask
const ValidityMask *validity = nullptr;

~ModeState() {
if (frequency_map) {
delete frequency_map;
Expand All @@ -109,6 +121,43 @@ struct ModeState {
}
}

void InitializePage(const ColumnDataCollection &inputs) {
if (page.ColumnCount() == 0) {
this->inputs = &inputs;
inputs.InitializeScan(scan);
inputs.InitializeScanChunk(scan, page);
}
}

inline sel_t RowOffset(idx_t row_idx) const {
D_ASSERT(RowIsVisible(row_idx));
return UnsafeNumericCast<sel_t>(row_idx - scan.current_row_index);
}

inline bool RowIsVisible(idx_t row_idx) const {
return (row_idx < scan.next_row_index && scan.current_row_index <= row_idx);
}

inline idx_t Seek(idx_t row_idx) {
if (!RowIsVisible(row_idx)) {
D_ASSERT(inputs);
inputs->Seek(row_idx, scan, page);
data = FlatVector::GetData<KEY_TYPE>(page.data[0]);
validity = &FlatVector::Validity(page.data[0]);
}
return RowOffset(row_idx);
}

inline const KEY_TYPE &GetCell(idx_t row_idx) {
const auto offset = Seek(row_idx);
return data[offset];
}

inline bool RowIsValid(idx_t row_idx) {
const auto offset = Seek(row_idx);
return validity->RowIsValid(offset);
}

void Reset() {
if (frequency_map) {
frequency_map->clear();
Expand All @@ -118,7 +167,8 @@ struct ModeState {
valid = false;
}

void ModeAdd(const KEY_TYPE &key, idx_t row) {
void ModeAdd(idx_t row) {
const auto &key = GetCell(row);
auto &attr = (*frequency_map)[key];
auto new_count = (attr.count += 1);
if (new_count == 1) {
Expand All @@ -138,7 +188,8 @@ struct ModeState {
}
}

void ModeRm(const KEY_TYPE &key, idx_t frame) {
void ModeRm(idx_t frame) {
const auto &key = GetCell(frame);
auto &attr = (*frequency_map)[key];
auto old_count = attr.count;
nonzero -= size_t(old_count == 1);
Expand All @@ -164,16 +215,16 @@ struct ModeState {
}
};

template <typename STATE>
struct ModeIncluded {
inline explicit ModeIncluded(const ValidityMask &fmask_p, const ValidityMask &dmask_p)
: fmask(fmask_p), dmask(dmask_p) {
inline explicit ModeIncluded(const ValidityMask &fmask_p, STATE &state) : fmask(fmask_p), state(state) {
}

inline bool operator()(const idx_t &idx) const {
return fmask.RowIsValid(idx) && dmask.RowIsValid(idx);
return fmask.RowIsValid(idx) && state.RowIsValid(idx);
}
const ValidityMask &fmask;
const ValidityMask &dmask;
STATE &state;
};

template <typename TYPE_OP>
Expand Down Expand Up @@ -261,11 +312,9 @@ struct ModeFunction : TypedModeFunction<TYPE_OP> {
template <typename STATE, typename INPUT_TYPE>
struct UpdateWindowState {
STATE &state;
const INPUT_TYPE *data;
ModeIncluded &included;
ModeIncluded<STATE> &included;

inline UpdateWindowState(STATE &state, const INPUT_TYPE *data, ModeIncluded &included)
: state(state), data(data), included(included) {
inline UpdateWindowState(STATE &state, ModeIncluded<STATE> &included) : state(state), included(included) {
}

inline void Neither(idx_t begin, idx_t end) {
Expand All @@ -274,15 +323,15 @@ struct ModeFunction : TypedModeFunction<TYPE_OP> {
inline void Left(idx_t begin, idx_t end) {
for (; begin < end; ++begin) {
if (included(begin)) {
state.ModeRm(data[begin], begin);
state.ModeRm(begin);
}
}
}

inline void Right(idx_t begin, idx_t end) {
for (; begin < end; ++begin) {
if (included(begin)) {
state.ModeAdd(data[begin], begin);
state.ModeAdd(begin);
}
}
}
Expand All @@ -292,17 +341,25 @@ struct ModeFunction : TypedModeFunction<TYPE_OP> {
};

template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask,
AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result,
idx_t rid, const STATE *gstate) {
static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition,
const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result,
idx_t rid) {
auto &state = *reinterpret_cast<STATE *>(l_state);

D_ASSERT(partition.inputs);
const auto &inputs = *partition.inputs;
D_ASSERT(inputs.ColumnCount() == 1);
state.InitializePage(inputs);
const auto &fmask = partition.filter_mask;

auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
auto &rmask = FlatVector::Validity(result);
auto &prevs = state.prevs;
if (prevs.empty()) {
prevs.resize(1);
}

ModeIncluded included(fmask, dmask);
ModeIncluded<STATE> included(fmask, state);

if (!state.frequency_map) {
state.frequency_map = TYPE_OP::CreateEmpty(Allocator::DefaultAllocator());
Expand All @@ -315,13 +372,13 @@ struct ModeFunction : TypedModeFunction<TYPE_OP> {
for (const auto &frame : frames) {
for (auto i = frame.start; i < frame.end; ++i) {
if (included(i)) {
state.ModeAdd(data[i], i);
state.ModeAdd(i);
}
}
}
} else {
using Updater = UpdateWindowState<STATE, INPUT_TYPE>;
Updater updater(state, data, included);
Updater updater(state, included);
AggregateExecutor::IntersectFrames(prevs, frames, updater);
}

Expand Down Expand Up @@ -380,7 +437,7 @@ AggregateFunction GetTypedModeFunction(const LogicalType &type) {
using STATE = ModeState<INPUT_TYPE, TYPE_OP>;
using OP = ModeFunction<TYPE_OP>;
auto func = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP>(type, type);
func.window = AggregateFunction::UnaryWindow<STATE, INPUT_TYPE, INPUT_TYPE, OP>;
func.window = OP::template Window<STATE, INPUT_TYPE, INPUT_TYPE>;
return func;
}

Expand Down
42 changes: 30 additions & 12 deletions extension/core_functions/aggregate/holistic/quantile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,19 @@ struct QuantileScalarOperation : public QuantileOperation {
}

template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask,
AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result,
idx_t ridx, const STATE *gstate) {
QuantileIncluded included(fmask, dmask);
static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition,
const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result,
idx_t ridx) {
auto &state = *reinterpret_cast<STATE *>(l_state);
auto gstate = reinterpret_cast<const STATE *>(g_state);

D_ASSERT(partition.inputs);
const auto &inputs = *partition.inputs;
D_ASSERT(inputs.ColumnCount() == 1);
auto &data = state.GetOrCreateWindowCursor(inputs, partition.all_valid);
const auto &fmask = partition.filter_mask;

QuantileIncluded<INPUT_TYPE> included(fmask, data);
const auto n = FrameSize(included, frames);

D_ASSERT(aggr_input_data.bind_data);
Expand Down Expand Up @@ -305,13 +314,22 @@ struct QuantileListOperation : QuantileOperation {
}

template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask,
AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &list,
idx_t lidx, const STATE *gstate) {
static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition,
const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &list,
idx_t lidx) {
auto &state = *reinterpret_cast<STATE *>(l_state);
auto gstate = reinterpret_cast<const STATE *>(g_state);

D_ASSERT(partition.inputs);
const auto &inputs = *partition.inputs;
D_ASSERT(inputs.ColumnCount() == 1);
auto &data = state.GetOrCreateWindowCursor(inputs, partition.all_valid);
const auto &fmask = partition.filter_mask;

D_ASSERT(aggr_input_data.bind_data);
auto &bind_data = aggr_input_data.bind_data->Cast<QuantileBindData>();

QuantileIncluded included(fmask, dmask);
QuantileIncluded<INPUT_TYPE> included(fmask, data);
const auto n = FrameSize(included, frames);

// Result is a constant LIST<RESULT_TYPE> with a fixed length
Expand Down Expand Up @@ -410,7 +428,7 @@ struct ScalarDiscreteQuantile {
using OP = QuantileScalarOperation<true>;
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP>(type, type);
#ifndef DUCKDB_SMALLER_BINARY
fun.window = AggregateFunction::UnaryWindow<STATE, INPUT_TYPE, INPUT_TYPE, OP>;
fun.window = OP::Window<STATE, INPUT_TYPE, INPUT_TYPE>;
fun.window_init = OP::WindowInit<STATE, INPUT_TYPE>;
#endif
return fun;
Expand Down Expand Up @@ -447,7 +465,7 @@ struct ListDiscreteQuantile {
auto fun = QuantileListAggregate<STATE, INPUT_TYPE, list_entry_t, OP>(type, type);
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
#ifndef DUCKDB_SMALLER_BINARY
fun.window = AggregateFunction::UnaryWindow<STATE, INPUT_TYPE, list_entry_t, OP>;
fun.window = OP::template Window<STATE, INPUT_TYPE, list_entry_t>;
fun.window_init = OP::template WindowInit<STATE, INPUT_TYPE>;
#endif
return fun;
Expand Down Expand Up @@ -538,7 +556,7 @@ struct ScalarContinuousQuantile {
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP>(input_type, target_type);
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
#ifndef DUCKDB_SMALLER_BINARY
fun.window = AggregateFunction::UnaryWindow<STATE, INPUT_TYPE, TARGET_TYPE, OP>;
fun.window = OP::template Window<STATE, INPUT_TYPE, TARGET_TYPE>;
fun.window_init = OP::template WindowInit<STATE, INPUT_TYPE>;
#endif
return fun;
Expand All @@ -553,7 +571,7 @@ struct ListContinuousQuantile {
auto fun = QuantileListAggregate<STATE, INPUT_TYPE, list_entry_t, OP>(input_type, target_type);
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
#ifndef DUCKDB_SMALLER_BINARY
fun.window = AggregateFunction::UnaryWindow<STATE, INPUT_TYPE, list_entry_t, OP>;
fun.window = OP::template Window<STATE, INPUT_TYPE, list_entry_t>;
fun.window_init = OP::template WindowInit<STATE, INPUT_TYPE>;
#endif
return fun;
Expand Down
Loading