1515from __future__ import annotations
1616
1717import typing
18- from typing import List , Tuple , Union
18+ from typing import Tuple , Union
1919
2020import ibis
2121import pandas as pd
@@ -147,19 +147,22 @@ def __getitem__(
147147 ...
148148
149149 def __getitem__ (self , key ):
150- # TODO(swast ): If the DataFrame has a MultiIndex, we'll need to
151- # disambiguate this from a single row selection.
150+ # TODO(tbergeron ): Pandas will try both splitting 2-tuple into row, index or as 2-part
151+ # row key. We must choose one, so bias towards treating as multi-part row label
152152 if isinstance (key , tuple ) and len (key ) == 2 :
153- df = typing .cast (
154- bigframes .dataframe .DataFrame ,
155- _loc_getitem_series_or_dataframe (self ._dataframe , key [0 ]),
156- )
153+ is_row_multi_index = self ._dataframe .index .nlevels > 1
154+ is_first_item_tuple = isinstance (key [0 ], tuple )
155+ if not is_row_multi_index or is_first_item_tuple :
156+ df = typing .cast (
157+ bigframes .dataframe .DataFrame ,
158+ _loc_getitem_series_or_dataframe (self ._dataframe , key [0 ]),
159+ )
157160
158- columns = key [1 ]
159- if isinstance (columns , pd .Series ) and columns .dtype == "bool" :
160- columns = df .columns [columns ]
161+ columns = key [1 ]
162+ if isinstance (columns , pd .Series ) and columns .dtype == "bool" :
163+ columns = df .columns [columns ]
161164
162- return df [columns ]
165+ return df [columns ]
163166
164167 return typing .cast (
165168 bigframes .dataframe .DataFrame ,
@@ -283,94 +286,40 @@ def _loc_getitem_series_or_dataframe(
283286 pd .Series ,
284287 bigframes .core .scalar .Scalar ,
285288]:
286- if isinstance (key , bigframes .series .Series ) and key .dtype == "boolean" :
287- return series_or_dataframe [key ]
288- elif isinstance (key , bigframes .series .Series ):
289- temp_name = guid .generate_guid (prefix = "temp_series_name_" )
290- if len (series_or_dataframe .index .names ) > 1 :
291- temp_name = series_or_dataframe .index .names [0 ]
292- key = key .rename (temp_name )
293- keys_df = key .to_frame ()
294- keys_df = keys_df .set_index (temp_name , drop = True )
295- return _perform_loc_list_join (series_or_dataframe , keys_df )
296- elif isinstance (key , bigframes .core .indexes .Index ):
297- block = key ._block
298- block = block .select_columns (())
299- keys_df = bigframes .dataframe .DataFrame (block )
300- return _perform_loc_list_join (series_or_dataframe , keys_df )
301- elif pd .api .types .is_list_like (key ):
302- key = typing .cast (List , key )
303- if len (key ) == 0 :
304- return typing .cast (
305- Union [bigframes .dataframe .DataFrame , bigframes .series .Series ],
306- series_or_dataframe .iloc [0 :0 ],
307- )
308- if pd .api .types .is_list_like (key [0 ]):
309- original_index_names = series_or_dataframe .index .names
310- num_index_cols = len (original_index_names )
311-
312- entry_col_count_correct = [len (entry ) == num_index_cols for entry in key ]
313- if not all (entry_col_count_correct ):
314- # pandas usually throws TypeError in these cases- tuple causes IndexError, but that
315- # seems like unintended behavior
316- raise TypeError (
317- "All entries must be of equal length when indexing by list of listlikes"
318- )
319- temporary_index_names = [
320- guid .generate_guid (prefix = "temp_loc_index_" )
321- for _ in range (len (original_index_names ))
322- ]
323- index_cols_dict = {}
324- for i in range (num_index_cols ):
325- index_name = temporary_index_names [i ]
326- values = [entry [i ] for entry in key ]
327- index_cols_dict [index_name ] = values
328- keys_df = bigframes .dataframe .DataFrame (
329- index_cols_dict , session = series_or_dataframe ._get_block ().expr .session
330- )
331- keys_df = keys_df .set_index (temporary_index_names , drop = True )
332- keys_df = keys_df .rename_axis (original_index_names )
333- else :
334- # We can't upload a DataFrame with None as the column name, so set it
335- # an arbitrary string.
336- index_name = series_or_dataframe .index .name
337- index_name_is_none = index_name is None
338- if index_name_is_none :
339- index_name = "unnamed_col"
340- keys_df = bigframes .dataframe .DataFrame (
341- {index_name : key },
342- session = series_or_dataframe ._get_block ().expr .session ,
343- )
344- keys_df = keys_df .set_index (index_name , drop = True )
345- if index_name_is_none :
346- keys_df .index .name = None
347- return _perform_loc_list_join (series_or_dataframe , keys_df )
348- elif isinstance (key , slice ):
289+ if isinstance (key , slice ):
349290 if (key .start is None ) and (key .stop is None ) and (key .step is None ):
350291 return series_or_dataframe .copy ()
351292 raise NotImplementedError (
352293 f"loc does not yet support indexing with a slice. { constants .FEEDBACK_LINK } "
353294 )
354- elif callable (key ):
295+ if callable (key ):
355296 raise NotImplementedError (
356297 f"loc does not yet support indexing with a callable. { constants .FEEDBACK_LINK } "
357298 )
358- elif pd .api .types .is_scalar (key ):
359- index_name = "unnamed_col"
360- keys_df = bigframes .dataframe .DataFrame (
361- {index_name : [key ]}, session = series_or_dataframe ._get_block ().expr .session
362- )
363- keys_df = keys_df .set_index (index_name , drop = True )
364- keys_df .index .name = None
365- result = _perform_loc_list_join (series_or_dataframe , keys_df )
366- pandas_result = result .to_pandas ()
367- # although loc[scalar_key] returns multiple results when scalar_key
368- # is not unique, we download the results here and return the computed
369- # individual result (as a scalar or pandas series) when the key is unique,
370- # since we expect unique index keys to be more common. loc[[scalar_key]]
371- # can be used to retrieve one-item DataFrames or Series.
372- if len (pandas_result ) == 1 :
373- return pandas_result .iloc [0 ]
299+ elif isinstance (key , bigframes .series .Series ) and key .dtype == "boolean" :
300+ return series_or_dataframe [key ]
301+ elif (
302+ isinstance (key , bigframes .series .Series )
303+ or isinstance (key , indexes .Index )
304+ or (pd .api .types .is_list_like (key ) and not isinstance (key , tuple ))
305+ ):
306+ index = indexes .Index (key , session = series_or_dataframe ._session )
307+ index .names = series_or_dataframe .index .names [: index .nlevels ]
308+ return _perform_loc_list_join (series_or_dataframe , index )
309+ elif pd .api .types .is_scalar (key ) or isinstance (key , tuple ):
310+ index = indexes .Index ([key ], session = series_or_dataframe ._session )
311+ index .names = series_or_dataframe .index .names [: index .nlevels ]
312+ result = _perform_loc_list_join (series_or_dataframe , index , drop_levels = True )
313+
314+ if index .nlevels == series_or_dataframe .index .nlevels :
315+ pandas_result = result .to_pandas ()
316+ # although loc[scalar_key] returns multiple results when scalar_key
317+ # is not unique, we download the results here and return the computed
318+ # individual result (as a scalar or pandas series) when the key is unique,
319+ # since we expect unique index keys to be more common. loc[[scalar_key]]
320+ # can be used to retrieve one-item DataFrames or Series.
321+ if len (pandas_result ) == 1 :
322+ return pandas_result .iloc [0 ]
374323 # when the key is not unique, we return a bigframes data type
375324 # as usual for methods that return dataframes/series
376325 return result
@@ -385,39 +334,47 @@ def _loc_getitem_series_or_dataframe(
385334@typing .overload
386335def _perform_loc_list_join (
387336 series_or_dataframe : bigframes .series .Series ,
388- keys_df : bigframes .dataframe .DataFrame ,
337+ keys_index : indexes .Index ,
338+ drop_levels : bool = False ,
389339) -> bigframes .series .Series :
390340 ...
391341
392342
393343@typing .overload
394344def _perform_loc_list_join (
395345 series_or_dataframe : bigframes .dataframe .DataFrame ,
396- keys_df : bigframes .dataframe .DataFrame ,
346+ keys_index : indexes .Index ,
347+ drop_levels : bool = False ,
397348) -> bigframes .dataframe .DataFrame :
398349 ...
399350
400351
401352def _perform_loc_list_join (
402353 series_or_dataframe : Union [bigframes .dataframe .DataFrame , bigframes .series .Series ],
403- keys_df : bigframes .dataframe .DataFrame ,
354+ keys_index : indexes .Index ,
355+ drop_levels : bool = False ,
404356) -> Union [bigframes .series .Series , bigframes .dataframe .DataFrame ]:
405357 # right join based on the old index so that the matching rows from the user's
406358 # original dataframe will be duplicated and reordered appropriately
407- original_index_names = series_or_dataframe .index .names
408359 if isinstance (series_or_dataframe , bigframes .series .Series ):
409360 original_name = series_or_dataframe .name
410361 name = series_or_dataframe .name if series_or_dataframe .name is not None else "0"
411362 result = typing .cast (
412363 bigframes .series .Series ,
413- series_or_dataframe .to_frame ()._perform_join_by_index (keys_df , how = "right" )[
414- name
415- ],
364+ series_or_dataframe .to_frame ()._perform_join_by_index (
365+ keys_index , how = "right"
366+ )[ name ],
416367 )
417368 result = result .rename (original_name )
418369 else :
419- result = series_or_dataframe ._perform_join_by_index (keys_df , how = "right" ) # type: ignore
420- result = result .rename_axis (original_index_names )
370+ result = series_or_dataframe ._perform_join_by_index (keys_index , how = "right" ) # type: ignore
371+
372+ if drop_levels and series_or_dataframe .index .nlevels > keys_index .nlevels :
373+ # drop common levels
374+ levels_to_drop = [
375+ name for name in series_or_dataframe .index .names if name in keys_index .names
376+ ]
377+ result = result .droplevel (levels_to_drop ) # type: ignore
421378 return result
422379
423380
0 commit comments