@@ -354,10 +354,7 @@ def unpivot(
354354 * ,
355355 passthrough_columns : typing .Sequence [str ] = (),
356356 index_col_ids : typing .Sequence [str ] = ["index" ],
357- dtype : typing .Union [
358- bigframes .dtypes .Dtype , typing .Tuple [bigframes .dtypes .Dtype , ...]
359- ] = pandas .Float64Dtype (),
360- how : typing .Literal ["left" , "right" ] = "left" ,
357+ join_side : typing .Literal ["left" , "right" ] = "left" ,
361358 ) -> ArrayValue :
362359 """
363360 Unpivot ArrayValue columns.
@@ -367,23 +364,88 @@ def unpivot(
367364 unpivot_columns: Mapping of column id to list of input column ids. Lists of input columns may use None.
368365 passthrough_columns: Columns that will not be unpivoted. Column id will be preserved.
369366 index_col_id (str): The column id to be used for the row labels.
370- dtype (dtype or list of dtype): Dtype to use for the unpivot columns. If list, must be equal in number to unpivot_columns.
371367
372368 Returns:
373369 ArrayValue: The unpivoted ArrayValue
374370 """
371+ # There will be N labels, used to disambiguate which of N source columns produced each output row
372+ explode_offsets_id = bigframes .core .guid .generate_guid ("unpivot_offsets_" )
373+ labels_array = self ._create_unpivot_labels_array (row_labels , index_col_ids )
374+ labels_array = labels_array .promote_offsets (explode_offsets_id )
375+
376+ # Unpivot creates N output rows for each input row, labels disambiguate these N rows
377+ joined_array = self ._cross_join_w_labels (labels_array , join_side )
378+
379+ # Build the output rows as a case statment that selects between the N input columns
380+ unpivot_exprs = []
381+ # Supports producing multiple stacked ouput columns for stacking only part of hierarchical index
382+ for col_id , input_ids in unpivot_columns :
383+ # row explode offset used to choose the input column
384+ # we use offset instead of label as labels are not necessarily unique
385+ cases = tuple (
386+ (
387+ ops .eq_op .as_expr (explode_offsets_id , ex .const (i )),
388+ ex .free_var (id_or_null )
389+ if (id_or_null is not None )
390+ else ex .const (None ),
391+ )
392+ for i , id_or_null in enumerate (input_ids )
393+ )
394+ col_expr = ops .case_when_op .as_expr (* cases )
395+ unpivot_exprs .append ((col_expr , col_id ))
396+
397+ label_exprs = ((ex .free_var (id ), id ) for id in index_col_ids )
398+ # passthrough columns are unchanged, just repeated N times each
399+ passthrough_exprs = ((ex .free_var (id ), id ) for id in passthrough_columns )
375400 return ArrayValue (
376- nodes .UnpivotNode (
377- child = self .node ,
378- row_labels = tuple (row_labels ),
379- unpivot_columns = tuple (unpivot_columns ),
380- passthrough_columns = tuple (passthrough_columns ),
381- index_col_ids = tuple (index_col_ids ),
382- dtype = dtype ,
383- how = how ,
401+ nodes .ProjectionNode (
402+ child = joined_array .node ,
403+ assignments = (* label_exprs , * unpivot_exprs , * passthrough_exprs ),
384404 )
385405 )
386406
407+ def _cross_join_w_labels (
408+ self , labels_array : ArrayValue , join_side : typing .Literal ["left" , "right" ]
409+ ) -> ArrayValue :
410+ """
411+ Convert each row in self to N rows, one for each label in labels array.
412+ """
413+ table_join_side = (
414+ join_def .JoinSide .LEFT if join_side == "left" else join_def .JoinSide .RIGHT
415+ )
416+ labels_join_side = table_join_side .inverse ()
417+ labels_mappings = tuple (
418+ join_def .JoinColumnMapping (labels_join_side , id , id )
419+ for id in labels_array .schema .names
420+ )
421+ table_mappings = tuple (
422+ join_def .JoinColumnMapping (table_join_side , id , id )
423+ for id in self .schema .names
424+ )
425+ join = join_def .JoinDefinition (
426+ conditions = (), mappings = (* labels_mappings , * table_mappings ), type = "cross"
427+ )
428+ if join_side == "left" :
429+ joined_array = self .join (labels_array , join_def = join )
430+ else :
431+ joined_array = labels_array .join (self , join_def = join )
432+ return joined_array
433+
434+ def _create_unpivot_labels_array (
435+ self ,
436+ former_column_labels : typing .Sequence [typing .Hashable ],
437+ col_ids : typing .Sequence [str ],
438+ ) -> ArrayValue :
439+ """Create an ArrayValue from a list of label tuples."""
440+ rows = []
441+ for row_offset in range (len (former_column_labels )):
442+ row_label = former_column_labels [row_offset ]
443+ row_label = (row_label ,) if not isinstance (row_label , tuple ) else row_label
444+ row = {col_ids [i ]: row_label [i ] for i in range (len (col_ids ))}
445+ rows .append (row )
446+
447+ return ArrayValue .from_pyarrow (pa .Table .from_pylist (rows ), session = self .session )
448+
387449 def join (
388450 self ,
389451 other : ArrayValue ,
0 commit comments