Source code for polyglotdb.query.base.query

from .results import BaseQueryResults

from .func import Count
from ..base.helper import key_for_cypher, value_for_cypher


[docs]class BaseQuery(object): query_template = '''{match} {where} {optional_match} {with} {return}''' delete_template = '''DETACH DELETE {alias}''' aggregate_template = '''RETURN {aggregates}{order_by}''' distinct_template = '''RETURN {columns}{order_by}{offset}{limit}''' set_label_template = '''{alias} {value}''' remove_label_template = '''{alias}{value}''' set_property_template = '''{alias}.{attribute} = {value}''' def __init__(self, corpus, to_find): self.corpus = corpus self.to_find = to_find self._criterion = [] self._columns = [] self._hidden_columns = [] self._order_by = [] self._group_by = [] self._aggregate = [] self._preload = [] self._cache = [] self._delete = False self._set_labels = [] self._remove_labels = [] self._set_properties = {} self._limit = None self._offset = None self.call_back = None self.stop_check = None def cache(self): raise NotImplementedError def required_nodes(self): ns = {self.to_find} tf_type = type(self.to_find) for c in self._criterion: ns.update(x for x in c.nodes if type(x) is not tf_type) for c in self._columns + self._hidden_columns + self._aggregate + self._preload + self._cache: ns.update(x for x in c.nodes if type(x) is not tf_type and x.non_optional) for c, _ in self._order_by: ns.update(x for x in c.nodes if type(x) is not tf_type and x.non_optional) return ns def optional_nodes(self): required_nodes = self.required_nodes() ns = set() tf_type = type(self.to_find) for c in self._columns + self._aggregate + self._preload + self._cache: ns.update(x for x in c.nodes if type(x) is not tf_type and x not in required_nodes) for c, _ in self._order_by: ns.update(x for x in c.nodes if type(x) is not tf_type and x not in required_nodes) return sorted(ns)
[docs] def clear_columns(self): """ Remove any columns specified. The default columns for any query are the id of the token and the label of the type. """ self._columns = [] return self
def offset(self, number): self._offset = number return self
[docs] def filter(self, *args): """ Apply one or more filters to a query. """ from .elements import EqualClauseElement for a in args: for c in self._criterion: if isinstance(c, EqualClauseElement) and isinstance(a, EqualClauseElement) and \ c.attribute.node == a.attribute.node and c.attribute.label == a.attribute.label: c.value = a.value break else: self._criterion.append(a) return self
[docs] def columns(self, *args): """ Add one or more additional columns to the results. Columns should be :class:`~polyglotdb.query.base.Attribute` objects. """ column_set = set(self._columns) for c in args: if c in column_set: continue else: self._columns.append(c) # column_set.add(c) # FIXME failing tests return self
[docs] def group_by(self, *args): """ Specify one or more fields for how aggregates should be grouped. """ self._group_by.extend(args) return self
[docs] def order_by(self, field, descending=False): """ Specify how the results of the query should be ordered. Parameters ---------- field : Attribute Determines what the ordering should be based on descending : bool, defaults to False Whether the order should be descending """ self._order_by.append((field, descending)) return self
[docs] def to_csv(self, path): """ Same as ``all``, but the results of the query are output to the specified path as a CSV file. """ results = self.all() if self.stop_check is not None and self.stop_check(): return results.to_csv(path)
[docs] def count(self): """ Returns the number of rows in the query. """ self._aggregate = [Count()] cypher = self.cypher() value = self.corpus.execute_cypher(cypher, **self.cypher_params()) self._aggregate = [] return list(value[0].values())[0]
[docs] def aggregate(self, *args): """ Aggregate the results of the query by a grouping factor or overall. Not specifying a ``group_by`` in the query will result in a single result for the aggregate from the whole query. """ self._aggregate.extend(args) cypher = self.cypher() value = self.corpus.execute_cypher(cypher, **self.cypher_params()) if self._group_by or any(not x.collapsing for x in self._aggregate): return list(value) elif len(self._aggregate) > 1: return list(value)[0] else: return list(list(value)[0].values())[0]
def preload(self, *args): self._preload.extend(args) return self
[docs] def limit(self, limit): """ sets object limit to parameter limit """ self._limit = limit return self
def to_json(self): data = {'corpus_name': self.corpus.corpus_name, 'filters': [x.for_json() for x in self._criterion], 'columns': [x.for_json() for x in self._columns]} return data
[docs] def cypher(self): """ Generates a Cypher statement based on the query. """ kwargs = {'match': '', 'optional_match': '', 'where': '', 'with': '', 'return': ''} # generate initial match strings match_strings = set() withs = set() nodes = self.required_nodes() for node in nodes: if node.has_subquery: continue match_strings.add(node.for_match()) withs.update(node.withs) kwargs['match'] = 'MATCH ' + ',\n'.join(match_strings) # generate main filters properties = [] for c in self._criterion: if c.in_subquery: continue properties.append(c.for_cypher()) if properties: kwargs['where'] += 'WHERE ' + '\nAND '.join(properties) optional_nodes = self.optional_nodes() optional_match_strings = [] for node in optional_nodes: if node.has_subquery: continue optional_match_strings.append(node.for_match()) withs.update(node.withs) if optional_match_strings: s = '' for i, o in enumerate(optional_match_strings): s += 'OPTIONAL MATCH ' + o + '\n' kwargs['optional_match'] = s # generate subqueries with_statements = ['WITH ' + ', '.join(withs)] for node in nodes: if not node.has_subquery: continue statement = node.subquery(withs, self._criterion) with_statements.append(statement) withs.update(node.withs) for node in optional_nodes: if not node.has_subquery: continue statement = node.subquery(withs, self._criterion, optional=True) with_statements.append(statement) withs.update(node.withs) kwargs['with'] = '\n'.join(with_statements) kwargs['return'] = self.generate_return() cypher = self.query_template.format(**kwargs) return cypher
def create_subset(self, label): self._set_labels.append(label) self.corpus.execute_cypher(self.cypher(), **self.cypher_params()) self._set_labels = [] def remove_subset(self, label): self._remove_labels.append(label) self.corpus.execute_cypher(self.cypher(), **self.cypher_params()) self._remove_labels = []
[docs] def delete(self): """ Remove the results of a query from the graph. CAUTION: this is irreversible. """ self._delete = True self.corpus.execute_cypher(self.cypher(), **self.cypher_params())
def set_properties(self, **kwargs): self._set_properties = {k: v for k,v in kwargs.items()} self.corpus.execute_cypher(self.cypher(), **self.cypher_params()) self._set_properties = {} def all(self): return BaseQueryResults(self) def get(self): r = BaseQueryResults(self) if len(r) > 1: raise Exception("Can't use get on query with more than one result.") return r[0] def cypher_params(self): from ..base.complex import ComplexClause from ..base.elements import SubsetClauseElement, NotSubsetClauseElement from ..base.attributes import NodeAttribute params = {} for c in self._criterion: if isinstance(c, ComplexClause): params.update(c.generate_params()) elif isinstance(c, (SubsetClauseElement, NotSubsetClauseElement)): pass else: try: if not isinstance(c.value, NodeAttribute): params[c.cypher_value_string()[1:-1].replace('`', '')] = c.value except AttributeError: pass return params
[docs] def generate_return(self): """ Generates final statement from query object, calling whichever one of the other generate statements is specified in the query obj Parameters ---------- query : :class: `~polyglotdb.graph.GraphQuery` a query object Returns ------- str cypher formatted string """ if self._delete: statement = self._generate_delete_return() elif self._cache: statement = self._generate_cache_return() elif self._set_properties: statement = self._generate_set_properties_return() elif self._set_labels: statement = self._generate_set_labels_return() elif self._remove_labels: statement = self._generate_remove_labels_return() elif self._aggregate: statement = self._generate_aggregate_return() else: statement = self._generate_distinct_return() return statement
def _generate_delete_return(self): kwargs = {} kwargs['alias'] = self.to_find.alias return_statement = self.delete_template.format(**kwargs) return return_statement def _generate_cache_return(self): properties = [] for c in self._cache: kwargs = {'alias': c.node.cache_alias, 'attribute': c.output_alias, 'value': c.for_cypher() } if c.label == 'position': kwargs['alias'] = self.to_find.alias set_string = self.set_property_template.format(**kwargs) properties.append(set_string) return 'SET {}'.format(', '.join(properties)) def _generate_remove_labels_return(self): remove_label_strings = [] kwargs = {} kwargs['alias'] = self.to_find.alias kwargs['value'] = ':' + ':'.join(map(key_for_cypher, self._remove_labels)) remove_label_strings.append(self.remove_label_template.format(**kwargs)) return_statement = '' if remove_label_strings: if return_statement: return_statement += '\nWITH {alias}\n'.format(alias=self.to_find.alias) return_statement += '\nREMOVE ' + ', '.join(remove_label_strings) return return_statement def _generate_set_properties_return(self): set_strings = [] for k, v in self._set_properties.items(): if v is None: v = 'NULL' else: v = value_for_cypher(v) s = self.set_property_template.format(alias=self.to_find.alias, attribute=k, value=v) set_strings.append(s) return 'SET ' + ', '.join(set_strings) def _generate_set_labels_return(self): set_label_strings = [] kwargs = {} kwargs['alias'] = self.to_find.alias kwargs['value'] = ':' + ':'.join(map(key_for_cypher, self._set_labels)) set_label_strings.append(self.set_label_template.format(**kwargs)) return 'SET ' + ', '.join(set_label_strings) def _generate_aggregate_return(self): kwargs = {'order_by': self._generate_order_by(), 'limit': self._generate_limit()} properties = [] for g in self._group_by: properties.append(g.aliased_for_output()) if any(not x.collapsing for x in self._aggregate): for c in self._columns: properties.append(c.aliased_for_output()) if len(self._order_by) == 0 and len(self._group_by) > 0: self._order_by.append((self._group_by[0], False)) for a in self._aggregate: properties.append(a.aliased_for_output()) kwargs['aggregates'] = ', '.join(properties) return self.aggregate_template.format(**kwargs) def _generate_distinct_return(self): kwargs = {'order_by': self._generate_order_by(), 'limit': self._generate_limit(), 'offset': self._generate_offset()} properties = [] for c in self._columns + self._hidden_columns: properties.append(c.aliased_for_output()) if not properties: properties = self.to_find.withs for a in self._preload: properties.extend(a.withs) kwargs['columns'] = ', '.join(properties) return self.distinct_template.format(**kwargs) def _generate_limit(self): if self._limit is not None: return '\nLIMIT {}'.format(self._limit) return '' def _generate_offset(self): if self._offset is not None: return '\nSKIP {}'.format(self._offset) return '' def _generate_order_by(self): properties = [] for c in self._order_by: ac_set = set(self._columns) gb_set = set(self._group_by) h_c = hash(c[0]) for col in ac_set: if h_c == hash(col): element = col.for_cypher() break else: for col in gb_set: if h_c == hash(col): element = col.for_cypher() break else: element = c[0].for_cypher() # query.columns(c[0]) if c[1]: element += ' DESC' properties.append(element) if properties: return '\nORDER BY ' + ', '.join(properties) return ''