数据切片生成器

数据切片生成器是用于为标注函数生成数据切片的基础函数。如果在搜索过程中标签生成器出现错误或输出的标签不正确,则需要检查标注函数中的逻辑或检查数据是否存在固有错误。数据切片生成器可以帮助识别这些问题。在开发标注函数期间,最好也使用生成器,这是一种最佳实践。但这只是一个可选步骤,并非生成标签的必要条件。

在本指南中,我们将使用数据切片生成器来检查数据切片并应用标注函数。首先,加载一个模拟交易数据集并采样数据以查看交易情况。

[1]:
import composeml as cp
[2]:
df = cp.demos.load_transactions()
df = df[df.columns[:7]]
df.sample(n=5, random_state=0)
[2]:
transaction_id session_id transaction_time product_id amount customer_id device
26 94 24 2014-01-01 05:55:20 5 100.42 5 tablet
86 274 7 2014-01-01 01:46:10 5 14.45 3 tablet
2 495 1 2014-01-01 00:14:05 5 69.45 2 desktop
55 275 4 2014-01-01 00:45:30 5 108.11 1 mobile
75 368 27 2014-01-01 06:36:30 5 139.43 1 mobile

标注函数

定义一个标注函数,该函数根据给定的交易切片返回客户花费的总金额。

[3]:
def total_spent(df):
    return df['amount'].sum()

数据切片

LabelMaker.slice() 方法创建数据切片生成器。此方法的参数可以直接传递给 LabelMaker.search() 来生成标签。在下面的部分中,我们将解释如何使用数据切片生成器使数据切片连续、重叠或分散。

另请参阅

有关此过程的概念性解释,请参阅 主要概念

连续

当间隔大小等于窗口大小时,数据切片是连续的。换句话说,数据切片不重叠且不分散(即不跳过任何数据)。这是间隔大小的默认值。为了演示此示例,请使用这些参数生成数据切片。

首先,创建一个窗口大小为 2 小时的标签生成器。

[4]:
lm = cp.LabelMaker(
    target_dataframe_index="customer_id",
    time_index="transaction_time",
    labeling_function=total_spent,
    window_size="2h",
)

接下来,创建一个间隔大小为 2 小时的数据切片生成器。间隔大小的默认值是窗口大小。

提示

您可以将 minimum_data 直接设置为第一个截止时间。

[5]:
slices = lm.slice(
    df.sort_values('transaction_time'),
    num_examples_per_instance=-1,
    minimum_data='2014-01-01',
)

连续 - 数据切片 #1

通过打印此数据切片,您可以看到它是客户 1 的第一个交易切片(由 slice_number 表示)。此数据切片包含客户在 2014-01-01 00:00:002014-01-01 02:00:00 之间的 2 小时窗口内发生的所有交易。您还可以看到,2 小时的间隔将截止时间与窗口对齐,因此下一个数据切片从当前数据切片的末尾开始。

[6]:
ds = next(slices)
print(ds.context)
ds
customer_id                       1
slice_number                      1
slice_start     2014-01-01 00:00:00
slice_stop      2014-01-01 02:00:00
next_start      2014-01-01 02:00:00
[6]:
transaction_id session_id product_id amount customer_id device
transaction_time
2014-01-01 00:45:30 275 4 5 108.11 1 mobile
2014-01-01 00:46:35 101 4 5 112.53 1 mobile
2014-01-01 00:47:40 80 4 5 6.29 1 mobile
2014-01-01 00:52:00 163 4 5 31.37 1 mobile
2014-01-01 00:53:05 293 4 5 82.88 1 mobile
2014-01-01 00:57:25 103 4 5 20.79 1 mobile
2014-01-01 01:03:55 488 4 5 129.00 1 mobile
2014-01-01 01:05:00 413 4 5 119.98 1 mobile
2014-01-01 01:31:00 191 6 5 139.23 1 tablet
2014-01-01 01:37:30 372 6 5 114.84 1 tablet
2014-01-01 01:38:35 387 6 5 49.71 1 tablet

对此数据切片应用标注函数,计算总花费金额。

