from polyglotdb.query.base.func import Count
from polyglotdb.query.base.helper import key_for_cypher, value_for_cypher
from polyglotdb.query.base.results import BaseQueryResults
[docs]
class BaseQuery(object):
query_template = """{match}
{where}
{optional_match}
{with}
{return}"""
delete_template = """DETACH DELETE {alias}"""
aggregate_template = """RETURN {aggregates}{order_by}"""
distinct_template = """RETURN {columns}{order_by}{offset}{limit}"""
set_label_template = """{alias} {value}"""
remove_label_template = """{alias}{value}"""
set_property_template = """{alias}.{attribute} = {value}"""
def __init__(self, corpus, to_find):
self.corpus = corpus
self.to_find = to_find
self._criterion = []
self._columns = []
self._hidden_columns = []
self._order_by = []
self._group_by = []
self._aggregate = []
self._preload = []
self._cache = []
self._delete = False
self._set_labels = []
self._remove_labels = []
self._set_properties = {}
self._limit = None
self._offset = None
self.call_back = None
self.stop_check = None
def cache(self):
raise NotImplementedError
def required_nodes(self):
from polyglotdb.query.annotations.attributes.hierarchical import HierarchicalAnnotation
ns = {self.to_find}
tf_type = type(self.to_find)
for c in self._criterion:
for n in c.nodes:
if isinstance(n, HierarchicalAnnotation):
n.reset_anchor_node(self.to_find)
ns.update(x for x in c.nodes if type(x) is not tf_type)
for c in (
self._columns + self._hidden_columns + self._aggregate + self._preload + self._cache
):
for n in c.nodes:
if isinstance(n, HierarchicalAnnotation):
n.reset_anchor_node(self.to_find)
ns.update(x for x in c.nodes if type(x) is not tf_type and x.non_optional)
for c, _ in self._order_by:
for n in c.nodes:
if isinstance(n, HierarchicalAnnotation):
n.reset_anchor_node(self.to_find)
ns.update(x for x in c.nodes if type(x) is not tf_type and x.non_optional)
return ns
def optional_nodes(self):
required_nodes = self.required_nodes()
ns = set()
tf_type = type(self.to_find)
for c in self._columns + self._aggregate + self._preload + self._cache:
ns.update(x for x in c.nodes if type(x) is not tf_type and x not in required_nodes)
for c, _ in self._order_by:
ns.update(x for x in c.nodes if type(x) is not tf_type and x not in required_nodes)
return sorted(ns)
def clear_columns(self):
"""
Remove any columns specified. The default columns for any query
are the id of the token and the label of the type.
"""
self._columns = []
return self
def offset(self, number):
self._offset = number
return self
def filter(self, *args):
"""
Apply one or more filters to a query.
"""
from .elements import EqualClauseElement
for a in args:
for c in self._criterion:
if (
isinstance(c, EqualClauseElement)
and isinstance(a, EqualClauseElement)
and c.attribute.node == a.attribute.node
and c.attribute.label == a.attribute.label
):
c.value = a.value
break
else:
self._criterion.append(a)
return self
def columns(self, *args):
"""
Add one or more additional columns to the results.
Columns should be :class:`~polyglotdb.query.base.Attribute` objects.
"""
column_set = set(self._columns)
for c in args:
if c in column_set:
continue
else:
self._columns.append(c)
# column_set.add(c) # FIXME failing tests
return self
def group_by(self, *args):
"""
Specify one or more fields for how aggregates should be grouped.
"""
self._group_by.extend(args)
return self
def order_by(self, field, descending=False):
"""
Specify how the results of the query should be ordered.
Parameters
----------
field : Attribute
Determines what the ordering should be based on
descending : bool, defaults to False
Whether the order should be descending
"""
self._order_by.append((field, descending))
return self
def to_csv(self, path):
"""
Same as ``all``, but the results of the query are output to the
specified path as a CSV file.
"""
results = self.all()
if self.stop_check is not None and self.stop_check():
return
results.to_csv(path)
def count(self):
"""
Returns the number of rows in the query.
"""
self._aggregate = [Count()]
cypher = self.cypher()
value = self.corpus.execute_cypher(cypher, **self.cypher_params())
self._aggregate = []
return list(value[0].values())[0]
def aggregate(self, *args):
"""
Aggregate the results of the query by a grouping factor or overall.
Not specifying a ``group_by`` in the query will result in a single
result for the aggregate from the whole query.
"""
self._aggregate.extend(args)
cypher = self.cypher()
value = self.corpus.execute_cypher(cypher, **self.cypher_params())
if self._group_by or any(not x.collapsing for x in self._aggregate):
return list(value)
elif len(self._aggregate) > 1:
return list(value)[0]
else:
return list(list(value)[0].values())[0]
def preload(self, *args):
self._preload.extend(args)
return self
def limit(self, limit):
"""sets object limit to parameter limit"""
self._limit = limit
return self
def to_json(self):
data = {
"corpus_name": self.corpus.corpus_name,
"filters": [x.for_json() for x in self._criterion],
"columns": [x.for_json() for x in self._columns],
}
return data
def cypher(self):
"""
Generates a Cypher statement based on the query.
"""
kwargs = {
"match": "",
"optional_match": "",
"where": "",
"with": "",
"return": "",
}
# generate initial match strings
match_strings = set()
withs = set()
nodes = self.required_nodes()
for node in nodes:
if node.has_subquery:
continue
match_strings.add(node.for_match())
withs.update(node.withs)
kwargs["match"] = "MATCH " + ",\n".join(match_strings)
# generate main filters
properties = []
for c in self._criterion:
if c.in_subquery:
continue
properties.append(c.for_cypher())
if properties:
kwargs["where"] += "WHERE " + "\nAND ".join(properties)
optional_nodes = self.optional_nodes()
optional_match_strings = []
for node in optional_nodes:
if node.has_subquery:
continue
optional_match_strings.append(node.for_match())
withs.update(node.withs)
if optional_match_strings:
s = ""
for i, o in enumerate(optional_match_strings):
s += "OPTIONAL MATCH " + o + "\n"
kwargs["optional_match"] = s
# generate subqueries
with_statements = ["WITH " + ", ".join(withs)]
for node in nodes:
if not node.has_subquery:
continue
statement = node.subquery(withs, self._criterion)
with_statements.append(statement)
withs.update(node.withs)
for node in optional_nodes:
if not node.has_subquery:
continue
statement = node.subquery(withs, self._criterion, optional=True)
with_statements.append(statement)
withs.update(node.withs)
kwargs["with"] = "\n".join(with_statements)
kwargs["return"] = self.generate_return()
cypher = self.query_template.format(**kwargs)
return cypher
def create_subset(self, label):
self._set_labels.append(label)
self.corpus.execute_cypher(self.cypher(), **self.cypher_params())
self._set_labels = []
def remove_subset(self, label):
self._remove_labels.append(label)
self.corpus.execute_cypher(self.cypher(), **self.cypher_params())
self._remove_labels = []
def delete(self):
"""
Remove the results of a query from the graph. CAUTION: this is
irreversible.
"""
self._delete = True
self.corpus.execute_cypher(self.cypher(), **self.cypher_params())
def set_properties(self, **kwargs):
self._set_properties = {k: v for k, v in kwargs.items()}
self.corpus.execute_cypher(self.cypher(), **self.cypher_params())
self._set_properties = {}
def all(self):
return BaseQueryResults(self)
def get(self):
r = BaseQueryResults(self)
if len(r) > 1:
raise Exception("Can't use get on query with more than one result.")
return r[0]
def cypher_params(self):
from ..base.attributes import NodeAttribute
from ..base.complex import ComplexClause
from ..base.elements import NotSubsetClauseElement, SubsetClauseElement
params = {}
for c in self._criterion:
if isinstance(c, ComplexClause):
params.update(c.generate_params())
elif isinstance(c, (SubsetClauseElement, NotSubsetClauseElement)):
pass
else:
try:
if not isinstance(c.value, NodeAttribute):
params[c.cypher_value_string()[1:-1].replace("`", "")] = c.value
except AttributeError:
pass
return params
def generate_return(self):
"""
Generates final statement from query object, calling whichever one of the other generate statements is specified in the query obj
Parameters
----------
query : :class: `~polyglotdb.graph.GraphQuery`
a query object
Returns
-------
str
cypher formatted string
"""
if self._delete:
statement = self._generate_delete_return()
elif self._cache:
statement = self._generate_cache_return()
elif self._set_properties:
statement = self._generate_set_properties_return()
elif self._set_labels:
statement = self._generate_set_labels_return()
elif self._remove_labels:
statement = self._generate_remove_labels_return()
elif self._aggregate:
statement = self._generate_aggregate_return()
else:
statement = self._generate_distinct_return()
return statement
def _generate_delete_return(self):
kwargs = {}
kwargs["alias"] = self.to_find.alias
return_statement = self.delete_template.format(**kwargs)
return return_statement
def _generate_cache_return(self):
properties = []
for c in self._cache:
kwargs = {
"alias": c.node.cache_alias,
"attribute": c.output_alias,
"value": c.for_cypher(),
}
if c.label == "position":
kwargs["alias"] = self.to_find.alias
set_string = self.set_property_template.format(**kwargs)
properties.append(set_string)
return "SET {}".format(", ".join(properties))
def _generate_remove_labels_return(self):
remove_label_strings = []
kwargs = {}
kwargs["alias"] = self.to_find.alias
kwargs["value"] = ":" + ":".join(map(key_for_cypher, self._remove_labels))
remove_label_strings.append(self.remove_label_template.format(**kwargs))
return_statement = ""
if remove_label_strings:
if return_statement:
return_statement += "\nWITH {alias}\n".format(alias=self.to_find.alias)
return_statement += "\nREMOVE " + ", ".join(remove_label_strings)
return return_statement
def _generate_set_properties_return(self):
set_strings = []
for k, v in self._set_properties.items():
if v is None:
v = "NULL"
else:
v = value_for_cypher(v)
s = self.set_property_template.format(alias=self.to_find.alias, attribute=k, value=v)
set_strings.append(s)
return "SET " + ", ".join(set_strings)
def _generate_set_labels_return(self):
set_label_strings = []
kwargs = {}
kwargs["alias"] = self.to_find.alias
kwargs["value"] = ":" + ":".join(map(key_for_cypher, self._set_labels))
set_label_strings.append(self.set_label_template.format(**kwargs))
return "SET " + ", ".join(set_label_strings)
def _generate_aggregate_return(self):
kwargs = {
"order_by": self._generate_order_by(),
"limit": self._generate_limit(),
}
properties = []
for g in self._group_by:
properties.append(g.aliased_for_output())
if any(not x.collapsing for x in self._aggregate):
for c in self._columns:
properties.append(c.aliased_for_output())
if len(self._order_by) == 0 and len(self._group_by) > 0:
self._order_by.append((self._group_by[0], False))
for a in self._aggregate:
properties.append(a.aliased_for_output())
kwargs["aggregates"] = ", ".join(properties)
return self.aggregate_template.format(**kwargs)
def _generate_distinct_return(self):
kwargs = {
"order_by": self._generate_order_by(),
"limit": self._generate_limit(),
"offset": self._generate_offset(),
}
properties = []
for c in self._columns + self._hidden_columns:
properties.append(c.aliased_for_output())
if not properties:
properties = self.to_find.withs
for a in self._preload:
properties.extend(a.withs)
kwargs["columns"] = ", ".join(properties)
return self.distinct_template.format(**kwargs)
def _generate_limit(self):
if self._limit is not None:
return "\nLIMIT {}".format(self._limit)
return ""
def _generate_offset(self):
if self._offset is not None:
return "\nSKIP {}".format(self._offset)
return ""
def _generate_order_by(self):
properties = []
for c in self._order_by:
ac_set = set(self._columns)
gb_set = set(self._group_by)
h_c = hash(c[0])
for col in ac_set:
if h_c == hash(col):
element = col.for_cypher()
break
else:
for col in gb_set:
if h_c == hash(col):
element = col.for_cypher()
break
else:
element = c[0].for_cypher()
# query.columns(c[0])
if c[1]:
element += " DESC"
properties.append(element)
if properties:
return "\nORDER BY " + ", ".join(properties)
return ""