R/ds_generalization.R
ds_generalization.Rd
This datasource is useful for assessing whether information is invariant/abstract to particular conditions.
ds_generalization(
binned_data,
labels,
num_cv_splits,
train_label_levels,
test_label_levels,
use_count_data = FALSE,
num_label_repeats_per_cv_split = 1,
num_resample_sites = NULL,
site_IDs_to_use = NULL,
site_IDs_to_exclude = NULL,
randomly_shuffled_labels = FALSE,
create_simultaneous_populations = 0
)
A string that list a path to a file that has data in binned format, or a data frame of binned_data that is in binned format.
A string specifying the name of the labels that should be decoded. This label must be one of the columns in the binned data that starts with 'label.'
A number specifying how many cross-validation splits should be used.
A list that contains vectors specifying which label levels belong to which training class. Each element in the list corresponds to a class that the specified training labels will be mapped to. For example, values in the vector in the first element in the list will be mapped onto the first training class, etc.
A list that contains vectors specifying which label
levels belong to which test class. Each element in the list corresponds to
a class that the specified test labels will be mapped to. For example,
values in the vector in the first element in the list will be mapped onto
the first test class, etc. The number of elements in this list must be the
same as the number of elements in train_label_levels
.
If the binned data is neural spike counts, then setting use_count_data = TRUE will convert the data into spike counts. This is useful for classifiers that work on spike count data, e.g., the poisson_naive_bayes_CL.
A number specifying how many times each label level should be repeated in each cross-validation split.
The number of sites that should be randomly selected when constructing training and test vectors. This number needs to be less than or equal to the number of sites available that have num_cv_splits * num_label_repeats_per_cv_split repeats.
A vector of integers specifying which sites should be used.
A vector of integers specifying which sites should be excluded.
A Boolean specifying whether the labels should be shuffled prior to running an analysis (i.e., prior to the first call to the the get_data() method). This is used when one wants to create a null distribution for comparing when decoding results are above chance.
If the data from all sites were recorded simultaneously, then setting this variable to 1 will cause the get_data() function to return simultaneous populations rather than pseudo-populations.
This constructor creates an NDR datasource object with the class
ds_generalization
. Like all NDR datasource objects, this datasource will
be used by the cross-validator to generate training and test data sets.
Like all datasources, this datasource takes binned format data and has a get_data() method that is called by a cross-validation object to get training and testing splits of data that can be passed to a classifier.
Other datasource:
ds_basic()
# One can test if a neural population contains information that is position
# invariant by generating training data for objects presented at 'upper' and 'middle'
# locations, and generating test data at a 'lower' location.
id_levels <- c("hand", "flower", "guitar", "face", "kiwi", "couch", "car")
train_label_levels <- NULL
test_label_levels <- NULL
for (i in seq_along(id_levels)) {
train_label_levels[[i]] <- c(
paste(id_levels[i], "upper", sep = "_"),
paste(id_levels[i], "middle", sep = "_")
)
test_label_levels[[i]] <- list(paste(id_levels[i], "lower", sep = "_"))
}
data_file <- system.file("extdata/ZD_150bins_50sampled.Rda", package = "NeuroDecodeR")
ds <- ds_generalization(
data_file,
"combined_ID_position", 18,
train_label_levels,
test_label_levels
)
#> Automatically selecting sites_IDs_to_use. Since num_cv_splits = 18 and num_label_repeats_per_cv_split = 1, all sites that have 18 repetitions have been selected. This yields 132 sites that will be used for decoding (out of 132 total).