@@ -514,6 +514,50 @@ def test_where_dataframe_cond_dataframe_other(
514514 pandas .testing .assert_frame_equal (bf_result , pd_result )
515515
516516
517+ def test_where_callable_cond_constant_other (scalars_df_index , scalars_pandas_df_index ):
518+ # Condition is callable, other is a constant.
519+ columns = ["int64_col" , "float64_col" ]
520+ dataframe_bf = scalars_df_index [columns ]
521+ dataframe_pd = scalars_pandas_df_index [columns ]
522+
523+ other = 10
524+
525+ bf_result = dataframe_bf .where (lambda x : x > 0 , other ).to_pandas ()
526+ pd_result = dataframe_pd .where (lambda x : x > 0 , other )
527+ pandas .testing .assert_frame_equal (bf_result , pd_result )
528+
529+
530+ def test_where_dataframe_cond_callable_other (scalars_df_index , scalars_pandas_df_index ):
531+ # Condition is a dataframe, other is callable.
532+ columns = ["int64_col" , "float64_col" ]
533+ dataframe_bf = scalars_df_index [columns ]
534+ dataframe_pd = scalars_pandas_df_index [columns ]
535+
536+ cond_bf = dataframe_bf > 0
537+ cond_pd = dataframe_pd > 0
538+
539+ def func (x ):
540+ return x * 2
541+
542+ bf_result = dataframe_bf .where (cond_bf , func ).to_pandas ()
543+ pd_result = dataframe_pd .where (cond_pd , func )
544+ pandas .testing .assert_frame_equal (bf_result , pd_result )
545+
546+
547+ def test_where_callable_cond_callable_other (scalars_df_index , scalars_pandas_df_index ):
548+ # Condition is callable, other is callable too.
549+ columns = ["int64_col" , "float64_col" ]
550+ dataframe_bf = scalars_df_index [columns ]
551+ dataframe_pd = scalars_pandas_df_index [columns ]
552+
553+ def func (x ):
554+ return x ["int64_col" ] > 0
555+
556+ bf_result = dataframe_bf .where (func , lambda x : x * 2 ).to_pandas ()
557+ pd_result = dataframe_pd .where (func , lambda x : x * 2 )
558+ pandas .testing .assert_frame_equal (bf_result , pd_result )
559+
560+
517561def test_drop_column (scalars_dfs ):
518562 scalars_df , scalars_pandas_df = scalars_dfs
519563 col_name = "int64_col"
0 commit comments