数据切片生成器¶
数据切片生成器是用于为标注函数生成数据切片的基础函数。如果在搜索过程中标签生成器出现错误或输出的标签不正确,则需要检查标注函数中的逻辑或检查数据是否存在固有错误。数据切片生成器可以帮助识别这些问题。在开发标注函数期间,最好也使用生成器,这是一种最佳实践。但这只是一个可选步骤,并非生成标签的必要条件。
在本指南中,我们将使用数据切片生成器来检查数据切片并应用标注函数。首先,加载一个模拟交易数据集并采样数据以查看交易情况。
[1]:
import composeml as cp
[2]:
df = cp.demos.load_transactions()
df = df[df.columns[:7]]
df.sample(n=5, random_state=0)
[2]:
transaction_id | session_id | transaction_time | product_id | amount | customer_id | device | |
---|---|---|---|---|---|---|---|
26 | 94 | 24 | 2014-01-01 05:55:20 | 5 | 100.42 | 5 | tablet |
86 | 274 | 7 | 2014-01-01 01:46:10 | 5 | 14.45 | 3 | tablet |
2 | 495 | 1 | 2014-01-01 00:14:05 | 5 | 69.45 | 2 | desktop |
55 | 275 | 4 | 2014-01-01 00:45:30 | 5 | 108.11 | 1 | mobile |
75 | 368 | 27 | 2014-01-01 06:36:30 | 5 | 139.43 | 1 | mobile |
标注函数¶
定义一个标注函数,该函数根据给定的交易切片返回客户花费的总金额。
[3]:
def total_spent(df):
return df['amount'].sum()
数据切片¶
LabelMaker.slice()
方法创建数据切片生成器。此方法的参数可以直接传递给 LabelMaker.search()
来生成标签。在下面的部分中,我们将解释如何使用数据切片生成器使数据切片连续、重叠或分散。
另请参阅
有关此过程的概念性解释,请参阅 主要概念。
连续¶
当间隔大小等于窗口大小时,数据切片是连续的。换句话说,数据切片不重叠且不分散(即不跳过任何数据)。这是间隔大小的默认值。为了演示此示例,请使用这些参数生成数据切片。
首先,创建一个窗口大小为 2 小时的标签生成器。
[4]:
lm = cp.LabelMaker(
target_dataframe_index="customer_id",
time_index="transaction_time",
labeling_function=total_spent,
window_size="2h",
)
接下来,创建一个间隔大小为 2 小时的数据切片生成器。间隔大小的默认值是窗口大小。
提示
您可以将 minimum_data
直接设置为第一个截止时间。
[5]:
slices = lm.slice(
df.sort_values('transaction_time'),
num_examples_per_instance=-1,
minimum_data='2014-01-01',
)
连续 - 数据切片 #1¶
通过打印此数据切片,您可以看到它是客户 1 的第一个交易切片(由 slice_number
表示)。此数据切片包含客户在 2014-01-01 00:00:00
和 2014-01-01 02:00:00
之间的 2 小时窗口内发生的所有交易。您还可以看到,2 小时的间隔将截止时间与窗口对齐,因此下一个数据切片从当前数据切片的末尾开始。
[6]:
ds = next(slices)
print(ds.context)
ds
customer_id 1
slice_number 1
slice_start 2014-01-01 00:00:00
slice_stop 2014-01-01 02:00:00
next_start 2014-01-01 02:00:00
[6]:
transaction_id | session_id | product_id | amount | customer_id | device | |
---|---|---|---|---|---|---|
transaction_time | ||||||
2014-01-01 00:45:30 | 275 | 4 | 5 | 108.11 | 1 | mobile |
2014-01-01 00:46:35 | 101 | 4 | 5 | 112.53 | 1 | mobile |
2014-01-01 00:47:40 | 80 | 4 | 5 | 6.29 | 1 | mobile |
2014-01-01 00:52:00 | 163 | 4 | 5 | 31.37 | 1 | mobile |
2014-01-01 00:53:05 | 293 | 4 | 5 | 82.88 | 1 | mobile |
2014-01-01 00:57:25 | 103 | 4 | 5 | 20.79 | 1 | mobile |
2014-01-01 01:03:55 | 488 | 4 | 5 | 129.00 | 1 | mobile |
2014-01-01 01:05:00 | 413 | 4 | 5 | 119.98 | 1 | mobile |
2014-01-01 01:31:00 | 191 | 6 | 5 | 139.23 | 1 | tablet |
2014-01-01 01:37:30 | 372 | 6 | 5 | 114.84 | 1 | tablet |
2014-01-01 01:38:35 | 387 | 6 | 5 | 49.71 | 1 | tablet |
对此数据切片应用标注函数,计算总花费金额。
[7]:
total_spent(ds)
[7]:
914.7300000000001
连续 - 数据切片 #2¶
在第二个数据切片中,您可以看到 2014-01-01 02:00:00
和 2014-01-01 04:00:00
之间的接下来连续 2 小时的交易。这对于生成仅连续处理数据一次的标签非常有用。
[8]:
ds = next(slices)
print(ds.context)
ds
customer_id 1
slice_number 2
slice_start 2014-01-01 02:00:00
slice_stop 2014-01-01 04:00:00
next_start 2014-01-01 04:00:00
[8]:
transaction_id | session_id | product_id | amount | customer_id | device | |
---|---|---|---|---|---|---|
transaction_time | ||||||
2014-01-01 02:28:25 | 287 | 9 | 5 | 50.94 | 1 | desktop |
2014-01-01 03:29:05 | 190 | 14 | 5 | 110.52 | 1 | tablet |
2014-01-01 03:39:55 | 7 | 14 | 5 | 107.42 | 1 | tablet |
对此数据切片应用我们的标注函数,计算总花费金额。
[9]:
total_spent(ds)
[9]:
268.88
重叠¶
当间隔大小小于窗口大小时,数据切片会重叠。这可用于基于滚动窗口的标注过程。重叠量是窗口大小和间隔大小之间的差值。例如,如果窗口大小为 3 小时,间隔大小为 1 小时,则每个数据切片将重叠 2 小时。为了演示此示例,请使用这些参数生成数据切片。
首先,创建一个窗口大小为 3 小时的标签生成器。
[10]:
lm = cp.LabelMaker(
target_dataframe_index="customer_id",
time_index="transaction_time",
labeling_function=total_spent,
window_size="3h",
)
接下来,创建一个间隔大小为 1 小时的数据切片生成器。
[11]:
slices = lm.slice(
df.sort_values('transaction_time'),
num_examples_per_instance=-1,
minimum_data='2014-01-01',
gap="1h",
)
重叠 - 数据切片 #1¶
第一个数据切片包含客户在 2014-01-01 00:00:00
和 2014-01-01 03:00:00
之间的 3 小时窗口内发生的所有交易。1 小时的间隔将此数据切片的截止时间(2014-01-01 00:00:00
)与下一个数据切片的截止时间(2014-01-01 01:00:00
)分隔开。
[12]:
ds = next(slices)
print(ds.context)
ds
customer_id 1
slice_number 1
slice_start 2014-01-01 00:00:00
slice_stop 2014-01-01 03:00:00
next_start 2014-01-01 01:00:00
[12]:
transaction_id | session_id | product_id | amount | customer_id | device | |
---|---|---|---|---|---|---|
transaction_time | ||||||
2014-01-01 00:45:30 | 275 | 4 | 5 | 108.11 | 1 | mobile |
2014-01-01 00:46:35 | 101 | 4 | 5 | 112.53 | 1 | mobile |
2014-01-01 00:47:40 | 80 | 4 | 5 | 6.29 | 1 | mobile |
2014-01-01 00:52:00 | 163 | 4 | 5 | 31.37 | 1 | mobile |
2014-01-01 00:53:05 | 293 | 4 | 5 | 82.88 | 1 | mobile |
2014-01-01 00:57:25 | 103 | 4 | 5 | 20.79 | 1 | mobile |
2014-01-01 01:03:55 | 488 | 4 | 5 | 129.00 | 1 | mobile |
2014-01-01 01:05:00 | 413 | 4 | 5 | 119.98 | 1 | mobile |
2014-01-01 01:31:00 | 191 | 6 | 5 | 139.23 | 1 | tablet |
2014-01-01 01:37:30 | 372 | 6 | 5 | 114.84 | 1 | tablet |
2014-01-01 01:38:35 | 387 | 6 | 5 | 49.71 | 1 | tablet |
2014-01-01 02:28:25 | 287 | 9 | 5 | 50.94 | 1 | desktop |
对此数据切片应用我们的标注函数,计算总花费金额。
[13]:
total_spent(ds)
[13]:
965.6700000000001
重叠 - 数据切片 #2¶
在第二个数据切片中,2014-01-01 01:00:00
和 2014-01-01 03:00:00
之间发生的交易有 2 小时的重叠。通过调整间隔大小,您可以精确设置数据切片中的重叠量。这对于生成具有特定重叠的标签非常有用。
[14]:
ds = next(slices)
print(ds.context)
ds
customer_id 1
slice_number 2
slice_start 2014-01-01 01:00:00
slice_stop 2014-01-01 04:00:00
next_start 2014-01-01 02:00:00
[14]:
transaction_id | session_id | product_id | amount | customer_id | device | |
---|---|---|---|---|---|---|
transaction_time | ||||||
2014-01-01 01:03:55 | 488 | 4 | 5 | 129.00 | 1 | mobile |
2014-01-01 01:05:00 | 413 | 4 | 5 | 119.98 | 1 | mobile |
2014-01-01 01:31:00 | 191 | 6 | 5 | 139.23 | 1 | tablet |
2014-01-01 01:37:30 | 372 | 6 | 5 | 114.84 | 1 | tablet |
2014-01-01 01:38:35 | 387 | 6 | 5 | 49.71 | 1 | tablet |
2014-01-01 02:28:25 | 287 | 9 | 5 | 50.94 | 1 | desktop |
2014-01-01 03:29:05 | 190 | 14 | 5 | 110.52 | 1 | tablet |
2014-01-01 03:39:55 | 7 | 14 | 5 | 107.42 | 1 | tablet |
对此数据切片应用我们的标注函数,计算总花费金额。
[15]:
total_spent(ds)
[15]:
821.6400000000001
分散¶
当间隔大小大于窗口大小时,数据切片之间会跳过一些数据。这可用于按特定时间间隔对数据进行标注。跳过的数据量是间隔大小和窗口大小之间的差值。例如,如果间隔大小为 3 小时,窗口大小为 1 小时,则数据切片之间将跳过 2 小时的数据。为了演示此示例,请使用这些参数生成数据切片。
首先,创建一个窗口大小为 1 小时的标签生成器。
[16]:
lm = cp.LabelMaker(
target_dataframe_index="customer_id",
time_index="transaction_time",
labeling_function=total_spent,
window_size="1h",
)
接下来,创建一个间隔大小为 3 小时的数据切片生成器。
[17]:
slices = lm.slice(
df.sort_values('transaction_time'),
num_examples_per_instance=-1,
minimum_data='2014-01-01',
gap="3h",
)
分散 - 数据切片 #1¶
第一个数据切片包含客户在 2014-01-01 00:00:00
和 2014-01-01 01:00:00
之间的 1 小时窗口内发生的所有交易。3 小时的间隔将此数据切片的截止时间(2014-01-01 00:00:00
)与下一个数据切片的截止时间(2014-01-01 03:00:00
)分隔开。
[18]:
ds = next(slices)
print(ds.context)
ds
customer_id 1
slice_number 1
slice_start 2014-01-01 00:00:00
slice_stop 2014-01-01 01:00:00
next_start 2014-01-01 03:00:00
[18]:
transaction_id | session_id | product_id | amount | customer_id | device | |
---|---|---|---|---|---|---|
transaction_time | ||||||
2014-01-01 00:45:30 | 275 | 4 | 5 | 108.11 | 1 | mobile |
2014-01-01 00:46:35 | 101 | 4 | 5 | 112.53 | 1 | mobile |
2014-01-01 00:47:40 | 80 | 4 | 5 | 6.29 | 1 | mobile |
2014-01-01 00:52:00 | 163 | 4 | 5 | 31.37 | 1 | mobile |
2014-01-01 00:53:05 | 293 | 4 | 5 | 82.88 | 1 | mobile |
2014-01-01 00:57:25 | 103 | 4 | 5 | 20.79 | 1 | mobile |
对此数据切片应用我们的标注函数,计算总花费金额。
[19]:
total_spent(ds)
[19]:
361.96999999999997
分散 - 数据切片 #2¶
在第二个数据切片中,您可以看到 2014-01-01 01:00:00
和 2014-01-01 03:00:00
之间跳过了 2 小时的交易。通过调整间隔大小,您可以精确设置数据切片之间要跳过的数据量。这对于生成针对数据集特定部分的标签非常有用。
[20]:
ds = next(slices)
print(ds.context)
ds
customer_id 1
slice_number 2
slice_start 2014-01-01 03:00:00
slice_stop 2014-01-01 04:00:00
next_start 2014-01-01 06:00:00
[20]:
transaction_id | session_id | product_id | amount | customer_id | device | |
---|---|---|---|---|---|---|
transaction_time | ||||||
2014-01-01 03:29:05 | 190 | 14 | 5 | 110.52 | 1 | tablet |
2014-01-01 03:39:55 | 7 | 14 | 5 | 107.42 | 1 | tablet |
对此数据切片应用标注函数,计算总花费金额。
[21]:
total_spent(ds)
[21]:
217.94
数据切片上下文¶
每个数据切片都有一个 context
属性,用于访问其元数据。这对于将上下文与标注函数中的逻辑集成非常有用。
[22]:
vars(ds.context)
[22]:
{'next_start': Timestamp('2014-01-01 06:00:00'),
'slice_stop': Timestamp('2014-01-01 04:00:00'),
'slice_start': Timestamp('2014-01-01 03:00:00'),
'slice_number': 2,
'customer_id': 1}