@@ -980,24 +980,97 @@ def test_comparison_invalid(self):
980980 self .assertRaises (TypeError , lambda : x <= y )
981981
982982 def test_more_na_comparisons (self ):
983- left = Series (['a' , np .nan , 'c' ])
984- right = Series (['a' , np .nan , 'd' ])
983+ for dtype in [None , object ]:
984+ left = Series (['a' , np .nan , 'c' ], dtype = dtype )
985+ right = Series (['a' , np .nan , 'd' ], dtype = dtype )
985986
986- result = left == right
987- expected = Series ([True , False , False ])
988- assert_series_equal (result , expected )
987+ result = left == right
988+ expected = Series ([True , False , False ])
989+ assert_series_equal (result , expected )
989990
990- result = left != right
991- expected = Series ([False , True , True ])
992- assert_series_equal (result , expected )
991+ result = left != right
992+ expected = Series ([False , True , True ])
993+ assert_series_equal (result , expected )
993994
994- result = left == np .nan
995- expected = Series ([False , False , False ])
996- assert_series_equal (result , expected )
995+ result = left == np .nan
996+ expected = Series ([False , False , False ])
997+ assert_series_equal (result , expected )
997998
998- result = left != np .nan
999- expected = Series ([True , True , True ])
1000- assert_series_equal (result , expected )
999+ result = left != np .nan
1000+ expected = Series ([True , True , True ])
1001+ assert_series_equal (result , expected )
1002+
1003+ def test_nat_comparisons (self ):
1004+ data = [([pd .Timestamp ('2011-01-01' ), pd .NaT ,
1005+ pd .Timestamp ('2011-01-03' )],
1006+ [pd .NaT , pd .NaT , pd .Timestamp ('2011-01-03' )]),
1007+
1008+ ([pd .Timedelta ('1 days' ), pd .NaT ,
1009+ pd .Timedelta ('3 days' )],
1010+ [pd .NaT , pd .NaT , pd .Timedelta ('3 days' )]),
1011+
1012+ ([pd .Period ('2011-01' , freq = 'M' ), pd .NaT ,
1013+ pd .Period ('2011-03' , freq = 'M' )],
1014+ [pd .NaT , pd .NaT , pd .Period ('2011-03' , freq = 'M' )])]
1015+
1016+ # add lhs / rhs switched data
1017+ data = data + [(r , l ) for l , r in data ]
1018+
1019+ for l , r in data :
1020+ for dtype in [None , object ]:
1021+ left = Series (l , dtype = dtype )
1022+
1023+ # Series, Index
1024+ for right in [Series (r , dtype = dtype ), Index (r , dtype = dtype )]:
1025+ expected = Series ([False , False , True ])
1026+ assert_series_equal (left == right , expected )
1027+
1028+ expected = Series ([True , True , False ])
1029+ assert_series_equal (left != right , expected )
1030+
1031+ expected = Series ([False , False , False ])
1032+ assert_series_equal (left < right , expected )
1033+
1034+ expected = Series ([False , False , False ])
1035+ assert_series_equal (left > right , expected )
1036+
1037+ expected = Series ([False , False , True ])
1038+ assert_series_equal (left >= right , expected )
1039+
1040+ expected = Series ([False , False , True ])
1041+ assert_series_equal (left <= right , expected )
1042+
1043+ def test_nat_comparisons_scalar (self ):
1044+ data = [[pd .Timestamp ('2011-01-01' ), pd .NaT ,
1045+ pd .Timestamp ('2011-01-03' )],
1046+
1047+ [pd .Timedelta ('1 days' ), pd .NaT , pd .Timedelta ('3 days' )],
1048+
1049+ [pd .Period ('2011-01' , freq = 'M' ), pd .NaT ,
1050+ pd .Period ('2011-03' , freq = 'M' )]]
1051+
1052+ for l in data :
1053+ for dtype in [None , object ]:
1054+ left = Series (l , dtype = dtype )
1055+
1056+ expected = Series ([False , False , False ])
1057+ assert_series_equal (left == pd .NaT , expected )
1058+ assert_series_equal (pd .NaT == left , expected )
1059+
1060+ expected = Series ([True , True , True ])
1061+ assert_series_equal (left != pd .NaT , expected )
1062+ assert_series_equal (pd .NaT != left , expected )
1063+
1064+ expected = Series ([False , False , False ])
1065+ assert_series_equal (left < pd .NaT , expected )
1066+ assert_series_equal (pd .NaT > left , expected )
1067+ assert_series_equal (left <= pd .NaT , expected )
1068+ assert_series_equal (pd .NaT >= left , expected )
1069+
1070+ assert_series_equal (left > pd .NaT , expected )
1071+ assert_series_equal (pd .NaT < left , expected )
1072+ assert_series_equal (left >= pd .NaT , expected )
1073+ assert_series_equal (pd .NaT <= left , expected )
10011074
10021075 def test_comparison_different_length (self ):
10031076 a = Series (['a' , 'b' , 'c' ])
0 commit comments