Skip to content

create_type_2

airt.keras.layers.create_type_2(inputs: Union[TensorLike, Dict[str, TensorLike], List[TensorLike]], *, input_units: Optional[int] = None, units: int, final_units: int, activation: Union[str, Callable[[TensorLike], TensorLike]], n_layers: int, final_activation: Optional[Union[str, Callable[[TensorLike], TensorLike]]] = None, monotonicity_indicator: Union[int, Dict[str, int], List[int]] = 1, is_convex: Union[bool, Dict[str, bool], List[bool]] = False, is_concave: Union[bool, Dict[str, bool], List[bool]] = False, dropout: Optional[float] = None) -> TensorLike ยค

Builds Type-2 monotonic network

Type-2 architecture is another example of a neural network architecture that can be built employing proposed monotonic dense blocks. The difference when compared to the architecture described above lies in the way input features are fed into the hidden layers of neural network architecture. Instead of concatenating the features directly, this architecture provides flexibility to employ any form of complex feature extractors for the non-monotonic features and use the extracted feature vectors as inputs. Another difference is that each monotonic input is passed through separate monotonic dense units. This provides an advantage since depending on whether the input is completely concave or convex or both, we can adjust the activation selection vector \(\mathbf{s}\) appropriately along with an appropriate value for the indicator vector \(\mathbf{t}\). Thus, each of the monotonic input features has a separate monotonic dense layer associated with it. Thus as the major difference to the above-mentioned architecture, we concatenate the feature vectors instead of concatenating the inputs directly. The subsequent parts of the network are similar to the architecture described above wherein for the rest of the hidden monotonic dense units, the indicator vector \(\mathbf{t}\) is always set to \(1\) to preserve monotonicity.

mono-dense-layer-diagram.png

Parameters:

Name Type Description Default
inputs Union[TensorLike, Dict[str, TensorLike], List[TensorLike]]

input tensor or a dictionary of tensors

required
input_units Optional[int]

used to preprocess features before entering the common mono block

None
units int

number of units in hidden layers

required
final_units int

number of units in the output layer

required
activation Union[str, Callable[[TensorLike], TensorLike]]

the base activation function

required
n_layers int

total number of layers (hidden layers plus the output layer)

required
final_activation Optional[Union[str, Callable[[TensorLike], TensorLike]]]

the activation function of the final layer (typicall softmax, sigmoid or linear). If set to None (default value), then the linear activation is used.

None
monotonicity_indicator Union[int, Dict[str, int], List[int]]

if an instance of dictionary, then maps names of input feature to their monotonicity indicator (-1 for monotonically decreasing, 1 for monotonically increasing and 0 otherwise). If int, then all input features are set to the same monotinicity indicator.

1
is_convex Union[bool, Dict[str, bool], List[bool]]

set to True if a particular input feature is convex

False
is_concave Union[bool, Dict[str, bool], List[bool]]

set to True if a particular inputs feature is concave

False
dropout Optional[float]

dropout rate. If set to float greater than 0, Dropout layers are inserted after hidden layers.

None

Returns:

Type Description
TensorLike

Output tensor

Source code in airt/_components/mono_dense_layer.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
@export
def create_type_2(
    inputs: Union[TensorLike, Dict[str, TensorLike], List[TensorLike]],
    *,
    input_units: Optional[int] = None,
    units: int,
    final_units: int,
    activation: Union[str, Callable[[TensorLike], TensorLike]],
    n_layers: int,
    final_activation: Optional[Union[str, Callable[[TensorLike], TensorLike]]] = None,
    monotonicity_indicator: Union[int, Dict[str, int], List[int]] = 1,
    is_convex: Union[bool, Dict[str, bool], List[bool]] = False,
    is_concave: Union[bool, Dict[str, bool], List[bool]] = False,
    dropout: Optional[float] = None,
) -> TensorLike:
    """Builds Type-2 monotonic network

    Type-2 architecture is another example of a neural network architecture that can be built employing proposed
    monotonic dense blocks. The difference when compared to the architecture described above lies in the way input
    features are fed into the hidden layers of neural network architecture. Instead of concatenating the features
    directly, this architecture provides flexibility to employ any form of complex feature extractors for the
    non-monotonic features and use the extracted feature vectors as inputs. Another difference is that each monotonic
    input is passed through separate monotonic dense units. This provides an advantage since depending on whether the
    input is completely concave or convex or both, we can adjust the activation selection vector $\mathbf{s}$ appropriately
    along with an appropriate value for the indicator vector $\mathbf{t}$. Thus, each of the monotonic input features has
    a separate monotonic dense layer associated with it. Thus as the major difference to the above-mentioned architecture,
    we concatenate the feature vectors instead of concatenating the inputs directly. The subsequent parts of the network are
    similar to the architecture described above wherein for the rest of the hidden monotonic dense units, the indicator vector
    $\mathbf{t}$ is always set to $1$ to preserve monotonicity.

    ![mono-dense-layer-diagram.png](../../../images/nbs/images/type-2.png)

    Args:
        inputs: input tensor or a dictionary of tensors
        input_units: used to preprocess features before entering the common mono block
        units: number of units in hidden layers
        final_units: number of units in the output layer
        activation: the base activation function
        n_layers: total number of layers (hidden layers plus the output layer)
        final_activation: the activation function of the final layer (typicall softmax, sigmoid or linear).
            If set to None (default value), then the linear activation is used.
        monotonicity_indicator: if an instance of dictionary, then maps names of input feature to their monotonicity
            indicator (-1 for monotonically decreasing, 1 for monotonically increasing and 0 otherwise). If int,
            then all input features are set to the same monotinicity indicator.
        is_convex: set to True if a particular input feature is convex
        is_concave: set to True if a particular inputs feature is concave
        dropout: dropout rate. If set to float greater than 0, Dropout layers are inserted after hidden layers.

    Returns:
        Output tensor

    """
    _, is_convex, _ = _prepare_mono_input_n_param(inputs, is_convex)
    _, is_concave, _ = _prepare_mono_input_n_param(inputs, is_concave)
    x, monotonicity_indicator, names = _prepare_mono_input_n_param(
        inputs, monotonicity_indicator
    )
    has_convex, has_concave = _check_convexity_params(
        monotonicity_indicator, is_convex, is_concave, names
    )

    if input_units is None:
        input_units = max(units // 4, 1)

    y = [
        (
            MonoDense(
                units=input_units,
                activation=activation,
                monotonicity_indicator=monotonicity_indicator[i],
                is_convex=is_convex[i],
                is_concave=is_concave[i],
                name=f"mono_dense_{names[i]}"
                + ("_increasing" if monotonicity_indicator[i] == 1 else "_decreasing")
                + ("_convex" if is_convex[i] else "")
                + ("_concave" if is_concave[i] else ""),
            )
            if monotonicity_indicator[i] != 0
            else (
                Dense(
                    units=input_units, activation=activation, name=f"dense_{names[i]}"
                )
            )
        )(x[i])
        for i in range(len(inputs))
    ]

    y = Concatenate(name="preprocessed_features")(y)
    monotonicity_indicator_block: List[int] = sum(
        [[abs(x)] * input_units for x in monotonicity_indicator], []
    )

    y = _create_mono_block(
        units=[units] * (n_layers - 1) + [final_units],
        activation=activation,
        monotonicity_indicator=monotonicity_indicator_block,
        is_convex=has_convex,
        is_concave=has_concave and not has_convex,
        dropout=dropout,
    )(y)

    if final_activation is not None:
        y = tf.keras.activations.get(final_activation)(y)

    return y