[7]:
total_spent(ds)
[7]:
914.7300000000001

连续 - 数据切片 #2

在第二个数据切片中,您可以看到 2014-01-01 02:00:002014-01-01 04:00:00 之间的接下来连续 2 小时的交易。这对于生成仅连续处理数据一次的标签非常有用。

[8]:
ds = next(slices)
print(ds.context)
ds
customer_id                       1
slice_number                      2
slice_start     2014-01-01 02:00:00
slice_stop      2014-01-01 04:00:00
next_start      2014-01-01 04:00:00
[8]:
transaction_id session_id product_id amount customer_id device
transaction_time
2014-01-01 02:28:25 287 9 5 50.94 1 desktop
2014-01-01 03:29:05 190 14 5 110.52 1 tablet
2014-01-01 03:39:55 7 14 5 107.42 1 tablet

对此数据切片应用我们的标注函数,计算总花费金额。

[9]:
total_spent(ds)
[9]:
268.88

重叠

当间隔大小小于窗口大小时,数据切片会重叠。这可用于基于滚动窗口的标注过程。重叠量是窗口大小和间隔大小之间的差值。例如,如果窗口大小为 3 小时,间隔大小为 1 小时,则每个数据切片将重叠 2 小时。为了演示此示例,请使用这些参数生成数据切片。

首先,创建一个窗口大小为 3 小时的标签生成器。

[10]:
lm = cp.LabelMaker(
    target_dataframe_index="customer_id",
    time_index="transaction_time",
    labeling_function=total_spent,
    window_size="3h",
)

接下来,创建一个间隔大小为 1 小时的数据切片生成器。

[11]:
slices = lm.slice(
    df.sort_values('transaction_time'),
    num_examples_per_instance=-1,
    minimum_data='2014-01-01',
    gap="1h",
)

重叠 - 数据切片 #1

第一个数据切片包含客户在 2014-01-01 00:00:002014-01-01 03:00:00 之间的 3 小时窗口内发生的所有交易。1 小时的间隔将此数据切片的截止时间(2014-01-01 00:00:00)与下一个数据切片的截止时间(2014-01-01 01:00:00)分隔开。

[12]:
ds = next(slices)
print(ds.context)
ds
customer_id                       1
slice_number                      1
slice_start     2014-01-01 00:00:00
slice_stop      2014-01-01 03:00:00
next_start      2014-01-01 01:00:00
[12]:
transaction_id session_id product_id amount customer_id device
transaction_time
2014-01-01 00:45:30 275 4 5 108.11 1 mobile
2014-01-01 00:46:35 101 4 5 112.53 1 mobile
2014-01-01 00:47:40 80 4 5 6.29 1 mobile
2014-01-01 00:52:00 163 4 5 31.37 1 mobile
2014-01-01 00:53:05 293 4 5 82.88 1 mobile
2014-01-01 00:57:25 103 4 5 20.79 1 mobile
2014-01-01 01:03:55 488 4 5 129.00 1 mobile
2014-01-01 01:05:00 413 4 5 119.98 1 mobile
2014-01-01 01:31:00 191 6 5 139.23 1 tablet
2014-01-01 01:37:30 372 6 5 114.84 1 tablet
2014-01-01 01:38:35 387 6 5 49.71 1 tablet
2014-01-01 02:28:25 287 9 5 50.94 1 desktop

对此数据切片应用我们的标注函数,计算总花费金额。

[13]:
total_spent(ds)
[13]:
965.6700000000001

重叠 - 数据切片 #2

在第二个数据切片中,2014-01-01 01:00:002014-01-01 03:00:00 之间发生的交易有 2 小时的重叠。通过调整间隔大小,您可以精确设置数据切片中的重叠量。这对于生成具有特定重叠的标签非常有用。

