Skip to content

Commit

Permalink
Adds HasPropWithValue Pickler (#7692)
Browse files Browse the repository at this point in the history
* Adds HasPropWithValue Pickler

* Revert changes

* Resolve review comments

* more comprehensive testing

---------

Co-authored-by: Greg Landrum <greg.landrum@gmail.com>
  • Loading branch information
bp-kelley and greglandrum authored Aug 26, 2024
1 parent ce4a288 commit fa0463a
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 16 deletions.
51 changes: 51 additions & 0 deletions Code/GraphMol/MolPickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ QueryDetails getQueryDetails(const Query<int, T const *, true> *query) {
static_cast<const SetQuery<int, T const *, true> *>(query)->endSet());
return QueryDetails(
std::make_tuple(MolPickler::QUERY_SET, std::move(tset)));
} else if (auto q = dynamic_cast<const HasPropWithValueQueryBase*>(query)) {
return std::make_tuple(MolPickler::QUERY_PROPERTY_WITH_VALUE, q->getPair(), q->getTolerance());
} else {
throw MolPicklerException("do not know how to pickle part of the query.");
}
Expand Down Expand Up @@ -461,6 +463,14 @@ void pickleQuery(std::ostream &ss, const Query<int, T const *, true> *query) {
const auto &pval = std::get<1>(v);
streamWrite(ss, MolPickler::QUERY_VALUE, pval);
} break;
case 6: {
auto v =
boost::get<std::tuple<MolPickler::Tags, Dict::Pair, double>>(qdetails);
streamWrite(ss, std::get<0>(v));
// The tolerance is pickled first as we can't pickle the Dict::Pair with the QUERY_VALUE tag
streamWrite(ss, MolPickler::QUERY_VALUE, std::get<2>(v));
streamWriteProp(ss, std::get<1>(v), MolPickler::getCustomPropHandlers());
} break;
default:
throw MolPicklerException(
"do not know how to pickle part of the query.");
Expand Down Expand Up @@ -609,6 +619,47 @@ Query<int, T const *, true> *buildBaseQuery(std::istream &ss, T const *owner,
streamRead(ss, propName, version);
res = makeHasPropQuery<T>(propName);
} break;
case MolPickler::QUERY_PROPERTY_WITH_VALUE: {
streamRead(ss, tag, version);
if (tag != MolPickler::QUERY_VALUE) {
throw MolPicklerException(
"Bad pickle format: QUERY_VALUE tag not found.");
}
double tolerance{0.0};
streamRead(ss, tolerance, version);
Dict::Pair pair;
bool hasNonPod = false;
streamReadProp(ss, pair, hasNonPod, MolPickler::getCustomPropHandlers());
switch (pair.val.getTag()) {
case RDTypeTag::IntTag:
res = makePropQuery<T, int>(pair.key, rdvalue_cast<int>(pair.val), tolerance);
break;
case RDTypeTag::UnsignedIntTag:
res = makePropQuery<T, unsigned int>(pair.key, rdvalue_cast<unsigned int>(pair.val), tolerance);
break;
case RDTypeTag::BoolTag:
res = makePropQuery<T, bool>(pair.key, rdvalue_cast<bool>(pair.val), tolerance);
break;
case RDTypeTag::FloatTag:
res = makePropQuery<T, float>(pair.key, rdvalue_cast<float>(pair.val), tolerance);
break;
case RDTypeTag::DoubleTag:
res = makePropQuery<T, double>(pair.key, rdvalue_cast<double>(pair.val), tolerance);
break;
case RDTypeTag::StringTag:
res = makePropQuery<T, std::string>(pair.key, rdvalue_cast<std::string>(pair.val), tolerance);
break;
case RDTypeTag::AnyTag: {
if(rdvalue_is<ExplicitBitVect>(pair.val)) {
res = makePropQuery<T, ExplicitBitVect>(pair.key, rdvalue_cast<ExplicitBitVect>(pair.val), tolerance);
} else {
throw MolPicklerException("unknown query-type tag encountered");
}
} break;
}
// hasNonPod should be false for now...

} break;
default:
throw MolPicklerException("unknown query-type tag encountered");
}
Expand Down
4 changes: 3 additions & 1 deletion Code/GraphMol/MolPickler.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class RDKIT_GRAPHMOL_EXPORT MolPickler {
BEGINFASTFIND,
BEGINFINDOTHERORUNKNOWN,
QUERY_PROPERTY,
QUERY_PROPERTY_WITH_VALUE,
// add new entries above here
INVALID_TAG = 255
} Tags;
Expand Down Expand Up @@ -312,7 +313,8 @@ using QueryDetails = boost::variant<
std::tuple<MolPickler::Tags, int32_t, int32_t>,
std::tuple<MolPickler::Tags, int32_t, int32_t, int32_t, char>,
std::tuple<MolPickler::Tags, std::set<int32_t>>,
std::tuple<MolPickler::Tags, std::string>>;
std::tuple<MolPickler::Tags, std::string>,
std::tuple<MolPickler::Tags, Dict::Pair, double>>;
// clang-format on
template <class T>
QueryDetails getQueryDetails(const Queries::Query<int, T const *, true> *query);
Expand Down
4 changes: 2 additions & 2 deletions Code/GraphMol/QueryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ void finalizeQueryFromDescription(
// don't need to do anything here because the classes
// automatically have everything set
} else if (descr == "AtomAnd" || descr == "AtomOr" || descr == "AtomXor" ||
descr == "HasProp") {
descr == "HasProp" || descr == "HasPropWithValue") {
// don't need to do anything here because the classes
// automatically have everything set
} else {
Expand Down Expand Up @@ -1162,7 +1162,7 @@ void finalizeQueryFromDescription(
query->setDataFunc(nullDataFun);
query->setMatchFunc(nullQueryFun);
} else if (descr == "BondAnd" || descr == "BondOr" || descr == "BondXor" ||
descr == "HasProp") {
descr == "HasProp" || descr == "HasPropWithValue") {
// don't need to do anything here because the classes
// automatically have everything set
} else {
Expand Down
57 changes: 46 additions & 11 deletions Code/GraphMol/QueryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
\brief Includes a bunch of functionality for handling Atom and Bond queries.
*/
#include <RDGeneral/export.h>
#include <RDGeneral/Dict.h>
#ifndef RD_QUERY_OPS_H
#define RD_QUERY_OPS_H

Expand Down Expand Up @@ -860,13 +861,21 @@ Queries::EqualityQuery<int, const Target *, true> *makeHasPropQuery(
return new HasPropQuery<const Target *>(property);
}


// ! Query whether an atom has a property with a value
class HasPropWithValueQueryBase {
public:
virtual Dict::Pair getPair() const = 0;
virtual double getTolerance() const = 0;
};


template <class TargetPtr, class T>
class HasPropWithValueQuery
: public Queries::EqualityQuery<int, TargetPtr, true> {
: public HasPropWithValueQueryBase, public Queries::EqualityQuery<int, TargetPtr, true> {
std::string propname;
T val;
T tolerance;
double tolerance{0.0};

public:
HasPropWithValueQuery()
Expand All @@ -875,6 +884,15 @@ class HasPropWithValueQuery
this->setDescription("HasPropWithValue");
this->setDataFunc(0);
}

Dict::Pair getPair() const override {
return Dict::Pair(propname, val);
}

double getTolerance() const override {
return tolerance;
}

explicit HasPropWithValueQuery(std::string prop, const T &v,
const T &tol = 0.0)
: Queries::EqualityQuery<int, TargetPtr, true>(),
Expand All @@ -891,7 +909,7 @@ class HasPropWithValueQuery
if (res) {
try {
T atom_val = what->template getProp<T>(propname);
res = Queries::queryCmp(atom_val, this->val, this->tolerance) == 0;
res = Queries::queryCmp(atom_val, this->val, static_cast<T>(this->tolerance)) == 0;
} catch (KeyErrorException &) {
res = false;
} catch (std::bad_any_cast &) {
Expand Down Expand Up @@ -928,9 +946,10 @@ class HasPropWithValueQuery
}
};


template <class TargetPtr>
class HasPropWithValueQuery<TargetPtr, std::string>
: public Queries::EqualityQuery<int, TargetPtr, true> {
: public HasPropWithValueQueryBase, public Queries::EqualityQuery<int, TargetPtr, true> {
std::string propname;
std::string val;

Expand All @@ -942,16 +961,23 @@ class HasPropWithValueQuery<TargetPtr, std::string>
this->setDataFunc(0);
}
explicit HasPropWithValueQuery(std::string prop, std::string v,
const std::string &tol = "")
const double /*tol*/ = 0.0)
: Queries::EqualityQuery<int, TargetPtr, true>(),
propname(std::move(prop)),
val(std::move(v)) {
RDUNUSED_PARAM(tol);
// default is to just do a number of rings query:
this->setDescription("HasPropWithValue");
this->setDataFunc(nullptr);
}

Dict::Pair getPair() const override {
return Dict::Pair(propname, val);
}

double getTolerance() const override {
return 0.0;
}

bool Match(const TargetPtr what) const override {
bool res = what->hasProp(propname);
if (res) {
Expand Down Expand Up @@ -995,12 +1021,13 @@ class HasPropWithValueQuery<TargetPtr, std::string>
}
};


template <class TargetPtr>
class HasPropWithValueQuery<TargetPtr, ExplicitBitVect>
: public Queries::EqualityQuery<int, TargetPtr, true> {
: public HasPropWithValueQueryBase, public Queries::EqualityQuery<int, TargetPtr, true> {
std::string propname;
ExplicitBitVect val;
float tol{0.0};
double tol{0.0};

public:
HasPropWithValueQuery()
Expand All @@ -1010,7 +1037,7 @@ class HasPropWithValueQuery<TargetPtr, ExplicitBitVect>
}

explicit HasPropWithValueQuery(std::string prop, const ExplicitBitVect &v,
float tol = 0.0)
double tol = 0.0)
: Queries::EqualityQuery<int, TargetPtr, true>(),
propname(std::move(prop)),
val(v),
Expand All @@ -1019,6 +1046,14 @@ class HasPropWithValueQuery<TargetPtr, ExplicitBitVect>
this->setDataFunc(nullptr);
}

Dict::Pair getPair() const override {
return Dict::Pair(propname, val);
}

double getTolerance() const override {
return tol;
}

bool Match(const TargetPtr what) const override {
bool res = what->hasProp(propname);
if (res) {
Expand Down Expand Up @@ -1066,14 +1101,14 @@ class HasPropWithValueQuery<TargetPtr, ExplicitBitVect>

template <class Target, class T>
Queries::EqualityQuery<int, const Target *, true> *makePropQuery(
const std::string &propname, const T &val, const T &tolerance = T()) {
const std::string &propname, const T &val, double tolerance = 0.0) {
return new HasPropWithValueQuery<const Target *, T>(propname, val, tolerance);
}

template <class Target>
Queries::EqualityQuery<int, const Target *, true> *makePropQuery(
const std::string &propname, const ExplicitBitVect &val,
float tolerance = 0.0) {
double tolerance = 0.0) {
return new HasPropWithValueQuery<const Target *, ExplicitBitVect>(
propname, val, tolerance);
}
Expand Down
Loading

0 comments on commit fa0463a

Please sign in to comment.