[14]:
ds = next(slices)
print(ds.context)
ds
customer_id                       1
slice_number                      2
slice_start     2014-01-01 01:00:00
slice_stop      2014-01-01 04:00:00
next_start      2014-01-01 02:00:00
[14]:
transaction_id session_id product_id amount customer_id device
transaction_time
2014-01-01 01:03:55 488 4 5 129.00 1 mobile
2014-01-01 01:05:00 413 4 5 119.98 1 mobile
2014-01-01 01:31:00 191 6 5 139.23 1 tablet
2014-01-01 01:37:30 372 6 5 114.84 1 tablet
2014-01-01 01:38:35 387 6 5 49.71 1 tablet
2014-01-01 02:28:25 287 9 5 50.94 1 desktop
2014-01-01 03:29:05 190 14 5 110.52 1 tablet
2014-01-01 03:39:55 7 14 5 107.42 1 tablet

对此数据切片应用我们的标注函数,计算总花费金额。

[15]:
total_spent(ds)
[15]:
821.6400000000001

分散

当间隔大小大于窗口大小时,数据切片之间会跳过一些数据。这可用于按特定时间间隔对数据进行标注。跳过的数据量是间隔大小和窗口大小之间的差值。例如,如果间隔大小为 3 小时,窗口大小为 1 小时,则数据切片之间将跳过 2 小时的数据。为了演示此示例,请使用这些参数生成数据切片。

首先,创建一个窗口大小为 1 小时的标签生成器。

[16]:
lm = cp.LabelMaker(
    target_dataframe_index="customer_id",
    time_index="transaction_time",
    labeling_function=total_spent,
    window_size="1h",
)

接下来,创建一个间隔大小为 3 小时的数据切片生成器。

[17]:
slices = lm.slice(
    df.sort_values('transaction_time'),
    num_examples_per_instance=-1,
    minimum_data='2014-01-01',
    gap="3h",
)

分散 - 数据切片 #1

第一个数据切片包含客户在 2014-01-01 00:00:002014-01-01 01:00:00 之间的 1 小时窗口内发生的所有交易。3 小时的间隔将此数据切片的截止时间(2014-01-01 00:00:00)与下一个数据切片的截止时间(2014-01-01 03:00:00)分隔开。

[18]:
ds = next(slices)
print(ds.context)
ds
customer_id                       1
slice_number                      1
slice_start     2014-01-01 00:00:00
slice_stop      2014-01-01 01:00:00
next_start      2014-01-01 03:00:00
[18]:
transaction_id session_id product_id amount customer_id device
transaction_time
2014-01-01 00:45:30 275 4 5 108.11 1 mobile
2014-01-01 00:46:35 101 4 5 112.53 1 mobile
2014-01-01 00:47:40 80 4 5 6.29 1 mobile
2014-01-01 00:52:00 163 4 5 31.37 1 mobile
2014-01-01 00:53:05 293 4 5 82.88 1 mobile
2014-01-01 00:57:25 103 4 5 20.79 1 mobile

对此数据切片应用我们的标注函数,计算总花费金额。

[19]:
total_spent(ds)
[19]:
361.96999999999997

分散 - 数据切片 #2

在第二个数据切片中,您可以看到 2014-01-01 01:00:002014-01-01 03:00:00 之间跳过了 2 小时的交易。通过调整间隔大小,您可以精确设置数据切片之间要跳过的数据量。这对于生成针对数据集特定部分的标签非常有用。

[20]:
ds = next(slices)
print(ds.context)
ds
customer_id                       1
slice_number                      2
slice_start     2014-01-01 03:00:00
slice_stop      2014-01-01 04:00:00
next_start      2014-01-01 06:00:00
[20]:
transaction_id session_id product_id amount customer_id device
transaction_time
2014-01-01 03:29:05 190 14 5 110.52 1 tablet
2014-01-01 03:39:55 7 14 5 107.42 1 tablet

对此数据切片应用标注函数,计算总花费金额。

[21]:
total_spent(ds)
[21]:
217.94

数据切片上下文

每个数据切片都有一个 context 属性,用于访问其元数据。这对于将上下文与标注函数中的逻辑集成非常有用。

[22]:
vars(ds.context)
[22]:
{'next_start': Timestamp('2014-01-01 06:00:00'),
 'slice_stop': Timestamp('2014-01-01 04:00:00'),
 'slice_start': Timestamp('2014-01-01 03:00:00'),
 'slice_number': 2,
 'customer_id': 1}